Purpose of this notebook is to plot the Class Activation Maps for the different emotion-classes in the dataset. CAMs further visualize where the model detects salient emotion-specific features.

In [None]:
!pip install torchcam

In [None]:
import sys
sys.path.append('PATH_TO_FEC_MODEL_FILE')

In [None]:
from utils.fer_models import FEClassifier

In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = FEClassifier(base='efficientnet')
model.load_state_dict(torch.load("/content/effnetb2_14_2.pt"))
model.eval()

In [None]:
test_loader = torch.load('/content/test_loader.pth')

In [None]:
import random

Samples can be chosen from each class specifically or from the entirety of the dataset

In [None]:
# {'anger': 2557, 'disgust': 418, 'fear': 685, 'happy': 13570, 'neutral': 7587, 'sad': 2486, 'surprise': 1438}
# angry = test_loader.dataset[random.randint(0,2557)]
# disgust = test_loader.dataset[random.randint(2558,2976)]
# fear = test_loader.dataset[random.randint(2977,3662)]
# happy = test_loader.dataset[random.randint(3663,17233)]
# neutral = test_loader.dataset[random.randint(17234,24821)]
# sad = test_loader.dataset[random.randint(24822,27308)]
# surprise = test_loader.dataset[random.randint(27309,28747)]
# datasample = [angry,disgust,fear,happy,neutral,sad,surprise]
datasample = [test_loader.dataset[random.randint(27309,28747)] for _ in range(5)]


CAMs are extracted according to source code from https://github.com/frgfm/torch-cam/tree/main?tab=readme-ov-file

In [None]:
from torchcam.methods import SmoothGradCAMpp

In [None]:
cams,pred_labels = [],[]
for data in datasample:
  with SmoothGradCAMpp(model) as cam_extractor:
      image = data[0]
      logits = model(image.unsqueeze(0))
      pred_label = torch.argmax(logits,dim=1)
      pred_labels.append(pred_label.item())
      activation_map = cam_extractor(pred_label.item(),logits)
      cams.append(activation_map)

In [None]:
import matplotlib.pyplot as plt
from torchcam.utils import overlay_mask
from torchvision.transforms.v2.functional import to_pil_image

In [None]:
class_names = {0: 'anger',1:'disgust',2:'fear',3:'happy',4:'neutral' ,5:'sad',6:'surprise'}

Images need to be unnormalized before plotting because they were loaded from a pth file containing a test set of normalized images.

In [None]:
class UnNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
        return tensor

In [None]:
# tensor([0.5474, 0.4259, 0.3695]) tensor([0.2782, 0.2465, 0.2398])
#tensor([0.5691, 0.4458, 0.3910]) tensor([0.2746, 0.2446, 0.2383]) for only affectnet, no gen images
unnorm = UnNormalize(mean=(0.5691, 0.4458, 0.3910),std=(0.2746, 0.2446, 0.2383))
for image,label in datasample:
  image = unnorm(image)

In [None]:
fig,ax = plt.subplots(1,5,figsize=(16,4))
for i,(data,cam) in enumerate(zip(datasample,cams)):
  result = overlay_mask(to_pil_image(data[0],mode='RGB'),to_pil_image(cam[0],mode='F'),alpha=0.5)
  plt.subplot(1,5,i+1)
  plt.title(f"Predicted label: {class_names[pred_labels[i]]}\nTrue Label: {class_names[data[1]]}",
            fontsize=8)
  plt.imshow(result)
  plt.axis('off')
  plt.tight_layout()

plt.show()