In [None]:
!unzip /content/drive/MyDrive/dataset.zip

In [None]:
!pip install timm wandb

In [4]:
!wandb login b3e9d73d3a833e55685c96273c77feaa9880ae5f

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [1]:
import argparse
import os
import math
import random
import warnings


import cv2
import numpy as np
import pandas as pd
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb


from tqdm import tqdm
from datetime import datetime
from operator import itemgetter
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import transforms

warnings.filterwarnings("ignore")

In [2]:
from google.colab.patches import cv2_imshow

In [3]:
class EmotionsDataset(Dataset):
    def __init__(self,
                 dataset_dir: str,
                 img_size = (224, 224),
                 emb_mode = False,
                 model = None,
                 sample_mode: str = 'triplet',
                 count_negatives: int = 2,
                 protocol_path='/content/drive/MyDrive',
                 prefix = 'train'
                 ):
        super(EmotionsDataset, self).__init__()

        self.dataset_root = dataset_dir
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Resize(size=img_size),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        self.emb_mode = emb_mode
        self.model = model

        self.embeddings = list()

        self.sample_mode = sample_mode
        self.n_count = count_negatives #/val_balanced_protocol.csv

        self.protocol_path = os.path.join(protocol_path, f'{prefix}_balanced_protocol.csv')
        self.protocol = pd.read_csv(self.protocol_path)

        self.images, self.names, self.labels, self.sample_nums = self._load_list(self.dataset_root)
        self.classes = {'Neutral': 0, 'Anger': 1, 'Disgust': 2, 'Fear': 3, 'Happiness': 4, 'Sadness': 5, 'Surprise': 6, 'Other': 7}

    def _load_list(self, list_root):
        samples, sample_names, sample_labels, frame_nums = list(), list(), list(), list()

        for path in self.protocol['path'].tolist():
            # path, frame, label

            sample, name, label, frame_num = self._load_samples_with_labels(path)

            samples.append(sample)
            sample_names.append(name)
            sample_labels.append(int(label))
            frame_nums.append(frame_num)

        return samples, sample_names, sample_labels, frame_nums

    def _load_samples_with_labels(self, path):
        name, class_label, frame_num = path.split('/')[-1].replace('.jpeg', '').split('_')

        if self.emb_mode:
            image = cv2.imread(path)
            #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            image = self.transform(image)

            emb = self.model.get_embeddings(image.unsqueeze(0))

            self.embeddings.append(emb[0])

        return path, name, class_label, frame_num

    def extract_embeddings(self, model):
        print(f'Extracting embeddings')
        self.embeddings = model.get_embeddings(self.images)

    def load_image(self, path):
        if type(path) != str:
            final_list = list()
            for p in path:
                image = cv2.imread(p)
                image = self.transform(image)

                final_list.append(image)

            return final_list

        image = cv2.imread(path)
        #image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        #cv2_imshow(image)
        image = self.transform(image)

        return image

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        if self.emb_mode:
            return self.embeddings[idx], self.labels[idx], self.names[idx], self.sample_nums[idx]

        if self.sample_mode == 'triplet':
            anchor, anchor_label, anchor_name = self.load_image(self.images[idx]), self.labels[idx], self.names[idx]
            #print(anchor_label)
            negative_list = list()

            for label in self.classes.values():
              if label != anchor_label:
                negative_cand = [self.images[i] for i in range(len(self.images)) if
                             (self.labels[i] == label)]

                negative_list.extend(np.random.choice(negative_cand, self.n_count, replace=False).tolist())


            positive_list = [self.images[i] for i in range(len(self.images)) if
                             (anchor_name != self.names[i]) and (self.labels[i] == anchor_label)]

            # if not positive_list:
            #    positive_list = [self.images[i] for i in range(len(self.images)) if
            #                  (idx != i) and (self.labels[i] == anchor_label)]

            positive = self.load_image(np.random.choice(positive_list, 3, replace=False).tolist())
            negative = self.load_image(negative_list)

            return anchor, positive, negative, anchor_label

        if self.sample_mode == 'arcface':
            anchor, anchor_label, anchor_name = self.load_image(self.images[idx]), int(self.labels[idx]), self.names[idx]

            return anchor, anchor_label

        return self.images[idx], int(self.labels[idx]), self.names[idx], int(self.sample_nums[idx])



