In [1]:
import os
import io
import time
import timm
import umap
import click
import random
import logging
from typing import Tuple
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.preprocessing import StandardScaler

In [2]:
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchmetrics import Accuracy, ConfusionMatrix

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from pytorch_metric_learning import losses, miners
from pytorch_metric_learning.samplers import MPerClassSampler
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator

In [3]:
import albumentations as albu
from albumentations.pytorch import ToTensorV2
from albumentations.augmentations import CoarseDropout

In [4]:
BATCH_SIZE = 32
SIZE = 224
BACKBONE = "resnext101_32x8d"

dataset_folder = "../data/interim/dataset_part/"
tb_log_dir = "../logs"
model_dir = "../models"
max_epochs = 30

# Functions

In [5]:
class Transforms:
    def __init__(self, segment="train"):
        if segment == "train":
            transforms = [
                albu.LongestMaxSize(max_size=224 + 5, always_apply=True, p=1),
                albu.RandomBrightnessContrast(p=0.3),
                albu.ColorJitter(hue=0.01, saturation=0.02, p=0.3),
                # geometric transformations
                albu.GridDistortion(distort_limit=0.6, p=0.3),
                albu.ShiftScaleRotate(border_mode=1, rotate_limit=3, p=0.3),
                albu.PadIfNeeded(
                    min_height=224 + 5,
                    min_width=224 + 5,
                    always_apply=True,
                    border_mode=0,
                    value=(255, 255, 255),
                ),
                albu.RandomCrop(width=224, height=224),
                albu.HorizontalFlip(p=0.5),
            ]
        else:
            transforms = [
                albu.LongestMaxSize(max_size=224, always_apply=True, p=1),
                albu.PadIfNeeded(
                    min_height=224,
                    min_width=224,
                    always_apply=True,
                    border_mode=0,
                    value=(255, 255, 255),
                ),
            ]
        transforms.extend(
            [
                albu.Normalize(),
                ToTensorV2(),
            ]
        )

        self.transforms = albu.Compose(transforms)

    def __call__(self, img, *args, **kwargs):
        return self.transforms(image=np.array(img))["image"]

In [6]:
class EmbeddingsModel(nn.Module):
    def __init__(
        self,
        num_classes: int,
        embedding_size: int = 512,
        backbone: str = "resnext101_32x8d",
    ):
        super().__init__()
        self.trunk = timm.create_model(backbone, pretrained=True)
        self.embedding_size = embedding_size
        self.trunk.fc = nn.Linear(
            in_features=self.trunk.fc.in_features,
            out_features=embedding_size,
            bias=False,
        )

        self.classifier = torch.nn.Sequential(
            nn.Linear(embedding_size, num_classes, bias=True),
        )

    def forward(self, inpt):
        # get embeddings
        emb = self.trunk(inpt)

        # get logits
        logits = self.classifier(emb)

        return logits, emb

In [32]:
def get_embeddings(trainer, loader):
    embeddings = trainer.predict(dataloaders=loader, ckpt_path="best")

    embs = []
    targets = []
    for element in embeddings:
        emb, target = element
        embs.append(emb)
        targets.append(target)

    embeddings, targets = torch.concat(embs), torch.concat(targets)
    return embeddings.cpu(), targets.cpu()


def sampling(embeddings, targets, N):
    """
    stratified sampling to speed up validation
    :param embeddings: full embeddings
    :param targets: full targets
    :param N: number of samples to save per class
    :return: subsample of embeddings and targets
    """
    embeddings = embeddings.numpy()
    targets = targets.numpy()
    df = pd.DataFrame.from_records(data=embeddings)
    df["target"] = targets
    df = df.groupby("target", group_keys=False).apply(
        lambda x: x.sample(min(len(x), N), random_state=42)
    )
    targets = df["target"].values
    df.drop(columns=["target"], inplace=True)
    embeddings = df.values

    return torch.tensor(embeddings).contiguous(), torch.tensor(targets).contiguous()


def calculate_accuracy(trainer, train_dl, val_dl):
    embeddings_train, targets_train = get_embeddings(trainer, train_dl)
    embeddings_val, targets_val = get_embeddings(trainer, val_dl)
    
    embeddings_train, targets_train = sampling(embeddings_train, targets_train, 300)
    embeddings_val, targets_val = sampling(embeddings_val, targets_val, 300)

    accuracy_calculator = AccuracyCalculator(device=torch.device("cpu"))

    accuracies = accuracy_calculator.get_accuracy(
        embeddings_train, embeddings_val, targets_train, targets_val, False
    )

    return accuracies


