In [None]:


from torch.utils.data import DataLoader
from tqdm import tqdm

import sys

sys.path.append('..')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature

    def forward(self, zis, zjs):
        batch_size = zis.size(0)
        device = zis.device

        zis = F.normalize(zis, dim=1)
        zjs = F.normalize(zjs, dim=1)

        z = torch.cat([zis, zjs], dim=0)

        similarity_matrix = torch.matmul(z, z.T)

        mask = torch.eye(2 * batch_size, device=device).bool()
        similarity_matrix = similarity_matrix.masked_fill(mask, float('-inf'))

        logits = similarity_matrix / self.temperature

        labels = torch.arange(batch_size, device=device)
        labels = torch.cat([labels + batch_size, labels], dim=0)

        # Loss
        loss = F.cross_entropy(logits, labels)
        return loss


In [None]:
def train_simclr(model, dataloader, loss_fn, optimizer, scheduler, epochs=10):
    model.train()
    scaler = torch.amp.GradScaler()

    for epoch in range(epochs):
        total_loss = 0

        for x1, x2 in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            x1, x2 = x1.to(device), x2.to(device)

            with torch.amp.autocast(device_type='cuda'):
                z1 = model(x1)
                z2 = model(x2)
                loss = loss_fn(z1, z2)

            optimizer.zero_grad()
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()

        scheduler.step()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}")

        if (epoch + 1) % 5 == 0 or (epoch + 1) == epochs:
            torch.save(model.encoder.state_dict(), f"encoder_ssl_epoch{epoch+1}.pt")
            print(f"✅ Saved checkpoint: encoder_ssl_epoch{epoch+1}.pt")


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {device}")

In [None]:
import torch
import torch.nn as nn

class LayerNorm2d(nn.Module):
    def __init__(self, channels, eps=1e-6):
        super().__init__()
        self.norm = nn.LayerNorm(channels, eps=eps)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = self.norm(x)
        return x.permute(0, 3, 1, 2)


class ConvNeXtBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
        self.norm = LayerNorm2d(dim, eps=1e-6)
        self.pwconv1 = nn.Conv2d(dim, 4 * dim, kernel_size=1)
        self.act = nn.GELU()
        self.pwconv2 = nn.Conv2d(4 * dim, dim, kernel_size=1)

    def forward(self, x):
        residual = x
        x = self.dwconv(x)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        return x + residual


class MiniConvNeXt(nn.Module):
    def __init__(self, in_chans=3, num_classes=16,
                 depths=None, dims=None):
        super().__init__()
        if depths is None:
            depths = [2, 2, 2]
        if dims is None:
            dims = [64, 128, 256]
        self.num_classes = num_classes
        self.downsample_layers = nn.ModuleList()
        stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
            LayerNorm2d(dims[0], eps=1e-6)
        )
        self.downsample_layers.append(stem)

        for i in range(2):
            downsample_layer = nn.Sequential(
                LayerNorm2d(dims[i], eps=1e-6),
                nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
            )
            self.downsample_layers.append(downsample_layer)

        self.stages = nn.ModuleList([
            nn.Sequential(*[ConvNeXtBlock(dim) for _ in range(depth)])
            for dim, depth in zip(dims, depths)
        ])

        self.norm = nn.LayerNorm(dims[-1], eps=1e-6)
        self.head = nn.Linear(dims[-1], num_classes) if num_classes is not None else nn.Identity()

    def forward_features(self, x):
        for down, stage in zip(self.downsample_layers, self.stages):
            x = down(x)
            x = stage(x)
        x = x.mean([-2, -1])
        x = self.norm(x)
        return x

    def forward(self, x):
        x = self.forward_features(x)
        return self.head(x)


class SimCLRModel(nn.Module):
    def __init__(self, encoder: MiniConvNeXt, projection_dim=128):
        super().__init__()
        encoder_output_dim = encoder.norm.normalized_shape[0]
        self.encoder = encoder
        self.projection_head = nn.Sequential(
            nn.Linear(encoder_output_dim, 512),
            nn.ReLU(),
            nn.Linear(512, projection_dim)
        )

    def forward(self, x):
        features = self.encoder.forward_features(x)
        projections = self.projection_head(features)
        return projections