In [4]:
class ArcFaceLoss(nn.Module):
    def __init__(self,
                 emb_size: int,
                 num_classes: int,
                 device: str = 'cuda',
                 s: float = 64.0,
                 m: float = 0.5,
                 eps: float = 1e-6,
                 **kwargs
                 ):
        super(ArcFaceLoss, self).__init__()

        self.in_features = emb_size
        self.out_features = num_classes

        self.s = s
        self.m = m

        self.threshold = math.pi - m
        self.eps = eps

        self.device = device

        self.weight = nn.Parameter(torch.FloatTensor(self.out_features, self.in_features))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, x: torch.Tensor, target: torch.LongTensor = None) -> torch.Tensor:
        cos_theta = F.linear(F.normalize(x), F.normalize(self.weight.to(self.device)))

        if not target.numel():
            return cos_theta

        theta = torch.acos(torch.clamp(cos_theta, -1 + self.eps, 1.0 - self.eps))

        one_hot = torch.zeros_like(cos_theta)
        one_hot.scatter_(1, target.view(-1, 1).long(), 1)

        mask = torch.where(theta > self.threshold, torch.zeros_like(one_hot), one_hot)

        logits = torch.cos(torch.where(mask.bool(), theta + self.m, theta))

        logits *= self.s

        return logits


class TripletLoss(nn.Module):
    def __init__(self, margin=1.0, **kwargs):
        super(TripletLoss, self).__init__()
        self.margin = margin

    @staticmethod
    def calc_distance(x1, x2):
        x1 = F.normalize(x1, dim=1)
        x2 = F.normalize(x2, dim=1)

        cos_sim = (x1 * x2).sum(dim=1)

        return 1 - cos_sim


    @torch.no_grad()
    def log_stuff(self, pos_scores, neg_scores, prefix):
        pos_mean = pos_scores.mean().item()
        neg_mean = neg_scores.mean().item()
        difference = pos_mean - neg_mean

        wandb.log({
            f"{prefix}_pos_mean": pos_mean,
            f"{prefix}_neg_mean": neg_mean,
            f"{prefix}_difference": difference
        })

    def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor, prefix: str) -> torch.Tensor:
        distance_positive = self.calc_distance(anchor, positive)
        distance_negative = self.calc_distance(anchor, negative)

        self.log_stuff(distance_positive, distance_negative, prefix)

        losses = torch.relu(distance_positive - distance_negative + self.margin)

        return losses.mean()


class ContrastiveCrossEntropy(nn.Module):
    def __init__(self, margin=1.0, **kwargs):
        super().__init__()
        self.margin = margin

    @torch.no_grad()
    def log_stuff(self, pos_scores, neg_scores, prefix):
        pos_mean = pos_scores.mean().item()
        neg_mean = neg_scores.mean().item()
        difference = pos_mean - neg_mean

        wandb.log({
            f"{prefix}_pos_mean": pos_mean,
            f"{prefix}_neg_mean": neg_mean,
            f"{prefix}_difference": difference
        })

    def forward(self, vac_emb: torch.Tensor, pos_emb: torch.Tensor, neg_emb: torch.Tensor, prefix: str) -> torch.Tensor:
        pos_scores = (vac_emb * pos_emb).sum(dim=1)
        neg_scores = (vac_emb * neg_emb).sum(dim=1)

        self.log_stuff(pos_scores, neg_scores, prefix)

        loss_val = torch.exp(neg_scores + self.margin) - pos_scores

        # loss_val = torch.clamp(loss_val, min=1.001, max=2**16)
        loss_val[loss_val < 1] = 1

        return loss_val.log().mean()

