## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from typing import Literal

import matplotlib.pyplot as plt
import pandas as pd
import torch
import tqdm
from torch import nn
from torch.utils.data import DataLoader
from transformers import AutoImageProcessor, AutoModelForImageClassification

import data
import lib

from lib import predict_siglip

## Model instantiation

In [None]:
restore_checkpoint: bool = False

model_id = "google/siglip2-large-patch16-384"  # FixRes вариант
model_preprocessor = AutoImageProcessor.from_pretrained(model_id)  # даст resize/normalize, mean/std/size

optimizer = None

# upcoming training epoch
epoch = 0

if restore_checkpoint:
    epochs = lib.model_checkpoints(f'./models_siglip2/checkpoint_*.pth')

    if len(epochs) == 0:
        print('no models found')
        raise ValueError('No model found')

    print(f'Loading model from epoch { epochs[ 0 ] }')

    checkpoint = torch.load(f'./models_siglip2/checkpoint_{ epochs[ 0 ] }.pth', weights_only=False)

    model = checkpoint['model']
    optimizer = checkpoint['optimizer']

    epoch = model.epoch + 1
else:
    # Веса энкодера + НОВАЯ голова классификации (num_labels=2):
    model = AutoModelForImageClassification.from_pretrained(
        model_id,
        num_labels=len(data.species_labels),
        ignore_mismatched_sizes=True,  # создаст новую голову нужного размера
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    model.tracking_loss = []
    model.tracking_loss_val = []
    model.tracking_accuracy = []
    model.tracking_val_probs = []
    # the last epoch we finished training on
    model.epoch = None

tracking_loss = model.tracking_loss
tracking_loss_val = model.tracking_loss_val
tracking_accuracy = model.tracking_accuracy
tracking_val_probs = model.tracking_val_probs

## Training

### Data

In [None]:
train_ds = lib.ImageDatasetSigLip2(data.x_train, data.y_train, processor=model_preprocessor, learning=True)
val_ds   = lib.ImageDatasetSigLip2(data.x_eval, data.y_eval, processor=model_preprocessor, learning=False)

train_loader = DataLoader(train_ds, batch_size=384, shuffle=True, num_workers=6)
val_loader   = DataLoader(val_ds,   batch_size=384, shuffle=False, num_workers=6)

### Freezing

In [None]:
unfreezing: Literal['classifier_only', 'classifier_and_encoder', 'all'] = 'classifier_only'

# C) Параметрические группы с «ступенчатым» LR: у головы LR выше, у энкодера ниже
head_params = []
enc_params  = []

if unfreezing == 'classifier_only':
    # 2) Заморозим всё, кроме головы (линейный пробинг)
    for name, p in model.named_parameters():
        p.requires_grad = "classifier" in name  # у HF-классификаторов голова обычно называется "classifier"

        if "classifier" in name:
            head_params.append(p)

elif unfreezing == 'classifier_and_encoder':
    # A) Сначала всё заморозим
    for p in model.parameters():
        p.requires_grad = False
    for name, p in model.named_parameters():
        if "classifier" in name:
            p.requires_grad = True  # голова остаётся обучаемой
            head_params.append(p)

    # B) Разморозим последние L блоков визуального энкодера
    L = 4  # начните с 2–4; при достаточном VRAM можно 6–8
    layers = model.vision_model.encoder.layers   # ModuleList
    for block in layers[-L:]:
        for p in block.parameters():
            p.requires_grad = True
            enc_params.append(p)
elif unfreezing == 'all':
    for p in model.parameters():
        p.requires_grad = True
else:
    raise ValueError(f"Unknown unfreezing mode: {unfreezing}")

### Loss (possibly with weights)

In [None]:
# 4) Баланс классов (простой вариант: веса в CrossEntropy по частотам)
# import numpy as np
# counts = np.bincount(train_labels, minlength=2)  # counts[0], counts[1]
# class_weights = torch.tensor((counts.sum() / (2.0 * np.maximum(counts, 1))), dtype=torch.float32, device=device)
# criterion = nn.CrossEntropyLoss(weight=class_weights)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

### Optimizer

In [None]:
if optimizer is None:
    optimizer = torch.optim.AdamW(
        [
            {"params": enc_params,  "lr": 1e-4, "weight_decay": 0.05},
            {"params": head_params, "lr": 1e-3, "weight_decay": 0.01},
        ]
    )

### Cutmix + mixup

In [None]:
from torchvision.transforms import v2

use_cutmix_mixup = True

cutmix = v2.CutMix(num_classes=len(data.species_labels))
mixup = v2.MixUp(num_classes=len(data.species_labels))
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])

