# HW2P2: Image Recognition and Verification

This is the second homework  in 11785: Introduction to Deep Learning. We are trying to tackle the problem of Image Verification. For this, we will need to first train our own CNN model to tackle the problem of classification, consisting of 8631 identities. Using this, we get the face embeddings for different pairs of images and try to identify if the pair of face matches or not.

## Running the Notebook

- Clone or download the notebook and place it in your desired directory.
- Navigate to the notebook’s directory.
- Uncomment and run the cell labeled Kaggle if you're running from platforms other than Kaggle
- Extract the dataset and ensure it’s in the correct folder structure.
- Open the notebook in Jupyter or Colab, and run the cells sequentially.

## Data Loading Scheme

- Batch size     :  256
- Context        :  80
- Input Classes     :  8631

## Architecture and Hyperparameters

During development, multiple model architectures were tested, with several hyperparameter variations for each to 
explore different configurations and determine the best-performing approach. Some of the initial models are included 
in the code as commented-out sections, providing insight into alternative architectures and configurations attempted 
during experimentation.

Previous Archtecture Explored
- a custom 5 layer CNN for feature extraction for early submission
- the number of channels for each layer were [64,128,256,512,1024]
- Each layer Accompanied by a BatchNorm and ReLu activation
- Another architecture explored was the ResNeXt.
- ResNeXt is based on the concept of residual connections allowing network to learn identity mapping
- ResNext introduced cardinality -- a split transform merge strategy to enhance feature representation

for the complete ablations, kindly check [this link to wandb public link](https://wandb.ai/DL_Busters/hw2p2-ablations/table?nw=nwusermabdulba)

### Best performing architecture & model

- I used the Sqeeze-and-Excitation ResNeXt (SEResNeXt)
- This archtecture adopts a Squeeze Excitation block for adaptive feature recalibration
- Global Average Pooling was used for summarizing global information
- The Specific archtecture used a 4 stage layer with each layer containing varying number of bottlenecks
- Similarity for verification
- Optimizer: Adam
- Metrics: Retrieval Accuracy
- Initial Learning Rate: 0.1
- Scheduler: ReduceLROnPlateau
- Batch Size: 256
- Epochs: 60

## Submitted by: mabdulba


# Libraries

In [None]:
!nvidia-smi # Run this to see what GPU you have

In [None]:
!pip install wandb --quiet # Install WandB
!pip install pytorch_metric_learning --quiet #Install the Pytorch Metric Library

In [None]:
!pip install torchsummary
# !pip install torchsummaryX==1.1.0

In [None]:
import torch
from torchsummary import summary
# from torchsummaryX import summary
import torchvision
from torchvision import transforms
import torch.nn.functional as F
import os
import gc
from tqdm import tqdm
from PIL import Image
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn import metrics as mt
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import glob
import wandb
import math
import matplotlib.pyplot as plt
from pytorch_metric_learning import samplers
import csv
from torchvision.transforms import v2
from torch.utils.data import default_collate

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", DEVICE)


In [None]:
# from google.colab import drive # Link to your drive if you are not using Colab with GCP
# drive.mount('/content/drive') # Models in this HW take a long time to get trained and make sure to save it here

# Kaggle

In [None]:
# TODO: Use the same Kaggle code from HW1P2

In [None]:
# # # Reminder: Make sure you have connected your kaggle API before running this block
# !mkdir '/content/data'

# !kaggle competitions download -c 11785-hw-2-p-2-face-verification-fall-2024
# !unzip -qo '11785-hw-2-p-2-face-verification-fall-2024.zip' -d '/content/data'

# Config

In [None]:
config = {
    'batch_size': 256, # Increase this if your GPU can handle it
    'lr': 0.1,
    'epochs': 80, # 20 epochs is recommended ONLY for the early submission - you will have to train for much longer typically.
    'data_dir': "/kaggle/input/11785-hw-2-p-2-face-verification-fall-2024/11-785-f24-hw2p2-verification/cls_data", #TODO
    'data_ver_dir': "/kaggle/input/11785-hw-2-p-2-face-verification-fall-2024/11-785-f24-hw2p2-verification/ver_data", #TODO
    'checkpoint_dir': "/kaggle/working/" ,
    # Include other parameters as needed.
    'optimizer'     : 'SGD',
    'scheduler'     : 'CosineAnnealingLR',
    'patience'      : 2,
    'weight_decay'  : 1e-4,
    'momentum'      : 0.9,
    'resume_training': False,
    'model_path'     : "/kaggle/input/seresnext/pytorch/default/1/last.pth"
}

In [None]:
def build_optimizer(model, optimizer: str, lr: float, weight_decay: float=0, momentum:float=0):
    if optimizer == 'SGD':
        return torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

    elif optimizer == 'Adam':
        return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    elif optimizer == 'AdamW':
        weight_decay_ = weight_decay if weight_decay != 0 else 0.01
        return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay_)
    
    else:
        raise ValueError(f'Unknown optimizer: {optimizer}')