In [5]:
class ImageEmbedder(nn.Module):
    def __init__(self,
                 model_path: str,
                 embedding_size: int = 512,
                 freeze: bool = False,
                 device: str = 'cpu',
                 normalize: bool = False):
        super().__init__()

        #self.base_model = timm.create_model('efficientnet_b1', pretrained=True)
        self.base_model = torch.load(model_path, map_location=torch.device(device))

        self.internal_embedding_size = self.base_model.classifier[0].in_features
        self.base_model.classifier = nn.Linear(in_features=self.internal_embedding_size, out_features=embedding_size)
        self.normalize = normalize

        if freeze:
            for param in self.base_model.parameters():
                param.requires_grad = False
        else:
            for param in self.base_model.parameters():
                param.requires_grad = True
        self.base_model.classifier.requires_grad_(True)

        self.base_model.to(device)

    def embed_image(self, image):
        return self.base_model(image)

    def save(self, model_path):
        torch.save(self.base_model.state_dict(), os.path.join(model_path, 'model.pth'))

    @torch.no_grad()
    def get_embeddings(self, image):

        out = self.base_model.forward(image)

        if self.normalize:
            out = F.normalize(out, dim=1)

        return out.cpu().numpy()

    def forward(self, x):
        embedding = self.embed_image(x)

        if self.normalize:
            embedding = F.normalize(embedding, dim=-1)

        return embedding



In [6]:
class EmotionsTrainer:
    def __init__(
            self, model, checkpoint_dir,
            train_dataloader, dev_dataloader, test_dataloader,
            optimizer, optimizer_arc, scheduler, loss,
            device='cuda', save_best=False, sample_mode='triplet'
    ):
        self.model = model
        self.device = device
        self.checkpoint_dir = checkpoint_dir
        self.save_best = save_best

        self.train_dataloader = train_dataloader
        self.dev_dataloader = dev_dataloader
        self.test_dataloader = test_dataloader

        self.optimizer = optimizer
        self.optimizer_arc = optimizer_arc
        self.scheduler = scheduler
        self.loss = loss

        self.experiment_dir = os.path.join(self.checkpoint_dir,
                                           self.model.__class__.__name__+ f'_{datetime.today().strftime("%Y_%m_%d")}')
        self.model_dir = os.path.join(self.experiment_dir, 'model')

        self.eval_step = 0
        self.sample_mode = sample_mode

        os.makedirs(self.experiment_dir, exist_ok=True)
        os.makedirs(os.path.join(self.model_dir), exist_ok=True)

    def train_epoch(self, epoch):
        self.model.train()

        epoch_loss = 0
        with tqdm(total=len(self.train_dataloader)) as pbar:
          if self.sample_mode=='triplet':
            for i, (anchor, positive, negative, class_label) in enumerate(self.train_dataloader):
                anchor = anchor.to(self.device)
                anchor_emb = self.model(anchor)

                loss = torch.tensor(0.0).to(self.device)

                for neg in negative:
                    neg = neg.to(self.device)
                    negative_emb = self.model(neg)
                    for pos in positive:
                      pos = pos.to(self.device)
                      positive_emb = self.model(pos)

                      loss += self.loss(anchor_emb, positive_emb, negative_emb, "train")

                loss.backward(loss)

                self._optimizer_step()

                wandb.log({"train_loss": loss.item()})

                pbar.set_description('Epoch {} - current loss: {:.4f}'.format(epoch, loss.item()))
                pbar.update(1)

          if self.sample_mode=='arcface':
            for i, (anchor, class_label) in enumerate(self.train_dataloader):
                anchor = anchor.to(self.device)
                class_label = class_label.to(self.device)

                anchor_emb = self.model(anchor)

                loss = self.loss(anchor_emb, class_label).sum()

                loss.backward(loss)

                self._optimizer_step()
                self.optimizer_arc.step()
                self.optimizer_arc.zero_grad()

                wandb.log({"train_loss": loss.sum().item()})

                pbar.set_description('Epoch {} - current loss: {:.4f}'.format(epoch, loss.sum().item()))
                pbar.update(1)


        return epoch_loss / len(self.train_dataloader)

    @torch.inference_mode()
    def val_epoch(self):
        self.model.eval()

        val_loss = 0
        with tqdm(total=len(self.dev_dataloader)) as pbar:
          if self.sample_mode=='triplet':
            for i, (anchor, positive, negative, class_label) in enumerate(self.dev_dataloader):
                anchor = anchor.to(self.device)
                anchor_emb = self.model(anchor)

                loss = torch.tensor(0.0).to(self.device)
                for neg in negative:
                    neg = neg.to(self.device)
                    negative_emb = self.model(neg)
                    for pos in positive:
                      pos = pos.to(self.device)
                      positive_emb = self.model(pos)

                      loss += self.loss(anchor_emb, positive_emb, negative_emb, "eval")

                val_loss += loss.item()

          if self.sample_mode=='arcface':
            for i, (anchor, class_label) in enumerate(self.dev_dataloader):
                anchor = anchor.to(self.device)
                class_label = class_label.to(self.device)

                anchor_emb = self.model(anchor)

                loss = self.loss(anchor_emb, class_label).sum()

                val_loss += loss.item()

        return val_loss / len(self.dev_dataloader)

    def train(self, num_epochs):
        best_val_loss = float('inf')

        for epoch in range(num_epochs):
            loss = self.train_epoch(epoch)
            print(f'Epoch {epoch} - loss {loss}')
            wandb.log({"train_epoch_loss": loss})

            val_loss = self.val_epoch()
            print(f'Epoch {epoch} - validation loss {val_loss}')
            wandb.log({"val_loss": val_loss})
