In [None]:
from google.colab import drive  
drive.mount('/content/drive')   

In [None]:
!pip install timm pytorch-lightning wandb -qqq -q -U 
!pip install -U albumentations --no-binary qudida,albumentations

In [3]:
import os
import gc
import warnings
from typing import Dict, List, Tuple, Optional
import numpy as np
import pandas as pd
import math
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from sklearn.model_selection import StratifiedKFold
from sklearn import preprocessing
from torch.utils.data import ConcatDataset, DataLoader, Dataset
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image

In [None]:
exp_name = "beluga"

df = pd.read_csv("/content/drive/MyDrive/beluga/metadata.csv")  #
out_dir = "/content/drive/MyDrive/beluga"                       #

!unzip "/content/drive/MyDrive/beluga/images.zip"               #
train_imgs = "/content/images"                                  #

In [5]:
cfg = { 
    "checkpoint_name": "tf_efficientnet_b2_ns_380",
    "model_name": "tf_efficientnet_b2_ns",
    "batch_size": 64,
    "image_size": (380,380),
    "max_epochs": 20,  

    # "checkpoint_name": "tf_efficientnet_b3_ns_380",
    # "model_name": "tf_efficientnet_b3_ns",
    # "batch_size": 32,
    # "image_size": (380,380),
    # "max_epochs": 30,

    # "checkpoint_name": "tf_efficientnet_b4_ns_380",
    # "model_name": "tf_efficientnet_b4_ns",
    # "batch_size": 32,
    # "image_size": (380,380),
    # "max_epochs": 25,

    # "checkpoint_name": "tf_efficientnet_b4_ns_456",
    # "model_name": "tf_efficientnet_b4_ns",
    # "batch_size": 24,
    # "image_size": (456,456),
    # "max_epochs": 20,

    # "checkpoint_name": "tf_efficientnet_b5_ns_456",
    # "model_name": "tf_efficientnet_b5_ns",
    # "batch_size": 16,
    # "image_size": (456,456),
    # "max_epochs": 22,

    # "checkpoint_name": "tf_efficientnet_b5_ns_528",
    # "model_name": "tf_efficientnet_b5_ns",
    # "batch_size": 16,
    # "image_size": (528,528),
    # "max_epochs": 20,

    # "checkpoint_name": "tf_efficientnetv2_m_in21ft1k_380",
    # "model_name": "tf_efficientnetv2_m_in21ft1k",
    # "batch_size": 32,
    # "image_size": (380,380),
    # "max_epochs": 20,


    "lr_backbone": 1.6e-3,  
    "lr_head": 1.6e-2,      
    "lr_decay_scale": 1.0e-2, 
    "out_indices": (3,4),
    "n_splits": -1,  # -1, 5,
    "num_classes": 788,
    "warmup_steps_ratio": 0.2,
    "n_data": -1,
    "s_id": 21.0,               
    "margin_coef_id": 0.5,      
    "margin_power_id": -0.125,
    "margin_cons_id": 0.05,
    "n_center_id": 2,

    "num_workers" : 2,
    "wandb" : False,
}

In [6]:
if cfg["wandb"]:
    wandb.login()

In [7]:
class BelugaDataset(Dataset):
    def __init__(self, df, image_dir, data_aug):
        super().__init__()
        self.index = df.index
        self.x_paths = np.array(df.image_id)
        self.ids = np.array(df.individual_id, dtype=int) if hasattr(df, "individual_id") else np.full(len(df), -1)
        self.image_dir = image_dir
        self.df = df
        self.data_aug = data_aug
        augments = []
        if data_aug:
            augments = [
                A.Affine(rotate=(-15, 15),  
                         translate_percent=(0.0, 0.25), 
                         shear=(-3, 3), 
                         p=0.5),
                A.RandomResizedCrop(cfg["image_size"][0], cfg["image_size"][1], 
                                    scale=(0.9, 1.0), 
                                    ratio=(0.75, 1.3333333333)),
                A.ToGray(p=0.1),
                A.GaussianBlur(blur_limit=(3, 7), p=0.05),
                A.GaussNoise(p=0.05),
                A.RandomGridShuffle(grid=(2, 2), p=0.3),
                A.Posterize(p=0.2),
                A.RandomBrightnessContrast(p=0.5),
                A.CoarseDropout(p=0.05),
                A.RandomSnow(p=0.1),
                A.RandomRain(p=0.05),
                A.HorizontalFlip(p=0.5),
            ]
        augments.append(A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
        augments.append(ToTensorV2()) 
        self.transform = A.Compose(augments)

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, i: int):
        image = Image.open(f"{train_imgs}/{self.x_paths[i]}.jpg").convert("RGB")
        image = np.array(image.resize((cfg["image_size"][0], cfg["image_size"][1]), Image.BICUBIC))
        augmented = self.transform(image=image)["image"]
        return {
            "original_index": self.index[i],
            "image": augmented,
            "label": self.ids[i],
        }

