In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
import os
import faiss
import copy
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from PIL import Image
from torchvision import transforms
import cv2

class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, in_features, out_features, s=30.0, 
                 m=0.50, easy_margin=False, ls_eps=0.0, device=torch.device('cuda')):
        super(ArcMarginProduct, self).__init__()
        self.device = device
        self.in_features = in_features
        self.out_features = out_features
        self.s = s
        self.m = m
        self.ls_eps = ls_eps  # label smoothing
        self.weight = nn.Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, input, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------
        # one_hot = torch.zeros(cosine.size(), requires_grad=True, device='cuda')
        one_hot = torch.zeros(cosine.size(), device=self.device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) ------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return output

class DenseCrossEntropy(nn.Module):
    def forward(self, x, target):
        x = x.float()
        target = target.float()
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)

        loss = -logprobs * target
        loss = loss.sum(-1)
        return loss.mean()

class ArcMarginProduct_subcenter(nn.Module):
    def __init__(self, in_features, out_features, k=3):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(out_features*k, in_features))
        self.reset_parameters()
        self.k = k
        self.out_features = out_features
        
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        
    def forward(self, features):
        cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
        cosine_all = cosine_all.view(-1, self.out_features, self.k)
        cosine, _ = torch.max(cosine_all, dim=2)
        return cosine   

class ArcFaceLossAdaptiveMargin(nn.modules.Module):
    def __init__(self, margins, s=30.0):
        super().__init__()
        self.crit = DenseCrossEntropy()
        self.s = s
        self.margins = margins
            
    def forward(self, logits, labels, out_dim):
        ms = []
        ms = self.margins[labels.cpu().numpy()]
        cos_m = torch.from_numpy(np.cos(ms)).float().cuda()
        sin_m = torch.from_numpy(np.sin(ms)).float().cuda()
        th = torch.from_numpy(np.cos(math.pi - ms)).float().cuda()
        mm = torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda()
        labels = F.one_hot(labels, out_dim).float()
        logits = logits.float()
        cosine = logits
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * cos_m.view(-1,1) - sine * sin_m.view(-1,1)
        phi = torch.where(cosine > th.view(-1,1), phi, cosine - mm.view(-1,1))
        output = (labels * phi) + ((1.0 - labels) * cosine)
        output *= self.s
        loss = self.crit(output, labels)
        return loss     

def set_seed(seed):
    '''Sets the seed of the entire notebook so results are the same every time we run.
    This is for REPRODUCIBILITY.'''
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # When running on the CuDNN backend, two further options must be set
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Set a fixed value for the hash seed
    os.environ['PYTHONHASHSEED'] = str(seed)


def get_similiarity(embeddings, k):
    print('Processing indices...')

    index = faiss.IndexFlatL2(embeddings.shape[1])

    res = faiss.StandardGpuResources()

    index = faiss.index_cpu_to_gpu(res, 0, index)

    index.add(embeddings)

    scores, indices = index.search(embeddings, k) 
    print('Finished processing indices')

    return scores, indices

def map_per_image(label, predictions): 
    try:
        return 1 / (predictions[:5].index(label) + 1)
    except ValueError:
        return 0.0

def map_per_set(labels, predictions):
    return np.mean([map_per_image(l, p) for l,p in zip(labels, predictions)])

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self, window_size=None):
        self.length = 0
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.window_size = window_size

    def reset(self):
        self.length = 0
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        if self.window_size and (self.count >= self.window_size):
            self.reset()
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def get_lr_groups(param_groups):
        groups = sorted(set([param_g['lr'] for param_g in param_groups]))
        groups = ["{:2e}".format(group) for group in groups]
        return groups

def convert_indices_to_labels(indices, labels):
    indices_copy = copy.deepcopy(indices)
    for row in indices_copy:
        for j in range(len(row)):
            row[j] = labels[row[j]]
    return indices_copy