#            self.eval()
            self._write_checkpoint(val_loss, best_val_loss)

    def _optimizer_step(self):
        self.optimizer.step()
        self.scheduler.step()

        self.optimizer.zero_grad()

    def _write_checkpoint(self, val_loss, best_val_loss):
        if self.save_best:
            if val_loss < best_val_loss:
                self.model.save(self.model_dir)

        else:
            save_dir = os.path.join(self.experiment_dir, f'loss_{val_loss}')
            os.makedirs(save_dir)
            self.model.save(save_dir)


In [10]:
from operator import itemgetter

In [9]:
def to_tens(x):
    for i in range(len(x[0])):
        yield torch.stack(list(map(itemgetter(i), x)))

In [8]:
import os
import random
import shutil

import numpy as np
import torch

from random import choices


def set_all_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


def parce_dataset(source_path: str, dest_path: str, val_size: float = 0.2) -> None:
    source_paths = [path for path in os.listdir(source_path) if path not in ['test', 'train', 'val', '.DS_Store']]

    for directory in source_paths:
        directory_path = os.path.join(source_path, directory)

        for file in os.listdir(directory_path):
            part = 'val' if choices([0, 1], [1 - val_size, val_size])[0] else 'train'
            destination_dir = os.path.join(dest_path, part)

            src_path, move_path = os.path.join(directory_path, file), os.path.join(destination_dir, file)

            shutil.move(src_path, move_path)


def collate_fn(batch):
    anchor = torch.stack(list(map(lambda x: x[0], batch)))
    positive = list(to_tens(list(map(lambda x: x[1], batch))))
    negative = list(to_tens(list(map(lambda x: x[2], batch))))
    target = torch.tensor(list(map(lambda x: x[3], batch)))

    return [anchor, positive, negative, target]

def collate_fn_arcface(batch):
    anchor = torch.stack(list(map(lambda x: x[0], batch)))
    target = torch.tensor(list(map(lambda x: x[1], batch)))

    return [anchor, target]


def cosine_similarity(emb1: np.ndarray, emb2: np.array) -> np.ndarray:
    """
    Batched cosine similarity for normalized vectors.
    :param emb1: (n, dim)
    :param emb2: (n, dim)
    :return: (n)
    """

    return np.sum(emb1 * emb2, axis=1)



# Train

