In [None]:
! pip install timm pigeon-jupyter

# 1. Environment Set Up
- Do neccesary imports
- Set up where the data is 
- Create a Labeled Dataset

In [None]:
import torch, sys, os, numpy as np, pandas as pd
from tqdm import tqdm
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torchvision
from fastai.vision.all import *
import timm

from sklearn.preprocessing import MinMaxScaler, OrdinalEncoder

from skimage import io
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import KFold, StratifiedKFold, train_test_split

from pigeon import annotate
from IPython import display

from matplotlib import pyplot as plt
import matplotlib.image as mpimg
import seaborn as sns

from fastai.vision.all import *

In [None]:
! gdown "1XE_nUkrtiybUAm2c0yYO-lPJbYyFUcww"

In [None]:
! unzip images.zip

In [None]:
ROOT_DIR = ""

## 1(a) Create a Labeled Dataset

In [None]:
annotations = annotate(
  [os.path.join(ROOT_DIR, "train", i) for  i in os.listdir(os.path.join(ROOT_DIR, "train"))],
  options=[],
  display_fn=lambda filename: display.display(display.Image(filename))
)

In [None]:
labels_df = pd.DataFrame([(i[0].split("/")[-1], i[1]) for i in annotations], columns=[["filename", "label"]])

# 2. Specify Additional Useful Functions (Utils)
- ArcFace Loss
- Transformation function to convert images to RGB

https://www.kaggle.com/code/slawekbiel/arcface-explained

In [None]:
class ArcMarginProduct(nn.Module):
    r"""Implement of large margin arc distance: :
        Args:
            in_features: size of each input sample
            out_features: size of each output sample
            s: norm of input feature
            m: margin
            cos(theta + m)
        """
    def __init__(self, s=30.0, m=0.50, easy_margin=False, ls_eps=0.0, final_loss = FocalLossFlat()):
        super(ArcMarginProduct, self).__init__()
        self.s = s
        self.m = m
        self.ls_eps = ls_eps  # label smoothing
        self.easy_margin = easy_margin
        self.final_loss = final_loss
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, cosine, label):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        sine = torch.sqrt(1.0 - torch.pow(cosine,2)).to(cosine.dtype) #needed for to_fp16()
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        one_hot = torch.zeros(cosine.size(), device=CFG.device)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        if self.ls_eps > 0:
            one_hot = (1 - self.ls_eps) * one_hot + self.ls_eps / self.out_features
        # -------------torch.where(out_i = {x_i if condition_i else y_i) -------------
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s

        return self.final_loss(output, label)

In [None]:
def arcface_loss(cosine, targ, m=.4):
    # this prevents nan when a value slightly crosses 1.0 due to numerical error
    cosine = cosine.clip(-1+1e-7, 1-1e-7) 
    # Step 3:
    arcosine = cosine.arccos()
    # Step 4:
    one_hot = torch.zeros(cosine.size(), device='cuda')
    one_hot.scatter_(1, targ.view(-1, 1).long(), 1)
    arcosine += one_hot * m
    # Step 5:
    cosine2 = arcosine.cos()
    # Step 6:
    return FocalLossFlat()(cosine2, targ)

In [None]:
class CosineClassifier(nn.Module):
    def __init__(self, emb_size, output_classes):
        super(CosineClassifier, self).__init__()
        self.W = nn.Parameter(torch.Tensor(emb_size, output_classes))
        nn.init.kaiming_uniform_(self.W)
    def forward(self, x):
        # Step 1:
        x_norm = F.normalize(x)
        W_norm = F.normalize(self.W, dim=0)
        # Step 2:
        return x_norm @ W_norm

In [None]:
def convert_to_color(img: PILImage):
    np_img = np.array(img)
    if np_img.ndim <3:
        np_img = np.repeat(np_img[:, :, np.newaxis], 3, axis=2)
    return PILImage.create(np_img)

# 2. Set Up Configuration

In [None]:
class CFG:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # Data
    batch_size= 16
    img_size = 384
    shuffle=True
    label_col = "label"
    fn_col = "filename"
    data_path = os.path.join(ROOT_DIR, "train")
    
    # Model
    model_name = 'resnext101_64x4d'
    model_dir = "fastai_models"
    pretrained = True
    num_classes=2
    
    # Training
    epochs = 15
    freeze_epochs=3
    loss_func = BaseLoss(ArcMarginProduct)
    callbacks = [
        EarlyStoppingCallback(monitor="valid_loss", min_delta=0.001, patience=3)
    ]
    metrics = [
        accuracy_multi, F1ScoreMulti
    ]
    num_folds = 4

