<a href="https://colab.research.google.com/github/dcafarelli/CMT-ABAW2020-EXPR/blob/main/train/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import torchvision
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import matplotlib.pyplot as plt
import torch.optim as optim
import cv2
import shutil
import sys
from fastprogress.fastprogress import master_bar, progress_bar
import sklearn.metrics as sm

# GENERAL SETTING

In [None]:
#CUDA FOR PYTORCH

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True #This flag allows you to enable the inbuilt cudnn auto-tuner to find the best algorithm to use for your hardware.
print(device)

In [None]:
#--------- Dataset settings ---------
#--------- Dataset settings ---------
downsampling_prob = 0.1
resize_algo = Image.BILINEAR
valid_resolution = 112

#Weights = max(num of occorence)/num of occorence
loss_weights = [1, 22.92, 37.50, 50.66, 3.47, 5.79, 13.55]
loss_weights_nt = [1, 3.47, 3.81]
loss_weights = torch.FloatTensor(loss_weights).to(device)
loss_weights_nt = torch.FloatTensor(loss_weights_nt).to(device)

batch_size = 64
batch_val_size = 32 #different from the above to avoid CUDA Out of Memory
classes = ('Neutral', 'Anger', 'Disgust', 'Fear', 'Happiness', 'Sadness', 'Surprise')
classes_nt = ('Neutral', 'Positive', 'Negative')

#--------- Training settings ---------
iteration_step = 3920
curr_step = 35000
batch_accumulation = 4

In [None]:
# --------- PATHS ---------

#--------- path to dataframe (path/label) --> AffWild2 + Expw Dataset ---------
train_df_path = '/annotations/aff_wild_expW_train_set.pkl'
val_df_path = '/annotations/val_set.pkl'

#--------- path to dataframe (path/label) --> "new task" Dataset ---------
train_df_nt_path = '/annotations/three_classes_label.pkl'
val_df_nt_path = '/annotations/three_classes_val_label.pkl'

#--------- ckp path to load the model ---------
model_base_path_colab = '/model_checkpoint/pytorch_models/senet50_ft_pytorch.pth'
model_ckp_path = '/model_checkpoint/pytorch_models/models_ckp_78561.pth.tar'

In [None]:
def subtract_mean(x):
    mean_vector = [91.4953, 103.8827, 131.0912]
    x *= 255.
    x[0] -= mean_vector[0]
    x[1] -= mean_vector[1]
    x[2] -= mean_vector[2]
    return x

In [None]:
#DATA AUGMENTATION

transformed_train = transforms.Compose([                                     
                      transforms.Resize((256,256)),
                      transforms.RandomCrop((224,224)),
                      transforms.ColorJitter(brightness=0.4, contrast = 0.3, saturation = 0.25, hue = 0.05),
                      transforms.RandomHorizontalFlip(p=0.5),
                      transforms.ToTensor(),
                      transforms.Lambda(lambda x : subtract_mean(x))
])

transformed_val = transforms.Compose([
                      transforms.Resize((224,224)),
                      transforms.ToTensor(),
                      transforms.Lambda(lambda x : subtract_mean(x))
])

# "STANDARD" DATASET

Run the cells below if you want to train your model with samples at the same *resolution*

In [None]:
class AffWild2Dataset(Dataset):
    def __init__(self, choose_set, flag, transform=None):
        self.flag = flag
        self.choose_set = choose_set
        
        if choose_set == 'affwild2': 
          if flag == 'train':
              pkl_path = train_df_path 
          else:
              pkl_path = val_df_path
        else:
          if flag == 'train':
              pkl_path = train_df_nt_path 
          else:
              pkl_path = val_df_nt_path
   
        self.emotion_frame = pd.read_pickle(pkl_path)
        self.transform = transform
   
    def __len__(self):
        return len(self.emotion_frame)
    
    def __getitem__(self, index):
        if self.flag == 'train':
            img_path = self.emotion_frame.iloc[index, 0]           
            fp = os.path.join('/cropped_aligned_train%s' %img_path) #here the path to training frames
        else:
            img_path = self.emotion_frame.iloc[index, 0]
            fp = os.path.join('/cropped_aligned_val%s' %img_path) #here the path to validation frames
        
        img_array = Image.fromarray(cv2.imread(fp))
        
        y_label = self.emotion_frame['label'].values[index]
        
        if self.transform:
            img_array = self.transform(img_array)
        
        return img_array, y_label

In [None]:
#DATASET CREATION

train_set = AffWild2Dataset(flag = 'train', choose_set = 'affwild2', transform=transformed_train)