In [7]:
def parse_argus():
    parser = argparse.ArgumentParser()

    parser.add_argument('--dataset-dir', default='/content/dataset/train')
    parser.add_argument('--val-dataset-dir', default='/content/dataset/val')
    parser.add_argument('--test-dataset-dir', default='/content/dataset/test')
    parser.add_argument('--checkpoint-dir', default='/content/drive/MyDrive/experiments', help='Checkpoint directory')
    parser.add_argument('--model-path', default='/content/drive/MyDrive/enet_b0_8_best_vgaf.pt',
                        help='Model directory')
    parser.add_argument('--model-class', default='ImageEmbedder', help='')

    parser.add_argument('--epochs', type=int, default=5, help='Number of epochs for training')
    parser.add_argument('--batch-size', type=int, default=8, help='Number of examples for each iteration')
    parser.add_argument('--accumulate-batches', type=int, default=1, help='Number of batches to accumulate')
    parser.add_argument('--learning-rate', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--optim-betas', type=list, nargs='+', default=[0.9, 0.999], help='Optimizer betas')
    parser.add_argument('--weight-decay', type=float, default=0.01, help='Optimizer weight decay')
    parser.add_argument('--threshold_chooser', type=str, default='accuracy',
                        help='Could be either max_f1, eer, accuracy')

    parser.add_argument('--loss', type=str, default='TripletLoss', help='Could be either ArcFaceLoss or TripletLoss')

    parser.add_argument('--seed', type=int, default=1004, help='Random seed value')
    parser.add_argument('--checkpoint-iter', type=int, default=5000, help='Eval and checkpoint frequency.')
    parser.add_argument('--scale-scores', type=bool, default=True,
                        help='Scale cosine similarity to [0, 1] for a better score interpretability')
    parser.add_argument('--device', default='cuda', help='Device to use for training: cpu or cuda')

    return parser.parse_args(args=[])


In [11]:
args = parse_argus()

In [12]:
wandb.init(
    # set the wandb project where this run will be logged
    project="Emotions_Recognition",

    # track hyperparameters and run metadata
    config={
    "learning_rate": args.learning_rate,
    "architecture": args.model_path,
    "loss": args.loss
    }
)

[34m[1mwandb[0m: Currently logged in as: [33m412549[0m ([33mnotn3ss_team[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
if args.seed is not None:
    set_all_seeds(args.seed)

model = eval(args.model_class)(
    model_path=args.model_path,
    device=args.device
)

train_dataset = EmotionsDataset(
    dataset_dir=args.dataset_dir
)
dev_dataset = EmotionsDataset(
    dataset_dir=args.val_dataset_dir,
)
test_dataset = EmotionsDataset(
    dataset_dir=args.test_dataset_dir,
)

train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=6,
                              drop_last=True, collate_fn=collate_fn)
dev_dataloader = DataLoader(dev_dataset, batch_size=args.batch_size, shuffle=False, num_workers=6,
                            collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=6,
                              collate_fn=collate_fn)

#Emotion (01 = neutral, 02 = calm, 03 = happy, 04 = sad, 05 = angry, 06 = fearful, 07 = disgust, 08 = surprised)

loss = eval(args.loss)(
    emb_size=512,
    num_classes=8,
    device=args.device
)

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=args.learning_rate,
    betas=args.optim_betas,
    weight_decay=args.weight_decay
)

if args.loss == 'ArcFaceLoss':
  optimizer_arc = torch.optim.AdamW(
    loss.parameters(),
    lr=args.learning_rate,
    betas=args.optim_betas,
    weight_decay=args.weight_decay
  )
else:
  optimizer_arc = None

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    len(train_dataset) * args.epochs / args.batch_size,
    eta_min=1e-6
)

trainer = EmotionsTrainer(
    model=model,
    checkpoint_dir=args.checkpoint_dir,
    train_dataloader=train_dataloader,
    dev_dataloader=dev_dataloader,
    test_dataloader=test_dataloader,
    optimizer=optimizer,
    optimizer_arc=optimizer_arc,
    scheduler=scheduler,
    loss=loss,
    device=args.device
)

trainer.train(args.epochs)

wandb.finish()

Epoch 0 - current loss: 34.8152:   1%|          | 26/2153 [00:52<1:07:24,  1.90s/it]