In [8]:
class WarmupCosineLambda:
    def __init__(self, warmup_steps: int, cycle_steps: int, decay_scale: float, exponential_warmup: bool = False):
        self.warmup_steps = warmup_steps
        self.cycle_steps = cycle_steps
        self.decay_scale = decay_scale
        self.exponential_warmup = exponential_warmup

    def __call__(self, epoch: int):
        if epoch < self.warmup_steps:
            if self.exponential_warmup:
                return self.decay_scale * pow(self.decay_scale, -epoch / self.warmup_steps)
            ratio = epoch / self.warmup_steps
        else:
            ratio = (1 + math.cos(math.pi * (epoch - self.warmup_steps) / self.cycle_steps)) / 2
        return self.decay_scale + (1 - self.decay_scale) * ratio
    
def topk_average_precision(output: torch.Tensor, y: torch.Tensor, k: int):
    score_array = torch.tensor([1.0 / i for i in range(1, k + 1)], device=output.device)
    topk = output.topk(k)[1]
    match_mat = topk == y[:, None].expand(topk.shape)
    return (match_mat * score_array).sum(dim=1)    

def calc_map5(output: torch.Tensor, y: torch.Tensor, threshold: Optional[float]):
    if threshold is not None:
        output = torch.cat([output, torch.full((output.shape[0], 1), threshold, device=output.device)], dim=1)
    return topk_average_precision(output, y, 5).mean().detach()

def map_dict(output: torch.Tensor, y: torch.Tensor, prefix: str):
    d = {f"{prefix}/acc": topk_average_precision(output, y, 1).mean().detach()}
    for threshold in [None, 0.3, 0.4, 0.5, 0.6, 0.7]:
        d[f"{prefix}/map{threshold}"] = calc_map5(output, y, threshold)
    return d

In [9]:
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6, requires_grad=False):
        super().__init__()
        self.p = nn.Parameter(torch.ones(1) * p, requires_grad=requires_grad)
        self.eps = eps

    def forward(self, x: torch.Tensor):
        return x.clamp(min=self.eps).pow(self.p).mean((-2, -1)).pow(1.0 / self.p)

class ArcMarginProductSubcenter(nn.Module):
    def __init__(self, in_features, out_features, k=3):
        super().__init__()
        self.weight = nn.Parameter(torch.FloatTensor(out_features*k, in_features))
        self.reset_parameters()
        self.k = k
        self.out_features = out_features

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, features):
        cosine_all = F.linear(F.normalize(features), F.normalize(self.weight))
        cosine_all = cosine_all.view(-1, self.out_features, self.k)
        cosine, _ = torch.max(cosine_all, dim=2)
        return cosine

class ArcFaceLossAdaptiveMargin(nn.modules.Module):
    def __init__(self, margins, n_classes, s = 30.0):
        super().__init__()
        self.s = s
        self.margins = margins
        self.out_dim = n_classes

    def forward(self, logits, labels):
        ms = self.margins[labels.cpu().numpy()]
        cos_m = torch.from_numpy(np.cos(ms)).float().cuda()
        sin_m = torch.from_numpy(np.sin(ms)).float().cuda()
        th = torch.from_numpy(np.cos(math.pi - ms)).float().cuda()
        mm = torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda()
        labels = F.one_hot(labels, self.out_dim).float()
        logits = logits.float()
        cosine = logits
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * cos_m.view(-1, 1) - sine * sin_m.view(-1, 1)
        phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1))
        return ((labels * phi) + ((1.0 - labels) * cosine)) * self.s

In [10]:
class BelugaDataModule(LightningDataModule):
    def __init__(self, df, image_dir, fold):
        super().__init__()
        self.image_dir = image_dir
        if cfg["n_data"] != -1:
            df = df.iloc[: cfg["n_data"]]
        self.all_df = df
        if fold == -1:
            self.train_df = df
        else:
            skf = StratifiedKFold(n_splits=cfg["n_splits"], shuffle=True, random_state=0)
            train_idx, val_idx = list(skf.split(df, df.individual_id))[fold]
            self.train_df = df.iloc[train_idx].copy()
            self.val_df = df.iloc[val_idx].copy()
            new_mask = ~self.val_df.individual_id.isin(self.train_df.individual_id)
            self.val_df.individual_id.mask(new_mask, cfg["num_classes"], inplace=True)

    def get_dataset(self, df, data_aug):
        return BelugaDataset(df, self.image_dir, data_aug)

    def train_dataloader(self):
        dataset = self.get_dataset(self.train_df, True)
        return DataLoader(
            dataset,
            batch_size=cfg["batch_size"],
            shuffle=True,
            num_workers = cfg["num_workers"],
            pin_memory=True,
            drop_last=True,
        )

    def val_dataloader(self):
        if cfg["n_splits"] == -1:
            return None
        return DataLoader(
            self.get_dataset(self.val_df, False),
            batch_size=cfg["batch_size"],
            shuffle=False,
            num_workers = cfg["num_workers"],
            pin_memory=True,
        )