In [None]:
def build_scheduler(optimizer, scheduler: str, epochs: int, lr: float):
    if scheduler == 'CosineAnnealingLR':
        return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0.00004)
    elif scheduler == 'CosineAnnealingWarmRestarts':
        return torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, eta_min=1e-7, T_mult=2)
    elif scheduler == 'StepLR':
        return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    elif scheduler == 'ReduceLROnPlateau':
        return torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,patience=config['patience'])
    else:
        raise ValueError(f'Unknown scheduler: {scheduler}')
        

# Dataset

## Dataset Class for doing Image Verification

In [None]:
class ImagePairDataset(torch.utils.data.Dataset):

    def __init__(self, data_dir, csv_file, transform):
        self.data_dir = data_dir
        self.transform = transform
        self.pairs = []
        if csv_file.endswith('.csv'):
            with open(csv_file, 'r') as f:
                reader = csv.reader(f)
                for i, row in enumerate(reader):
                    if i == 0:
                        continue
                    else:
                        self.pairs.append(row)
        else:
            with open(csv_file, 'r') as f:
                for line in f.readlines():
                    self.pairs.append(line.strip().split(' '))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):

        img_path1, img_path2, match = self.pairs[idx]
        img1 = Image.open(os.path.join(self.data_dir, img_path1))
        img2 = Image.open(os.path.join(self.data_dir, img_path2))
        return self.transform(img1), self.transform(img2), int(match)

In [None]:
class TestImagePairDataset(torch.utils.data.Dataset):

    def __init__(self, data_dir, csv_file, transform):
        self.data_dir = data_dir
        self.transform = transform
        self.pairs = []
        if csv_file.endswith('.csv'):
            with open(csv_file, 'r') as f:
                reader = csv.reader(f)
                for i, row in enumerate(reader):
                    if i == 0:
                        continue
                    else:
                        self.pairs.append(row)
        else:
            with open(csv_file, 'r') as f:
                for line in f.readlines():
                    self.pairs.append(line.strip().split(' '))

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):

        img_path1, img_path2 = self.pairs[idx]
        img1 = Image.open(os.path.join(self.data_dir, img_path1))
        img2 = Image.open(os.path.join(self.data_dir, img_path2))
        return self.transform(img1), self.transform(img2)

## Create Dataloaders for Image Recognition

In [None]:
cutmix = v2.CutMix(num_classes=8631)
mixup = v2.MixUp(alpha=0.8, num_classes=8631)
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])


def collate_fn(batch):
    return cutmix_or_mixup(*default_collate(batch))

In [None]:
data_dir = config['data_dir']
# train_dir = os.path.join(data_dir)

train_dir = os.path.join(data_dir, 'train')
val_dir = os.path.join(data_dir, 'dev')

# train transforms
train_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(112), # Why are we resizing the Image?
    torchvision.transforms.RandomRotation(degrees=30),
    torchvision.transforms.RandomHorizontalFlip(p=0.25),
    torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0),
    torchvision.transforms.ToTensor(),
    # torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5],
    #                                   std=[0.5, 0.5, 0.5])
])

# val transforms
val_transforms = torchvision.transforms.Compose([
    torchvision.transforms.Resize(112),
    torchvision.transforms.ToTensor(),
    # torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5],
    #                                   std=[0.5, 0.5, 0.5])
]
                                               )


# get datasets
train_dataset = torchvision.datasets.ImageFolder(train_dir, transform=train_transforms)
val_dataset = torchvision.datasets.ImageFolder(val_dir, transform=val_transforms)

train_loader = torch.utils.data.DataLoader(train_dataset,
                                            batch_size=config["batch_size"],
                                            shuffle=True,
                                            pin_memory=True,
                                            num_workers=4,
                                            collate_fn=collate_fn,
                                            sampler=None)
val_loader = torch.utils.data.DataLoader(val_dataset,
                                          batch_size=config["batch_size"],
                                          shuffle=False,
                                          num_workers=4)

In [None]:
data_dir = config['data_ver_dir']


# get datasets