validation_set = AffWild2Dataset(flag = 'validation',choose_set = 'affwild2', transform=transformed_val)

# "MULTI-RESOLUTION" DATASET
Run the cells below if you want to train your model with samples at different resolution

In [None]:
class AffWild2Dataset(Dataset):
    def __init__(self, flag, choose_set, resize_algo, downsampling_prob, valid_resolution, curr_step, transform=None):
        self.flag = flag
        self.choose_set = choose_set
        
        if choose_set == 'affwild2':
          if flag == 'train':
              pkl_path = train_df_path  
              self.downsampling_prob = downsampling_prob
          else:
              pkl_path = val_df_path
              self.downsampling_prob = 1.0
        else:
          if flag == 'train':
              pkl_path = train_df_nt_path  
              self.downsampling_prob = downsampling_prob
          else:
              pkl_path = val_df_nt_path
              self.downsampling_prob = 1.0

        self.emotion_frame = pd.read_pickle(pkl_path)
        self.transform = transform
        self.resize_algo = resize_algo
        self.valid_resolution = valid_resolution
        self._loader = self._get_loader
        self.curr_index = 0
        self.curr_step = curr_step
    
    @staticmethod
    def _get_loader(path):
        return Image.fromarray(cv2.imread(path))

    def _lower_resolution(self, img):
        w_i, h_i = img.size
        r = h_i/float(w_i)
        if self.flag == 'train':
            res = torch.rand(1).item()
            res = 3 + 5*res
            res = 2**int(res)
        else:
            res = self.valid_resolution
        if res >= w_i or res >= h_i:
            return img
        if h_i < w_i:
            h_n = res
            w_n = h_n/float(r)
        else:
            w_n = res
            h_n = w_n*float(r)
        img2 = img.resize((int(w_n), int(h_n)), self.resize_algo)
        img2 = img2.resize((w_i, h_i), self.resize_algo)
        return img2
    
    def __len__(self):
        return len(self.emotion_frame)
    
    def __getitem__(self, index):
        if self.flag == 'train':
            self.curr_index +=1
            img_path = self.emotion_frame.iloc[index, 0]           
            fp = os.path.join('/content/cropped_aligned_train%s' %img_path) #here the path to training frames

            if (self.curr_index % self.curr_step) == 0 and self.downsampling_prob < 1.0 :
              self.downsampling_prob += 0.1
        else:
            img_path = self.emotion_frame.iloc[index, 0]
            fp = os.path.join('/content/cropped_aligned_val%s' %img_path) #here the path to validation frames
        
        img = self._loader(fp)

        if torch.rand(1).item() < self.downsampling_prob:
            img = self._lower_resolution(img)
        
        y_label = self.emotion_frame['label'].values[index]
        
        if self.transform:
            img = self.transform(img)
        
        return img, y_label

In [None]:
#DATASET CREATION

train_set = AffWild2Dataset(flag = 'train', choose_set = 'balanced', resize_algo = resize_algo, downsampling_prob = downsampling_prob, 
                            valid_resolution = valid_resolution, curr_step = curr_step, transform=transformed_train)

validation_set = AffWild2Dataset(flag = 'validation', choose_set = 'balanced', resize_algo = resize_algo, downsampling_prob = downsampling_prob, 
                            valid_resolution = valid_resolution, curr_step = curr_step, transform=transformed_val)

# DATASET GENERATOR

In [None]:

#DATA GENERATORS WITH SAMPLER TO BALANCE THE DATASET

train_generator = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers=8 ,pin_memory=True, drop_last=True)

validation_generator = DataLoader(validation_set, batch_size = batch_val_size, num_workers = 8,  pin_memory=True, drop_last=False)


# MODEL CONFIGURATION



In [None]:
sys.path.append('/path/where/MainModel.py/is_located') #append the path where MainModel.py is located
import MainModel

In [None]:
def load_models(model_base_path, device="cpu", model_ckp=None):
    assert os.path.exists(model_base_path), "Base model checkpoint not found at: {}".format(model_base_path)
    model = torch.load(model_base_path)
    if model_ckp is not None:
        assert os.path.exists(model_ckp), f"Model checkpoint not found at: {model_ckp}"
        ckp = torch.load(model_ckp, map_location='cpu')
        [p.data.copy_(torch.from_numpy(ckp['model_state_dict'][n].numpy())) for n, p in model.named_parameters()]
        for n, m in model.named_modules():
            if isinstance(m, nn.BatchNorm2d):
                m.momentum = 0.1
                m.running_var = ckp['model_state_dict'][n + '.running_var']
                m.running_mean = ckp['model_state_dict'][n + '.running_mean']
                m.num_batches_tracked = ckp['model_state_dict'][n + '.num_batches_tracked']
    
    return model