class Multisample_Dropout(nn.Module):
    def __init__(self):
        super(Multisample_Dropout, self).__init__()
        self.dropout = nn.Dropout(.1)
        self.dropouts = nn.ModuleList([nn.Dropout((i+1)*.1) for i in range(5)])
        
    def forward(self, x, module):
        x = self.dropout(x)
        return torch.mean(torch.stack([module(dropout(x)) for dropout in self.dropouts],dim=0),dim=0) 

def transforms_auto_augment(image_path, image_size):
    image = Image.open(image_path).convert('RGB')
    train_transforms = transforms.Compose([transforms.AutoAugment(transforms.AutoAugmentPolicy.IMAGENET), transforms.PILToTensor()])
    return train_transforms(image)

def transforms_cutout(image_path, image_size):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.uint8)
    train_transforms = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.ImageCompression(quality_lower=99, quality_upper=100),
            A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=10, border_mode=0, p=0.7),
            A.Resize(image_size, image_size),
            A.Cutout(max_h_size=int(image_size * 0.4), max_w_size=int(image_size * 0.4), num_holes=1, p=0.5),
            ToTensorV2(),
        ])
    return train_transforms(image=image)['image']

def transforms_happy_whale(image_path, image_size):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB).astype(np.uint8)
    aug8p3 = A.OneOf([
            A.Sharpen(p=0.3),
            A.ToGray(p=0.3),
            A.CLAHE(p=0.3),
        ], p=0.5)

    train_transforms = A.Compose([
            A.ShiftScaleRotate(rotate_limit=15, scale_limit=0.1, border_mode=cv2.BORDER_REFLECT, p=0.5),
            A.Resize(image_size, image_size),
            aug8p3,
            A.HorizontalFlip(p=0.5),
            A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
            ToTensorV2(),
        ])
    return train_transforms(image=image)['image']

def transforms_valid(image_path, image_size):
    image = Image.open(image_path).convert('RGB')
    valid_transforms = transforms.Compose([transforms.PILToTensor()]) 
    return valid_transforms(image)

In [None]:
import pandas as pd
import glob
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import timm
import math
from transformers import (get_linear_schedule_with_warmup, 
                          get_cosine_schedule_with_warmup, 
                          get_cosine_with_hard_restarts_schedule_with_warmup,
                          get_constant_schedule_with_warmup)
from tqdm import tqdm
import faiss
import random
import gc
import transformers
from transformers import CLIPProcessor, CLIPVisionModel,  CLIPVisionConfig
from PIL import Image
from torchvision import transforms
from pytorch_metric_learning import losses
import open_clip
import sys

In [None]:
import json
from tqdm import tqdm

    
# df = pd.read_csv(df_path)
class CFG:
    model_name = 'ViT-B-16' 
    model_data = 'openai'
    samples_per_class = 50
    min_samples = 4
    image_size = 336 
    seed = 5
    workers = 8
    train_batch_size = 128
    valid_batch_size = 128
    emb_size = 512
    vit_bb_lr = {'8': 1.25e-6, '16': 2.5e-6, '20': 5e-6, '24': 10e-6} 
    vit_bb_wd = 1e-3
    hd_lr = 3e-4
    hd_wd = 1e-5
    autocast = True
    n_warmup_steps = 30
    n_epochs = 20
    device = torch.device('cuda')
    s=10.
    m=.45
    m_min=.05
    acc_steps = 4
    global_step = 0
    train_csv_path = df_path
    data_path = '../'
    df = pd.read_csv(train_csv_path)
    df = df[df['split']=='train'].reset_index(drop=True)
    if datasets2exclude:
        datasets2use = [i for i in df['dataset'].unique() if i not in datasets2exclude]
        df = df[np.sum([df['dataset']==i for i in datasets2use], axis=0, dtype=bool)]
        new_class = {i:idx for idx, i in enumerate(df['class'].unique())}
        df['class'] = df['class'].map(new_class)

    df = df.reset_index(drop=True)
    n_classes = len(train_df)
    