# TODO: Add your validation pair txt file
pair_dataset = ImagePairDataset(data_dir, csv_file='/kaggle/input/11785-hw-2-p-2-face-verification-fall-2024/11-785-f24-hw2p2-verification/val_pairs.txt', transform=val_transforms)
pair_dataloader = torch.utils.data.DataLoader(pair_dataset,
                                              batch_size=config["batch_size"],
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=4)

# TODO: Add your validation pair txt file
test_pair_dataset = TestImagePairDataset(data_dir, csv_file='/kaggle/input/11785-hw-2-p-2-face-verification-fall-2024/11-785-f24-hw2p2-verification/test_pairs.txt', transform=val_transforms)
test_pair_dataloader = torch.utils.data.DataLoader(test_pair_dataset,
                                              batch_size=config["batch_size"],
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=4)

# EDA and Viz

In [None]:
# Double-check your dataset/dataloaders work as expected

print("Number of classes    : ", len(train_dataset.classes))
print("No. of train images  : ", train_dataset.__len__())
print("Shape of image       : ", train_dataset[0][0].shape)
print("Batch size           : ", config['batch_size'])
print("Train batches        : ", train_loader.__len__())
print("Val batches          : ", val_loader.__len__())

# Feel free to print more things if needed

In [None]:
# Visualize a few images in the dataset

"""
You can write your own code, and you don't need to understand the code
It is highly recommended that you visualize your data augmentation as sanity check
"""

r, c    = [5, 5]
fig, ax = plt.subplots(r, c, figsize= (15, 15))

k       = 0
dtl     = torch.utils.data.DataLoader(
    dataset     = torchvision.datasets.ImageFolder(train_dir, transform= train_transforms), # dont wanna see the images with transforms
    batch_size  = config['batch_size'],
    shuffle     = True)

for data in dtl:
    x, y = data

    for i in range(r):
        for j in range(c):
            img = x[k].numpy().transpose(1, 2, 0)
            ax[i, j].imshow(img)
            ax[i, j].axis('off')
            k+=1
    break

del dtl

# Model Architecture

FAQ:

**What's a very low early deadline architecture (mandatory early submission)**?