class Runner(pl.LightningModule):
    def __init__(
        self,
        model,
        classes,
        mapper,
        lr: float = 1e-3,
        scheduler_T=1000,
        metric_coeff: float = 0.3,
    ) -> None:

        super().__init__()

        self.model = model
        self.classes = classes
        self.lr = lr
        self.scheduler_T = scheduler_T
        self.criterion = CrossEntropyLoss()
        self.metric_coeff = metric_coeff
        self.miner = miners.MultiSimilarityMiner(epsilon=0.1)
        self.metric_loss = losses.SubCenterArcFaceLoss(  # ArcFaceLoss
            num_classes=len(classes), embedding_size=model.embedding_size
        ).to(torch.device("cuda"))

        num_classes = len(self.classes)
        self.mapper = mapper

        self.metrics = torch.nn.ModuleDict(
            {
                "accuracy": Accuracy(
                    num_classes=num_classes, compute_on_step=False, average="macro"
                ),
                "confusion_matrix": ConfusionMatrix(
                    num_classes=num_classes, normalize="true", compute_on_step=False
                ),
            }
        )
        self.accuracy_calculator = AccuracyCalculator(device=torch.device("cpu"))

        self.embeddings_train = []
        self.embeddings_val = []
        self.targets_train = []
        self.targets_val = []

    def predict_step(self, batch, batch_idx, **kwargs):
        images, targets = batch
        logits, embeddings = self.model(images)
        return embeddings, targets

    def training_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx
    ) -> torch.Tensor:

        images, targets = batch

        logits, embeddings = self.model(images)

        self.embeddings_train.append(embeddings.detach().cpu())
        self.targets_train.append(targets.detach().cpu())

        # calculating classification loss
        clf_loss = self.criterion(logits, targets)

        # calculating metric loss
        hard_pairs = self.miner(embeddings, targets)
        m_loss = self.metric_loss(embeddings, targets, hard_pairs)

        loss = self.metric_coeff * clf_loss + (1 - self.metric_coeff) * m_loss

        # calculating metrics
        for i, metric in enumerate(self.metrics.values()):
            metric.update(logits.softmax(axis=1), targets)

        self.log(
            "Train/Metric Loss",
            m_loss.item(),
            on_step=True,
            batch_size=BATCH_SIZE,
        )
        self.log(
            "Train/Classification Loss",
            clf_loss.item(),
            on_step=True,
            batch_size=BATCH_SIZE,
        )
        self.log(
            "Train/LR",
            self.lr_schedulers().get_last_lr()[0],
            on_step=True,
            batch_size=BATCH_SIZE,
        )

        return loss

    def validation_step(
        self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx
    ) -> torch.Tensor:

        images, targets = batch
        logits, embeddings = self.model(images)
        self.embeddings_val.append(embeddings.detach().cpu())
        self.targets_val.append(targets.detach().cpu())

        clf_loss = self.criterion(logits, targets)

        # calculating metrics
        for i, metric in enumerate(self.metrics.values()):
            metric.update(logits.softmax(axis=1), targets)

        self.log(
            "Validation/Classification Loss",
            clf_loss.item(),
            on_step=True,
            batch_size=BATCH_SIZE,
        )
        return clf_loss

    def log_cm(self, confusion_matrix):
        print("start drawing conf matrix")
        plt.figure(figsize=(50, 50))
        sns.heatmap(
            np.around(confusion_matrix.cpu().numpy(), 3),
            annot=True,
            cmap="YlGnBu",
            xticklabels=self.classes,
            yticklabels=self.classes,
        )
        buf = io.BytesIO()
        plt.savefig(buf)
        buf.seek(0)
        image = np.array(Image.open(buf))[:, :, :3]
        buf.close()
        plt.clf()
        self.logger.experiment.add_image(
            "conf_matr", image, self.current_epoch, dataformats="HWC"
        )
        print("done!")

    def log_umap(self, embeddings, targets):
        print("run umap logger")
        sns.set(style="whitegrid", font_scale=1.3)

        print("     run StandardScaler")
        scaler = StandardScaler()
        embeddings_scaled = scaler.fit_transform(embeddings)
        print("     done!")

        print("     sampling")
        embeddings_scaled, targets = sampling(
            torch.tensor(embeddings_scaled), torch.tensor(targets), 50
        )
        print("     done!")

        print("     starting umap transforms")
        umap_obj = umap.UMAP(n_neighbors=20, min_dist=0.15)
        embedding_2d = umap_obj.fit_transform(embeddings_scaled.numpy())
        print("     done!")

        plot_df = pd.DataFrame.from_records(data=embedding_2d, columns=["x", "y"])
        plot_df["target"] = targets.numpy()
        plot_df["target"] = plot_df["target"].apply(lambda x: self.mapper[x])

        plt.figure(figsize=(14, 10))
        plt.title("UMAP")
        sns.scatterplot(x="x", y="y", data=plot_df, hue="target", palette="Paired")

        buf = io.BytesIO()
        plt.savefig(buf)
        buf.seek(0)
        image = np.array(Image.open(buf))[:, :, :3]
        buf.close()
        plt.clf()
        self.logger.experiment.add_image(
            "umap", image, self.current_epoch, dataformats="HWC"
        )
        print("done!")

    def validation_epoch_end(self, outputs) -> None:

        if len(self.embeddings_train) != 0:
            self.embeddings_train = torch.concat(self.embeddings_train)
            self.targets_train = torch.concat(self.targets_train)
            self.embeddings_val = torch.concat(self.embeddings_val)
            self.targets_val = torch.concat(self.targets_val)
            print("embedding example", self.embeddings_train[0][:10])
            print("embeddings train shape: ", self.embeddings_train.shape)
            print("embeddings val shape: ", self.embeddings_val.shape)

            self.embeddings_train, self.targets_train = sampling(
                self.embeddings_train, self.targets_train, 300
            )
            self.embeddings_val, self.targets_val = sampling(
                self.embeddings_val, self.targets_val, 300
            )

            # embeddings metrices
            print("accuracy_calculator start")
            accuracies = self.accuracy_calculator.get_accuracy(
                self.embeddings_train,
                self.embeddings_val,
                self.targets_train,
                self.targets_val,
                False,
            )
            print("done!")

            for name in accuracies:
                self.log(
                    f"Validation/{name}", accuracies[name], on_step=False, on_epoch=True
                )

            self.log_umap(
                embeddings=self.embeddings_val.numpy(),
                targets=self.targets_val.numpy(),
            )

            self.embeddings_train = []
            self.embeddings_val = []
            self.targets_train = []
            self.targets_val = []

        # classification metrices
        for name, metric in self.metrics.items():
            metric_val = metric.compute()
            self.log(f"Validation/{name}", metric_val, on_step=False, on_epoch=True)
            metric.reset()
            if name != "confusion_matrix":
                print(f"Validation {name} = {metric_val}")
            else:
                self.log_cm(metric_val)

    def configure_optimizers(self):
        params = list(filter(lambda p: p.requires_grad, self.model.parameters()))
        optimizer = torch.optim.Adam(params, lr=self.lr)

        if len([p for p in self.metric_loss.parameters()]) > 0:
            optimizer.add_param_group({"params": self.metric_loss.parameters()})

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=optimizer, T_max=self.scheduler_T, eta_min=1e-7
        )

        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler},
        }