In [None]:
from src.ssl_dataset import SimCLRDataset

image_dir = '../data/human_poses_data/img_train'
dataset = SimCLRDataset(image_folder='../data/human_poses_data/img_train')

dataloader = DataLoader(
    dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

NUM_EPOCHS = 50

encoder = MiniConvNeXt(num_classes=None).to(device)
model = SimCLRModel(encoder).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max=NUM_EPOCHS, eta_min=1e-5
)

loss_fn = NTXentLoss()


train_simclr(model, dataloader, loss_fn, optimizer, scheduler, epochs=NUM_EPOCHS)

 # # Full fine-tune

In [None]:
import sys

sys.path.append('..')

from src.models.miniconvnext import MiniConvNeXt

import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
from pathlib import Path
import numpy as np

from src.trainer import Trainer
from src.dataset import HumanPosesDataset
from sklearn.model_selection import train_test_split
import torch

In [None]:
import plotly.io as pio
pio.renderers.default = "browser"

# Датасет

In [None]:
from torchvision import transforms

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224, scale=(0.6, 1.0), ratio=(0.75, 1.33)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.05),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.2),
    transforms.ToTensor(),
    transforms.RandomApply([transforms.RandomErasing()], p=0.3),
    transforms.Normalize(mean=mean, std=std),
])

val_transform= transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

In [None]:
CSV_PATH = Path("../data/human_poses_data/train_answers.csv")
TRAIN_DIR = Path("../data/human_poses_data/img_train")

df = pd.read_csv(CSV_PATH)

train_ids, val_ids = train_test_split(
    df['img_id'].values,
    test_size=0.2,
    stratify=df['target_feature'],
    random_state=42
)

train_df = df[df['img_id'].isin(train_ids)].reset_index(drop=True)
val_df = df[df['img_id'].isin(val_ids)].reset_index(drop=True)

train_dataset = HumanPosesDataset(
    data_df=train_df,
    img_dir=TRAIN_DIR,
    transform=train_transform,
)

val_dataset = HumanPosesDataset(
    data_df=val_df,
    img_dir=TRAIN_DIR,
    transform=val_transform,
)



train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Validation dataset size: {len(val_dataset)}")

In [None]:
num_classes = len(np.unique(df['target_feature']))
print(f"Количество классов: {num_classes}")

# Модель

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {device}")

In [None]:
encoder = MiniConvNeXt(num_classes=None)
encoder.load_state_dict(torch.load("encoder_ssl_epoch25.pt"), strict=False)

model = MiniConvNeXt(num_classes=16)
model.load_state_dict(torch.load("encoder_ssl_epoch25.pt"), strict=False)

model.to(device)

In [None]:
from torch.amp import GradScaler

NUM_EPOCH = 75

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=3e-4,
    steps_per_epoch=len(train_loader),
    epochs=NUM_EPOCH,
    pct_start=0.1,
    anneal_strategy='cos',
    div_factor=25.0,
    final_div_factor=1e4
)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

scaler = GradScaler()

In [None]:
from src.utils import MixupCutMixAugmenter

mixup_cutmix_fn = MixupCutMixAugmenter(alpha=1.0, p_mixup=0.3)

trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=NUM_EPOCH,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    batch_augment_fn=mixup_cutmix_fn,
    experiment_name="ssl_1_1",
    use_wandb=True,
    seed=42,
    scaler=scaler
)

history = trainer.train()

In [None]:
from src.utils import load_best_model

load_best_model(model, 'checkpoints/ssl_1_1_best.pth')

In [None]:
from torch.amp import GradScaler

NUM_EPOCH = 75

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=3e-4,
    steps_per_epoch=len(train_loader),
    epochs=NUM_EPOCH,
    pct_start=0.1,
    anneal_strategy='cos',
    div_factor=25.0,
    final_div_factor=1e4
)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

scaler = GradScaler()

In [None]:
from src.utils import MixupCutMixAugmenter

mixup_cutmix_fn = MixupCutMixAugmenter(alpha=1.0, p_mixup=0.3)

trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=NUM_EPOCH,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    batch_augment_fn=mixup_cutmix_fn,
    experiment_name="ssl_1_2",
    use_wandb=True,
    seed=42,
    scaler=scaler
)

history = trainer.train()