- The very low early deadline architecture is a 5-layer CNN.
- The first convolutional layer has 64 channels, kernel size 7, and stride 4. The next three have 128, 256, 512 and 1024 channels. Each have kernel size 3 and stride 2. Documentation to make convolutional layers: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
- Think about strided convolutions from the lecture, as convolutions with stride = 1 and downsampling. For strided convolution, what padding do you need for preserving the spatial resolution? (Hint => padding = kernel_size // 2) - Think why?
- Each convolutional layer is accompanied by a Batchnorm and ReLU layer.
- Finally, you want to average pool over the spatial dimensions to reduce them to 1 x 1. Use AdaptiveAvgPool2d. Documentation for AdaptiveAvgPool2d: https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveAvgPool2d.html
- Then, remove (Flatten?) these trivial 1x1 dimensions away.
Look through https://pytorch.org/docs/stable/nn.html


**Why does a very simple network have 4 convolutions**?

Input images are 112x112. Note that each of these convolutions downsample. Downsampling 2x effectively doubles the receptive field, increasing the spatial region each pixel extracts features from. Downsampling 32x is standard for most image models.

**Why does a very simple network have high channel sizes**?

Every time you downsample 2x, you do 4x less computation (at same channel size). To maintain the same level of computation, you 2x increase # of channels, which increases computation by 4x. So, balances out to same computation. Another intuition is - as you downsample, you lose spatial information. We want to preserve some of it in the channel dimension.

**What is return_feats?**

It essentially returns the second-to-last-layer features of a given image. It's a "feature encoding" of the input image, and you can use it for the verification task. You would use the outputs of the final classification layer for the classification task. You might also find that the classification outputs are sometimes better for verification too - try both.

In [None]:
# # TODO: Fill out the model definition below

# class Network(torch.nn.Module):

#     def __init__(self, num_classes=8631):
#         super().__init__()

#         self.backbone = torch.nn.Sequential(
#             # TODO
#             torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=4, padding=3),
#             torch.nn.BatchNorm2d(64),
#             torch.nn.ReLU(),

#             torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1),
#             torch.nn.BatchNorm2d(128),
#             torch.nn.ReLU(),

#             torch.nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
#             torch.nn.BatchNorm2d(256),
#             torch.nn.ReLU(),
            
#             torch.nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1),
#             torch.nn.BatchNorm2d(512),
#             torch.nn.ReLU(),
            
#             torch.nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=2, padding=1),
#             torch.nn.BatchNorm2d(1024),
#             torch.nn.ReLU(),
            
#             torch.nn.AdaptiveAvgPool2d((1, 1))  # Reduces the spatial dimensions to 1x1
#             )

#         self.cls_layer = torch.nn.Linear(1024, num_classes)

#     def forward(self, x):
#             # TODO:
        
#         # Pass input through the backbone (CNN layers)
#         feats = self.backbone(x)
        
#         # Flatten the features before passing to the classification layer
#         feats = feats.view(feats.size(0), -1)  # Remove 1x1 spatial dimensions
        
#         # Pass the flattened features through the classification layer
#         out = self.cls_layer(feats)

#         return {"feats": feats, "out": out}

# # Initialize your model
# model = Network().to(DEVICE)
# summary(model, (3, 112, 112))

In [None]:
# # ConvNeXt Block Definition
# class ConvNeXtBlock(torch.nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(ConvNeXtBlock, self).__init__()
#         self.dw_conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels)  # Depthwise conv
#         self.bn = torch.nn.BatchNorm2d(in_channels)  # Batch Normalization
#         self.pointwise_conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1)  # Pointwise conv
#         self.act = torch.nn.GELU()  # Activation function

#     def forward(self, x):
#         x = self.dw_conv(x)         # Apply depthwise convolution
#         x = self.bn(x)              # Apply BatchNorm
#         x = self.pointwise_conv(x)  # Apply pointwise convolution
#         x = self.act(x)             # Apply activation
#         return x

# # Complete ConvNeXt Model Definition
# class ConvNeXt(torch.nn.Module):
#     def __init__(self, num_classes=8631):
#         super(ConvNeXt, self).__init__()
        
#         # Stem Layer
#         self.stem = torch.nn.Sequential(
#             torch.nn.Conv2d(3, 64, kernel_size=7, stride=4, padding=3),  # Downsample input
#             torch.nn.BatchNorm2d(64),
#             torch.nn.ReLU(inplace=True)
#         )
        
#         # ConvNeXt Stages
#         self.stage1 = ConvNeXtBlock(64, 256)  # First stage
#         self.stage2 = ConvNeXtBlock(256, 512)  # Second stage
#         self.stage3 = ConvNeXtBlock(512, 1024)  # Third stage
#         self.stage4 = ConvNeXtBlock(1024, 2058)  # Fourth stage

#         # Global Pooling and Classification Layer
#         self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))  # Global Average Pooling
#         self.flatten = torch.nn.Flatten()
#         self.cls_layer = torch.nn.Linear(2058, num_classes)  # Classification layer

#     def forward(self, x):
#         x = self.stem(x)       # Initial downsampling
#         x = self.stage1(x)     # Pass through ConvNeXt blocks
#         x = self.stage2(x)
#         x = self.stage3(x)
#         x = self.stage4(x)
#         x = self.avgpool(x)    # Global pooling to reduce spatial size to (1, 1)
#         x = self.flatten(x)    # Flatten the output
#         out = self.cls_layer(x)  # Classification layer
#         return {"feats": x, "out": out}

# # Initialize the ConvNeXt model
# model = ConvNeXt(num_classes=8631).to(DEVICE)


# # Using your input data
# input_data = torch.zeros((1, 3, 112, 112)).to(DEVICE)  

# # Ensure you pass a tensor and not a tuple
# summary(model, input_data)  # Pass the tensor directly for summary


# model = torch.nn.DataParallel(model).to(DEVICE)

**#SEResNet**

In [None]:
# class SqueezeExcitationBlock(torch.nn.Module):
#     def __init__(self, in_channels, reduction_ratio=16):
#         super(SqueezeExcitationBlock, self).__init__()
        
#         self.global_avg_pool = torch.nn.AdaptiveAvgPool2d(1)
#         self.fc1 = torch.nn.Linear(in_channels, in_channels // reduction_ratio)
        
#         self.relu = torch.nn.ReLU(inplace=True)
#         self.fc2 = torch.nn.Linear(in_channels // reduction_ratio, in_channels)
        
#         self.sigmoid = torch.nn.Sigmoid()

#     def forward(self, x):
#         out = self.global_avg_pool(x).squeeze(-1).squeeze(-1)
#         out = self.fc1(out)
        
#         out = self.relu(out)
#         out = self.fc2(out)
        
#         out = self.sigmoid(out)
#         out = out.unsqueeze(-1).unsqueeze(-1)
#         return x * out

# class BasicBlock(torch.nn.Module):
#     def __init__(self, in_channels, out_channels, stride=1, downsample=None):
#         super(BasicBlock, self).__init__()
        
#         self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
#         self.bn1 = torch.nn.BatchNorm2d(out_channels)
        
#         self.relu = torch.nn.ReLU(inplace=True)
#         self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        
#         self.bn2 = torch.nn.BatchNorm2d(out_channels)
#         self.se_block = SqueezeExcitationBlock(out_channels)
        
#         self.downsample = downsample

#     def forward(self, x):
#         residual = x

#         out = self.conv1(x)
#         out = self.bn1(out)
#         out = self.relu(out)

#         out = self.conv2(out)
#         out = self.bn2(out)

#         out = self.se_block(out)

#         if self.downsample is not None:
#             residual = self.downsample(x)

#         out += residual
#         out = self.relu(out)

#         return out

# class SEResNet(torch.nn.Module):
#     def __init__(self, block, layers, num_classes=8631):
#         super(SEResNet, self).__init__()
#         self.in_channels = 64
#         self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
#         self.bn1 = torch.nn.BatchNorm2d(64)
#         self.relu = torch.nn.ReLU(inplace=True)
#         self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

#         self.layer1 = self._make_layer(block, 64, layers[0])
#         self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
#         self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
#         self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

#         self.global_avg_pool = torch.nn.AdaptiveAvgPool2d(1)
#         self.fc = torch.nn.Linear(512, num_classes)

#     def _make_layer(self, block, out_channels, blocks, stride=1):
#         downsample = None
#         if stride != 1 or self.in_channels != out_channels:
#             downsample = torch.nn.Sequential(
#                 torch.nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
#                 torch.nn.BatchNorm2d(out_channels)
#             )

#         layers = [block(self.in_channels, out_channels, stride, downsample)]
#         self.in_channels = out_channels

#         for _ in range(1, blocks):
#             layers.append(block(out_channels, out_channels))

#         return torch.nn.Sequential(*layers)

#     def forward(self, x, return_feats=True):
#         x = self.conv1(x)
#         x = self.bn1(x)
#         x = self.relu(x)
#         x = self.maxpool(x)

#         x = self.layer1(x)
#         x = self.layer2(x)
#         x = self.layer3(x)
#         x = self.layer4(x)

#         x = self.global_avg_pool(x)
        
#         x_feats = x.view(x.size(0), -1)
#         x = self.fc(x_feats)
        
#         return {"feats": x_feats, "out": x}




# # Initialize your model
# model = SEResNet(BasicBlock, [4, 5, 6, 2], num_classes=8631)

# model = torch.nn.DataParallel(model).to(DEVICE)


# summary(model.module, (3, 112, 112))

**# SE-ResNeXt**

In [None]:
class Selayer(torch.nn.Module):

    def __init__(self, inplanes):
        super(Selayer, self).__init__()
        self.global_avgpool = torch.nn.AdaptiveAvgPool2d(1)
        self.conv1 = torch.nn.Conv2d(inplanes, inplanes // 16, kernel_size=1, stride=1)
        self.conv2 = torch.nn.Conv2d(inplanes // 16, inplanes, kernel_size=1, stride=1)
        self.relu = torch.nn.ReLU(inplace=True)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):

        out = self.global_avgpool(x)

        out = self.conv1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.sigmoid(out)

        return x * out


class Bottleneck(torch.nn.Module):
    expansion = 4

    def __init__(self, inplanes, planes, cardinality, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = torch.nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(planes * 2)

        self.conv2 = torch.nn.Conv2d(planes * 2, planes * 2, kernel_size=3, stride=stride,
                               padding=1, groups=cardinality, bias=False)
        self.bn2 = torch.nn.BatchNorm2d(planes * 2)

        self.conv3 = torch.nn.Conv2d(planes * 2, planes * 4, kernel_size=1, bias=False)
        self.bn3 = torch.nn.BatchNorm2d(planes * 4)

        self.selayer = Selayer(planes * 4)

        self.relu = torch.nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        out = self.selayer(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out

In [None]:
class SeResNeXt(torch.nn.Module):

    def __init__(self, block, layers, cardinality=32, num_classes=8631):
        super(SeResNeXt, self).__init__()
        self.cardinality = cardinality
        self.inplanes = 64

        self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
                               bias=False)
        self.bn1 = torch.nn.BatchNorm2d(64)
        self.relu = torch.nn.ReLU(inplace=True)
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)

        self.avgpool = torch.nn.AdaptiveAvgPool2d(1)
        self.fc = torch.nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, math.sqrt(2. / n))
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, torch.nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = torch.nn.Sequential(
                torch.nn.Conv2d(self.inplanes, planes * block.expansion,
                          kernel_size=1, stride=stride, bias=False),
                torch.nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, self.cardinality, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, self.cardinality))

        return torch.nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # x = self.avgpool(x)
        # x = x.view(x.size(0), -1)

        x = self.avgpool(x)
        x_feats = x.view(x.size(0), -1)

        x = self.fc(x_feats)
        return {"feats": x_feats, "out":x}

