<a href="https://colab.research.google.com/github/dcafarelli/CMT-ABAW2020-EXPR/blob/main/test/multi_resolution_validation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install confplot

In [None]:
import os
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import sys
from fastprogress.fastprogress import master_bar, progress_bar
import sklearn.metrics as sm
import confplot

In [None]:
validation_dir = '/cropped_aligned_val/'

In [None]:
#CUDA FOR PYTORCH

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True #This flag allows you to enable the inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware.
print(device)

In [None]:
classes = ('Neutral', 'Anger', 'Disgust', 'Fear', 'Happiness', 'Sadness', 'Surprise')
classes_nt = ('Neutral', 'Positive', 'Negative')

val_df_nt_path = '/annotations/three_classes_val_label.pkl' #path to 3 classes validation set
val_df_path = '/annotations/val_set.pkl' #path to affwild2 validation set

path_best_model = '/best_performance_model/best_model-senet50-0.43.pt' #model 7 classes

#-------- models below work only on 3 classes set -----------
#path_best_model = '/best_performance_model/best_model_sen50_newtask.pt' #standard train/base model
#path_best_model= '/best_performance_model/best_multi-res-train_newtask3920_unbalanced_0612.pt' #multi-res train/multi-res model
#path_best_model = '/best_performance_model/best_multi-res-train_newtask_unbalanced-basemodel.pt' #multi-res train/base model

#--------- ckp path to load the model ---------
model_base_path_colab = '/model_checkpoint/pytorch_models/senet50_ft_pytorch.pth'
model_ckp_path = '/model_checkpoint/pytorch_models/models_ckp_78561.pth.tar' 

# LOAD BEST MODEL SENET50

In [None]:
sys.path.append('/path/where/MainModel.py/is_located') #append the path where MainModel.py is located
import MainModel

In [None]:
def load_models(model_base_path, device="cpu", model_ckp=None):
    assert os.path.exists(model_base_path), "Base model checkpoint not found at: {}".format(model_base_path)
    model = torch.load(model_base_path)
    if model_ckp is not None:
        assert os.path.exists(model_ckp), f"Model checkpoint not found at: {model_ckp}"
        ckp = torch.load(model_ckp, map_location='cpu')
        [p.data.copy_(torch.from_numpy(ckp['model_state_dict'][n].numpy())) for n, p in model.named_parameters()]
        for n, m in model.named_modules():
            if isinstance(m, nn.BatchNorm2d):
                m.momentum = 0.1
                m.running_var = ckp['model_state_dict'][n + '.running_var']
                m.running_mean = ckp['model_state_dict'][n + '.running_mean']
                m.num_batches_tracked = ckp['model_state_dict'][n + '.num_batches_tracked']
    
    return model

In [None]:
model = load_models(model_base_path_colab, device, None)

for k, m in model.named_modules():
  m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatability

In [None]:
def reshape(flag, model):
  if flag == "affwild2":
    model.classifier_1 = nn.Linear(2048, len(classes))
  else 
    model.classifier_1 = nn.Linear(2048, len(classes_nt))
  return model

In [None]:
model = reshape("affwild2", model)

In [None]:
model = model.to(device)

In [None]:
#Load state dict
ckp = torch.load(path_best_model)
model.load_state_dict(ckp['state_dict'])

<All keys matched successfully>

# VALIDATION ON DIFFERENT IMAGE RESOLUTION



In [None]:
class AffWild2ValSet(Dataset):
    def __init__(self, choose_set, transform=None, res=None):

      self.choose_set = choose_set

      if choose_set == 'affwild2':
        pkl_path = val_df_path
      else:
        pkl_path = val_df_nt_path

    self.emotion_frame = pd.read_pickle(pkl_path)
    self.transform = transform
    self.res = res
    self.flag = flag

    def __len__(self):
        return len(self.emotion_frame)
    
    def __getitem__(self, index):

      img_path = self.emotion_frame.iloc[index, 0]           
      fp = os.path.join('/content/cropped_aligned_val%s' %img_path) #here the path to validation frames
      assert os.path.exists(fp), "Image not found at: {}".format(fp)

      val_set_face = Image.fromarray(cv2.imread(fp))

      if self.res is not None:
        val_set_face = val_set_face.resize((self.res,self.res), Image.BILINEAR) #downsamplung

      y_label = self.emotion_frame['label'].values[index]
      if self.transform:
        val_set_face = self.transform(val_set_face)
          
      return val_set_face, y_label

