## Imports

In [10]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
from typing import Literal

import matplotlib.pyplot as plt
import pandas as pd
import torch
import tqdm
from torch import nn
from torch.utils.data import DataLoader
from transformers import AutoImageProcessor, AutoModelForImageClassification

import data
import lib

## Model instantiation

In [12]:
restore_checkpoint: bool = False

model_id = "google/siglip2-large-patch16-384"  # FixRes вариант
model_preprocessor = AutoImageProcessor.from_pretrained(model_id)  # даст resize/normalize, mean/std/size

optimizer = None

if restore_checkpoint:
    epochs = lib.model_checkpoints(f'./models_siglip2/checkpoint_*.pth')

    if len(epochs) == 0:
        print('no models found')
        raise ValueError('No model found')

    print(f'Loading model from epoch { epochs[ 0 ] }')

    checkpoint = torch.load(f'./models_siglip2/checkpoint_{ epochs[ 0 ] }.pth', weights_only=False)

    model = checkpoint['model']
    optimizer = checkpoint['optimizer']
else:
    # Веса энкодера + НОВАЯ голова классификации (num_labels=2):
    model = AutoModelForImageClassification.from_pretrained(
        model_id,
        num_labels=len(data.species_labels),
        ignore_mismatched_sizes=True,  # создаст новую голову нужного размера
    )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    model.tracking_loss = []
    model.tracking_loss_val = []
    model.epoch = 0

tracking_loss = model.tracking_loss
tracking_loss_val = model.tracking_loss_val

Fetching 1 files:   0%|          | 0/1 [00:00<?, ?it/s]

Some weights of SiglipForImageClassification were not initialized from the model checkpoint at google/siglip2-large-patch16-384 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Training

### Data

In [13]:
train_ds = lib.ImageDatasetSigLip2(data.x_train, data.y_train, processor=model_preprocessor, learning=True)
val_ds   = lib.ImageDatasetSigLip2(data.x_eval, data.y_eval, processor=model_preprocessor, learning=False)

train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_ds,   batch_size=256, shuffle=False, num_workers=4)

### Freezing

In [14]:
unfreezing: Literal['classifier_only', 'classifier_and_encoder', 'all'] = 'classifier_only'

# C) Параметрические группы с «ступенчатым» LR: у головы LR выше, у энкодера ниже
head_params = []
enc_params  = []

if unfreezing == 'classifier_only':
    # 2) Заморозим всё, кроме головы (линейный пробинг)
    for name, p in model.named_parameters():
        p.requires_grad = "classifier" in name  # у HF-классификаторов голова обычно называется "classifier"

        if "classifier" in name:
            head_params.append(p)

elif unfreezing == 'classifier_and_encoder':
    # A) Сначала всё заморозим
    for p in model.parameters():
        p.requires_grad = False
    for name, p in model.named_parameters():
        if "classifier" in name:
            p.requires_grad = True  # голова остаётся обучаемой
            head_params.append(p)

    # B) Разморозим последние L блоков визуального энкодера
    L = 4  # начните с 2–4; при достаточном VRAM можно 6–8
    layers = model.vision_model.encoder.layers   # ModuleList
    for block in layers[-L:]:
        for p in block.parameters():
            p.requires_grad = True
            enc_params.append(p)
elif unfreezing == 'all':
    for p in model.parameters():
        p.requires_grad = True
else:
    raise ValueError(f"Unknown unfreezing mode: {unfreezing}")

### Loss (possibly with weights)

In [15]:
# 4) Баланс классов (простой вариант: веса в CrossEntropy по частотам)
# import numpy as np
# counts = np.bincount(train_labels, minlength=2)  # counts[0], counts[1]
# class_weights = torch.tensor((counts.sum() / (2.0 * np.maximum(counts, 1))), dtype=torch.float32, device=device)
# criterion = nn.CrossEntropyLoss(weight=class_weights)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

### Optimizer

In [16]:
if optimizer is None:
    optimizer = torch.optim.AdamW(
        [
            {"params": enc_params,  "lr": 1e-4, "weight_decay": 0.05},
            {"params": head_params, "lr": 1e-3, "weight_decay": 0.01},
        ]
    )

### Loop

In [None]:
num_epochs = 1

