In [None]:
import sys

sys.path.append('..')

from torch.utils.data import DataLoader
import pandas as pd
from pathlib import Path
import numpy as np

from src.dataset import HumanPosesDataset
from sklearn.model_selection import train_test_split

In [None]:
import plotly.io as pio
pio.renderers.default = "browser"

# Датасет

In [None]:
from torchvision import transforms

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.5, 1.0), ratio=(0.75, 1.33)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
    ], p=0.3),
    transforms.RandomApply([
        transforms.RandomAffine(degrees=15, translate=(0.05, 0.05), scale=(0.9, 1.1))
    ], p=0.5),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3), value='random'),
    transforms.Normalize(mean=mean, std=mean)
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=mean)
])


In [None]:
CSV_PATH = Path("../data/human_poses_data/train_answers.csv")
TRAIN_DIR = Path("../data/human_poses_data/img_train")

df = pd.read_csv(CSV_PATH)

train_ids, val_ids = train_test_split(
    df['img_id'].values,
    test_size=0.2,
    stratify=df['target_feature'],
    random_state=42
)

train_df = df[df['img_id'].isin(train_ids)].reset_index(drop=True)
val_df = df[df['img_id'].isin(val_ids)].reset_index(drop=True)

train_dataset = HumanPosesDataset(
    data_df=train_df,
    img_dir=TRAIN_DIR,
    transform=train_transform,
)

val_dataset = HumanPosesDataset(
    data_df=val_df,
    img_dir=TRAIN_DIR,
    transform=val_transform,
)



train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

In [None]:
num_classes = len(np.unique(df['target_feature']))
print(f"Количество классов: {num_classes}")

# Переписанные функции обучения

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, teacher_model, temperature=4.0, alpha=0.7, ce_smoothing=0.1):
        super().__init__()
        self.teacher = teacher_model.eval()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_smoothing = ce_smoothing

    def forward(self, student_logits, _, labels):
        if isinstance(labels, tuple) and len(labels) == 2:
            targets, x_teacher = labels
        else:
            raise ValueError("Expected labels to be a tuple: (targets, x_teacher)")

        x_teacher = x_teacher.to(student_logits.device)
        targets = targets.to(student_logits.device)

        with torch.no_grad():
            teacher_logits = self.teacher(x_teacher)

        T = self.temperature
        soft_teacher = F.softmax(teacher_logits / T, dim=1)
        soft_student = F.log_softmax(student_logits / T, dim=1)
        distill_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T ** 2)

        ce_loss = F.cross_entropy(student_logits, targets, label_smoothing=self.ce_smoothing)
        return self.alpha * distill_loss + (1 - self.alpha) * ce_loss


def distill_batch_augment(images, labels, teacher_transform):
    to_pil = transforms.ToPILImage()
    x_teacher = torch.stack([
        teacher_transform(to_pil(img.cpu())).to(images.device)
        for img in images
    ])

    return images, (labels, x_teacher)

In [None]:
from tqdm import tqdm
from sklearn.metrics import f1_score
import torch
from torch.amp import autocast

def training_epoch(model, optimizer, criterion, train_loader, device, tqdm_desc, batch_augment_fn=None, scheduler=None, scaler=None):
    model.train()
    train_loss = 0.0
    all_preds, all_labels = [], []

    for images, labels in tqdm(train_loader, desc=tqdm_desc):
        images = images.to(device)

        if batch_augment_fn is not None:
            images, labels = batch_augment_fn(images, labels)

        optimizer.zero_grad()

        if scaler is not None:
            with autocast(device_type='cuda'):
                logits = model(images)

                if isinstance(labels, tuple) and len(labels) == 2:
                    targets, x_teacher = labels
                    x_teacher = x_teacher.to(device)
                    loss = criterion(logits, None, (targets, x_teacher))
                    labels_for_f1 = targets
                else:
                    loss = criterion(logits, labels)
                    labels_for_f1 = labels

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(images)

            if isinstance(labels, tuple) and len(labels) == 2:
                targets, x_teacher = labels
                x_teacher = x_teacher.to(device)
                loss = criterion(logits, None, (targets, x_teacher))
                labels_for_f1 = targets
            else:
                loss = criterion(logits, labels)
                labels_for_f1 = labels

            loss.backward()
            optimizer.step()

        if scheduler is not None:
            scheduler.step()

        train_loss += loss.item() * images.size(0)
        all_preds.append(logits.detach().argmax(dim=1).cpu())
        all_labels.append(labels_for_f1.cpu())

    train_loss /= len(train_loader.dataset)
    train_f1 = f1_score(torch.cat(all_labels), torch.cat(all_preds), average='macro')
    return train_loss, train_f1

@torch.no_grad()
def validation_epoch(model, criterion, val_loader, device, tqdm_desc, teacher_model=None, teacher_transform=None):
    model.eval()
    val_loss = 0.0
    all_preds, all_labels = [], []

    to_pil = transforms.ToPILImage()

    for images, labels in tqdm(val_loader, desc=tqdm_desc):
        images = images.to(device)
        labels = labels.to(device)

        if isinstance(criterion, DistillationLoss):
            x_teacher = torch.stack([
                teacher_transform(to_pil(img.cpu())).to(device)
                for img in images
            ])
            logits = model(images)
            loss = criterion(logits, None, (labels, x_teacher))
        else:
            logits = model(images)
            loss = criterion(logits, labels)

        val_loss += loss.item() * images.size(0)
        all_preds.append(logits.argmax(dim=1).cpu())
        all_labels.append(labels.cpu())

    val_loss /= len(val_loader.dataset)
    val_f1 = f1_score(torch.cat(all_labels), torch.cat(all_preds), average='macro')
    return val_loss, val_f1