In [None]:
set_seed(CFG.seed)

In [None]:
open_clip.list_pretrained()

In [None]:
vit_backbone, model_transforms, _ = open_clip.create_model_and_transforms(CFG.model_name, pretrained=CFG.model_data)

In [None]:
image_size = model_transforms.transforms[0].size[0]

In [None]:
mean, std = model_transforms.transforms[-1].mean, model_transforms.transforms[-1].std

In [None]:
train_aug = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.ImageCompression(quality_lower=50, quality_upper=100),
        A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=10, border_mode=0, p=0.7),
        A.Resize(image_size, image_size),
        A.Cutout(max_h_size=int(image_size * 0.4), max_w_size=int(image_size * 0.4), num_holes=1, p=0.5),
        A.OneOf([
            A.ChannelShuffle(),
            A.ChannelDropout(),
            A.ColorJitter(),
#             A.ISONoise(),
            A.ToGray(),
        ], p=0.65),
        A.Normalize(mean=mean, std=std, p=1), 
    ])

val_aug = A.Compose([
        A.Resize(image_size, image_size),
        A.Normalize(mean=mean, std=std, p=1)
    ])

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import torch
import numpy as np

class Products_dataset(Dataset):
    def __init__(self, 
                 df, 
                 mode='train', 
                 transform=None
                 ):
        self.df = df[df['split']==mode].reset_index(drop=True)
        self.idx2class = {k:int(i) for i in self.df for k in self.df[i]['idxs']}
        self.idxs = list(self.idx2class.keys())
        self.augs = transform
        self.mode = mode

    def __len__(self):
        return len(self.idx2class)
    
    def __getitem__(self, idx):
        path = f'{self.idxs[idx]}.png'
        label = self.idx2class[self.idxs[idx]]
                    
        try:

            img = Image.open(path).convert('RGB')
            img = np.array(img)
            
            if self.augs:
                img = self.augs(image=img)['image']
                img = torch.from_numpy(img).permute(2, 0, 1)
        except Exception as e:
            print(e, path)
            img = torch.zeros((3, 224, 224))
        
        return img, label

In [None]:
class Head(nn.Module):
    def __init__(self, hidden_size, emb_size, n_classes):
        super(Head, self).__init__()

        self.emb = nn.Linear(hidden_size, emb_size, bias=False)
        self.arc = ArcMarginProduct_subcenter(emb_size, n_classes)
        self.dropout = Multisample_Dropout()

    def forward(self, x):
        embeddings = self.dropout(x, self.emb)
        
        output = self.arc(embeddings)

        return output, F.normalize(embeddings)
    
class Model(nn.Module):
    def __init__(self, vit_backbone):
        super(Model, self).__init__()
        
        vit_backbone = vit_backbone.visual
        self.img_size = vit_backbone.image_size
        if type(self.img_size)==tuple:
            self.img_size = self.img_size[1]
        hidden_size = vit_backbone(torch.zeros((1, 3, self.img_size, self.img_size))).shape[1]
        self.vit_backbone = vit_backbone
        self.head = Head(hidden_size, CFG.emb_size, CFG.n_classes)

    def forward(self, x):
        x = self.vit_backbone(x)
        return self.head(x)

    def get_parameters(self):

        parameter_settings = [] 
        parameter_settings.extend(self.get_parameter_section([(n, p) for n, p in self.vit_backbone.named_parameters()], lr=CFG.vit_bb_lr, wd=CFG.vit_bb_wd)) 

        parameter_settings.extend(self.get_parameter_section([(n, p) for n, p in self.head.named_parameters()], lr=CFG.hd_lr, wd=CFG.hd_wd)) 

        return parameter_settings

    def get_parameter_section(self, parameters, lr=None, wd=None): 
        parameter_settings = []


        lr_is_dict = isinstance(lr, dict)
        wd_is_dict = isinstance(wd, dict)

        layer_no = None
        for no, (n,p) in enumerate(parameters):
            
            for split in n.split('.'):
                if split.isnumeric():
                    layer_no = int(split)
            
            if not layer_no:
                layer_no = 0
            
            if lr_is_dict:
                for k,v in lr.items():
                    if layer_no < int(k):
                        temp_lr = v
                        break
            else:
                temp_lr = lr

            if wd_is_dict:
                for k,v in wd.items():
                    if layer_no < int(k):
                        temp_wd = v
                        break
            else:
                temp_wd = wd

            weight_decay = 0.0 if 'bias' in n else temp_wd

            parameter_setting = {"params" : p, "lr" : temp_lr, "weight_decay" : temp_wd}

            parameter_settings.append(parameter_setting)

            #print(f'no {no} | params {n} | lr {temp_lr} | weight_decay {weight_decay} | requires_grad {p.requires_grad}')

        return parameter_settings


