## Imports

In [None]:
from sympy.polys.subresultants_qq_zz import res_q
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import torch
import tqdm
import pandas as pd
from torch import nn
from torch.utils.data import DataLoader
from transformers import AutoImageProcessor, AutoModelForImageClassification

import data
import lib

## Model instantiation

In [None]:
model_id = "google/siglip2-large-patch16-384"  # FixRes вариант
processor = AutoImageProcessor.from_pretrained(model_id)  # даст resize/normalize, mean/std/size
# Веса энкодера + НОВАЯ голова классификации (num_labels=2):
model = AutoModelForImageClassification.from_pretrained(
    model_id,
    num_labels=len(data.species_labels),
    ignore_mismatched_sizes=True,  # создаст новую голову нужного размера
)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.tracking_loss = []
model.tracking_loss_val = []
model.epoch = 0

tracking_loss = model.tracking_loss
tracking_loss_val = model.tracking_loss_val

## Training

In [None]:
# 2) Заморозим всё, кроме головы (линейный пробинг)
for name, p in model.named_parameters():
    p.requires_grad = ("classifier" in name)  # у HF-классификаторов голова обычно называется "classifier"

# Пример: соберите свои train/val списки файлов и меток
train_ds = lib.ForestDataset(data.x_train, data.y_train, processor=processor)
val_ds   = lib.ForestDataset(data.x_eval,   data.y_eval, processor=processor)
train_loader = DataLoader(train_ds, batch_size=256, shuffle=True, num_workers=4)
val_loader   = DataLoader(val_ds,   batch_size=256, shuffle=False, num_workers=4)

# 4) Баланс классов (простой вариант: веса в CrossEntropy по частотам)
# import numpy as np
# counts = np.bincount(train_labels, minlength=2)  # counts[0], counts[1]
# class_weights = torch.tensor((counts.sum() / (2.0 * np.maximum(counts, 1))), dtype=torch.float32, device=device)
# criterion = nn.CrossEntropyLoss(weight=class_weights)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

# 5) Оптимизируем ТОЛЬКО голову
head_params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(head_params, lr=1e-3, weight_decay=0.01)

# 6) Тренировочный цикл (минимальный)

num_epochs = 1

for epoch in range(model.epoch, model.epoch + num_epochs):
    print(f"Starting epoch {epoch}")
    print(f"Training: ")

    model.train()

    loss_acc = 0
    count = 0

    for batch in tqdm.tqdm(train_loader, total=len(train_loader), ):
        optimizer.zero_grad(set_to_none=True)

        out = model(batch['pixel_values'].to('cuda'))              # logits: (B, 2)
        loss = criterion(out.logits, batch["labels"].to('cuda'))

        c = batch['pixel_values'].size(0)
        loss_acc += loss.item() * c
        count += c

        loss.backward()
        optimizer.step()

    tracking_loss.append(loss_acc / count)

    # валидация
    model.eval()

    loss_acc = 0
    count = 0

    with torch.no_grad():
        for batch in tqdm.tqdm(val_loader, total=len(val_loader)):
            out = model(batch['pixel_values'].to('cuda'))
            loss = criterion(out.logits, batch["labels"].to('cuda'))

            c = batch['pixel_values'].size(0)
            loss_acc += loss.item() * c
            count += c

    tracking_loss_val.append(loss_acc / count)

    lib.save_model(model, f"./models_siglip2/model_{str(epoch).rjust(2, "0")}.pth")

    model.epoch += 1

## Training progress

In [None]:
tracking_loss, tracking_loss_val

In [None]:
fig, ax = plt.subplots(figsize=(15, 5))
epochs_train = list(range(len(tracking_loss)))
epochs_val = list(range(len(tracking_loss_val)))
ax.plot(epochs_train, tracking_loss, label="train loss")
ax.plot(epochs_val, tracking_loss_val, label="validation loss", alpha=0.8)
ax.set_xlabel("Epoch (index)")
ax.set_ylabel("Loss")
ax.legend(loc="best")
ax.set_xticks(epochs_train)
ax.grid(True)
fig.tight_layout()


## Validation

In [None]:
def predict(model, data_loader: DataLoader, T = 1):
    preds_collector = []

    # put the model in eval mode so we don't update any parameters
    model.eval()

    model.to(torch.device("cuda"))

    # we aren't updating our weights so no need to calculate gradients
    with torch.no_grad():
        for batch in tqdm.tqdm(data_loader, total=len(data_loader)):
            # 1) run the forward step
            logits = model.forward(batch["pixel_values"].to(torch.device("cuda"))).logits
            # 2) apply softmax so that model outputs are in range [0,1]
            preds = nn.functional.softmax(logits / T, dim=1)
            # 3) store this batch's predictions in df
            # note that PyTorch Tensors need to first be detached from their computational graph before converting to numpy arrays
            preds_df = pd.DataFrame(
                preds.detach().to('cpu').numpy(),
                index=batch["image_id"],
                columns=data.species_labels,
            )
            preds_collector.append(preds_df)

    return pd.concat(preds_collector)

eval_preds_df = predict(model, val_loader)

In [None]:
print("True labels (training):")
data.y_train.idxmax(axis=1).value_counts(normalize=True)

In [None]:
print("Predicted labels (eval):")
eval_preds_df.idxmax(axis=1).value_counts(normalize=True)

In [None]:
print("True labels (eval):")
data.y_eval.idxmax(axis=1).value_counts(normalize=True)

In [None]:
eval_predictions = eval_preds_df.idxmax(axis=1)
eval_true = data.y_eval.idxmax(axis=1)

In [None]:
correct = (eval_predictions == eval_true).sum()
accuracy = correct / len(eval_predictions)
accuracy.item()

### Confusion matrix

In [None]:
from sklearn.metrics import ConfusionMatrixDisplay

fig, ax = plt.subplots(figsize=(10, 10))
cm = ConfusionMatrixDisplay.from_predictions(
    data.y_eval.idxmax(axis=1),
    eval_preds_df.idxmax(axis=1),
    ax=ax,
    xticks_rotation=30,
    colorbar=True,
)

## Create submission

In [None]:
test_dataset = lib.ForestDataset(data.test_features, learning=False)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)

In [None]:
submission_df = predict(model, test_dataloader)

In [None]:
submission_format = pd.read_csv("data/submission_format.csv", index_col="id")

assert all(submission_df.index == submission_format.index)
assert all(submission_df.columns == submission_format.columns)

In [None]:
submission_df.to_csv("submission.csv")