## Summary

This is a densenet121 arcface model trained with cross entropy loss to assist image retrieval task. Some key points:
- Does not have a validation set
- Trained on the full chestx dataset
- Will be used to generate features

## Config

In [None]:
fold_id = 0

image_size = 512
seed = 42
warmup_epo = 1
init_lr = 1e-3 
batch_size = 28
n_epochs = 25
warmup_factor = 10
num_workers = 24

add_ext = False
use_amp = False
save = True
debug = False
use_neptune = False
accumulation_step = 1
device_id = 0 # your GPU_id

kernel_type = 'metric_learning_ranzcr_chestx'
data_dir = 'train'
enet_type = 'densenet121'
model_dir = f'weights/{kernel_type}'
experiment_name = f'fold{fold_id}_{kernel_type}_{enet_type}'

! mkdir $model_dir

## Imports

In [None]:
import pandas as pd
import numpy as np
from efficientnet_pytorch import model as enet
import os
import sys
import time
import cv2
import PIL.Image
import random
from sklearn.metrics import accuracy_score
from tqdm.notebook import tqdm
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from warmup_scheduler import GradualWarmupScheduler
import albumentations
import geffnet
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import gc
from sklearn.metrics import roc_auc_score
%matplotlib inline
import seaborn as sns
from pylab import rcParams
import timm
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from warnings import filterwarnings
import neptune
from sklearn.preprocessing import LabelEncoder
import math
import glob
filterwarnings("ignore")

device = torch.device('cuda') 
torch.cuda.set_device(device_id)

## Utils

In [None]:
class GradualWarmupSchedulerV2(GradualWarmupScheduler):
    def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
        super(GradualWarmupSchedulerV2, self).__init__(optimizer, multiplier, total_epoch, after_scheduler)
    def get_lr(self):
        if self.last_epoch > self.total_epoch:
            if self.after_scheduler:
                if not self.finished:
                    self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
                    self.finished = True
                return self.after_scheduler.get_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]
        if self.multiplier == 1.0:
            return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
        else:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]

In [None]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f'Setting all seeds to be {seed} to reproduce...')
seed_everything(seed)

## Folds

In [None]:
my_training_set = pd.read_csv('chestx/Data_Entry_2017.csv')
my_training_set['file_path'] = sorted(glob.glob('chestx/images_*/*/*'))

In [None]:
le = LabelEncoder()
my_training_set['Patient ID'] = le.fit_transform(my_training_set['Patient ID'])

In [None]:
from sklearn.model_selection import StratifiedKFold
skf = StratifiedKFold(random_state=42, shuffle=True)
for fold, (_, valid_idx) in enumerate(skf.split(my_training_set, my_training_set['Patient ID'])):
    my_training_set.loc[valid_idx, 'fold'] = fold

## Dataset