In [None]:
train_ds = Products_dataset(
        train_df,
        mode='train', 
        transform=train_aug,
    )
train_loader = DataLoader(train_ds, num_workers=8, batch_size=CFG.train_batch_size, shuffle=True)


valid_ds = Products_dataset(
        valid_df,
        mode='valid', 
        transform=val_aug,
    )
valid_loader = DataLoader(valid_ds, num_workers=8, batch_size=CFG.valid_batch_size, shuffle=False)

In [None]:
def ArcFace_criterion(logits_m, target, margins):
    arc = ArcFaceLossAdaptiveMargin(margins=margins, s=CFG.s)
    loss_m = arc(logits_m, target, CFG.n_classes)
    return loss_m

In [None]:
def train(model, train_loader, optimizer, scaler, scheduler, epoch, pbar_draw=False):
    model.train()
    loss_metrics = AverageMeter()
    criterion = ArcFace_criterion

    tmp = np.sqrt(1 / np.sqrt(value_counts))
    margins = (tmp - tmp.min()) / (tmp.max() - tmp.min()) * CFG.m + CFG.m_min
        
    bar = tqdm(train_loader, disable=not pbar_draw)
#     bar = tqdm(train_loader, disable=False)
    for step, data in enumerate(bar):
        step += 1
        images = data[0].to(CFG.device, dtype=torch.float)
        labels = data[1].to(CFG.device)
        batch_size = labels.size(0)

        with torch.cuda.amp.autocast(enabled=CFG.autocast):
            outputs, features = model(images)

        loss = criterion(outputs, labels, margins)
        loss_metrics.update(loss.item(), batch_size)
        loss = loss / CFG.acc_steps
        scaler.scale(loss).backward()

        if step % CFG.acc_steps == 0 or step == len(bar):
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            scheduler.step()
            CFG.global_step += 1
            
        lrs = get_lr_groups(optimizer.param_groups)

        loss_avg = loss_metrics.avg

        bar.set_postfix(loss=loss_avg, epoch=epoch, lrs=lrs, step=CFG.global_step)

In [None]:
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score
from tqdm.notebook import tqdm
import numpy as np
import faiss

def map_per_image(label, predictions, k=5):
    """Computes the precision score of one image.

    Parameters
    ----------
    label : string
            The true label of the image
    predictions : list
            A list of predicted elements (order does matter, 5 predictions allowed per image)

    Returns
    -------
    score : double
    """    
    try:
        return 1 / (predictions[:k].index(label) + 1)
    except ValueError:
        return 0.0

def map_per_set(labels, predictions, k=5):
    """Computes the average over multiple images.

    Parameters
    ----------
    labels : list
             A list of the true labels. (Only one true label per images allowed!)
    predictions : list of list
             A list of predicted elements (order does matter, 5 predictions allowed per image)

    Returns
    -------
    score : double
    """
    return np.mean([map_per_image(l, p, k=k) for l,p in zip(labels, predictions)])


