In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision, torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np

In [2]:
class MLP(nn.Module):
    def __init__(self, dim, hid_size, p_drop=0.1):
        super().__init__()
        
        self.layers = nn.Sequential(
            nn.Linear(dim, hid_size),
            nn.GELU(),
            nn.Dropout(p_drop),
            nn.Linear(hid_size, dim),
            nn.Dropout(p_drop))

    def forward(self, x):
        return self.layers(x)
        

In [3]:
class Attention(nn.Module):
    def __init__(self, dim, n_heads, p_drop=0.1):
        super().__init__()
        self.hid_size = dim // n_heads
        self.n_heads = n_heads
        self.scale = np.sqrt(self.hid_size)
        self.first_linear = nn.Linear(dim, 3*n_heads*self.hid_size, bias=False)
        self.out = nn.Sequential(
            nn.Linear(n_heads*self.hid_size, dim, bias=False),
            nn.Dropout(p=p_drop))
    
    def make_attention(self, q, k, v):
        A = F.softmax(torch.matmul(q, k.permute(0, 2, 1)) / self.scale, dim=-1)
        return torch.matmul(A, v)
    
    def forward(self, inputs):
        x = self.first_linear(inputs)
        x = torch.cat([self.make_attention(x[:, :, 3*i*self.hid_size:(3*i + 1)*self.hid_size],
                    x[:, :, (3*i + 1)*self.hid_size:(3 * i + 2)*self.hid_size],
                    x[:, :, (3*i + 2)*self.hid_size:(3 * i + 3)*self.hid_size]) for i in range(self.n_heads)],
                                                                                                         dim=2)
        return self.out(x)
        

In [4]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, dim, mlp_size, n_patches, n_heads, p_drop=0.1):
        super().__init__()
        self.attention = Attention(dim, n_heads, p_drop)
        self.mlp = MLP(dim, mlp_size, p_drop)
        self.ln1 = nn.LayerNorm([n_patches + 1, dim])
        self.ln2 = nn.LayerNorm([n_patches + 1, dim])
        
    def forward(self, inputs):
        x = self.attention(self.ln1(inputs)) + inputs      
        return self.mlp(self.ln2(x)) + x
        

