In [21]:
%pip install tqdm

Note: you may need to restart the kernel to use updated packages.


In [22]:
import math
from pathlib import Path
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm


device = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: mps


In [23]:
batch_size = 128
num_workers = 4

data_dir = Path("hands-on/data")
data_dir.mkdir(parents=True, exist_ok=True)


cifar_mean = (0.4914, 0.4822, 0.4465)
cifar_std = (0.2470, 0.2435, 0.2616)
normalize = transforms.Normalize(mean=cifar_mean, std=cifar_std)

train_transforms = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    normalize,
])


train_ds = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transforms)
test_ds = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=test_transforms)

pin_memory = device == "cuda"
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory)

class_names = train_ds.classes
print(f"Train samples: {len(train_ds)}, Test samples: {len(test_ds)}")


Train samples: 50000, Test samples: 10000


In [24]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_size = 32, patch_size = 4, in_channels = 3, embed_dim = 128 ): # 
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride= patch_size)
    
    def forward(self,x):
        x = self.proj(x) # (B, embed_dim, img_size/patch_size, img_size/patch_size )
        x = x.flatten(2).transpose(1,2) # (B, num_patches, embed_dim)
        return x 


class MLP(nn.Module):
    def __init__(self, embed_dim,  mlp_ratio=4.0, drop=0.0):
        super().__init__()
        hidden = int(embed_dim * mlp_ratio)
        self.fc1 = nn.Linear(embed_dim, hidden)
        self.act = nn.GELU() # Gaussian Error Linear Units - Empfehlung für alle Transformer 
        self.fc2 = nn.Linear(hidden,embed_dim)
        self.drop = nn.Dropout(drop)
    
    def forward(self, x ):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x 

class Block(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, drop=0.0, attn_drop=0.0):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim) # Normalisierung über alle Feature-Dimensione (nicht batchweit)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=attn_drop, batch_first=True)
        self.drop = nn.Dropout(drop)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_ratio, drop) 
    
    def forward(self,x):
        norm1 = self.norm1(x)
        attn_out, _ = self.attn(norm1,norm1,norm1)
        x = x + self.drop(attn_out) # unterer Teil
        norm2 = self.norm2(x)
        x = x + self.drop(self.mlp(norm2)) # oberer Teil
        return x 





class VisionTransformer(nn.Module):
    def __init__(self, img_size = 32, patch_size = 4, in_channels = 3, num_classes = 10, embed_dim = 128, depth = 6,
                 num_heads = 4,
                 mlp_ratio = 4.0,
                 drop = 0.1,
                 attn_drop = 0.0):
        super().__init__()
        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.randn(1,1, embed_dim))
        self.pos_embed = nn.Parameter(torch.rand(1,num_patches+1, embed_dim))
        self.pos_drop = nn.Dropout(drop)

        self.transformer = nn.ModuleList(
            [
                Block(embed_dim, num_heads, mlp_ratio, drop=drop, attn_drop=attn_drop) for _ in range(depth)
            ]
        )

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)


    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x),dim=1) 
        x = x + self.pos_embed # Embedding + Positional Embedding 
        x = self.pos_drop(x)

        for block in self.transformer:
            x = block(x)
        
        x = self.norm(x)
        cls_out = x[:,0]
        return self.head(cls_out)

In [25]:
model = VisionTransformer(img_size=32, patch_size=4, num_classes=10, embed_dim=128, depth=6, num_heads=4, drop=0.1)
model.to(torch.device(device))

lr = 3e-4
epochs = 10

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = lr)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

num_params = sum(p.numel() for p in model.parameters())
print(f"Model params: {num_params/1e6:.2f}M")

Model params: 1.21M


In [26]:
def train_one_epoch(model, dataloader, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in tqdm(dataloader, leave=False):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # Gradient Clipping 
        optimizer.step()

        running_loss += loss.item() * labels.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return running_loss / total, correct / total

@torch.inference_mode()
def evaluate(model, dataloader, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)
        running_loss += loss.item() * labels.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return running_loss / total, correct / total


In [None]:
for epoch in range(epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device)
    val_loss, val_acc = evaluate(model, test_loader, device)
    scheduler.step()
    print(f"Epoch {epoch + 1}/{epochs} | train loss {train_loss:.4f} acc {train_acc:.3f} | val loss {val_loss:.4f} acc {val_acc:.3f}")

                                                 