In [None]:
def se_resnext50(**kwargs):

    model = SeResNeXt(Bottleneck, [3, 4, 6, 3], **kwargs)
    return model


def se_resnext101(**kwargs):
    
    model = SeResNeXt(Bottleneck, [3, 4, 23, 3], **kwargs)
    return model


def se_resnext152(**kwargs):
    
    model = SeResNeXt(Bottleneck, [3, 8, 36, 3], **kwargs)
    return model

def se_resnext_custom(**kwargs):

    model = SeResNeXt(Bottleneck, [4, 5, 6, 2], **kwargs)
    return model

In [None]:
# Initialize your model
model = se_resnext_custom(num_classes=8631)

model = torch.nn.DataParallel(model).to(DEVICE)


summary(model.module, (3, 112, 112))

In [None]:
# --------------------------------------------------- #

# Defining Loss function
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.2) # TODO: What loss do you need for a multi class classification problem and would label smoothing be beneficial here?

# --------------------------------------------------- #

# Defining Optimizer
# optimizer = torch.optim.SGD(model.parameters(), lr=config["lr"], momentum=0.9) # TODO: Feel free to pick a optimizer
optimizer = build_optimizer(model, optimizer=config["optimizer"], lr=config["lr"], weight_decay=config['weight_decay'], momentum=config['momentum'])
# --------------------------------------------------- #

