In [1]:
import os
import cv2
import timm
import torch
import random
import numpy as np
import pandas as pd
import torch.nn as nn
import albumentations as A
import torch.nn.functional as F
import matplotlib.pyplot as plt
from copy import deepcopy
from tqdm import tqdm
from torch.optim import Adam, SGD, AdamW
from torch.utils.data import Dataset, DataLoader, SequentialSampler, RandomSampler

## Legend
## Slava's glass. Pt. 1

Slava had a glass.
But not just any glass — his glass.
It wasn’t dishwasher-safe, it had a chip on the rim and a suspicious stain from the 9th grade — but it brought luck, prestige, and the kind of quiet power you feel when you slam it down after solving a problem no one else even attempted.

Enter Andrey.

— “Lend it to me for a day. I’ll make you two just like it. Better, even. With a non-slip base and maybe a gold rim.”

— “You're serious?”

— “I give you my programmer word.”

Slava trusted him.
Slava was a fool.

The next day:

— “Where’s my glass?”

— “Gone.”

— “Gone where?”

— “To a better place. It's in a safe now.”

— “You said you’d make two more!”

— “Oh, I lied. Let's just say... you lost more than just a glass.”

That’s when Slava realized — Andrey wasn’t just talking about a drinking vessel.
He meant his glass.

That night, Slava broke into Andrey’s apartment.
The safe gleamed in the dark like a physics teacher’s conscience during finals.

Suddenly, the screen lit up with a message:

❗ Want your glass back? Solve this:

You have a neural network that performs well on some classes, but struggles with others.
Luckily, you’ve been given extra data for those underperforming classes.

Your task:
— Fine-tune the model so it handles all classes effectively;
— And make sure it doesn’t forget what it already knows —
(catastrophic forgetting is your enemy).

Slava sat down.
Opened his laptop.
Pressed play on an old mp3 in his headphones.

Al Pacino’s voice came on, raspy and fired up:

“I don’t know what to say, really. Three minutes till the biggest battle of our professional lives all comes down to today...”

Slava took a breath.
— “Alright, glass. Let’s bring you home.”

## Overview

In this competition, you will need to retrain the neural network. The initial neural network that you have been given works well on one part (**GOOD CLASSES**) of the dataset (90 percent accuracy). But on the other part of the dataset (**BAD CLASSES**), the accuracy is about 30 percent. You need to achieve as much accuracy as possible on each part of the datasets. To do this, we will provide you with a small dataset for training (10 samples per class). However, it contains only bad classes. Your task is to retrain the neural network in order to improve the quality of the bad classes and keep the quality of the good classes the same.

## Metric

The harmonic mean between the accuracy of the good classes ($acc_{good}$) and the bad classes ($acc_{bad}$),
$$ SCORE = \sqrt{(acc_{good} \times acc_{bad})} . $$

## Restriction

For generating a submission, you **should use the `make_predict` function**. You can only use data from the folder `train_images` for **fine-tuning the network**.

## Data

* `test_images` - folder with images for creating a submission file.
* `train_images` - folder with images for fine-tuning the network.
* `model.pt` - initial weight of the pretrained network.
* `sample_submission.csv` - example of submission file. \
  Columns:
  * `id` - image filename in `test_images` folder.
  * `class` - class that you predict.
* `train.csv` - dataset with classes for finetuning for bad classes. \
  Columns:
  * `path` - image filename in `train_images` folder.
  * `class` - class of image.

## Create dataframes for training and inference

In [2]:
train = pd.read_csv('train.csv')
train['path'] = [f'train_images/{x}' for x in train['path']]

In [3]:
paths_list = []
main_path = 'test_images'
for path in sorted(os.listdir(main_path)):
    paths_list += [f'{main_path}/{path}']

test = pd.DataFrame({'path': paths_list, 'class': 0})

## Create a model and dataset for training

In [4]:
class ImageDataset(Dataset):
    def __init__(self, paths, targets, transform):
        self.paths = paths
        self.targets = targets
        self.transform = transform

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, item):
        image = cv2.imread(self.paths[item])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        target = self.targets[item]
        image = self.transform(image=image)['image']
        image = image.astype(np.float32) / 255.0 - 0.5
        image = torch.from_numpy(image).permute(2, 0, 1)

        return image, target

In [5]:
class PetNet(nn.Module):
    def __init__(self, model_name, num_classes):
        super().__init__()
        self.model = timm.create_model(model_name, num_classes=num_classes)

    def forward(self, image):
        x = self.model(image)
        return x

In [6]:
def get_train_transforms(dim=224):
    return A.Compose([
        A.LongestMaxSize(max_size=dim, p=1.0),
        A.PadIfNeeded(dim, dim, p=1.0),
        A.HorizontalFlip(p=0.5)
    ])

In [7]:
def get_test_transforms(dim=224):
    return A.Compose([
        A.LongestMaxSize(max_size=dim, p=1.0),
        A.PadIfNeeded(dim, dim, p=1.0)
    ])

In [8]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

## Train loop

In [9]:
seed_everything(230)

dim = 224
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 32
lr = 3e-4
num_epochs = 20
alpha = 0.1
temperature = 2
clip_grad_norm = 5

In [10]:
train_dataset = ImageDataset(train['path'], train['class'], get_train_transforms(dim))
test_dataset = ImageDataset(test['path'], test['class'], get_test_transforms(dim))

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [11]:
model = PetNet('tiny_vit_5m_224.dist_in22k_ft_in1k', num_classes=102).to(device)
model_dict = torch.load('model.pt', map_location=device, weights_only=False)
model.load_state_dict(model_dict, strict=False)