In [None]:
model = load_models(model_base_path_colab, device, model_ckp_path)

In [None]:
for k, m in model.named_modules():
  m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatability

In [None]:
def reshape(flag, model):
  if flag == "affwild2":
    model.classifier_1 = nn.Linear(2048, len(classes))
  else 
    model.classifier_1 = nn.Linear(2048, len(classes_nt))
  return model

In [None]:
model = reshape("affwild2", model)

In [None]:
model = model.to(device)

In [None]:
my_list = ['classifier_1.weight', 'classifier_1.bias']
params = list(map(lambda x: x[1],list(filter(lambda kv: kv[0] in my_list, model.named_parameters()))))
base_params = list(map(lambda x: x[1],list(filter(lambda kv: kv[0] not in my_list, model.named_parameters()))))

In [None]:
def criterion(flag):
  if flag == "affwild2":
    criterion = nn.CrossEntropyLoss(weight = loss_weights)
  else:
    criterion = nn.CrossEntropyLoss(loss_weights_nt)
  return criterion

In [None]:
#LOSS FUNCTION AND OPTIMIZER

criterion = nn.CrossEntropyLoss(weight = new_weights)

#Different learning rate for fine tuning the network
optimizer = optim.SGD([{'params': base_params}, 
                       {'params': params, 'lr': 1e-4}], lr=1e-6, momentum=0.9, weight_decay = 5e-3)

scheduler = optim.lr_scheduler.StepLR(optimizer, 10, 0.1)

# TRAINING SECTION

In [None]:
#Saving function

def save_ckp(state, is_best, checkpoint_path, best_model_path):
    """
    state: checkpoint 
    is_best: is this the best checkpoint; 
    checkpoint_path: path to save checkpoint
    best_model_path: path to save best model
    """
    f_path = checkpoint_path
    # save checkpoint data to the path given, checkpoint_path
    torch.save(state, f_path)
    if is_best:
        best_fpath = best_model_path
        # copy that checkpoint file to best path given, best_model_path
        shutil.copyfile(f_path, best_fpath)

In [None]:
#Loading function

def load_ckp(checkpoint_fpath, model, optimizer):

    # load check point
    checkpoint = torch.load(checkpoint_fpath)
    # initialize state_dict from checkpoint to model
    model.load_state_dict(checkpoint['state_dict'])
    # initialize optimizer from checkpoint to optimizer
    optimizer.load_state_dict(checkpoint['optimizer'])
    # initialize valid_loss_min from checkpoint to valid_loss_min
    best_stat = checkpoint['best_stat']
    train_loss = checkpoint['train_loss']
    val_loss = checkpoint['val_loss']
    train_acc = checkpoint['train_acc']
    val_acc = checkpoint['val_acc']
    return model, optimizer, checkpoint['epoch'], best_stat, train_loss, val_loss, train_acc, val_acc

In [None]:
def metrics(lab, pred):
  lab_array = [t.numpy() for t in lab]
  pred_array = [t.numpy() for t in pred]

  pred_array = np.concatenate(pred_array, axis=0 )
  lab_array = np.concatenate(lab_array, axis=0)

  F1_score = sm.f1_score(lab_array, pred_array, average='macro', zero_division=1)
  classes_score = sm.f1_score(lab_array, pred_array, average=None, zero_division=1)
  print("Acc classes ", classes_score)
  accuracy = sm.accuracy_score(lab_array, pred_array)
  confusion_matrix = sm.confusion_matrix(lab_array, pred_array)
  
  return accuracy, F1_score, confusion_matrix

In [None]:
#STATISTIC COMPETITION
def stat_comp(F1_score, accuracy):
  stat = (0.33*accuracy) + (0.67*F1_score)
  return stat

In [None]:
def evaluate(model):

  running_val_loss = 0.0
  total = 0

  pred = []
  lab = []

  model.eval()
  print("Enter Evaluation. Is Training?", model.training)
  with torch.no_grad():
    for j, (data) in enumerate(progress_bar(validation_generator)):

      faces_val, labels_val = data
      faces_val = faces_val.to(device)
      labels_val = labels_val.to(device)

      _, outputs_val = model(faces_val)

      loss_val = criterion(outputs_val, labels_val)
      _, preds_val = torch.max(outputs_val.data, 1)

      running_val_loss += loss_val.item()*faces_val.size(0)
      
      pred.append(preds_val.cpu())
      lab.append(labels_val.cpu())
      
      total += labels_val.size(0)
          
  iteration_val_loss = running_val_loss / total
  iteration_val_acc, F1_score, cm = metrics(lab, pred)
              
  return iteration_val_loss, iteration_val_acc, F1_score, cm