# 3. Create Data Loaders
- create and check any custom transforms
- specify image augmentation
- add in any validation splits

In [None]:
train, val = train_test_split(labels_df)
train["valid"] = 0
val["valid"] = 1
labels_df = pd.concat([train, val])

In [None]:
dls = ImageDataLoaders.from_df(labels_df, num_workers=0,
                                   label_col= CFG.label_col,
                                   fn_col = CFG.fn_col,
                                   path = CFG.data_path,
                                   bs = CFG.batch_size,
                                   valid_col = "valid",
                                   shuffle = CFG.shuffle,
                                   device = CFG.device,
                                   item_tfms=[convert_to_color, Resize(460)],
                                   batch_tfms=aug_transforms(size=CFG.img_size)
                                  )

In [None]:
dls.show_batch()

# 4. Train the Model
- Create the model
- Create a FastAI Learner
- Find a good learning rate for model
- Fit a model across different folds of the data


In [None]:
class NN_Model(nn.Module):

    def __init__(self,
                 n_classes,
                 model_name=CFG.model_name,
                 pretrained=True):
        
        super(NN_Model, self).__init__()

        model = timm.create_model(model_name, pretrained=pretrained)
        final_in_features = list(model.children())[-1].in_features
        self.backbone = nn.Sequential(*list(model.children())[:-1])
        self.classifier = CosineClassifier(final_in_features, n_classes)
        #self.classifier = nn.Linear(final_in_features, n_classes, bias = True)


    def forward(self, x):
        feature = self.backbone(x)
        return self.classifier(feature)


In [None]:
model = NN_Model(CFG.num_classes, model_name=CFG.model_name, pretrained=CFG.pretrained).to(CFG.device)
learn = Learner(dls, model, loss_func=CFG.loss_func, metrics=accuracy, cbs = CFG.callbacks,
                path=CFG.model_path, model_dir=CFG.model_dir).to_fp16()

In [None]:
learn.lr_find()

In [None]:
learn.fine_tune(epochs=CFG.epochs, base_lr=1e-3, freeze_epochs=CFG.freeze_epochs, 
                    cbs=[SaveModelCallback(monitor="valid_loss", min_delta=0.001, fname=CFG.model_name+"_trained")])


### Check out predictions and losses

In [None]:
learn.load(os.path.join(CFG.model_name+"_trained"))

In [None]:
learn.loss_func = FocalLossFlat()
learn.show_results()

In [None]:
interp = Interpretation.from_learner(learn)

In [None]:
interp.plot_top_losses(9, figsize=(15,10))

# 5. Perform Inference
- Need to have the DataLoader and model already defined.

In [None]:
img_files = [os.path.join(ROOT_DIR, "test", i) for  i in os.listdir(os.path.join(ROOT_DIR, "test"))][:1000]
test_dl = learn.dls.test_dl(img_files)

In [None]:
preds, _ = learn.get_preds(dl=test_dl)

In [None]:
preds = preds.to('cpu').numpy().argmax(axis=1)

In [None]:
def plot_image_grid(images, ncols=None, cmap='gray'):
    '''Plot a grid of images'''
    if not ncols:
        factors = [i for i in range(1, len(images)+1) if len(images) % i == 0]
        ncols = factors[len(factors) // 2] if len(factors) else len(images) // 4 + 1
    nrows = int(len(images) / ncols) + int(len(images) % ncols)
    imgs = [images[i] if len(images) > i else None for i in range(nrows * ncols)]
    f, axes = plt.subplots(nrows, ncols, figsize=(3*ncols, 2*nrows))
    axes = axes.flatten()[:len(imgs)]
    for img, ax in zip(imgs, axes.flatten()): 
        if np.any(img):
            if len(img.shape) > 2 and img.shape[2] == 1:
                img = img.squeeze()
            ax.imshow(img, cmap=cmap)

In [None]:
mil_images = [mpimg.imread(img_files[i]) for i in np.where(preds==0)[0]][:10]

plot_image_grid(mil_images)

In [None]:
civ_images = [mpimg.imread(img_files[i]) for i in np.where(preds==1)[0]][:10]

plot_image_grid(civ_images)