In [11]:
class BelugaClassifier(LightningModule):
    def __init__(self, id_class_nums=None):
        super().__init__()
        self.save_hyperparameters(cfg, ignore=["id_class_nums"])
        self.test_results_fp = None

        self.backbone = timm.create_model(
            cfg["model_name"],
            in_chans=3,
            pretrained=True,
            num_classes=0,
            features_only=True,
            out_indices=cfg["out_indices"],
        )
        feature_dims = self.backbone.feature_info.channels()
        self.global_pools = torch.nn.ModuleList([GeM(p=3, requires_grad=False) for _ in cfg["out_indices"]])
        self.mid_features = np.sum(feature_dims)
        self.neck = torch.nn.BatchNorm1d(self.mid_features)
        self.head_id = ArcMarginProductSubcenter(self.mid_features, cfg["num_classes"], cfg["n_center_id"])
        margins_id = np.power(id_class_nums, cfg["margin_power_id"]) * cfg["margin_coef_id"] + cfg["margin_cons_id"]
        self.margin_fn_id = ArcFaceLossAdaptiveMargin(margins_id, cfg["num_classes"], cfg["s_id"])
        self.loss_fn_id = torch.nn.CrossEntropyLoss()

    def get_feat(self, x: torch.Tensor) -> torch.Tensor:
        ms = self.backbone(x)
        h = torch.cat([global_pool(m) for m, global_pool in zip(ms, self.global_pools)], dim=1)
        return self.neck(h)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feat = self.get_feat(x)
        return self.head_id(feat)

    def training_step(self, batch, batch_idx):
        x, ids = batch["image"], batch["label"]
        logits_ids = self(x)
        margin_logits_ids = self.margin_fn_id(logits_ids, ids)
        loss_ids = self.loss_fn_id(margin_logits_ids, ids)
        self.log_dict({"train/loss_ids": loss_ids.detach()}, on_step=False, on_epoch=True)
        with torch.no_grad():
            self.log_dict(map_dict(logits_ids, ids, "train"), on_step=False, on_epoch=True)
        return loss_ids

    def validation_step(self, batch, batch_idx):
        x, ids = batch["image"], batch["label"]
        out1 = self(x)
        out2 = self(x.flip(3))
        output = (out1 + out2) / 2
        self.log_dict(map_dict(output, ids, "val"), on_step=False, on_epoch=True)

    def configure_optimizers(self):
        backbone_params = list(self.backbone.parameters()) + list(self.global_pools.parameters())
        head_params = (list(self.neck.parameters()) + list(self.head_id.parameters()))
        params = [
            {"params": backbone_params, "lr": cfg["lr_backbone"]},
            {"params": head_params, "lr": cfg["lr_head"]},
        ]

        optimizer = torch.optim.AdamW(params)

        warmup_steps = cfg["max_epochs"] * cfg["warmup_steps_ratio"]
        cycle_steps = cfg["max_epochs"] - warmup_steps
        lr_lambda = WarmupCosineLambda(warmup_steps, cycle_steps, cfg["lr_decay_scale"])
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
        return [optimizer], [scheduler]

In [12]:
def train(df, fold):
    out_dr = f"{out_dir}/{exp_name}"
    id_class_nums = df.individual_id.value_counts().sort_index().values
    model = BelugaClassifier(id_class_nums=id_class_nums)
    data_module = BelugaDataModule(df, f"{train_imgs}", fold)
    loggers = [pl_loggers.CSVLogger(out_dr)]
    if cfg["wandb"]:
        loggers.append(
            pl_loggers.WandbLogger(
                project="beluga_", group=exp_name, name=f"{exp_name}", save_dir=out_dir
            )
        )
    checkpoint_callback = ModelCheckpoint(out_dir, save_last=True, save_top_k=0)
    checkpoint_callback.CHECKPOINT_NAME_LAST = cfg["checkpoint_name"]

    trainer = Trainer(
        gpus=1,
        max_epochs=cfg["max_epochs"],
        logger=loggers,
        callbacks=[checkpoint_callback],
        precision=16,
    )

    trainer.fit(model, datamodule=data_module)

    if cfg["wandb"]:
        wandb.finish()

In [None]:
# label encoder
label_encoder = preprocessing.LabelEncoder()
df["individual_id"] = label_encoder.fit_transform(df["whale_id"])
assert cfg["num_classes"] == len(label_encoder.classes_)

if cfg["n_splits"] == -1:
    train(df, -1)
else:
    train(df, 0)