In [5]:
class ViT(nn.Module):
    def __init__(self, n_classes, dim, mlp_size, n_patches, patch_size=16, depth=8, n_heads=8,
                n_channels=3, p_drop=0.1):
        super().__init__()
        self.patch_size = patch_size
        self.init_linear = nn.Linear(patch_size**2 * n_channels, dim, bias=False)
        self.xclass = nn.Parameter(torch.zeros(1,1,dim))
        self.E_pos = nn.Parameter(torch.randn(n_patches + 1, dim) * 0.02)
        self.dropout = nn.Dropout(p=p_drop)
        self.transformer = nn.Sequential(*[TransformerEncoderBlock(dim, mlp_size, n_patches,
                                                                   n_heads, p_drop) for _ in range(depth)])
        self.mlp_head = nn.Sequential(nn.LayerNorm(dim),
                                      nn.Linear(dim, n_classes))
        
    def patching_images(self, inputs):
        b, c, h, w = inputs.shape
        assert h % self.patch_size == 0 and w % self.patch_size == 0, "Need to change patch_size"
        x = inputs.reshape(b, c, h // self.patch_size, self.patch_size, w // self.patch_size, self.patch_size)
        x = x.permute(0, 2, 4, 1, 3, 5)
        return x.flatten(start_dim=3).flatten(start_dim=1, end_dim=2)
    
    def forward(self, inputs):
        x = self.patching_images(inputs)
        x = self.init_linear(x)
        x = torch.cat([self.xclass.repeat((x.shape[0], 1, 1)), x], dim=1) + self.E_pos
        x = self.dropout(x)
        x = self.transformer(x)
        return self.mlp_head(x[:, 0])
    

In [6]:
from tqdm.notebook import tqdm

def train(model, train_loader, val_loader, opt, scheduler=None, n_epochs=300, filename='best_transformer.pt'):
    train_loss = []
    val_accuracy = []
    val_loss = []
    best_accuracy = 0
    
    with tqdm(range(n_epochs * (len(val_loader) + len(train_loader)))) as pbar:
        for epoch in range(n_epochs):
            model.train()
            epoch_train_loss = 0
            epoch_val_loss = 0
            epoch_val_accuracy = 0

            for X_batch, y_batch in train_loader:
                X_batch = X_batch.to(device)
                y_batch = y_batch.to(device)
                loss = F.cross_entropy(model(X_batch), y_batch).mean()
                loss.backward()
                opt.step()
                if scheduler:
                    scheduler.step()
                opt.zero_grad() 
                epoch_train_loss += loss.item()
                pbar.update()
            train_loss.append(epoch_train_loss / len(train_loader))
            neptune.log_metric('train_loss', train_loss[-1])
            print("Epoch:", epoch + 1)
            print("Train loss: %.3f" % train_loss[-1])
            
            with torch.no_grad():
                model.eval()
                for (X_batch, y_batch) in val_loader:
                    logits = model(X_batch.to(device))
                    y_pred = torch.argmax(logits, dim=1)
                    epoch_val_accuracy += np.mean((y_batch == y_pred.cpu()).numpy())
                    loss = F.cross_entropy(logits, y_batch.to(device)).mean()
                    epoch_val_loss += loss.item()
                    pbar.update()
                val_accuracy.append(epoch_val_accuracy / len(val_loader))
                neptune.log_metric('val_accuracy', val_accuracy[-1])
                val_loss.append(epoch_val_loss / len(val_loader))
                neptune.log_metric('val_loss', val_loss[-1])
                if val_accuracy[-1] > best_accuracy:
                    torch.save(model, filename)
                    best_accuracy = val_accuracy[-1]
                print("Val loss: %.3f" % val_loss[-1])
                print("Val accuracy: %.3f" % val_accuracy[-1])
            
    return model

In [8]:
class WarmupSchedule:
    def __init__(self, dataset_len, batch_size, n_epochs, warmup_percent=1):
        self.n_iterations = n_epochs * (dataset_len // batch_size)
        self.warmuplen = (self.n_iterations * warmup_percent) // 100
        
    def get_lr_coef(self, i):
        if i < self.warmuplen:
            return (i + 1) / self.warmuplen
        return (self.n_iterations - i) / (self.n_iterations - self.warmuplen)

In [7]:
import neptune
NEPTUNE_API_TOKEN = 'sometoken'
neptune.init(f'calistro/vit', api_token=NEPTUNE_API_TOKEN)



Project(calistro/vit)

In [9]:
batch_size=200

train_dir = '../../../../data/evdmsivets/tiny-imagenet-200/train'
test_dir = '../../../../data/evdmsivets/tiny-imagenet-200/val'

train_transforms = torchvision.transforms.Compose([
    transforms.RandomRotation([-10, 10]),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

test_transforms = torchvision.transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

dataset = torchvision.datasets.ImageFolder(train_dir, transform=train_transforms)
val_dataset = torchvision.datasets.ImageFolder(test_dir, transform=test_transforms)

train_batch_gen = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=10)
val_batch_gen = DataLoader(val_dataset, batch_size=batch_size, num_workers=10)
  

### Training

In [13]:
device='cuda:5'

model = ViT(n_classes=200, dim=384, mlp_size=1536, n_patches=64, patch_size=8,
            depth=6, n_heads=6, n_channels=3, p_drop=0.05).to(device)

neptune.create_experiment()

opt = torch.optim.Adam(model.parameters(), lr=1e-3)
  
n_epochs = 150
lr_schedule = WarmupSchedule(len(dataset), batch_size, n_epochs, warmup_percent=2)

scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_schedule.get_lr_coef)
model = train(model, train_batch_gen, val_batch_gen, opt, scheduler, n_epochs)

https://ui.neptune.ai/calistro/vit/e/VIT-12


HBox(children=(FloatProgress(value=0.0, max=82500.0), HTML(value='')))

Epoch: 1
Train loss: 4.862
Val loss: 4.545
Val accuracy: 0.073
Epoch: 2
Train loss: 4.278
Val loss: 4.007
Val accuracy: 0.139
Epoch: 3
Train loss: 3.909
Val loss: 3.732
Val accuracy: 0.177
Epoch: 4
Train loss: 3.556
Val loss: 3.457
Val accuracy: 0.223
Epoch: 5
Train loss: 3.253
Val loss: 3.218
Val accuracy: 0.268
Epoch: 6
Train loss: 3.032
Val loss: 3.018
Val accuracy: 0.300
Epoch: 7
Train loss: 2.848
Val loss: 2.899
Val accuracy: 0.330
Epoch: 8
Train loss: 2.690
Val loss: 2.867
Val accuracy: 0.335
Epoch: 9
Train loss: 2.537
Val loss: 2.805
Val accuracy: 0.347
Epoch: 10
Train loss: 2.379
Val loss: 2.715
Val accuracy: 0.371
Epoch: 11
Train loss: 2.223
Val loss: 2.718
Val accuracy: 0.369
Epoch: 12
Train loss: 2.054
Val loss: 2.721
Val accuracy: 0.379
Epoch: 13
Train loss: 1.884
Val loss: 2.778
Val accuracy: 0.375
Epoch: 14
Train loss: 1.700
Val loss: 2.840
Val accuracy: 0.381
Epoch: 15
Train loss: 1.512
Val loss: 2.965
Val accuracy: 0.377
Epoch: 16
Train loss: 1.344
Val loss: 3.144
Val a

Epoch: 129
Train loss: 0.006
Val loss: 6.957
Val accuracy: 0.375
Epoch: 130
Train loss: 0.006
Val loss: 6.920
Val accuracy: 0.379
Epoch: 131
Train loss: 0.005
Val loss: 7.046
Val accuracy: 0.375
Epoch: 132
Train loss: 0.005
Val loss: 6.946
Val accuracy: 0.380
Epoch: 133
Train loss: 0.005
Val loss: 6.991
Val accuracy: 0.377
Epoch: 134
Train loss: 0.005
Val loss: 7.009
Val accuracy: 0.374
Epoch: 135
Train loss: 0.004
Val loss: 6.938
Val accuracy: 0.378
Epoch: 136
Train loss: 0.004
Val loss: 7.060
Val accuracy: 0.376
Epoch: 137
Train loss: 0.004
Val loss: 7.009
Val accuracy: 0.378
Epoch: 138
Train loss: 0.004
Val loss: 6.892
Val accuracy: 0.382
Epoch: 139
Train loss: 0.003
Val loss: 6.932
Val accuracy: 0.382
Epoch: 140
Train loss: 0.003
Val loss: 6.890
Val accuracy: 0.383
Epoch: 141
Train loss: 0.003
Val loss: 6.910
Val accuracy: 0.385
Epoch: 142
Train loss: 0.003
Val loss: 6.935
Val accuracy: 0.381
Epoch: 143
Train loss: 0.003
Val loss: 6.837
Val accuracy: 0.385
Epoch: 144
Train loss: 0.

In [15]:
neptune.stop()

The model was trained from scratch on the TinyImageNet dataset, achieving a validation accuracy of 0.38.