In [1]:
import sys

sys.path.append('..')

from torch.utils.data import DataLoader
import pandas as pd
from pathlib import Path
import numpy as np

from src.dataset import HumanPosesDataset
from sklearn.model_selection import train_test_split

In [2]:
import plotly.io as pio
pio.renderers.default = "browser"

# Датасет

In [3]:
from torchvision import transforms

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.5, 1.0), ratio=(0.75, 1.33)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
    ], p=0.3),
    transforms.RandomApply([
        transforms.RandomAffine(degrees=15, translate=(0.05, 0.05), scale=(0.9, 1.1))
    ], p=0.5),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3), value='random'),
    transforms.Normalize(mean=mean, std=std)
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])


In [4]:
CSV_PATH = Path("../data/human_poses_data/train_answers.csv")
TRAIN_DIR = Path("../data/human_poses_data/img_train")

df = pd.read_csv(CSV_PATH)

train_ids, val_ids = train_test_split(
    df['img_id'].values,
    test_size=0.2,
    stratify=df['target_feature'],
    random_state=42
)

train_df = df[df['img_id'].isin(train_ids)].reset_index(drop=True)
val_df = df[df['img_id'].isin(val_ids)].reset_index(drop=True)

train_dataset = HumanPosesDataset(
    data_df=train_df,
    img_dir=TRAIN_DIR,
    transform=train_transform,
)

val_dataset = HumanPosesDataset(
    data_df=val_df,
    img_dir=TRAIN_DIR,
    transform=val_transform,
)



train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

Train dataset size: 9893
Validation dataset size: 2474


In [5]:
num_classes = len(np.unique(df['target_feature']))
print(f"Количество классов: {num_classes}")

Количество классов: 16


# Переписанные функции обучения

In [6]:
import random
import torch.nn as nn
import torch.nn.functional as F

def cross_entropy_with_smoothing(logits, targets, smoothing=0.0):
    confidence = 1.0 - smoothing
    log_probs = F.log_softmax(logits, dim=-1)
    n_classes = logits.size(-1)

    true_dist = torch.zeros_like(log_probs)
    true_dist.fill_(smoothing / (n_classes - 1))
    true_dist.scatter_(1, targets.unsqueeze(1), confidence)

    return F.cross_entropy(logits, targets, label_smoothing=smoothing, reduction='none')



class DistillationLoss(nn.Module):
    def __init__(self, teacher_model, temperature=4.0, alpha=0.7, ce_smoothing=0.1):
        super().__init__()
        self.teacher = teacher_model.eval()
        self.temperature = temperature
        self.alpha = alpha
        self.ce_smoothing = ce_smoothing

    def forward(self, student_logits, _, labels):
        if not (isinstance(labels, tuple) and len(labels) == 2):
            raise ValueError("Expected labels to be ((targets or (y_a, y_b, lam)), x_teacher)")

        inner_labels, x_teacher = labels
        x_teacher = x_teacher.to(student_logits.device)

        with torch.no_grad():
            teacher_logits = self.teacher(x_teacher)

        T = self.temperature
        soft_teacher = F.softmax(teacher_logits / T, dim=1)
        soft_student = F.log_softmax(student_logits / T, dim=1)
        distill_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T ** 2)


        if isinstance(inner_labels, tuple) and len(inner_labels) == 3:
            y_a, y_b, lam = inner_labels
            y_a = y_a.to(student_logits.device)
            y_b = y_b.to(student_logits.device)

            ce_loss_a = cross_entropy_with_smoothing(student_logits, y_a, self.ce_smoothing)
            ce_loss_b = cross_entropy_with_smoothing(student_logits, y_b, self.ce_smoothing)
            ce_loss = lam * ce_loss_a + (1 - lam) * ce_loss_b
            ce_loss = ce_loss.mean()

            labels_for_f1 = y_a
        else:
            targets = inner_labels.to(student_logits.device)
            ce_loss = cross_entropy_with_smoothing(student_logits, targets, self.ce_smoothing)
            ce_loss = ce_loss.mean()
            labels_for_f1 = targets

        total_loss = self.alpha * distill_loss + (1 - self.alpha) * ce_loss
        return total_loss, labels_for_f1