### Loop

In [None]:
num_epochs = 1

for cur_epoch in range(epoch, epoch + num_epochs):
    print(f"Starting epoch {cur_epoch}")

    model.train()

    loss_acc = 0
    count = 0

    for batch in tqdm.tqdm(train_loader, total=len(train_loader), desc='Training'):
        optimizer.zero_grad(set_to_none=True)

        images, labels = batch["pixel_values"].to(torch.device("cuda")), batch["labels"].to(torch.device("cuda"))

        if use_cutmix_mixup:
            images, labels = cutmix_or_mixup(images, labels)

        out = model(images)              # logits: (B, 2)
        loss = criterion(out.logits, labels)

        c = batch['pixel_values'].size(0)
        loss_acc += loss.item() * c
        count += c

        loss.backward()
        optimizer.step()

    tracking_loss.append(loss_acc / count)

    # валидация
    model.eval()

    probs, loss_acc = predict_siglip(
        model, val_loader, accumulate_probs=True, accumulate_loss=True, desc='Validation', columns=data.species_labels, criterion=criterion
    )
    tracking_val_probs.append(probs)
    tracking_loss_val.append(loss_acc)

    eval_predictions = probs.idxmax(axis=1)
    eval_true = data.y_eval.idxmax(axis=1)
    correct = (eval_predictions == eval_true).sum()
    accuracy = correct / len(eval_predictions)
    tracking_accuracy.append(accuracy.item())

    model.epoch = cur_epoch
    lib.save_model(model, optimizer, f"./models_siglip2/checkpoint_{str(cur_epoch).rjust(2, "0")}.pth")

    epoch = cur_epoch + 1


## Training progress

In [None]:
tracking_loss, tracking_loss_val, tracking_accuracy

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))

epochs_train = list(range(len(tracking_loss)))
epochs_val = list(range(len(tracking_loss_val)))

line1, = ax.plot(epochs_train, tracking_loss, label="Train loss")
line2, = ax.plot(epochs_val, tracking_loss_val, label="Validation loss")

ax.set_xlabel("Epoch (index)")
ax.set_ylabel("Loss")
ax.legend(loc="best", handles=[line1, line2])

ax.set_xticks(epochs_train)

ax.grid(True)

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))

epochs_accuracy = list(range(len(tracking_accuracy)))

line1, = ax.plot(epochs_accuracy, tracking_accuracy, label="Accuracy", color="red")
ax.set_ylabel("Accuracy")

ax.legend(loc="best", handles=[line1])

ax.set_xticks(epochs_train)

ax.grid(True)

## Validation

In [None]:
eval_preds_df = tracking_val_probs[-1]

In [None]:
print("True labels (training):")
data.y_train.idxmax(axis=1).value_counts(normalize=True)

In [None]:
print("Predicted labels (eval):")
eval_preds_df.idxmax(axis=1).value_counts(normalize=True)

In [None]:
print("True labels (eval):")
data.y_eval.idxmax(axis=1).value_counts(normalize=True)

In [None]:
eval_predictions = eval_preds_df.idxmax(axis=1)
eval_true = data.y_eval.idxmax(axis=1)

In [None]:
correct = (eval_predictions == eval_true).sum()
accuracy = correct / len(eval_predictions)
accuracy.item()

### Confusion matrix

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay

fig, ax = plt.subplots(figsize=(10, 10))
cm = ConfusionMatrixDisplay.from_predictions(
    data.y_eval.idxmax(axis=1),
    eval_preds_df.idxmax(axis=1),
    ax=ax,
    xticks_rotation=30,
    colorbar=True,
)

## Create submission

In [None]:
test_dataset = lib.ImageDatasetSigLip2(data.test_features, processor=model_preprocessor, learning=False)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=6)

In [None]:
submission_df = predict_siglip(model, test_dataloader)

In [None]:
submission_format = pd.read_csv("data/submission_format.csv", index_col="id")

assert all(submission_df.index == submission_format.index)
assert all(submission_df.columns == submission_format.columns)

In [None]:
submission_df.to_csv("submission.csv")