# Defining Scheduler
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=config['patience'])  # TODO: Use a good scheduler such as ReduceLRonPlateau, StepLR, MultistepLR, CosineAnnealing, etc.
# scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 10, T_mult=2, eta_min=1e-6) 
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, config["epochs"], eta_min=1e-7)

scheduler = build_scheduler(optimizer, scheduler=config["scheduler"], epochs=config["epochs"], lr=config["lr"])
# --------------------------------------------------- #

# Initialising mixed-precision training. # Good news. We've already implemented FP16 (Mixed precision training) for you
# It is useful only in the case of compatible GPUs such as T4/V100
scaler = torch.cuda.amp.GradScaler()

# Metrics

In [None]:
class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    maxk = min(max(topk), output.size()[1])
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))
    return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]

In [None]:
def get_ver_metrics(labels, scores, FPRs):
    # eer and auc
    fpr, tpr, _ = mt.roc_curve(labels, scores, pos_label=1)
    roc_curve = interp1d(fpr, tpr)
    EER = 100. * brentq(lambda x : 1. - x - roc_curve(x), 0., 1.)
    AUC = 100. * mt.auc(fpr, tpr)

    # get acc
    tnr = 1. - fpr
    pos_num = labels.count(1)
    neg_num = labels.count(0)
    ACC = 100. * max(tpr * pos_num + tnr * neg_num) / len(labels)

    # TPR @ FPR
    if isinstance(FPRs, list):
        TPRs = [
            ('TPR@FPR={}'.format(FPR), 100. * roc_curve(float(FPR)))
            for FPR in FPRs
        ]
    else:
        TPRs = []

    return {
        'ACC': ACC,
        'EER': EER,
        'AUC': AUC,
        'TPRs': TPRs,
    }

# Train and Validation Function