def distill_batch_augment(images, labels, teacher_transform):
    to_pil = transforms.ToPILImage()
    x_teacher = torch.stack([
        teacher_transform(to_pil(img.cpu())).to(images.device)
        for img in images
    ])

    return images, (labels, x_teacher)


class MixupCutMixAugmenter:
    def __init__(self, alpha=1.0, p_mixup=0.5, p=1.0):
        self.alpha = alpha
        self.p_mixup = p_mixup
        self.p = p

    def __call__(self, x, y):
        if random.random() > self.p:
            return x, y

        y = y.to(x.device)
        if random.random() < self.p_mixup:
            return self.mixup(x, y)
        else:
            return self.cutmix(x, y)


    def mixup(self, x, y):
        lam = np.random.beta(self.alpha, self.alpha)
        batch_size = x.size(0)
        index = torch.randperm(batch_size).to(x.device)

        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]
        return mixed_x, (y_a, y_b, lam)

    def cutmix(self, x, y):
        lam = np.random.beta(self.alpha, self.alpha)
        batch_size, _, height, width = x.size()
        index = torch.randperm(batch_size).to(x.device)

        cut_rat = np.sqrt(1. - lam)
        cut_w = int(width * cut_rat)
        cut_h = int(height * cut_rat)

        cx = random.randint(0, width)
        cy = random.randint(0, height)

        bbx1 = np.clip(cx - cut_w // 2, 0, width)
        bby1 = np.clip(cy - cut_h // 2, 0, height)
        bbx2 = np.clip(cx + cut_w // 2, 0, width)
        bby2 = np.clip(cy + cut_h // 2, 0, height)

        x[:, :, bby1:bby2, bbx1:bbx2] = x[index, :, bby1:bby2, bbx1:bbx2]
        y_a, y_b = y, y[index]
        return x, (y_a, y_b, lam)

def distill_mixupcutmix_augment(mixcut_fn, teacher_transform):
    def wrapper(images, labels):
        images, labels = mixcut_fn(images, labels)
        to_pil = transforms.ToPILImage()
        x_teacher = torch.stack([
            teacher_transform(to_pil(img.cpu())).to(images.device)
            for img in images
        ])
        return images, (labels, x_teacher)
    return wrapper


In [7]:
from tqdm import tqdm
from sklearn.metrics import f1_score
import torch
from torch.amp import autocast
from torchvision import transforms

def training_epoch(model, optimizer, criterion, train_loader, device, tqdm_desc, batch_augment_fn=None, scheduler=None, scaler=None):
    model.train()
    train_loss = 0.0
    all_preds, all_labels = [], []

    is_onecycle = isinstance(scheduler, torch.optim.lr_scheduler.OneCycleLR)

    step_scheduler_per_batch = isinstance(scheduler, OneCycleLR)

    for images, labels in tqdm(train_loader, desc=tqdm_desc):
        images = images.to(device)

        if batch_augment_fn is not None:
            images, labels = batch_augment_fn(images, labels)

        optimizer.zero_grad()

        if scaler is not None:
            with autocast(device_type='cuda'):
                logits = model(images)
                if isinstance(criterion, DistillationLoss):
                    loss, labels_for_f1 = criterion(logits, None, labels)
                else:
                    if isinstance(labels, tuple) and len(labels) == 3:
                        y_a, y_b, lam = labels
                        loss = lam * criterion(logits, y_a) + (1 - lam) * criterion(logits, y_b)
                        labels_for_f1 = y_a
                    else:
                        loss = criterion(logits, labels)
                        labels_for_f1 = labels

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            if is_onecycle and scheduler is not None:
                scheduler.step()


        else:
            logits = model(images)
            if isinstance(criterion, DistillationLoss):
                loss, labels_for_f1 = criterion(logits, None, labels)
            else:
                if isinstance(labels, tuple) and len(labels) == 3:
                    y_a, y_b, lam = labels
                    loss = lam * criterion(logits, y_a) + (1 - lam) * criterion(logits, y_b)
                    labels_for_f1 = y_a
                else:
                    loss = criterion(logits, labels)
                    labels_for_f1 = labels

            loss.backward()

            if scheduler is not None and step_scheduler_per_batch:
                scheduler.step()
            optimizer.step()

        train_loss += loss.item() * images.size(0)
        all_preds.append(logits.detach().argmax(dim=1).cpu())
        all_labels.append(labels_for_f1.cpu())

    train_loss /= len(train_loader.dataset)
    train_f1 = f1_score(torch.cat(all_labels), torch.cat(all_preds), average='macro')
    return train_loss, train_f1


@torch.no_grad()
def validation_epoch(model, criterion, val_loader, device, tqdm_desc, teacher_model=None, teacher_transform=None):
    model.eval()
    val_loss = 0.0
    all_preds, all_labels = [], []

    to_pil = transforms.ToPILImage()

    for images, labels in tqdm(val_loader, desc=tqdm_desc):
        images = images.to(device)
        labels = labels.to(device)

        if isinstance(criterion, DistillationLoss):
            x_teacher = torch.stack([
                val_transform(to_pil(img.cpu())).to(device)
                for img in images
            ])
            logits = model(images)
            loss, _ = criterion(logits, None, (labels, x_teacher))
        else:
            logits = model(images)
            loss = criterion(logits, labels)

        val_loss += loss.item() * images.size(0)
        all_preds.append(logits.argmax(dim=1).cpu())
        all_labels.append(labels.cpu())

    val_loss /= len(val_loader.dataset)
    val_f1 = f1_score(torch.cat(all_labels), torch.cat(all_preds), average='macro')
    return val_loss, val_f1


In [8]:
import os
import torch
import wandb
from src.utils import set_seed, plot_losses
from torch.optim.lr_scheduler import OneCycleLR


class Trainer:
    def __init__(
        self,
        model,
        train_loader,
        val_loader,
        num_epochs: int,
        optimizer,
        criterion,
        scheduler=None,
        device: str = 'cuda' if torch.cuda.is_available() else 'cpu',
        experiment_name: str = 'experiment',
        save_dir: str = 'checkpoints',
        use_wandb: bool = False,
        seed: int = 42,
        batch_augment_fn=None,
        scaler=None,
    ):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.num_epochs = num_epochs
        self.device = torch.device(device)
        self.model.to(self.device)
        self.batch_augment_fn = batch_augment_fn

        os.makedirs(save_dir, exist_ok=True)
        self.save_dir = save_dir
        self.experiment_name = experiment_name

        self.criterion = criterion
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.scaler = scaler

        self.step_scheduler_per_batch = isinstance(scheduler, OneCycleLR)

        self.best_f1 = 0.0
        self.best_epoch = 0
        self.history = {
            'train_loss': [], 'val_loss': [],
            'train_f1': [], 'val_f1': []
        }

        self.use_wandb = use_wandb
        if use_wandb:
            wandb.init(
                project=experiment_name,
                config={
                    'num_epochs': num_epochs,
                    'optimizer': str(optimizer),
                    'device': device,
                    'criterion': str(criterion),
                    'scheduler': str(scheduler),
                    'seed': seed
                }
            )

        set_seed(seed)

    def save_checkpoint(self, epoch: int, is_best: bool = False):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
            'best_f1': self.best_f1,
            'history': self.history
        }

        if is_best:
            path = os.path.join(self.save_dir, f'{self.experiment_name}_best.pth')
        else:
            path = os.path.join(self.save_dir, f'{self.experiment_name}_epoch{epoch}.pth')

        torch.save(checkpoint, path)

    def train(self, start_epoch: int = 1):
        for epoch in range(start_epoch, self.num_epochs + 1):
            print(f"\nEpoch {epoch}/{self.num_epochs}")

            train_loss, train_f1 = training_epoch(
                self.model, self.optimizer, self.criterion,
                self.train_loader, self.device, f"Train {epoch}",
                batch_augment_fn=self.batch_augment_fn,
                scheduler=self.scheduler if self.step_scheduler_per_batch else None,
                scaler=self.scaler if self.scaler else None,
            )

            val_loss, val_f1 = validation_epoch(
                model=self.model,
                criterion=self.criterion,
                val_loader=self.val_loader,
                device=self.device,
                tqdm_desc=f"Val {epoch}",
                teacher_model=None,
                teacher_transform=None
            )

            if self.scheduler is not None and not self.step_scheduler_per_batch:
                self.scheduler.step()

            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['train_f1'].append(train_f1)
            self.history['val_f1'].append(val_f1)

            metrics = {
                'train/loss': train_loss,
                'train/f1': train_f1,
                'val/loss': val_loss,
                'val/f1': val_f1,
                'epoch': epoch
            }

            if self.scheduler:
                metrics['lr'] = self.optimizer.param_groups[0]['lr']


            if self.use_wandb:
                wandb.log(metrics)

            if val_f1 > self.best_f1:
                self.best_f1 = val_f1
                self.best_epoch = epoch
                self.save_checkpoint(epoch, is_best=True)

            self.save_checkpoint(epoch)

            plot_losses(
                self.history['train_loss'],
                self.history['val_loss'],
                self.history['train_f1'],
                self.history['val_f1'],
                clear=True
            )

        print(f"Training completed. Best Val F1: {self.best_f1:.4f} at epoch {self.best_epoch}")
        return self.history

# Обучение

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {device}")

✅ Using device: cuda


In [10]:
from src.models.teacher import ConvNeXtTeacher
from src.utils import load_best_model
from src.models.miniconvnext import MiniConvNeXt

student_model = MiniConvNeXt(num_classes=16).to(device)
student_model = student_model.to(device)

teacher_model = ConvNeXtTeacher(num_classes=16)
load_best_model(teacher_model, '../best_models/teacher.pth', device)
teacher_model.eval().to(device)


IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html


Mapping deprecated model name convnext_large_in22k to current convnext_large.fb_in22k.



✅ Loaded model weights from ../best_models/teacher.pth


ConvNeXtTeacher(
  (backbone): ConvNeXt(
    (stem): Sequential(
      (0): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((192,), eps=1e-06, elementwise_affine=True)
    )
    (stages): Sequential(
      (0): ConvNeXtStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): ConvNeXtBlock(
            (conv_dw): Conv2d(192, 192, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=192)
            (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=192, out_features=768, bias=True)
              (act): GELU()
              (drop1): Dropout(p=0.0, inplace=False)
              (norm): Identity()
              (fc2): Linear(in_features=768, out_features=192, bias=True)
              (drop2): Dropout(p=0.0, inplace=False)
            )
            (shortcut): Identity()
            (drop_path): Identity()
          )
          (1): ConvNeXtBlock(
            (

In [11]:
from torch.amp import GradScaler

NUM_EPOCH = 45

optimizer = torch.optim.AdamW(student_model.parameters(), lr=3e-4, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=3e-4,
    steps_per_epoch=len(train_loader),
    epochs=NUM_EPOCH,
    pct_start=0.1,
    anneal_strategy='cos',
    div_factor=10.0,
    final_div_factor=500
)

criterion = DistillationLoss(
    teacher_model=teacher_model,
    temperature=4.0,
    alpha=0.7,
    ce_smoothing=0.1
)

scaler = GradScaler()

In [12]:
augment_fn = distill_mixupcutmix_augment(
    mixcut_fn = MixupCutMixAugmenter(alpha=0.5, p_mixup=0.3, p=0.5),
    teacher_transform = val_transform
)


trainer = Trainer(
    model=student_model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=NUM_EPOCH,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    batch_augment_fn=augment_fn,
    experiment_name="distillation_1",
    use_wandb=True,
    seed=42,
    scaler=scaler,
)

history = trainer.train()




In [11]:
from src.utils import load_best_model

load_best_model(student_model, 'checkpoints/distillation_1_best.pth', device)

✅ Loaded model weights from checkpoints/distillation_1_best.pth


MiniConvNeXt(
  (downsample_layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d(
        (norm): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
      )
    )
    (1): Sequential(
      (0): LayerNorm2d(
        (norm): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
      )
      (1): Conv2d(64, 128, kernel_size=(2, 2), stride=(2, 2))
    )
    (2): Sequential(
      (0): LayerNorm2d(
        (norm): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      )
      (1): Conv2d(128, 256, kernel_size=(2, 2), stride=(2, 2))
    )
  )
  (stages): ModuleList(
    (0): Sequential(
      (0): ConvNeXtBlock(
        (dwconv): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64)
        (norm): LayerNorm2d(
          (norm): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        )
        (pwconv1): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (act): GELU(approximate

In [12]:
from torch.amp import GradScaler

NUM_EPOCH = 75

optimizer = torch.optim.AdamW(student_model.parameters(), lr=3e-4, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=NUM_EPOCH,
    eta_min=1e-6,
)

criterion = DistillationLoss(
    teacher_model=teacher_model,
    temperature=4.0,
    alpha=0.7,
    ce_smoothing=0.1
)

scaler = GradScaler()

In [13]:
augment_fn = distill_mixupcutmix_augment(
    mixcut_fn = MixupCutMixAugmenter(alpha=0.4, p_mixup=0.3, p=0.5),
    teacher_transform = val_transform
)


trainer = Trainer(
    model=student_model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=NUM_EPOCH,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    batch_augment_fn=augment_fn,
    experiment_name="distillation_1_2",
    use_wandb=True,
    seed=42,
    scaler=scaler,
)

history = trainer.train()




In [17]:
load_best_model(student_model, 'checkpoints/distillation_1_2_best.pth', device)

✅ Loaded model weights from checkpoints/distillation_1_2_best.pth


MiniConvNeXt(
  (downsample_layers): ModuleList(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d(
        (norm): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
      )
    )
    (1): Sequential(
      (0): LayerNorm2d(
        (norm): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
      )
      (1): Conv2d(64, 128, kernel_size=(2, 2), stride=(2, 2))
    )
    (2): Sequential(
      (0): LayerNorm2d(
        (norm): LayerNorm((128,), eps=1e-06, elementwise_affine=True)
      )
      (1): Conv2d(128, 256, kernel_size=(2, 2), stride=(2, 2))
    )
  )
  (stages): ModuleList(
    (0): Sequential(
      (0): ConvNeXtBlock(
        (dwconv): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64)
        (norm): LayerNorm2d(
          (norm): LayerNorm((64,), eps=1e-06, elementwise_affine=True)
        )
        (pwconv1): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1))
        (act): GELU(approximate

In [18]:
from torch.amp import GradScaler

NUM_EPOCH = 75

optimizer = torch.optim.AdamW(student_model.parameters(), lr=3e-4, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=NUM_EPOCH,
    eta_min=1e-7,
)

criterion = DistillationLoss(
    teacher_model=teacher_model,
    temperature=4.0,
    alpha=0.7,
    ce_smoothing=0.1
)

scaler = GradScaler()

In [19]:
augment_fn = distill_mixupcutmix_augment(
    mixcut_fn = MixupCutMixAugmenter(alpha=0.5, p_mixup=0.3, p=0.75),
    teacher_transform = val_transform
)


trainer = Trainer(
    model=student_model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=NUM_EPOCH,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    batch_augment_fn=augment_fn,
    experiment_name="distillation_1_3",
    use_wandb=True,
    seed=42,
    scaler=scaler,
)

history = trainer.train()




Train 75: 100%|██████████| 310/310 [00:44<00:00,  6.93it/s]
Val 75:  68%|██████▊   | 53/78 [00:15<00:07,  3.32it/s]


KeyboardInterrupt: 