In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datasets import load_dataset
from torchvision.transforms.v2 import (
    Compose, RandomResizedCrop, RandomHorizontalFlip, CutMix, MixUp,
    Resize, CenterCrop, PILToTensor, RandAugment
)
from timm.optim import AdamW
from PIL import Image
import numpy as np
import math
from fastprogress import progress_bar

In [None]:
class config:
    crop_size = 160
    valid_crop_size = 256
    channels = 3
    batch_size = 128
    lr = 1e-3
    weight_decay = 0.05
    warmup_epochs = 20
    epochs = 100
    mixup_alpha = 0.8
    cutmix_alpha = 1.0
    label_smoothing = 0.1
    ema_decay = 0.9999
    grad_clip = 1.0

In [None]:
class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels//reduction, 1),
            nn.GELU(),
            nn.Conv2d(channels//reduction, channels, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return x * self.fc(x)

In [None]:
class HybridBlock(nn.Module):
    def __init__(self, dim, expansion=4, drop_path=0.):
        super().__init__()
        self.dwconv = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)
        self.norm1 = nn.LayerNorm(dim)
        self.pwconv1 = nn.Linear(dim, expansion*dim)
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(expansion*dim, dim)
        self.norm2 = nn.LayerNorm(dim)
        self.attn = ChannelAttention(dim)
        self.drop_path = DropPath(drop_path) if drop_path > 0 else nn.Identity()

    def forward(self, x):
        input = x
        x = self.dwconv(x)
        x = x.permute(0, 2, 3, 1)
        x = self.norm1(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        x = x.permute(0, 3, 1, 2)
        x = input + self.drop_path(x)
        x = x + self.drop_path(self.attn(self.norm2(x.permute(0,2,3,1)).permute(0,3,1,2)))
        return x

In [None]:
class MetaNet(nn.Module):
    def __init__(self, in_chans=3, num_classes=1000, 
                 depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024],
                 drop_path_rate=0.4):
        super().__init__()
        self.stem = nn.Sequential(
            nn.Conv2d(in_chans, dims[0], 4, 4),
            nn.GroupNorm(16, dims[0]),
            nn.GELU()
        )
        
        self.stages = nn.ModuleList()
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        curr = 0
        
        for i in range(4):
            stage = nn.Sequential(
                *[HybridBlock(dims[i], drop_path=dp_rates[curr + j]) 
                  for j in range(depths[i])]
            )
            self.stages.append(stage)
            curr += depths[i]
            if i < 3:
                self.stages.append(nn.Sequential(
                    nn.Conv2d(dims[i], dims[i+1], 2, 2),
                    nn.GroupNorm(16, dims[i+1]),
                    nn.GELU()
                ))
        
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.LayerNorm(dims[-1]),
            nn.Linear(dims[-1], num_classes)
        )
        
        self.apply(self._init_weights)
    
    def _init_weights(self, m):
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        x = self.stem(x)
        for stage in self.stages:
            x = stage(x)
        return self.head(x)

In [None]:
class DropPath(nn.Module):
    def __init__(self, drop_prob=0.):
        super().__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        if self.drop_prob == 0. or not self.training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        mask = torch.empty(shape, device=x.device).bernoulli_(keep_prob)
        return x.div(keep_prob) * mask

class EMA:
    def __init__(self, model, decay):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}
        
    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()
                
    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()
    
    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data
                param.data = self.shadow[name]
    
    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]
        self.backup = {}

In [None]:
mix_transforms = [
    CutMix(num_classes=1000),
    MixUp(num_classes=1000)
]

rand_crop = Compose([
    RandomResizedCrop(config.crop_size),
    RandomHorizontalFlip(),
    RandAugment(),
    PILToTensor()
])

cent_crop = Compose([
    Resize(config.valid_crop_size, interpolation=Image.Resampling.LANCZOS),
    CenterCrop(config.valid_crop_size),
    PILToTensor()
])

def train_collate_fn(batch):
    B = len(batch)
    x = torch.zeros((B, config.channels, config.crop_size, config.crop_size), dtype=torch.uint8)
    y = torch.zeros(B, dtype=torch.long)
    for i_sample, sample in enumerate(batch):
        y[i_sample] = sample['cls']
        x[i_sample] = rand_crop(sample['jpg'].convert("RGB"))
    x = x.float() / 255 - 0.5
    return x, y

def valid_collate_fn(batch):
    B = len(batch)
    x = torch.zeros((B, config.channels, config.valid_crop_size, config.valid_crop_size), dtype=torch.uint8)
    y = torch.zeros(B, dtype=torch.long)
    for i_sample, sample in enumerate(batch):
        y[i_sample] = sample['cls']
        x[i_sample] = cent_crop(sample['jpg'].convert("RGB"))
    x = x.float() / 255 - 0.5
    return x, y

In [None]:
device = "cuda:0"
train_loader = DataLoader(load_dataset('timm/imagenet-1k-wds', split='train'),
                         batch_size=config.batch_size, collate_fn=train_collate_fn,
                         num_workers=32, shuffle=True, pin_memory=True)
valid_loader = DataLoader(load_dataset('timm/imagenet-1k-wds', split='validation'),
                         batch_size=config.batch_size*2, collate_fn=valid_collate_fn,
                         num_workers=32, pin_memory=True)

In [None]:
# Model
model = MetaNet().to(device)

print(sum(p.numel() for p in model.parameters())/1e6)

ema = EMA(model, config.ema_decay)
ema.register()

# Optimizer
optimizer = AdamW(model.parameters(), lr=config.lr, weight_decay=config.weight_decay)
criterion = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)

# Scheduler
def lr_lambda(current_step):
    if current_step < config.warmup_epochs * len(train_loader):
        return (current_step + 1) / (config.warmup_epochs * len(train_loader))
    progress = (current_step - config.warmup_epochs * len(train_loader)) / \
              ((config.epochs - config.warmup_epochs) * len(train_loader))
    return 0.5 * (1 + math.cos(math.pi * progress))

scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

In [None]:
# Training
best_acc = 0.0
for epoch in range(config.epochs):
    model.train()
    for step, (x, y) in enumerate(progress_bar(train_loader)):
        x, y = x.to(device), y.to(device)
        with torch.amp.autocast(device):
            if np.random.rand() < 0.5:  # Mixup
                x, y = mix_transforms[0](x, y)
            else:
                x, y = mix_transforms[1](x, y)
            
            logits = model(x)
            loss = criterion(logits, y)
        
        # Optimize
        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
        optimizer.step()
        ema.update()
        scheduler.step()
        break
    
    # Validation
    model.eval()
    ema.apply_shadow()
    total, correct = 0, 0
    with torch.no_grad():
        for x, y in progress_bar(valid_loader):
            x, y = x.to(device), y.to(device)
            with torch.amp.autocast(device):
                logits = model(x)
            preds = logits.argmax(dim=1)
            total += y.size(0)
            correct += (preds == y).sum().item()
    acc = 100 * correct / total
    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), 'best_model.pth')
    ema.restore()
    
    print(f'Epoch {epoch+1}/{config.epochs} | Val Acc: {acc:.2f}% | Best Acc: {best_acc:.2f}%')