In [None]:
def train(start_epochs, n_epochs, best_stat, classes, train_generator, val_generator, train_loss, val_loss, train_acc, val_acc, model, optimizer, criterion, checkpoint_path, best_model_path):
 
  class_correct = list(0. for i in range(len(classes)))
  class_total = list(0. for i in range(len(classes)))
  accumulation_step = 

  for epoch in range(start_epochs, n_epochs+1):

    running_train_loss = 0.0
    running_train_corrects = 0.0

    total_t = 0
    #Training
    model.train()
    optimizer.zero_grad()
    for k, (data) in enumerate(progress_bar(train_generator)):

        faces, labels = data
        faces = faces.to(device)
        labels = labels.to(device)

        #forward
        _,output = model(faces)

        # Compute loss  
        loss = criterion(output, labels)
         #predictions of the model determined using the torch.max() function, which returns the index of the maximum value in a tensor.
        _, preds = torch.max(output.data, 1)

        #optimizer.zero_grad()
        running_train_loss += loss.item()*faces.size(0) 
        running_train_corrects += (preds == labels).sum().item()

        total_t += labels.size(0)

        c = (preds == labels).squeeze()
        for i in range(faces.size()[0]):
          label = labels[i]
          class_correct[label] += c[i].item()
          class_total[label] += 1

        # Backpropagate the gradients
        loss.backward()
        if(k+1) % batch_accumulation == 0:
          # Update the parameters
          optimizer.step()
          optimizer.zero_grad()

          if (k + 1) % iteration_step == 0:

            # calculate average losses and accuracy train
            iteration_train_loss = running_train_loss / total_t
            iteration_train_acc = (running_train_corrects / total_t) * 100
            train_loss.append(iteration_train_loss)
            train_acc.append(iteration_train_acc)
            print("'Total after iter", total_t)

            print('---------------------Iteration: %d ---------------------' %k)

            print('Train Loss: {:.4f} Train Acc: {:.2f}%'.format(iteration_train_loss, iteration_train_acc))
            # calculate prediction of each class
            for i in range(7):
              print('Train Accuracy of %5s : %2d %%' % (classes[i], 100 * class_correct[i] / class_total[i]))

            iteration_val_loss, iteration_val_acc, F1_score, cm = evaluate(model)
            final_stat = stat_comp(F1_score, iteration_val_acc)

            val_loss.append(iteration_val_loss)
            val_acc.append(final_stat)

            print('Validation Loss: {:.4f} Validation Acc: {:.2f}'.format(iteration_val_loss, iteration_val_acc))
            print('F1_Score : {:.4f}'.format(F1_score))

            print('Final statistics: {:.4f}'.format(final_stat))
            print('_________________________________________________________')

            scheduler.step()
            print(optimizer.param_groups[0]['lr'])

            running_train_loss = 0.0
            running_train_corrects = 0.0
            total_t = 0
            print("Total before iter", total_t)
            model.train()
            print("Evaluation ended. Is training?", model.training)

            # create checkpoint variable and add important data
            checkpoint = {
                'iteration' :i + 1,
                'epoch': epoch + 1,
                'valid_loss_min': iteration_val_loss,
                'best_stat': best_stat,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'train_loss': train_loss,
                'val_loss': val_loss,
                'train_acc': train_acc,
                'val_acc':val_acc
                }
            
            # save checkpoint
            save_ckp(checkpoint, False, checkpoint_path, best_model_path)
            
            if(final_stat > best_stat):
              print('Statistic increases ({:.6f} --> {:.6f}).  Saving model ...'.format(best_stat,final_stat))
              best_stat = final_stat
              save_ckp(checkpoint, True, checkpoint_path, best_model_path)
      
  return model

In [None]:
checkpoint_path = '/checkpoint/ckp_model.pt'
best_model_path = '/best_model/best_model.pt'

In [None]:
train_loss = []
val_loss = []
train_acc = []
val_acc = []
avg_train_acc = []
avg_val_acc = []

#TRAINING AND VALIDATION 
trained_model = train(1, 20, 0.0, classes, train_generator, validation_generator, train_loss, val_loss, train_acc, val_acc, model, optimizer, criterion, checkpoint_path, best_model_path)

In [None]:
#LOAD SAVED CHECKPOINT

model_t, optimizer, start_epoch, best_stat, train_loss, val_loss, train_acc, val_acc = load_ckp(checkpoint_path,model, optimizer)