In [1]:
import sys

sys.path.append('..')

import torch
import torch.nn as nn
from torchvision import transforms
from timm import create_model

import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
from pathlib import Path
import numpy as np

from src.trainer import Trainer
from src.dataset import HumanPosesDataset
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import plotly.io as pio
pio.renderers.default = "browser"

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, teacher_model, temperature=4.0, alpha=0.7):
        super().__init__()
        self.teacher = teacher_model.eval()
        self.temperature = temperature
        self.alpha = alpha

    def forward(self, student_logits, _, labels):
        if isinstance(labels, tuple) and len(labels) == 2:
            targets, x_teacher = labels
        else:
            raise ValueError("Expected labels to be (targets, x_teacher) tuple")

        x_teacher = x_teacher.to(student_logits.device)
        with torch.no_grad():
            teacher_logits = self.teacher(x_teacher)

        T = self.temperature
        soft_teacher = F.softmax(teacher_logits / T, dim=1)
        soft_student = F.log_softmax(student_logits / T, dim=1)
        distill_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (T ** 2)

        ce_loss = F.cross_entropy(student_logits, targets)
        return self.alpha * distill_loss + (1 - self.alpha) * ce_loss

def distill_batch_augment(images, labels, val_transform):
    x_teacher = torch.stack([val_transform(img) for img in images.cpu()])
    return images, (labels, x_teacher.to(images.device))


# Датасет

In [4]:
from torchvision import transforms

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.5, 1.0), ratio=(0.75, 1.33)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
    transforms.RandomApply([
        transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))
    ], p=0.3),
    transforms.RandomApply([
        transforms.RandomAffine(degrees=15, translate=(0.05, 0.05), scale=(0.9, 1.1))
    ], p=0.5),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.25, scale=(0.02, 0.2), ratio=(0.3, 3.3), value='random'),
    transforms.Normalize(mean=mean, std=mean)
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=mean)
])


In [5]:
CSV_PATH = Path("../data/human_poses_data/train_answers.csv")
TRAIN_DIR = Path("../data/human_poses_data/img_train")

df = pd.read_csv(CSV_PATH)

train_ids, val_ids = train_test_split(
    df['img_id'].values,
    test_size=0.2,
    stratify=df['target_feature'],
    random_state=42
)

train_df = df[df['img_id'].isin(train_ids)].reset_index(drop=True)
val_df = df[df['img_id'].isin(val_ids)].reset_index(drop=True)

train_dataset = HumanPosesDataset(
    data_df=train_df,
    img_dir=TRAIN_DIR,
    transform=train_transform,
)

val_dataset = HumanPosesDataset(
    data_df=val_df,
    img_dir=TRAIN_DIR,
    transform=val_transform,
)



train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

Train dataset size: 9893
Validation dataset size: 2474


In [6]:
num_classes = len(np.unique(df['target_feature']))
print(f"Количество классов: {num_classes}")

Количество классов: 16


# Модель

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {device}")

✅ Using device: cuda


In [8]:
from src.models.miniconvnext import MiniConvNeXt
from src.models.teacher import ConvNeXtTeacher
from src.utils import load_best_model

student_model = MiniConvNeXt(num_classes=16)
student_model = student_model.to(device)

teacher_model = ConvNeXtTeacher(num_classes=16)
load_best_model(teacher_model, '../best_models/teacher.pth', device)
teacher_model.eval().to(device)


Mapping deprecated model name convnext_large_in22k to current convnext_large.fb_in22k.



✅ Loaded model weights from ../best_models/teacher.pth


ConvNeXtTeacher(
  (backbone): ConvNeXt(
    (stem): Sequential(
      (0): Conv2d(3, 192, kernel_size=(4, 4), stride=(4, 4))
      (1): LayerNorm2d((192,), eps=1e-06, elementwise_affine=True)
    )
    (stages): Sequential(
      (0): ConvNeXtStage(
        (downsample): Identity()
        (blocks): Sequential(
          (0): ConvNeXtBlock(
            (conv_dw): Conv2d(192, 192, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=192)
            (norm): LayerNorm((192,), eps=1e-06, elementwise_affine=True)
            (mlp): Mlp(
              (fc1): Linear(in_features=192, out_features=768, bias=True)
              (act): GELU()
              (drop1): Dropout(p=0.0, inplace=False)
              (norm): Identity()
              (fc2): Linear(in_features=768, out_features=192, bias=True)
              (drop2): Dropout(p=0.0, inplace=False)
            )
            (shortcut): Identity()
            (drop_path): Identity()
          )
          (1): ConvNeXtBlock(
            (

In [9]:
from torch.amp import GradScaler

NUM_EPOCH = 15

optimizer = torch.optim.AdamW(
    student_model.parameters(),
    lr=3e-4,
    weight_decay=1e-4
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=NUM_EPOCH
)

criterion = DistillationLoss(
    teacher_model=teacher_model,
    temperature=4.0,
    alpha=0.7
)

scaler = GradScaler()


In [11]:
from src.utils import MixupCutMixAugmenter

mixup_cutmix_fn = MixupCutMixAugmenter(alpha=1.0, p_mixup=0.3)

trainer = Trainer(
    model=student_model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=NUM_EPOCH,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    batch_augment_fn=None, #ВРЕМЕННО НЕ ЗАБЫТЬ БЫ
    experiment_name="",
    use_wandb=True,
    seed=42,
    scaler=scaler,
)

history = trainer.train()


Epoch 1/15


Train 1:   0%|          | 0/310 [00:00<?, ?it/s]


TypeError: DistillationLoss.forward() missing 1 required positional argument: 'labels'