# Config

In [4]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import json
import os 
import time
from tqdm import tqdm
import matplotlib.pyplot as plt 
from matplotlib.image import imread
from math import log, ceil, sqrt
import torch.nn as nn
import torchvision
import torchinfo
import torch
import wandb
from sklearn.metrics import accuracy_score, roc_auc_score
from typing import Dict
import torch.optim as optim
import torch.optim.lr_scheduler as lr_sched
from torchinfo import summary

class CONFIG:

    ROOT_DIRECTORY = os.path.join("..","data", "WLASL")
    JSON_FILE = "WLASL_v0.3.json"
    NSLT_FILE = "nslt_100.json"
    VIDEO_FOLDER = "videos"
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    #mean = [0, 0, 0]
    #std = [1,1,1]
    DEBUG = True
    
    P_OF_TRANSFORM = 0.9
    P_OF_TRANSFORM_COLOR = 0.2
    
    SHIFT_LIMIT=0.1
    SCALE_LIMIT=0.1
    ROTATE_LIMIT=10
    
    # set to small, when prototyping, or 0 when deploying to cloud or PC with loads of RAM
    DATA_LIMIT = 100
    FRAME_SIZE = 30
    
    
    DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    NUM_WORKERS = 0
    ROUND_NUMBER = 3
    TASK_NAME = "WLASL_RGB"

    BATCH_SIZE = 4
    PORTION_OF_DATA_FOR_TRAINING = 0.8
    PIN_MEMORY = False
    
torch.backends.cudnn.benchmark = True
print(f"Device : {CONFIG.DEVICE}")    

Device : cuda


In [19]:
wandb.finish()
wandb.login(relogin=True)