In [None]:
def train_epoch(model, dataloader, optimizer, lr_scheduler, scaler, device, config):

    model.train()

    # metric meters
    loss_m = AverageMeter()
    acc_m = AverageMeter()

    # Progress Bar
    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, leave=False, position=0, desc='Train', ncols=5)

    for i, (images, labels) in enumerate(dataloader):

        optimizer.zero_grad() # Zero gradients

        # send to cuda
        images = images.to(device, non_blocking=True)
        if isinstance(labels, (tuple, list)):
            targets1, targets2, lam = labels
            labels = (targets1.to(device), targets2.to(device), lam)
        else:
            labels = labels.to(device, non_blocking=True)

        # forward
        with torch.cuda.amp.autocast():  # This implements mixed precision. Thats it!
            outputs = model(images)
            
            

            # Use the type of output depending on the loss function you want to 
            if labels.shape == outputs['out'].shape:
                labels = torch.argmax(labels, dim=1)
                
            loss = criterion(outputs['out'], labels )

        scaler.scale(loss).backward() # This is a replacement for loss.backward()
        scaler.step(optimizer) # This is a replacement for optimizer.step()
        scaler.update()
        # metrics
        loss_m.update(loss.item())
        if 'feats' in outputs:
            acc = accuracy(outputs['out'], labels)[0].item()
        else:
            acc = 0.0
        acc_m.update(acc)

        # tqdm lets you add some details so you can monitor training as you train.
        batch_bar.set_postfix(
            # acc         = "{:.04f}%".format(100*accuracy),
            acc="{:.04f}% ({:.04f})".format(acc, acc_m.avg),
            loss        = "{:.04f} ({:.04f})".format(loss.item(), loss_m.avg),
            lr          = "{:.04f}".format(float(optimizer.param_groups[0]['lr'])))

        batch_bar.update() # Update tqdm bar

    # You may want to call some schedulers inside the train function. What are these?
    if lr_scheduler is not None:
        lr_scheduler.step(loss)

    batch_bar.close()

    return acc_m.avg, loss_m.avg

In [None]:
@torch.no_grad()
def valid_epoch_cls(model, dataloader, device, config):

    model.eval()
    batch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, position=0, leave=False, desc='Val Cls.', ncols=5)

    # metric meters
    loss_m = AverageMeter()
    acc_m = AverageMeter()

    for i, (images, labels) in enumerate(dataloader):

        # Move images to device
        images, labels = images.to(device), labels.to(device)

        # Get model outputs
        with torch.inference_mode():
            outputs = model(images)
            loss = criterion(outputs['out'], labels)

        # metrics
        acc = accuracy(outputs['out'], labels)[0].item()
        loss_m.update(loss.item())
        acc_m.update(acc)

        batch_bar.set_postfix(
            acc         = "{:.04f}% ({:.04f})".format(acc, acc_m.avg),
            loss        = "{:.04f} ({:.04f})".format(loss.item(), loss_m.avg))

        batch_bar.update()

    batch_bar.close()
    return acc_m.avg, loss_m.avg

In [None]:
gc.collect() # These commands help you when you face CUDA OOM error
torch.cuda.empty_cache()

# Verification Task

In [None]:
def valid_epoch_ver(model, pair_data_loader, device, config):

    model.eval()
    scores = []
    match_labels = []
    batch_bar = tqdm(total=len(pair_data_loader), dynamic_ncols=True, position=0, leave=False, desc='Val Veri.')
    for i, (images1, images2, labels) in enumerate(pair_data_loader):

        # match_labels = match_labels.to(device)
        images = torch.cat([images1, images2], dim=0).to(device)
        # Get model outputs
        with torch.inference_mode():
            outputs = model(images)

        feats = F.normalize(outputs['feats'], dim=1)
        feats1, feats2 = feats.chunk(2)
        similarity = F.cosine_similarity(feats1, feats2)
        scores.append(similarity.cpu().numpy())
        match_labels.append(labels.cpu().numpy())
        batch_bar.update()

    scores = np.concatenate(scores)
    match_labels = np.concatenate(match_labels)

    FPRs=['1e-4', '5e-4', '1e-3', '5e-3', '5e-2']
    metric_dict = get_ver_metrics(match_labels.tolist(), scores.tolist(), FPRs)
    print(metric_dict)

    return metric_dict['ACC'], metric_dict['EER']

# WandB

In [None]:
# wandb.login(key="4da7c51b7cc783a566fb240ea3dcb4bf061cec8e") 
wandb.login(key="7eebdcce14bbf616d584f912c3f6d9d9bc706e42") # API Key is in your wandb account, under settings (wandb.ai/settings)

In [None]:
# Create your wandb run
run = wandb.init(
    name = "SEResNeXt-MH-with-CutMix-No-Normalization-80", ## Wandb creates random run names if you skip this field
    reinit = True, ### Allows reinitalizing runs when you re-run this cell
    # run_id = ### Insert specific run id here if you want to resume a previous run
    # resume = "must", ### You need this to resume previous runs, but comment out reinit = True when using this
    project = "hw2p2-ablations", ### Project should be created in your wandb account
    config = config ### Wandb Config for your run
)

# Checkpointing and Loading Model

In [None]:
# Uncomment the line for saving the scheduler save dict if you are using a scheduler
def save_model(model, optimizer, scheduler, metrics, epoch, path):
    torch.save(
        {'model_state_dict'         : model.state_dict(),
         'optimizer_state_dict'     : optimizer.state_dict(),
         # 'scheduler_state_dict'     : scheduler.state_dict(),
         'metric'                   : metrics,
         'epoch'                    : epoch},
         path)


