In [None]:
import sys
import cv2
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torchvision import transforms
from torchvision import models
from torchvision.models.mobilenetv3 import MobileNet_V3_Large_Weights, MobileNet_V3_Small_Weights

sys.path.append('../')
from utils.gradcam import GradCAM

In [None]:
MODEL_PATH = '../models/mobilenetv3_final.pt'
model_size = 'small'

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

In [None]:
if model_size == 'small':
    model = models.mobilenet_v3_small(weights=MobileNet_V3_Small_Weights.DEFAULT)
else:
    model = models.mobilenet_v3_large(weights=MobileNet_V3_Large_Weights.DEFAULT)
model.classifier[-1] = nn.Linear(model.classifier[-1].in_features, 7)
model.to(device)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.eval()

cam = GradCAM(model, device)

t = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406], 
            std=[0.229, 0.224, 0.225]
        )
    ])
img_tensor = t(Image.open('../data/RAF/Image/aligned/test_0009_aligned.jpg'))


In [None]:


img = cv2.imread('../data/RAF/Image/aligned/test_0009_aligned.jpg')
img_plt = cv2.resize(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), (224, 224))

heatmap, _, _ = cam.get_heatmap(img_tensor)
print(heatmap.shape)
plt.imshow(img_plt)
plt.imshow(heatmap, alpha=0.3)
plt.show()