def validate(model, loader, embedding_size=512):
    model.eval()
    device = 'cuda'
    bd, labels = [], []
    for i, (images, label) in tqdm(enumerate(loader), total=len(loader)):
        with torch.no_grad():
            outputs = model(images.to(device))[1].cpu().detach().tolist()
        bd.extend(outputs)
        labels.extend(label)
    bd = np.array(bd, np.float32)
    labels = np.array(labels)

    index = faiss.IndexFlatL2(embedding_size)
    index.add(bd)
    D, I = index.search(bd, 6)
    global_preds = labels[I[:, 1:]]
    global_labels = labels
    acc_1 = map_per_set(global_labels.tolist(), global_preds.tolist(), k=1) 
    acc_5 = map_per_set(global_labels.tolist(), global_preds.tolist(), k=5)
    print(f'Metric mAP@1 = {acc_1}, mAP@5 = {acc_5}')
    return acc_1, acc_5


In [None]:
value_counts = [df[i]['cnt'] for i in df]

In [None]:
# vit_backbone, model_transforms, _ = open_clip.create_model_and_transforms(CFG.model_name, pretrained=CFG.model_data)
model = Model(vit_backbone.cpu()).to(CFG.device)

In [None]:
validate(model, valid_loader, embedding_size=512)

In [None]:
optimizer = torch.optim.AdamW(model.get_parameters())
scaler = torch.cuda.amp.GradScaler(enabled=CFG.autocast)
steps_per_epoch = math.ceil(len(train_loader) / CFG.acc_steps)
num_training_steps = math.ceil(CFG.n_epochs * steps_per_epoch)
scheduler = get_cosine_schedule_with_warmup(optimizer,
                                            num_training_steps=num_training_steps,
                                            num_warmup_steps=CFG.n_warmup_steps)   
CFG.global_step = 0                   
for epoch in tqdm(range(math.ceil(CFG.n_epochs))):

    train(model, train_loader, optimizer, scaler, scheduler, epoch, pbar_draw=True)
    score = validate(model, valid_loader, embedding_size=512)
    print(f'Epoch = {epoch}, score:', score)
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
import optuna
from optuna.trial import TrialState

def objective(trial):
    # Generate the model.

    CFG.s = trial.suggest_int("s", 5, 90, step=5)
    CFG.hd_lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    model = Model(vit_backbone.cpu()).to(CFG.device)
    
    CFG.n_epochs = trial.suggest_int("n_epochs", 1, 11, step=2)
#     CFG.n_epochs = 10
    CFG.n_warmup_steps = trial.suggest_int('n_warmup_steps', 1, 5)
    
    optimizer = torch.optim.AdamW(model.get_parameters())
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.autocast)
    steps_per_epoch = math.ceil(len(train_loader) / CFG.acc_steps)
    num_training_steps = math.ceil(CFG.n_epochs * steps_per_epoch)
    scheduler = get_cosine_schedule_with_warmup(optimizer,
                                                num_training_steps=num_training_steps,
                                                num_warmup_steps=CFG.n_warmup_steps)   
    CFG.global_step = 0                   
    for epoch in range(math.ceil(CFG.n_epochs)):

        train(model, train_loader, optimizer, scaler, scheduler, epoch)
        score = validate(model, valid_loader)
#         print(f'Epoch = {epoch}, score:', score)
        gc.collect()
        torch.cuda.empty_cache()

        trial.report(score, epoch)

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    return score


if __name__ == "__main__":
    study = optuna.create_study(direction="maximize")
    study.optimize(objective, n_trials=20, timeout=None)

    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Number of pruned trials: ", len(pruned_trials))
    print("  Number of complete trials: ", len(complete_trials))

    print("Best trial:")
    trial = study.best_trial

    print("  Value: ", trial.value)

    print("  Params: ")
    for key, value in trial.params.items():
        print("    {}: {}".format(key, value))