teacher = deepcopy(model)
teacher.eval();

In [12]:
model.model.patch_embed.requires_grad = False
model.model.stages[0].requires_grad = False

The total loss used in **Learning without Forgetting (LwF)** is a combination of the loss on the new task and the distillation loss on the old tasks,

$$
L = \alpha L_{\text{new}} + (1-\alpha) \cdot L_{\text{KD}} .
$$

where:
- $L_{\text{new}}$: Cross-entropy loss on the new task.
- $L_{\text{KD}}$: Knowledge distillation loss (typically KL divergence) to retain old task performance.
- $\alpha$: A hyperparameter to balance the two losses.

The knowledge distillation loss is defined as:

$$
L_{\text{KD}} = \sum_{i} \text{KL} \left( \sigma\left(z^{\text{old}}_i / T\right) \, \| \, \sigma\left(z^{\text{new}}_i / T\right) \right)
$$

where:
- $z^{\text{old}}_i$: Logits from the old model for sample $i$.
- $z^{\text{new}}_i$: Logits from the new model for the same sample $i$.
- $\sigma$: Softmax function.
- $T$: Temperature parameter to soften probabilities.

In [13]:
def criterion(inputs, labels, student_logits):
    with torch.no_grad():
        teacher_logits = teacher(inputs)

    student_soft = F.log_softmax(student_logits / temperature, dim=1)
    teacher_soft = F.softmax(teacher_logits / temperature, dim=1)

    ce_loss = F.cross_entropy(student_logits, labels)
    kd_loss = F.kl_div(student_soft, teacher_soft, reduction='batchmean')

    total_loss = alpha * ce_loss + (1 - alpha) * kd_loss * temperature ** 2
    return total_loss

In [14]:
optimizer = AdamW(model.parameters(), lr=lr)
scaler = torch.amp.GradScaler(device)

In [15]:
model.train()
for epoch in range(num_epochs):
    average_loss = 0
    correct_preds = 0
    total_preds = 0

    tk0 = tqdm(enumerate(train_loader), total=len(train_loader))
    for batch_number, (inputs, labels) in tk0:
        optimizer.zero_grad()
        inputs, labels = inputs.to(device), labels.long().to(device)

        with torch.amp.autocast(device):
            y_preds = model(inputs)
            loss = criterion(inputs, labels, y_preds)
        scaler.scale(loss).backward()

        if clip_grad_norm > 0:
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), clip_grad_norm)
        scaler.step(optimizer)
        scaler.update()

        average_loss += loss.cpu().detach().numpy()

        preds = torch.argmax(y_preds, dim=1)
        correct_preds += (preds == labels).sum().item()
        total_preds += labels.size(0)

        tk0.set_postfix(loss=average_loss / (batch_number + 1), acc=correct_preds / total_preds, stage='train', epoch=epoch)

100%|████████████████████████████████████████| 8/8 [00:02<00:00,  2.84it/s, acc=0.17, epoch=0, loss=0.842, stage=train]
100%|███████████████████████████████████████| 8/8 [00:01<00:00,  5.38it/s, acc=0.187, epoch=1, loss=0.405, stage=train]
100%|███████████████████████████████████████| 8/8 [00:01<00:00,  5.46it/s, acc=0.304, epoch=2, loss=0.352, stage=train]
100%|███████████████████████████████████████| 8/8 [00:01<00:00,  5.60it/s, acc=0.439, epoch=3, loss=0.324, stage=train]
100%|███████████████████████████████████████| 8/8 [00:01<00:00,  5.62it/s, acc=0.578, epoch=4, loss=0.307, stage=train]
100%|███████████████████████████████████████| 8/8 [00:01<00:00,  5.61it/s, acc=0.643, epoch=5, loss=0.297, stage=train]
100%|███████████████████████████████████████| 8/8 [00:01<00:00,  5.59it/s, acc=0.683, epoch=6, loss=0.289, stage=train]
100%|███████████████████████████████████████| 8/8 [00:01<00:00,  5.60it/s, acc=0.743, epoch=7, loss=0.281, stage=train]
100%|███████████████████████████████████

## Inference function

**You cannot change this function.**

In [16]:
def make_predict(state_dict, test_loader, name_csv='submission.csv', test_ids=None):
    if test_ids is None:
        test_ids = [x.split('/')[-1] for x in test['path']]
    
    preds = []
    len_loader = len(test_loader)
    tk0 = tqdm(enumerate(test_loader), total=len_loader)
    average_loss = 0
    model = timm.create_model('tiny_vit_5m_224.dist_in22k_ft_in1k', num_classes=102)
    model.cuda().eval()
    model.load_state_dict(state_dict)
    
    with torch.no_grad():
        for batch_number, (inputs, labels) in tk0:
            inputs = inputs.cuda()
            labels = labels.cuda().long()
    
            with torch.amp.autocast('cuda'):
                y_preds = model(inputs)
    
            preds += [y_preds.to('cpu').numpy()]
    
    preds = np.concatenate(preds)

    model.train()

    submission = pd.DataFrame()
    submission['id'] = test_ids
    submission['class'] = np.argmax(preds, 1)
    submission.to_csv(name_csv, index=None)

In [17]:
make_predict(model.model.state_dict(), test_loader, 'submission.csv')

100%|████████████████████████████████████████████████████████████████████████████████| 160/160 [00:18<00:00,  8.57it/s]


## Score

- Private: 0.65339
- Public: 0.64527

> Baseline:
> - Private: 0.55382
> - Public: 0.52975