VBox(children=(Label(value='0.020 MB of 0.020 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,█▁
train_ROC_AUC_ovo,█▁
train_ROC_AUC_ovr,█▁
train_acc,▁▁
val_ROC_AUC_ovo,█▁
val_ROC_AUC_ovr,█▁
val_acc,▁▁

0,1
loss,1.92663
train_ROC_AUC_ovo,0.44922
train_ROC_AUC_ovr,0.44855
train_acc,0.1875
val_ROC_AUC_ovo,0.61574
val_ROC_AUC_ovr,0.59143
val_acc,0.05


[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: C:\Users\mlewand\.netrc


True

# Read the necessary files

In [5]:
# setup the paths
video_path = os.path.join(CONFIG.ROOT_DIRECTORY, CONFIG.VIDEO_FOLDER)
dataset_description = os.path.join(CONFIG.ROOT_DIRECTORY)

# load the filepaths for videos
video_paths = [os.path.join(video_path, file) for file in os.listdir(video_path)]

# load the dataset config json
config_json = None
with open(os.path.join(CONFIG.ROOT_DIRECTORY, CONFIG.JSON_FILE)) as f:
    config_json = json.load(f)
    
# load the dataset json
dataset_json = None
with open(os.path.join(CONFIG.ROOT_DIRECTORY, CONFIG.NSLT_FILE)) as f:
    dataset_json = json.load(f)

print(f"the dataset consists of {len(dataset_json.keys())} videos")
print(f"there are {len(video_paths)} videos in total")

the dataset consists of 2038 videos
there are 11980 videos in total


# Util functions

In [6]:
def Format(x, rnd_digits=3):
    prefix = ""
    if 1024 <= x < 1024**2:
        prefix = "K"
        x /= 1024
    elif 1024**2 <= x < 1024**3:
        prefix = "M"
        x /= 1024**2
    elif x >= 1024**3:
        prefix = "G"
        x /= 1024**3
    
    return f"{round(x,rnd_digits)}{prefix}B"


## Regular Dataset : store the data on disk and load it from there

In [7]:
import os 
import torch
import random
import numpy as np
from tqdm import tqdm
from typing import List
import cv2

import torch
import torch.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
        
class SignRecognitionDataset(Dataset):

    def __init__(self, max_start : int, max_end) -> None:
        # setup the paths
        video_path = os.path.join(CONFIG.ROOT_DIRECTORY, CONFIG.VIDEO_FOLDER)
        dataset_description = os.path.join(CONFIG.ROOT_DIRECTORY)

        # load the filepaths for videos
        self.video_paths = [os.path.join(video_path, file) for file in os.listdir(video_path)]

        # load the dataset config json
        self.config_json = None
        with open(os.path.join(CONFIG.ROOT_DIRECTORY, CONFIG.JSON_FILE)) as f:
            self.config_json = json.load(f)

        # load the dataset json
        self.dataset_json = None
        with open(os.path.join(CONFIG.ROOT_DIRECTORY, CONFIG.NSLT_FILE)) as f:
            self.dataset_json = json.load(f)
        
        self.videos_paths = []
        self.paths_not_found = []
        self.labels = []
        self.start_frames = []
        self.end_frames = []
        

        for el in tqdm(dataset_json.items()):
            video_id, properties = el[0], el[1]
            path = os.path.join(video_path, video_id + ".mp4")
            
            if not os.path.exists(path):
                self.paths_not_found.append(path)
                continue

            subset = properties["subset"]
            label, start, end = properties["action"]
            
            if start > max_start:
                continue
                
            if end > max_end:
                continue
            
            self.videos_paths.append(path)
            self.labels.append(label)
            self.start_frames.append(start)
            self.end_frames.append(end)
    
        self.videos_paths = np.array(self.video_paths)
        self.paths_not_found = np.array(self.paths_not_found)
        self.labels = np.array(self.labels)
        self.start_frames = np.array(self.start_frames)
        self.end_frames = np.array(self.end_frames)

        self.unique_labels = np.unique(self.labels)
        
    
    def preprocess_trajectory(self, traj : List[np.ndarray]):
        return traj
    
    def __len__(self):
        return len(self.videos_paths)

    def __getitem__(self, idx):
        path, label = self.videos_paths[idx], self.labels[idx]
        trajectory = SignRecognitionDataset.get_video(path)
        
        return self.preprocess_trajectory(trajectory), label
        
        
    @staticmethod
    def get_video(video_path : str) -> List[np.ndarray]:

        if not os.path.exists(video_path):
            return None

        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            return None

        frames = []
        while cap.isOpened():
            ret, frame = cap.read()
            if ret:
                frame_array = np.array(frame)
                frames.append(cv2.cvtColor(frame_array, cv2.COLOR_BGR2RGB))
            else:
                break

        cap.release()
        return np.array(frames)

    @staticmethod
    def rescale_video(frames : np.ndarray, desired_shape) -> np.ndarray:
        refined = []
        for img in frames: 
            y, x, c = img.shape
            cropped = img[:, (x // 2 - y//2) : (x // 2 + y//2), :]
            scaled = cv2.resize(cropped, desired_shape)
            refined.append(scaled)
    
        return np.array(refined)
        
    
ds = SignRecognitionDataset(max_start=1, max_end=150)
print(f"after filtering : size of dataset={len(ds)}")

100%|██████████| 2038/2038 [00:00<00:00, 14949.66it/s]

after filtering : size of dataset=11980





## Cached Dataset

In [8]:
from typing import Tuple
from collections import Counter
import albumentations as A
from albumentations.pytorch import ToTensorV2

class SignRecognitionDatasetCached(SignRecognitionDataset):

    def __init__(self, max_start: int, max_end, per_image_transform=None,
                 scaled_resolution : Tuple[int]= (240, 240),
                 frame_size:int=CONFIG.FRAME_SIZE,
                 data_limit : int = CONFIG.DATA_LIMIT,
                 by_size=True) -> None:
        super().__init__(max_start, max_end)
        self.scaled_resolution = scaled_resolution
        self.DATA_LIMIT = data_limit
        self.by_size = by_size
        self.cache_data()
        self.per_image_transform = per_image_transform
        self.FRAME_SIZE = frame_size
        
        self.keywords = ["image" ] + list(str(i) for i in range(frame_size-1))
        
    def preprocess_trajectory(self, traj : List[np.ndarray]):
        return SignRecognitionDataset.rescale_video(traj, self.scaled_resolution)

    def permutate(self):
        l = len(self.videos_paths)
        mask = np.arange(l)
        np.random.shuffle(mask)
        
        self.videos_paths = np.array(self.videos_paths)[mask]
        self.labels = np.array(self.labels)[mask]
        self.start_frames = np.array(self.start_frames)[mask]
        self.end_frames = np.array(self.end_frames)[mask]
        
    def sort_by_size(self):
        c = Counter(self.labels)
        _mask = sorted([ (10000 * c[l] + l, i) for i, l in enumerate(self.labels)])[::-1]
        mask_by_size = np.array([el[1] for el in _mask])

        self.video_paths        = np.array(self.video_paths)[mask_by_size]
        self.labels             = np.array(self.labels)[mask_by_size]
        self.start_frames       = np.array(self.start_frames)[mask_by_size]
        self.end_frames         = np.array(self.end_frames)[mask_by_size]

 
    def cache_data(self):
        if not self.by_size:
            self.permutate()
        else:
            self.sort_by_size()
        
        self.cached_data_x = []
        self.cached_data_y = []
        
        pbar = tqdm(range(self.DATA_LIMIT))
        pbar.set_description("Loading/scaling trajectories")
        for i in pbar:
            trajectory = SignRecognitionDataset.get_video(self.videos_paths[i])            
            trajectory = self.preprocess_trajectory(trajectory)

            self.cached_data_x.append(trajectory)
            self.cached_data_y.append(self.labels[i])
        
        self.unique_labels = np.unique(self.cached_data_y)
        self.label_to_new_id = {self.unique_labels[i] : i for i in range(len(self.unique_labels))}
        self.new_id_to_label = {v : k for k,v in self.label_to_new_id.items()}
        
        for i in range(len(self.cached_data_y)):
            self.cached_data_y[i] = self.label_to_new_id[self.cached_data_y[i]]
        
    def __len__(self):
        return len(self.cached_data_x)

    def crop_video(self, trajectory : np.array) -> np.array:
        cropped = trajectory
        
        if self.FRAME_SIZE != 0:
            frame_size = len(trajectory)
            start = 0 
            
            if frame_size > self.FRAME_SIZE:
                start = np.random.randint(0, frame_size - self.FRAME_SIZE)
            cropped = trajectory[start: (start + self.FRAME_SIZE)]
                    
            if len(cropped) < self.FRAME_SIZE:
                necessary = self.FRAME_SIZE - len(cropped)
                t, h, w, c = trajectory.shape
                cropped = np.concatenate([cropped, np.zeros((necessary, h, w, c))], axis= 0)
                
            return cropped
                        
        return trajectory


    def __getitem__(self, idx):
        trajectory, label = self.crop_video(self.cached_data_x[idx]), self.cached_data_y[idx]

        if self.per_image_transform is not None:            
            frames = {self.keywords[i] : frame for i, frame in enumerate(trajectory)}
            processing = self.per_image_transform(**frames)       

            return np.stack([processing[kw] for kw in self.keywords]), label    
        return torch.Tensor(trajectory), label
    

transform = A.Compose(
    [
        A.Normalize(mean=CONFIG.mean, std=CONFIG.std),
        A.HorizontalFlip(p=CONFIG.P_OF_TRANSFORM),
        A.ShiftScaleRotate(p=CONFIG.P_OF_TRANSFORM, shift_limit=0.1, scale_limit=0.1, rotate_limit=10),
        #A.RandomBrightnessContrast(p=0.2),
        #A.RGBShift(p=0.2),
        ToTensorV2()
    ],
    additional_targets={str(i) : "image" for i in range(CONFIG.FRAME_SIZE)}
)
    

ds = SignRecognitionDatasetCached(max_start=1, max_end=150,
                                  data_limit=CONFIG.DATA_LIMIT, per_image_transform=transform)

100%|██████████| 2038/2038 [00:00<00:00, 48506.95it/s]
Loading/scaling trajectories: 100%|██████████| 100/100 [00:22<00:00,  4.36it/s]


In [9]:
from typing import Tuple
from collections import Counter
import albumentations as A
from albumentations.pytorch import ToTensorV2

class SignRecognitionDatasetCachedMHI(SignRecognitionDataset):

    def __init__(self, max_start: int, max_end, 
                 per_image_transform=None,
                 after_MHI_transform=None,
                 scaled_resolution : Tuple[int]= (240, 240),
                 frame_size:int=CONFIG.FRAME_SIZE,
                 data_limit : int = CONFIG.DATA_LIMIT,
                 decay : float = 0.7,
                 threshold_method : str = "regular",
                 threshold_val : int = 0.3 * 255,
                 by_size=True) -> None:
        super().__init__(max_start, max_end)
        self.scaled_resolution = scaled_resolution
        self.DATA_LIMIT = data_limit
        self.by_size = by_size
        self.cache_data()
        self.per_image_transform = per_image_transform
        self.after_MHI_transform = after_MHI_transform
        self.FRAME_SIZE = frame_size
        
        self.keywords = ["image" ] + list(str(i) for i in range(frame_size-1))
        
    def preprocess_trajectory(self, traj : List[np.ndarray]):
        return SignRecognitionDataset.rescale_video(traj, self.scaled_resolution)

    def permutate(self):
        l = len(self.videos_paths)
        mask = np.arange(l)
        np.random.shuffle(mask)
        
        self.videos_paths = np.array(self.videos_paths)[mask]
        self.labels = np.array(self.labels)[mask]
        self.start_frames = np.array(self.start_frames)[mask]
        self.end_frames = np.array(self.end_frames)[mask]
        
    def sort_by_size(self):
        c = Counter(self.labels)
        _mask = sorted([ (10000 * c[l] + l, i) for i, l in enumerate(self.labels)])[::-1]
        mask_by_size = np.array([el[1] for el in _mask])

        self.video_paths        = np.array(self.video_paths)[mask_by_size]
        self.labels             = np.array(self.labels)[mask_by_size]
        self.start_frames       = np.array(self.start_frames)[mask_by_size]
        self.end_frames         = np.array(self.end_frames)[mask_by_size]

 
    def cache_data(self):
        if not self.by_size:
            self.permutate()
        else:
            self.sort_by_size()
        
        self.cached_data_x = []
        self.cached_data_y = []
        
        pbar = tqdm(range(self.DATA_LIMIT))
        pbar.set_description("Loading/scaling trajectories")
        for i in pbar:
            trajectory = SignRecognitionDataset.get_video(self.videos_paths[i])            
            trajectory = self.preprocess_trajectory(trajectory)

            self.cached_data_x.append(trajectory)
            self.cached_data_y.append(self.labels[i])
        
        self.unique_labels = np.unique(self.cached_data_y)
        self.label_to_new_id = {self.unique_labels[i] : i for i in range(len(self.unique_labels))}
        self.new_id_to_label = {v : k for k,v in self.label_to_new_id.items()}
        
        for i in range(len(self.cached_data_y)):
            self.cached_data_y[i] = self.label_to_new_id[self.cached_data_y[i]]
        
    def __len__(self):
        return len(self.cached_data_x)

    def crop_video(self, trajectory : np.array) -> np.array:
        cropped = trajectory
        
        if self.FRAME_SIZE != 0:
            frame_size = len(trajectory)
            start = 0 
            
            if frame_size > self.FRAME_SIZE:
                start = np.random.randint(0, frame_size - self.FRAME_SIZE)
            cropped = trajectory[start: (start + self.FRAME_SIZE)]
                    
            if len(cropped) < self.FRAME_SIZE:
                necessary = self.FRAME_SIZE - len(cropped)
                t, h, w, c = trajectory.shape
                cropped = np.concatenate([cropped, np.zeros((necessary, h, w, c))], axis= 0)
                
            return cropped
                        
        return trajectory


    def __getitem__(self, idx):
        trajectory, label = self.crop_video(self.cached_data_x[idx]), self.cached_data_y[idx]

        if self.per_image_transform is not None:            
            frames = {self.keywords[i] : frame for i, frame in enumerate(trajectory)}
            processing = self.per_image_transform(**frames)       

            frames = np.array([processing[kw] for kw in self.keywords])    
        else:
            frames = torch.Tensor(trajectory)         
            
        return torch.Tensor(frames), label
    

transform = A.Compose(
    [
        A.Normalize(mean=CONFIG.mean, std=CONFIG.std),
        A.HorizontalFlip(p=CONFIG.P_OF_TRANSFORM),
        A.ShiftScaleRotate(p=CONFIG.P_OF_TRANSFORM, shift_limit=CONFIG.SHIFT_LIMIT, 
                           scale_limit=CONFIG.SCALE_LIMIT, rotate_limit=CONFIG.ROTATE_LIMIT),
        #A.RandomBrightnessContrast(p=CONFIG.P_OF_TRANSFORM_COLOR),
        #A.RGBShift(p=CONFIG.P_OF_TRANSFORM_COLOR),
        ToTensorV2()
    ],
    additional_targets={str(i) : "image" for i in range(CONFIG.FRAME_SIZE)}
)

# Training 

In [10]:
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, DataLoader, random_split

ds = SignRecognitionDatasetCached(max_start=1, max_end=150, data_limit=CONFIG.DATA_LIMIT, per_image_transform=transform)

# do train/val split
dataset_size = len(ds)
train_size = int(dataset_size * CONFIG.PORTION_OF_DATA_FOR_TRAINING)
val_size = dataset_size - train_size

print(f"splitting into : {train_size} {val_size}")   

# splitting dataset 
train_dataset, val_dataset = random_split(ds, [train_size, val_size])

print(f"sizes of datasets : len(train)={len(train_dataset)} len(val)={len(val_dataset)}")

# Create Dataloaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG.BATCH_SIZE, pin_memory=CONFIG.PIN_MEMORY, num_workers=CONFIG.NUM_WORKERS,  shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=CONFIG.BATCH_SIZE, pin_memory=CONFIG.PIN_MEMORY, num_workers=CONFIG.NUM_WORKERS)

100%|██████████| 2038/2038 [00:00<00:00, 43746.12it/s]
Loading/scaling trajectories: 100%|██████████| 100/100 [00:19<00:00,  5.00it/s]

splitting into : 80 20
sizes of datasets : len(train)=80 len(val)=20





In [11]:
def test(model : nn.Module, val_dataset : DataLoader, cfg : CONFIG,   run = None):
    # change the model to evaluation
    model.eval()
    
    # get the number of datapoints
    number_of_datapoints = len(val_dataset.dataset)    

    # allocate the memory for these datapoints (no need to keep appending the data, which will make it slower)
    predictions_prob = np.zeros((number_of_datapoints, len(val_dataset.dataset.dataset.unique_labels)))
    predictions = np.zeros(number_of_datapoints)
    true_values = np.zeros(number_of_datapoints) 
    
    # get the number of batches
    dataset_len = len(val_dataset)

    # create the progreess bar 
    pbar = tqdm(val_dataset)

    # variable that will track where we are in terms of all data (after iteration add batch size to it)
    c = 0
    
    for i, (x,y) in enumerate(pbar): 
        # get the predictions
        pred = model(x.to(cfg.DEVICE).float())
    
        # get the batch size
        bs = x.shape[0]

        true_values[c : (c + bs)] = y.detach().numpy()
        predictions_prob[c : (c + bs)] = torch.softmax(pred.cpu().detach(), dim=1).numpy()
        predictions[c : (c + bs)] = torch.argmax(pred, 1).cpu().detach().numpy()
        c += bs 
            
        if i % max(dataset_len//10, 1) == 0 or i == dataset_len -1:
            acc = accuracy_score(predictions[:c], true_values[:c])
            try:
                roc_auc = roc_auc_score(true_values[:c], predictions_prob[:c, :], multi_class='ovr')            
                
            # It can happen at the beginning
            except Exception as e:
                roc_auc = 0

            pbar.set_description(f"examples seen so far : {c}, accuracy = {round(acc, cfg.ROUND_NUMBER)}, AUC ROC = {round(roc_auc, CONFIG.ROUND_NUMBER)}")
    
    return {"predition_prob" : predictions_prob, "predictions" : predictions, "true" : true_values}

def report_metrics(results : Dict, epoch : int, WANDB_ON : bool = True, prefix="val", run=None) -> Dict:
    predictions = results["predictions"]
    true_values = results["true"]
    predictions_prob = results["predition_prob"]
    
    acc = accuracy_score(predictions, true_values)
    roc_auc_ovr = roc_auc_score(true_values, predictions_prob, multi_class='ovr')            
    roc_auc_ovo = roc_auc_score(true_values, predictions_prob, multi_class='ovo')  
    
    if WANDB_ON:
        wandb.log({f"{prefix}_acc": acc, f"{prefix}_ROC_AUC_ovr": roc_auc_ovr, f"{prefix}_ROC_AUC_ovo" : roc_auc_ovo})
        wandb.log({f"{prefix}_ROC_epoch={epoch}" : wandb.plot.roc_curve(true_values, predictions_prob)})
    
    return {"accuracy" : acc, "ROC_AUC_OVR" :  roc_auc_ovr, "ROC_AUC_OVO" : roc_auc_ovo}

def save_model(model : nn.Module, metrics_results : Dict, metric_keyword : str, best_metric : float, savepath : str):
    
    if metrics_results[metric_keyword] > best_metric:
        print(f"Saving metric with {metric_keyword}={metrics_results[metric_keyword]} (previous : {best_metric})")
        torch.save(model.state_dict(), savepath)
        
    return max(metrics_results[metric_keyword], best_metric)

def train(train_dataloader : torch.utils.data.DataLoader, 
          model : nn.Module, 
          optimizer : optim.Optimizer, 
          scheduler : lr_sched.LRScheduler, 
          criterion, 
          epoch : int, 
          cfg : CONFIG,
          categorical_cast : bool  = True,
          WANDB_ON : bool=True):
    model.train()
    
    running_loss = 0.0
    i = 1
    train_len = len(train_dataloader)
    
    pb = tqdm(train_dataloader)
    for inputs, labels in pb:
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        inputs = inputs.to(cfg.DEVICE).float()
        if categorical_cast:
            labels = labels.to(cfg.DEVICE).long()
        else:
            labels = labels.to(cfg.DEVICE).float()
        
        #with torch.autocast(device_type="cuda"):
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, labels)
                        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
                       
        running_loss += loss.item()
            
        if (i-1) % (train_len//10) == 0 or i == train_len:   
            lr = 0
            cnt = 0
            for param_group in optimizer.param_groups:
                learning_rate = param_group['lr']
                lr += learning_rate
                    
                cnt += 1
                
            pb.set_description(f"EPOCH : {epoch}, average loss : {running_loss / i}, lr={lr / cnt}")
        i += 1
    
    if WANDB_ON:
        wandb.log({"loss" : running_loss/len(train_dataloader)})

def run_experiment(train_dataloader : torch.utils.data.DataLoader,
                   val_dataloader : torch.utils.data.DataLoader,
                   Model : nn.Module, 
                   run_name : str, 
                   model_parameters : dict, 
                   epochs : int, 
                   learning_rate : float, 
                   optimizer : str, 
                   savepath : str,
                   cfg : CONFIG,
                   saved_path_file : str = None,
                   min_lr:float=1e-5, 
                   cosine_annealer_epochs=20,
                   scheduler_en : bool = True,
                   metric_keyword : str = "acc",
                   lr_steps : int = 1000,
                   WANDB_ON : bool = True):

    try:
        os.mkdir("models") 
    except FileExistsError:
        pass
    
    model = Model(**model_parameters).to("cuda")
    if saved_path_file is not None and os.path.exists(saved_path_file):
        model.load_state_dict(torch.load(saved_path_file))
        print("loaded state dict!")
    
    config = {"model name" : model.__class__,
              "run name" : run_name,
              "epochs" : epochs,
              "learning rate" : learning_rate,
              "optimizer" : optimizer, 
              "uses scheduler" : scheduler_en,
              "min_lr" : min_lr,
              "lr_steps" : lr_steps}
    
    config.update(model_parameters)    
            
    if WANDB_ON:
        run = wandb.init(project=cfg.TASK_NAME,
                     name=f"experiment_{run_name}",
                     notes="Model summary : \n" + str(model),
                     config=config)

    
    criterion = nn.CrossEntropyLoss()

    # Observe that all parameters are being optimized
    if optimizer.lower() == "adam":
        optimizer_ = optim.Adam(model.parameters(), lr=learning_rate)
    elif optimizer.lower() == "adamw":
        optimizer_ = optim.AdamW(model.parameters(), lr=learning_rate)
    else:
        raise Exception("specify correctly the optimizer !")

    if scheduler_en:
        scheduler = lr_sched.CosineAnnealingLR(optimizer_, cosine_annealer_epochs, eta_min=min_lr)

    best_metric = 0 
    class_labels = val_loader.dataset.dataset.label_to_new_id
    
    for epoch in range(epochs):
        train(train_dataloader, model, optimizer_, scheduler, criterion, epoch=epoch, WANDB_ON=WANDB_ON, cfg=cfg)

        test_res = test(model, train_dataloader, cfg=cfg)
        evaluation = report_metrics(test_res, epoch=epoch, #class_labels=class_labels,
                                    prefix="train", WANDB_ON=WANDB_ON)
            
        test_res = test(model, val_dataloader, cfg=cfg)
        evaluation = report_metrics(test_res, epoch=epoch, #class_labels=class_labels, 
                                    prefix="val", WANDB_ON=WANDB_ON)

        best_metric = save_model(model, evaluation, metric_keyword, best_metric, savepath)
            
        if scheduler_en:
            scheduler.step()
    
    if WANDB_ON:
        wandb.finish()


# Resnet18 + LSTM

In [12]:
class Resnet_plus_LSTM(nn.Module):
    def __init__(self, num_classes, hid_size):
        super().__init__() 
        self.backbone = torchvision.models.resnet18()
        self.backbone.fc = nn.Identity() 
        self.lstm = nn.LSTM(input_size=512, hidden_size=hid_size, num_layers=1, batch_first=True)
        self.num_classes = num_classes

        # Classifier layer
        self.fc = nn.Linear(hid_size, num_classes)

        
    def forward(self, x : torch.Tensor) -> torch.Tensor:
        # assumed that the data is in format (batch, time, channels, height, width)
        batch_size, frames, C, H, W = x.size()

        x = x.view(batch_size * frames, C, H, W)

        # Extract features for each frame using ResNet
        # TODO: calculate gradient or not for backbone ? 
        with torch.no_grad(): 
            features = self.backbone(x)            
        
        features = features.view(batch_size, frames, -1)        
        x, _ = self.lstm(features)

        # Take the output of the last time step
        x = x[:, -1, :]
        
        return self.fc(x)
    
m = Resnet_plus_LSTM(num_classes=len(ds.unique_labels), hid_size=64)
m

Resnet_plus_LSTM(
  (backbone): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True,

In [22]:
run_experiment(Model=Resnet_plus_LSTM, 
               run_name="resnet18+LSTM_hid=64", 
               model_parameters={"num_classes":len(ds.unique_labels), "hid_size":64},
               epochs=100,
               learning_rate=1e-4,
               optimizer="Adam",
               train_dataloader=train_loader,
               val_dataloader=val_loader,
               savepath=os.path.join("models", "resnet18.pth"),
               min_lr=5e-5,
               scheduler_en=True,
            cosine_annealer_epochs=50,
               cfg=CONFIG,
               metric_keyword="accuracy",
               WANDB_ON=True)

EPOCH : 0, average loss : 1.9601110994815827, lr=0.0001: 100%|██████████| 20/20 [00:10<00:00,  1.92it/s]
examples seen so far : 80, accuracy = 0.188, AUC ROC = 0.518: 100%|██████████| 20/20 [00:05<00:00,  3.85it/s]
examples seen so far : 20, accuracy = 0.05, AUC ROC = 0.468: 100%|██████████| 5/5 [00:01<00:00,  3.56it/s] 


Saving metric with accuracy=0.05 (previous : 0)


EPOCH : 1, average loss : 1.9380579948425294, lr=9.99506682107068e-05: 100%|██████████| 20/20 [00:09<00:00,  2.02it/s]
examples seen so far : 80, accuracy = 0.188, AUC ROC = 0.497: 100%|██████████| 20/20 [00:05<00:00,  3.77it/s]
examples seen so far : 20, accuracy = 0.05, AUC ROC = 0.564: 100%|██████████| 5/5 [00:01<00:00,  3.69it/s] 
EPOCH : 2, average loss : 1.9351492941379547, lr=9.980286753286194e-05: 100%|██████████| 20/20 [00:10<00:00,  1.96it/s]
examples seen so far : 80, accuracy = 0.175, AUC ROC = 0.514: 100%|██████████| 20/20 [00:05<00:00,  3.84it/s]
examples seen so far : 20, accuracy = 0.05, AUC ROC = 0.466: 100%|██████████| 5/5 [00:01<00:00,  3.51it/s] 
EPOCH : 3, average loss : 1.9284413039684296, lr=9.955718126821722e-05: 100%|██████████| 20/20 [00:09<00:00,  2.04it/s]
examples seen so far : 80, accuracy = 0.188, AUC ROC = 0.537: 100%|██████████| 20/20 [00:05<00:00,  3.67it/s]
examples seen so far : 20, accuracy = 0.1, AUC ROC = 0.398: 100%|██████████| 5/5 [00:01<00:00, 

Saving metric with accuracy=0.1 (previous : 0.05)


EPOCH : 4, average loss : 1.9258532106876374, lr=9.921457902821577e-05: 100%|██████████| 20/20 [00:09<00:00,  2.02it/s]
examples seen so far : 80, accuracy = 0.188, AUC ROC = 0.603: 100%|██████████| 20/20 [00:05<00:00,  3.88it/s]
examples seen so far : 20, accuracy = 0.05, AUC ROC = 0.432: 100%|██████████| 5/5 [00:01<00:00,  3.51it/s] 
EPOCH : 5, average loss : 1.9204324305057525, lr=9.877641290737884e-05: 100%|██████████| 20/20 [00:09<00:00,  2.04it/s]
examples seen so far : 80, accuracy = 0.188, AUC ROC = 0.571: 100%|██████████| 20/20 [00:05<00:00,  3.82it/s]
examples seen so far : 20, accuracy = 0.05, AUC ROC = 0.411: 100%|██████████| 5/5 [00:01<00:00,  3.80it/s] 
EPOCH : 6, average loss : 1.9123819291591644, lr=9.824441214720627e-05: 100%|██████████| 20/20 [00:09<00:00,  2.05it/s]
examples seen so far : 80, accuracy = 0.2, AUC ROC = 0.605: 100%|██████████| 20/20 [00:05<00:00,  3.83it/s]  
examples seen so far : 20, accuracy = 0.05, AUC ROC = 0.442: 100%|██████████| 5/5 [00:01<00:00

KeyboardInterrupt: 