In [None]:
import os
import torch
import wandb
from src.utils import set_seed, plot_losses


class Trainer:
    def __init__(
        self,
        model,
        train_loader,
        val_loader,
        num_epochs: int,
        optimizer,
        criterion,
        scheduler=None,
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
        experiment_name: str = 'experiment',
        save_dir: str = 'checkpoints',
        use_wandb: bool = False,
        seed: int = 42,
        batch_augment_fn=None,
        scaler = None,
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.num_epochs = num_epochs
        self.device = torch.device(device)
        self.model.to(self.device)
        self.batch_augment_fn = batch_augment_fn

        os.makedirs(save_dir, exist_ok=True)
        self.save_dir = save_dir
        self.experiment_name = experiment_name

        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.scaler = scaler

        from torch.optim.lr_scheduler import OneCycleLR
        self.step_scheduler_per_batch = isinstance(scheduler, OneCycleLR)

        self.best_f1 = 0.0
        self.best_epoch = 0
        self.history = {
            'train_loss': [], 'val_loss': [],
            'train_f1': [], 'val_f1': []
        }

        self.use_wandb = use_wandb
        if use_wandb:
            wandb.init(
                project=experiment_name,
                config={
                    'num_epochs': num_epochs,
                    'optimizer': str(optimizer),
                    'device': device,
                    'criterion': str(criterion),
                    'scheduler': str(scheduler),
                    'seed': seed
                }
            )

        set_seed(seed)

    def save_checkpoint(self, epoch: int, is_best: bool = False):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            'best_f1': self.best_f1,
            'history': self.history
        }

        if is_best:
            path = os.path.join(self.save_dir, f'{self.experiment_name}_best.pth')
        else:
            path = os.path.join(self.save_dir, f'{self.experiment_name}_epoch{epoch}.pth')

        torch.save(checkpoint, path)

    def train(self, start_epoch: int = 1):
        for epoch in range(start_epoch, self.num_epochs + 1):
            print(f"\nEpoch {epoch}/{self.num_epochs}")

            train_loss, train_f1 = training_epoch(
                self.model, self.optimizer, self.criterion,
                self.train_loader, self.device, f"Train {epoch}",
                batch_augment_fn=self.batch_augment_fn,
                scheduler=self.scheduler if self.step_scheduler_per_batch else None,
                scaler = self.scaler if self.scaler else None,
            )

            val_loss, val_f1 = validation_epoch(
                model=student_model,
                criterion=criterion,
                val_loader=val_loader,
                device=device,
                tqdm_desc=f"Val {epoch}",
                teacher_model=teacher_model,
                teacher_transform=val_transform
            )


            if self.scheduler is not None and not self.step_scheduler_per_batch:
                self.scheduler.step()

            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['train_f1'].append(train_f1)
            self.history['val_f1'].append(val_f1)

            metrics = {
                'train/loss': train_loss,
                'train/f1': train_f1,
                'val/loss': val_loss,
                'val/f1': val_f1,
                'epoch': epoch
            }

            if self.use_wandb:
                wandb.log(metrics)

            if val_f1 > self.best_f1:
                self.best_f1 = val_f1
                self.best_epoch = epoch
                self.save_checkpoint(epoch, is_best=True)

            self.save_checkpoint(epoch)

            plot_losses(
                self.history['train_loss'],
                self.history['val_loss'],
                self.history['train_f1'],
                self.history['val_f1'],
                clear=True
            )

        print(f"Training completed. Best Val F1: {self.best_f1:.4f} at epoch {self.best_epoch}")
        return self.history

# Обучение

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {device}")

In [None]:
from src.models.miniconvnext import MiniConvNeXt
from src.models.teacher import ConvNeXtTeacher
from src.utils import load_best_model

student_model = MiniConvNeXt(num_classes=16)
student_model = student_model.to(device)

teacher_model = ConvNeXtTeacher(num_classes=16)
load_best_model(teacher_model, '../best_models/teacher.pth', device)
teacher_model.eval().to(device)

In [None]:
from torch.amp import GradScaler

NUM_EPOCH = 50

optimizer = torch.optim.AdamW(student_model.parameters(), lr=3e-4, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=NUM_EPOCH
)

criterion = DistillationLoss(
    teacher_model=teacher_model,
    temperature=4.0,
    alpha=0.7
)

scaler = GradScaler()


In [None]:
from src.utils import MixupCutMixAugmenter

mixup_cutmix_fn = MixupCutMixAugmenter(alpha=1.0, p_mixup=0.3)

def mixup_then_distill(images, labels):
    images, labels = mixup_cutmix_fn(images, labels)
    return distill_batch_augment(images, labels, val_transform)

trainer = Trainer(
    model=student_model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=NUM_EPOCH,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    batch_augment_fn=lambda x, y: distill_batch_augment(x, y, val_transform),
    experiment_name="distillation_1",
    use_wandb=True,
    seed=42,
    scaler=scaler,
)

history = trainer.train()

In [None]:
from src.utils import load_best_model

load_best_model(student_model, 'checkpoints/distillation_1_best.pth', device)

In [None]:
from torch.amp import GradScaler

NUM_EPOCH = 25

optimizer = torch.optim.AdamW(student_model.parameters(), lr=3e-4, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=NUM_EPOCH
)

criterion = DistillationLoss(
    teacher_model=teacher_model,
    temperature=4.0,
    alpha=0.7
)

scaler = GradScaler()


In [None]:
from src.utils import MixupCutMixAugmenter

trainer = Trainer(
    model=student_model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=NUM_EPOCH,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    batch_augment_fn=lambda x, y: distill_batch_augment(x, y, val_transform),
    experiment_name="distillation_1_next",
    use_wandb=True,
    seed=42,
    scaler=scaler,
)

history = trainer.train()