In [None]:
#DATA TRANSFORMATION
def subtract_mean(x):
    mean_vector = [91.4953, 103.8827, 131.0912]
    x *= 255.
    x[0] -= mean_vector[0]
    x[1] -= mean_vector[1]
    x[2] -= mean_vector[2]
    return x

In [None]:
transformed_val = transforms.Compose([
                      transforms.Resize((224,224)),
                      transforms.ToTensor(),
                      transforms.Lambda(lambda x : subtract_mean(x))
                      ])

In [None]:
val_set = AffWild2ValSet('affwild2', transform=transformed_val, res = None)

In [None]:
#Show dataset images
def show_images(dataset, num_image):
  fig = plt.figure()

  for i in range(len(dataset)):

    faces, lab = dataset[num_image]

    #ax = plt.subplot(1, 4, i+1)
    #plt.tight_layout()
    print("fac ", faces.shape)
    faces = faces.permute(1,2,0)
    faces = cv2.cvtColor(np.float32(faces), cv2.COLOR_BGR2RGB)
    plt.imshow(faces)

    if(i == 3):
      plt.show()
      break

In [None]:
show_images(val_set, 40000) 

In [None]:
#DATA GENERATORS 
validation_generator = DataLoader(val_set, batch_size = 32, num_workers = 8,  pin_memory=True, drop_last=False)

In [None]:
def metrics(lab, pred):
  lab_array = [t.numpy() for t in lab]
  pred_array = [t.numpy() for t in pred]

  pred_array = np.concatenate(pred_array, axis=0 )
  lab_array = np.concatenate(lab_array, axis=0)

  F1_score = sm.f1_score(lab_array, pred_array, average='macro', zero_division=1)
  classes_score = sm.f1_score(lab_array, pred_array, average=None, zero_division=1)
  print("Acc classes ", classes_score)
  accuracy = sm.accuracy_score(lab_array, pred_array)
  confusion_matrix = sm.confusion_matrix(lab_array, pred_array)
  
  return accuracy, F1_score, confusion_matrix

In [None]:
#STATISTIC COMPETITION
def stat_comp(F1_score, accuracy):
  stat = (0.33*accuracy) + (0.67*F1_score)
  return stat

In [None]:
def evaluate(model):

  running_val_loss = 0.0
  total = 0

  pred = []
  lab = []

  model.eval()
  print("Enter Evaluation. Is Training?", model.training)
  with torch.no_grad():
    for j, (data) in enumerate(progress_bar(validation_generator)):

      faces_val, labels_val = data
      faces_val = faces_val.to(device)
      labels_val = labels_val.to(device)

      _, outputs_val = model(faces_val)
      _, preds_val = torch.max(outputs_val.data, 1)

      
      pred.append(preds_val.cpu())
      lab.append(labels_val.cpu())
      
      total += labels_val.size(0)
          
  iteration_val_acc, F1_score, cm = metrics(lab, pred)
              
  return iteration_val_acc, F1_score, cm, pred, lab

In [None]:
iteration_val_acc, F1_score, cm, pred, lab = evaluate(model)
final_stat = stat_comp(F1_score, iteration_val_acc)

In [None]:
print('_________________________________________________________')
print('Validation Acc: {:.2f}'.format(iteration_val_acc))
print('F1_Score : {:.4f}'.format(F1_score))
print('Final statistics: {:.4f}'.format(final_stat))
print('_________________________________________________________')

In [None]:
y_true = [t.numpy() for t in lab]
y_true = np.concatenate(y_true, axis=0 )

y_pred = [t.numpy() for t in pred]
y_pred = np.concatenate(y_pred, axis=0 )

In [None]:
columns = ["Neutral", "Anger", "Disgust", "Fear", "Happiness", "Sadness", "Surprise"]

In [None]:
columns = ["Neutral", "Positive", "Negative"]

In [None]:
#plot confusion matrix
confplot.plot_confusion_matrix_from_data(
    y_true,
    y_pred,
    columns,
    outputfile = "/content/cm_112.png"
)