for epoch in range(model.epoch, model.epoch + num_epochs):
    print(f"Starting epoch {epoch}")
    print(f"Training: ")

    model.train()

    loss_acc = 0
    count = 0

    for batch in tqdm.tqdm(train_loader, total=len(train_loader), ):
        optimizer.zero_grad(set_to_none=True)

        out = model(batch['pixel_values'].to('cuda'))              # logits: (B, 2)
        loss = criterion(out.logits, batch["labels"].to('cuda'))

        c = batch['pixel_values'].size(0)
        loss_acc += loss.item() * c
        count += c

        loss.backward()
        optimizer.step()

    tracking_loss.append(loss_acc / count)

    # валидация
    model.eval()

    loss_acc = 0
    count = 0

    with torch.no_grad():
        for batch in tqdm.tqdm(val_loader, total=len(val_loader)):
            out = model(batch['pixel_values'].to('cuda'))
            loss = criterion(out.logits, batch["labels"].to('cuda'))

            c = batch['pixel_values'].size(0)
            loss_acc += loss.item() * c
            count += c

    tracking_loss_val.append(loss_acc / count)

    lib.save_model(model, optimizer, f"./models_siglip2/checkpoint_{str(epoch).rjust(2, "0")}.pth")

    model.epoch += 1

Starting epoch 0
Training: 


  return self.preprocess(images, **kwargs)
  return self.preprocess(images, **kwargs)
  return self.preprocess(images, **kwargs)
  return self.preprocess(images, **kwargs)
 10%|▉         | 5/52 [00:34<04:50,  6.18s/it]

## Training progress

In [None]:
tracking_loss, tracking_loss_val

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))
epochs_train = list(range(len(tracking_loss)))
epochs_val = list(range(len(tracking_loss_val)))
ax.plot(epochs_train, tracking_loss, label="train loss")
ax.plot(epochs_val, tracking_loss_val, label="validation loss", alpha=0.8)
ax.set_xlabel("Epoch (index)")
ax.set_ylabel("Loss")
ax.legend(loc="best")
ax.set_xticks(epochs_train)
ax.grid(True)
fig.tight_layout()


## Validation

In [None]:
def predict(model, data_loader: DataLoader, T = 1):
    preds_collector = []

    # put the model in eval mode so we don't update any parameters
    model.eval()

    model.to(torch.device("cuda"))

    # we aren't updating our weights so no need to calculate gradients
    with torch.no_grad():
        for batch in tqdm.tqdm(data_loader, total=len(data_loader)):
            # 1) run the forward step
            logits = model.forward(batch["pixel_values"].to(torch.device("cuda"))).logits
            # 2) apply softmax so that model outputs are in range [0,1]
            preds = nn.functional.softmax(logits / T, dim=1)
            # 3) store this batch's predictions in df
            # note that PyTorch Tensors need to first be detached from their computational graph before converting to numpy arrays
            preds_df = pd.DataFrame(
                preds.detach().to('cpu').numpy(),
                index=batch["image_id"],
                columns=data.species_labels,
            )
            preds_collector.append(preds_df)

    return pd.concat(preds_collector)

In [None]:
eval_preds_df = predict(model, val_loader)

In [None]:
print("True labels (training):")
data.y_train.idxmax(axis=1).value_counts(normalize=True)

In [None]:
print("Predicted labels (eval):")
eval_preds_df.idxmax(axis=1).value_counts(normalize=True)

In [None]:
print("True labels (eval):")
data.y_eval.idxmax(axis=1).value_counts(normalize=True)

In [None]:
eval_predictions = eval_preds_df.idxmax(axis=1)
eval_true = data.y_eval.idxmax(axis=1)

In [None]:
correct = (eval_predictions == eval_true).sum()
accuracy = correct / len(eval_predictions)
accuracy.item()

### Confusion matrix

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay

fig, ax = plt.subplots(figsize=(10, 10))
cm = ConfusionMatrixDisplay.from_predictions(
    data.y_eval.idxmax(axis=1),
    eval_preds_df.idxmax(axis=1),
    ax=ax,
    xticks_rotation=30,
    colorbar=True,
)

## Create submission

In [None]:
test_dataset = lib.ImageDatasetSigLip2(data.test_features, learning=False)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

In [None]:
submission_df = predict(model, test_dataloader)

In [None]:
submission_format = pd.read_csv("data/submission_format.csv", index_col="id")

assert all(submission_df.index == submission_format.index)
assert all(submission_df.columns == submission_format.columns)

In [None]:
submission_df.to_csv("submission.csv")