In [9]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from PIL import Image
from torch.optim import Adam

from functools import partial
from typing import Any, Callable, List, Optional, Type, Union
from torch import Tensor

In [15]:
from attentions import HyperAttentionMLP


from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# cifar traininig data
def cifar_dataloader(batch_size, train=True, num_workers=8):
    transform = []
    if train:
        transform.extend([
            transforms.RandAugment(2, 14),
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
        ])
    transform.extend([
        transforms.Resize(32),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    transform = transforms.Compose(transform)
    # augs
    train_loader = DataLoader(datasets.CIFAR10('data', 
                                train=train, 
                                download=True, 
                                transform=transform),
                              batch_size=batch_size, 
                              shuffle=train,
                              num_workers=num_workers,
                              )
    return train_loader


# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        out = torch.nn.functional.scaled_dot_product_attention(q, k, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class ResidualBlock(nn.Module):

    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x, **kwargs)) + x


class Identity(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, x, **kwargs):
        return x


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        self.adas = nn.ModuleList([])
        for _ in range(depth):
            block = nn.ModuleList([
                ResidualBlock(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                # ResidualBlock(dim, FeedForward(dim, mlp_dim, dropout = dropout)),
                ResidualBlock(dim, HyperAttentionMLP(dim, dim)),
            ])
            self.layers.append(block)
    def forward(self, x):
        for block in self.layers:
            for layer in block:
                x = layer(x)
        return x


class ViT(nn.Module):
    def __init__(self, image_size=32,
                        patch_size=2,
                        num_classes=10,
                        dim=512,
                        depth=6,
                        heads=8,
                        mlp_dim=2048,
                        channels = 3,
                        dim_head = 64, 
                        dropout = 0.1, 
                        emb_dropout = 0.1,
                        ):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, 
                                        depth, 
                                        heads, 
                                        dim_head, 
                                        mlp_dim, 
                                        dropout,
                                        )

        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)


In [16]:
# training
from types import SimpleNamespace
import datetime
os.makedirs('logs', exist_ok=True)
args = SimpleNamespace(
    batch_size=512,
    epochs=100,
    lr=1e-4,
    rank=8,
    weight_decay=1e-3,
    beta_1=0.9,
    beta_2=0.99,
    validate_every=1,
    save_every=1,
    use_fp16=False,
    compile=False,
    max_grad_norm=None,
    lr_warmp_up_steps=100,
)

device = "cuda" if torch.cuda.is_available() else "cpu"
train_loader = cifar_dataloader(args.batch_size, train=True)
val_loader = cifar_dataloader(args.batch_size, train=False)
model = ViT(
    image_size=32,
    patch_size=4,
    num_classes=10,
    dim=512,
    depth=6,
    heads=8,
    mlp_dim=512,
    channels = 3,
    dim_head = 64, 
    dropout = 0.1, 
    emb_dropout = 0.1,

).to(device)
print(model)
model.train()
optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(args.beta_1, args.beta_2))
criterion = nn.CrossEntropyLoss()

run_stats = {
    'train_loss': [],
    'val_loss': [],
    'val_acc': [],
}
scaler = torch.amp.GradScaler(enabled=args.use_fp16)

if args.compile:
    model = torch.compile(model)


lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: min(1, x / args.lr_warmp_up_steps))

for epoch in range(args.epochs):
    model.train()
    train_losses = []
    pbar = tqdm(train_loader)
    for x, y in pbar:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast(enabled=args.use_fp16, device_type='cuda'):
            y_hat = model(x)
            loss = criterion(y_hat, y)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        lr_scheduler.step()
        # loss = do_step()
        train_losses.append(loss.item())
        pbar.set_description(f'loss: {loss.item()}')
    run_stats['train_loss'].append(np.mean(train_losses))
    if epoch % args.validate_every == 0:
        model.eval()
        results = []
        val_losses = []
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                y_hat = model(x)
                loss = criterion(y_hat, y)
                y_pred = y_hat.argmax(dim=1)
                results.append((y == y_pred).float().mean().item())
                val_losses.append(loss.item())
        run_stats['val_loss'].append(np.mean(val_losses))
        run_stats['val_acc'].append(np.mean(results))
        model.train()
    if epoch % args.save_every == 0:
        torch.save(model.state_dict(), f'logs/model_{epoch}.pth')
    
    print(f'Epoch {epoch} train loss: {run_stats["train_loss"][-1]} val loss: {run_stats["val_loss"][-1]} val acc: {run_stats["val_acc"][-1]}')


Files already downloaded and verified
Files already downloaded and verified
ViT(
  (to_patch_embedding): Sequential(
    (0): Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=4, p2=4)
    (1): Linear(in_features=48, out_features=512, bias=True)
  )
  (dropout): Dropout(p=0.1, inplace=False)
  (transformer): Transformer(
    (layers): ModuleList(
      (0-5): 6 x ModuleList(
        (0): ResidualBlock(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fn): Attention(
            (to_qkv): Linear(in_features=512, out_features=1536, bias=False)
            (to_out): Sequential(
              (0): Linear(in_features=512, out_features=512, bias=True)
              (1): Dropout(p=0.1, inplace=False)
            )
          )
        )
        (1): ResidualBlock(
          (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (fn): HyperAttentionMLP(
            (q_proj): Linear(in_features=512, out_features=512, bias=False)
           

loss: 2.30153226852417: 100%|██████████| 98/98 [01:51<00:00,  1.14s/it]  


Epoch 0 train loss: 2.334795049258641 val loss: 2.209939432144165 val acc: 0.17691291347146035


loss: 2.1888492107391357: 100%|██████████| 98/98 [01:54<00:00,  1.17s/it]


Epoch 1 train loss: 2.1877633552162017 val loss: 2.1133996844291687 val acc: 0.20506089180707932


loss: 2.187307357788086: 100%|██████████| 98/98 [01:55<00:00,  1.17s/it] 


Epoch 2 train loss: 2.176220239425192 val loss: 2.0670367896556856 val acc: 0.22558019310235977


loss: 2.1318089962005615: 100%|██████████| 98/98 [01:54<00:00,  1.17s/it]


Epoch 3 train loss: 2.1654149975095476 val loss: 2.0938838362693786 val acc: 0.21059857532382012


loss: 2.2353618144989014: 100%|██████████| 98/98 [01:54<00:00,  1.17s/it]


Epoch 4 train loss: 2.1834609313886992 val loss: 2.1569111824035643 val acc: 0.19330193027853965


loss: 2.1528215408325195: 100%|██████████| 98/98 [01:54<00:00,  1.17s/it]


Epoch 5 train loss: 2.186429602759225 val loss: 2.0731096267700195 val acc: 0.23203125


loss: 2.147599220275879:   4%|▍         | 4/98 [00:06<02:40,  1.71s/it] 


KeyboardInterrupt: 