def load_model(model, optimizer=None, scheduler=None, path='./checkpoint.pth'):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        optimizer = None
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    else:
        scheduler = None
    epoch = checkpoint['epoch']
    metrics = checkpoint['metric']
    return model, optimizer, scheduler, epoch, metrics

# Experiments

In [None]:
e = 0
best_valid_cls_acc = 0.0
eval_cls = True
best_valid_ret_acc = 0.0

resume_training = config['resume_training']

if resume_training:
    model, optimizer, scheduler, epoch, metrics = load_model(model, optimizer=optimizer, path=config['model_path'])
    e = epoch

for epoch in range(e, config['epochs']):
        # epoch
        print("\nEpoch {}/{}".format(epoch+1, config['epochs']))

        # train
        train_cls_acc, train_loss = train_epoch(model, train_loader, optimizer, scheduler, scaler, DEVICE, config)
        curr_lr = float(optimizer.param_groups[0]['lr'])
        print("\nEpoch {}/{}: \nTrain Cls. Acc {:.04f}%\t Train Cls. Loss {:.04f}\t Learning Rate {:.04f}".format(epoch + 1, config['epochs'], train_cls_acc, train_loss, curr_lr))
        metrics = {
            'train_cls_acc': train_cls_acc,
            'train_loss': train_loss,
        }
        # classification validation
        if eval_cls:
            valid_cls_acc, valid_loss = valid_epoch_cls(model, val_loader, DEVICE, config)
            print("Val Cls. Acc {:.04f}%\t Val Cls. Loss {:.04f}".format(valid_cls_acc, valid_loss))
            metrics.update({
                'valid_cls_acc': valid_cls_acc,
                'valid_loss': valid_loss,
            })

        # retrieval validation
        valid_ret_acc, eer = valid_epoch_ver(model, pair_dataloader, DEVICE, config)
        print("Val Ret. Acc {:.04f}%".format(valid_ret_acc))
        metrics.update({
            'valid_ret_acc': valid_ret_acc,
            "EER"          : eer
        })

        # save model
        save_model(model, optimizer, scheduler, metrics, epoch, os.path.join(config['checkpoint_dir'], 'last.pth'))
        wandb.save(os.path.join(config['checkpoint_dir'], 'last.pth'))
        print("Saved epoch model")

        # save best model
        if eval_cls:
            if valid_cls_acc >= best_valid_cls_acc:
                best_valid_cls_acc = valid_cls_acc
                save_model(model, optimizer, scheduler, metrics, epoch, os.path.join(config['checkpoint_dir'], 'best_cls.pth'))
                wandb.save(os.path.join(config['checkpoint_dir'], 'best_cls.pth'))
                print("Saved best classification model")

        if valid_ret_acc >= best_valid_ret_acc:
            best_valid_ret_acc = valid_ret_acc
            save_model(model, optimizer, scheduler, metrics, epoch, os.path.join(config['checkpoint_dir'], 'best_ret.pth'))
            wandb.save(os.path.join(config['checkpoint_dir'], 'best_ret.pth'))
            print("Saved best retrieval model")

        # log to tracker
        if run is not None:
            run.log(metrics)

# Testing and Kaggle Submission (Verification)

In [None]:
def test_epoch_ver(model, pair_data_loader, config):

    model.eval()
    scores = []
    batch_bar = tqdm(total=len(pair_data_loader), dynamic_ncols=True, position=0, leave=False, desc='Val Veri.')
    for i, (images1, images2) in enumerate(pair_data_loader):

        images = torch.cat([images1, images2], dim=0).to(DEVICE)
        # Get model outputs
        with torch.inference_mode():
            outputs = model(images)

        feats = F.normalize(outputs['feats'], dim=1)
        feats1, feats2 = feats.chunk(2)
        similarity = F.cosine_similarity(feats1, feats2)
        scores.extend(similarity.cpu().numpy().tolist())
        batch_bar.update()

    return scores

In [None]:
# model, optimizer, scheduler, epoch, metrics = load_model(model, optimizer=optimizer, path=config['model_path'])

In [None]:
scores = test_epoch_ver(model, test_pair_dataloader, config)

In [None]:
with open("verification_submission_slack.csv", "w+") as f:
    f.write("ID,Label\n")
    for i in range(len(scores)):
        f.write("{},{}\n".format(i, scores[i]))