In [None]:
class RANZERDataset(Dataset):
    def __init__(self, df, mode, transform=None):
        
        self.df = df.reset_index(drop=True)
        self.mode = mode
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        row = self.df.loc[index]
        img = cv2.imread(row.file_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        if self.transform is not None:
            res = self.transform(image=img)
            img = res['image']
                
        img = img.astype(np.float32)
        img = img.transpose(2,0,1)
        label = torch.tensor(row['Patient ID']).float()
        
        if self.mode == 'test':
            return torch.tensor(img).float()
        else:
            return torch.tensor(img).float(), label

## Transforms

In [None]:
transforms_train = albumentations.Compose([
    albumentations.Resize(image_size, image_size),
    albumentations.HorizontalFlip(p=0.5),
    albumentations.ShiftScaleRotate(p=0.5, shift_limit=0.0625, scale_limit=0.2, rotate_limit=20),
    albumentations.Cutout(p=0.5, max_h_size=16, max_w_size=16, fill_value=(0., 0., 0.), num_holes=16),
    albumentations.Normalize(),
])

transforms_valid = albumentations.Compose([
    albumentations.Resize(image_size, image_size),
    albumentations.Normalize()
])

## Model

In [None]:
class ArcModule(nn.Module):
    def __init__(self, in_features, out_features, s=10, m=0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_normal_(self.weight)

        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = torch.tensor(math.cos(math.pi - m))
        self.mm = torch.tensor(math.sin(math.pi - m) * m)

    def forward(self, inputs, labels):
        cos_th = F.linear(inputs, F.normalize(self.weight))
        cos_th = cos_th.clamp(-1, 1)
        sin_th = torch.sqrt(1.0 - torch.pow(cos_th, 2))
        cos_th_m = cos_th * self.cos_m - sin_th * self.sin_m
        cos_th_m = torch.where(cos_th > self.th, cos_th_m, cos_th - self.mm)

        cond_v = cos_th - self.th
        cond = cond_v <= 0
        cos_th_m[cond] = (cos_th - self.mm)[cond]

        if labels.dim() == 1:
            labels = labels.unsqueeze(-1)
        onehot = torch.zeros(cos_th.size()).cuda()
        labels = labels.type(torch.LongTensor).cuda()
        onehot.scatter_(1, labels, 1.0)
        outputs = onehot * cos_th_m + (1.0 - onehot) * cos_th
        outputs = outputs * self.s
        return outputs

In [None]:
class MetricLearningModel(nn.Module):

    def __init__(self, channel_size, out_feature, dropout=0.5, backbone='densenet121', pretrained=True):
        super(MetricLearningModel, self).__init__()
        self.backbone = timm.create_model(backbone, pretrained=pretrained)
        self.channel_size = channel_size
        self.out_feature = out_feature
        self.in_features = self.backbone.classifier.in_features
        self.margin = ArcModule(in_features=self.channel_size, out_features = self.out_feature)
        self.bn1 = nn.BatchNorm2d(self.in_features)
        self.dropout = nn.Dropout2d(dropout, inplace=True)
        self.fc1 = nn.Linear(self.in_features * 16 * 16 , self.channel_size)
        self.bn2 = nn.BatchNorm1d(self.channel_size)
        
    def forward(self, x, labels=None):
        features = self.backbone.features(x)
        features = self.bn1(features)
        features = self.dropout(features)
        features = features.view(features.size(0), -1)
        features = self.fc1(features)
        features = self.bn2(features)
        features = F.normalize(features)
        if labels is not None:
            return self.margin(features, labels)
        return features


## Train and Valid loop

In [None]:
def train_func(train_loader):
    model.train()
    bar = tqdm(train_loader)
    if use_amp:
        scaler = torch.cuda.amp.GradScaler()
    losses = []
    for batch_idx, (images, targets) in enumerate(bar):

        images, targets = images.to(device), targets.to(device).long()

        if use_amp:
            with torch.cuda.amp.autocast():
                logits = model(images, targets)
                loss = criterion(logits, targets)
            scaler.scale(loss).backward()
            if ((batch_idx + 1) %  accumulation_step == 0) or ((batch_idx + 1) == len(train_loader)):
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
        else:
            logits = model(images, targets)
            loss = criterion(logits, targets)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        losses.append(loss.item())
        smooth_loss = np.mean(losses[-30:])

        bar.set_description(f'loss: {loss.item():.5f}, smth: {smooth_loss:.5f}')

    loss_train = np.mean(losses)
    return loss_train


def valid_func(valid_loader):
    model.eval()
    bar = tqdm(valid_loader)

    PROB = []
    TARGETS = []
    losses1 = []
    PREDS = []

    with torch.no_grad():
        for batch_idx, (images, targets) in enumerate(bar):

            images, targets = images.to(device), targets.to(device).long()

            logits = model(images, targets)

            PREDS += [torch.argmax(logits, 1).detach().cpu()]
            TARGETS += [targets.detach().cpu()]

            loss1 = criterion(logits, targets)
            losses1.append(loss1.item())
           
            bar.set_description(f'loss1: {loss1.item():.5f}')

    PREDS = torch.cat(PREDS).cpu().numpy()
    TARGETS = torch.cat(TARGETS).cpu().numpy()
    accuracy = (PREDS==TARGETS).mean()
    
    loss_valid1 = np.mean(losses1)
    return loss_valid1, accuracy

## Start Training

In [None]:
model = MetricLearningModel(512, my_training_set['Patient ID'].nunique())
model.to(device);

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=init_lr/warmup_factor)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, n_epochs)
scheduler_warmup = GradualWarmupSchedulerV2(optimizer, multiplier=10, total_epoch=warmup_epo, after_scheduler=scheduler_cosine)

In [None]:
df_train_this = my_training_set
df_valid_this = my_training_set[my_training_set['fold'] == fold_id]

dataset_train = RANZERDataset(df_train_this, 'train', transform=transforms_train)
dataset_valid = RANZERDataset(df_valid_this, 'valid', transform=transforms_valid)

train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True,  num_workers=num_workers, pin_memory=False)
valid_loader = torch.utils.data.DataLoader(dataset_valid, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=False)

In [None]:
log = {}
accuracy_max = 0.

for epoch in range(1, n_epochs+1):
    scheduler_warmup.step(epoch-1)
    loss_train = train_func(train_loader)
    loss_valid, accuracy = valid_func(valid_loader)

    log['loss_train'] = log.get('loss_train', []) + [loss_train]
    log['loss_valid'] = log.get('loss_valid', []) + [loss_valid]
    log['lr'] = log.get('lr', []) + [optimizer.param_groups[0]["lr"]]
    log['accuracy'] = log.get('accuracy', []) + [accuracy]
    if use_neptune:
        neptune.log_metric('training_loss', loss_train)
        neptune.log_metric('validation_loss', loss_valid)
        neptune.log_metric('accuracy', accuracy)
    content = time.ctime() + ' ' + f'Fold {fold_id}, Epoch {epoch}, lr: {optimizer.param_groups[0]["lr"]:.7f}, loss_train: {loss_train:.5f}, loss_valid: {loss_valid:.5f}, accuracy: {accuracy:.6f}.'
    print(content)

    if accuracy > accuracy_max:
        print(f'accuracy ({accuracy_max:.6f} --> {accuracy:.6f}). Saving model ...')
        torch.save(model.state_dict(), f'{model_dir}dense121_feature_extractor.pth')
        accuracy_max = accuracy