# Dataset

In [33]:
os.makedirs(tb_log_dir, exist_ok=True)
os.makedirs(model_dir, exist_ok=True)

tb_log_dir_to_use = Path(tb_log_dir)
model_dir_to_use = Path(model_dir)
dataset_folder_to_use = Path(dataset_folder)

In [34]:
print("Running metric learning task!")

classes_train = set(
    [p.name for p in (dataset_folder_to_use / "train").glob("*")]
)
classes_val = set(
    [p.name for p in (dataset_folder_to_use / "test").glob("*")]
)

print(f"Number of classes in train {len(classes_train)}")
print(f"Number of classes in val {len(classes_val)}")
print(f"Number of classes in train & val {len(classes_train & classes_val)}")
print(f"Number of classes in train - val {len(classes_train - classes_val)}")

Running metric learning task!
Number of classes in train 12
Number of classes in val 12
Number of classes in train & val 12
Number of classes in train - val 0


In [35]:
print("creating datasets")

train_dataset = ImageFolder(
    root=str(dataset_folder_to_use / "train"),
    transform=Transforms(),
)

val_dataset = ImageFolder(
    root=str(dataset_folder_to_use / "test"),
    transform=Transforms(segment="val"),
)

mapper = {train_dataset.class_to_idx[i]:i for i in train_dataset.class_to_idx}
print("datasets were created")

creating datasets
datasets were created


In [36]:
print("creating data loaders")
sampler = MPerClassSampler(
    train_dataset.targets,
    m=3,
    length_before_new_iter=len(train_dataset),
)

train_dl = DataLoader(
    train_dataset,
    BATCH_SIZE,
    pin_memory=False,
    sampler=sampler,
    num_workers=4,
    drop_last=True,
)

val_dl = DataLoader(
    val_dataset,
    BATCH_SIZE,
    pin_memory=False,
    shuffle=False,
    num_workers=4,
    drop_last=False,
)
print("data loaders were created")

assert val_dataset.classes == train_dataset.classes

creating data loaders
data loaders were created


# Model

In [37]:
print("creating runner")
runner = Runner(
    model=EmbeddingsModel(
        num_classes=len(classes_train),
        backbone=BACKBONE
    ),
    classes=train_dataset.classes,
    lr=1e-3,
    scheduler_T=max_epochs,  # * len(train_dl)
    mapper=mapper
)
print("runner was created")

creating runner
runner was created


In [38]:
print("creating trainer!")
trainer = pl.Trainer(
    log_every_n_steps=30,
    max_epochs=max_epochs,
    gpus=-1,
    logger=pl.loggers.tensorboard.TensorBoardLogger(tb_log_dir_to_use),
    callbacks=[
        ModelCheckpoint(
            dirpath=model_dir,
            save_top_k=1,
            verbose=True,
            filename="checkpoint-{epoch:02d}",
        ),
        EarlyStopping(
            patience=10, monitor="Validation/accuracy", mode="max"
        ),

    ],
)
print("trainer was created!")

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


creating trainer!
trainer was created!


In [39]:
# # find learning rate
# print("Run learning rate finder")
# lr_finder = trainer.tuner.lr_find(runner, train_dl)

# # Pick point based on plot, or get suggestion
# new_lr = lr_finder.suggestion()

# # update hparams of the model
# runner.hparams.lr = new_lr
# print("Done!\n")

In [40]:
print("run training pipeline")
trainer.fit(runner, train_dl, val_dl)
print("done!")

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name        | Type                 | Params
-----------------------------------------------------
0 | model       | EmbeddingsModel      | 87.8 M
1 | criterion   | CrossEntropyLoss     | 0     
2 | miner       | MultiSimilarityMiner | 0     
3 | metric_loss | SubCenterArcFaceLoss | 18.4 K
4 | metrics     | ModuleDict           | 0     
-----------------------------------------------------
87.8 M    Trainable params
0         Non-trainable params
87.8 M    Total params
351.262   Total estimated model params size (MB)


run training pipeline


Sanity Checking: 0it [00:00, ?it/s]

Validation accuracy = 0.0
start drawing conf matrix




done!


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

embedding example tensor([ 0.7216,  0.0720, -0.1035, -0.4444,  0.3573, -0.1040,  0.0738,  0.0273,
        -0.0152, -0.2014])
embeddings train shape:  torch.Size([1056, 512])
embeddings val shape:  torch.Size([820, 512])
accuracy_calculator start
done!
run umap logger
     run StandardScaler
     done!
     sampling
     done!
     starting umap transforms
     done!
done!
Validation accuracy = 0.16911396384239197
start drawing conf matrix
done!


Validation: 0it [00:00, ?it/s]

embedding example tensor([-2.2894,  1.1505, -2.6921, -0.3550,  2.4992,  0.1322, -2.0354,  1.4891,
        -2.6784, -2.5955])
embeddings train shape:  torch.Size([1056, 512])
embeddings val shape:  torch.Size([756, 512])
accuracy_calculator start
done!
run umap logger
     run StandardScaler
     done!
     sampling
     done!
     starting umap transforms
     done!
done!
Validation accuracy = 0.24059827625751495
start drawing conf matrix
done!
done!


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


<Figure size 3600x3600 with 0 Axes>

<Figure size 1008x720 with 0 Axes>

<Figure size 3600x3600 with 0 Axes>

<Figure size 1008x720 with 0 Axes>

<Figure size 3600x3600 with 0 Axes>

# Test

In [41]:
train_dataset_clean = ImageFolder(
    root=str(dataset_folder_to_use / "train"),
    transform=Transforms(segment="test"),
)

train_dl_clean = DataLoader(
    train_dataset,
    BATCH_SIZE,
    shuffle=False,
    pin_memory=False,
    num_workers=4,
    drop_last=False,
)

In [44]:
accuracy = calculate_accuracy(
    trainer=trainer, 
    train_dl=train_dl_clean, 
    val_dl=val_dl)

Restoring states from the checkpoint path at /home/jupyter/train_embedder/models/checkpoint-epoch=01-v1.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /home/jupyter/train_embedder/models/checkpoint-epoch=01-v1.ckpt


Predicting: 3it [00:00, ?it/s]

Restoring states from the checkpoint path at /home/jupyter/train_embedder/models/checkpoint-epoch=01-v1.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /home/jupyter/train_embedder/models/checkpoint-epoch=01-v1.ckpt


Predicting: 3it [00:00, ?it/s]

In [45]:
accuracy

{'AMI': 0.20188743279380417,
 'NMI': 0.22100368269758192,
 'mean_average_precision': 0.13631831870650496,
 'mean_average_precision_at_r': 0.0419861583991085,
 'mean_reciprocal_rank': 0.3461434841156006,
 'precision_at_1': 0.18928901200369344,
 'r_precision': 0.1350349363728034}

# Save

In [None]:
# save the model
runner.model.eval()
b = next(iter(val_dl))
traced_model = torch.jit.trace(runner.model, b[0])

In [None]:
meta = {
    "inference_params": {
        "image_height": SIZE,
        "image_width": SIZE,
    },
}
traced_model.save(
    str(model_dir_to_use / "torchscript.pt"),
    _extra_files={f"{k}.txt": str(v) for k, v in meta.items()},
)