In [1]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1


In [35]:
import torch
from torch import nn
import torch.functional as F
import numpy as np

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

In [57]:
class MultiHeadedSelfAttention(nn.Module):
    def __init__(self, dim, num_heads, dropout):
        super().__init__()
        self.proj_q = nn.Linear(dim, dim)
        self.proj_k = nn.Linear(dim, dim)
        self.proj_v = nn.Linear(dim, dim)
        self.drop = nn.Dropout(dropout)
        self.n_heads = num_heads
        self.scores = None

    def forward(self, x, mask):
        #(b s d) -> (b, s d)
        q, k, v = self.proj_q(x), self.proj_k(x), self.proj_v(x)
        #(b s d) -> (b h s w), h*w = d
        q, k, v = map(lambda t : rearrange(t, 'b s (h w) -> b h s w', h=self.n_heads), (q,k,v))

        #(b h s w) @ (b h w s) -> (b h s s) -softmax-> (b h s s)
        scores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
        if mask is not None:
            mask = mask[:, None, None, :].float()
            scores -= 10000.0 * (1.0 - mask)
        scores = self.drop(torch.softmax(scores, dim=-1))

        #(b h s s) @ (b h s w) -> (b h s w) -trans-> (b s d)
        h = scores @ v
        h = rearrange(h, "b h s w -> b s (h w)")
        self.scores = scores
        return h

In [64]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, dim, ff_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, ff_dim)
        self.fc2 = nn.Linear(ff_dim, dim)
        self.gelu = nn.GELU()
    def forward(self, x):
        return self.fc2(self.gelu(self.fc1(x)))

In [65]:
class Block(nn.Module):
    def __init__(self, dim, num_heads, ff_dim, dropout):
        super().__init__()
        self.attn = MultiHeadedSelfAttention(dim, num_heads, dropout)
        self.proj = nn.Linear(dim, dim)
        self.norm1 = nn.LayerNorm(dim, eps=1e-6)
        self.pwff = PositionWiseFeedForward(dim, ff_dim)
        self.norm2 = nn.LayerNorm(dim, eps=1e-6)
        self.drop = nn.Dropout(dropout)

    def forward(self, x, mask):
        h = self.drop(self.proj(self.attn(self.norm1(x), mask)))
        x = x + h
        h = self.drop(self.pwff(self.norm2(x)))
        x = x + h
        return x

In [66]:
class Transformer(nn.Module):
    def __init__(self, num_layers, dim, num_heads, ff_dim, dropout):
        super().__init__()
        self.blocks = nn.ModuleList([
            Block(dim, num_heads, ff_dim, dropout) for _ in range(num_layers)
        ])

    def forward(self, x, mask=None):
        for block in self.blocks:
            x = block(x, mask)
        return x

In [67]:
class PositionalEmbedding1D(nn.Module):
    def __init__(self, seq_len, dim):
        super().__init__()
        self.pos_embedding = nn.Parameter(torch.zeros(1, seq_len, dim))
    def forward(self, x):
        return x + self.pos_embedding

In [102]:
def pair(t):
    return t if isinstance(t, tuple) else (t,t)

class ViT(nn.Module):
    def __init__(
        self,
        image_size,
        patch_size,
        num_classes,
        in_channels,
        dim,
        ff_dim,
        num_heads,
        num_layers,
        attention_dropout_rate=0.0,
        dropout_rate=0.1,
        representation_size=None,
        load_repr_layer=False,
        classifier='token',
        positional_embedding='1d',
        pretrained=False
    ):
        super().__init__()
        h, w = pair(image_size)
        fh, fw = pair(patch_size)
        gh, gw = h//fh, w//fw
        seq_len = gh * gw

        # Patch Embed
        self.patch_embedding = nn.Conv2d(in_channels, dim, kernel_size=(fh, fw), stride=(fh, fw))

        # CLASS Token
        if classifier == 'token':
            self.class_token = nn.Parameter(torch.zeros(1, 1, dim))
            seq_len += 1
        
        # Pos Embed
        if positional_embedding == '1d':
            self.positional_embedding = PositionalEmbedding1D(seq_len, dim)
        
        # Transformer
        self.transformer = Transformer(num_layers, dim, num_heads, ff_dim, dropout_rate)

        # Pre Fc Repr Layer
        if representation_size and load_repr_layer:
            self.pre_logits = nn.Linear(dim, representation_size)
            pre_logit_size = representation_size
        else:
            pre_logit_size = dim
        
        # Head
        self.norm = nn.LayerNorm(pre_logit_size, eps=1e-6)
        self.fc = nn.Linear(pre_logit_size, num_classes)
    
    @torch.no_grad()
    def init_weights(self):
        def _init(m):
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if hasattr(m, 'bias') and m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)
        self.apply(_init)
        nn.init.constant_(self.fc.weight, 0)
        nn.init.constant_(self.fc.bias, 0)
        nn.init.normal_(self.positional_embedding.pos_embedding, std=0.02)
        nn.init.constant_(self.class_token, 0)
    
    def forward(self, x):
        b, c, fh, fw = x.shape
        # (b c fh fw) -patch_embed-> (b d gh gw)
        x = self.patch_embedding(x)
        # (b d gh gw) -> (b s d)
        x = rearrange(x, 'b d gh gw -> b (gh gw) d')
        if hasattr(self, 'class_token'):
            # (b s d) -> (b s+1 d)
            x = torch.cat((repeat(self.class_token, '1 1 d -> b 1 d', b=b), x), dim=1)
        x = self.transformer(x)
        if hasattr(self, 'pre_logits'):
            x = self.pre_logits(x)
            x = torch.tanh(x)
        if hasattr(self, 'fc'):
            x = self.norm(x)[:, 0]
            x = self.fc(x)
        return x

In [103]:
from torch.utils import model_zoo

vit_b16_config = dict(
    image_size=(224, 224),
    patch_size=(16, 16),
    num_classes=21843,
    in_channels=3,
    dim=768,
    ff_dim=3072,
    num_heads=12,
    num_layers=12,
    attention_dropout_rate=0.0,
    dropout_rate=0.1,
    representation_size=768,
    classifier='token',
    pre_trained_url="https://github.com/lukemelas/PyTorch-Pretrained-ViT/releases/download/0.0.2/B_16.pth"
)
if 'pre_trained_url' in vit_b16_config.keys():
    url = vit_b16_config.pop('pre_trained_url')
    state_dict = model_zoo.load_url(url)
vit_b16 = ViT(**vit_b16_config)
ret = vit_b16.load_state_dict(state_dict)

In [104]:
import torchvision
import torchvision.transforms as transforms

In [106]:
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

train_ds = torchvision.datasets.CIFAR100(root='/content/', download=True, transform=transform, train=True)
train_ds, valid_ds = torch.utils.data.random_split(train_ds, [40000, 10000])

train_loader = torch.utils.data.DataLoader(train_ds,                                          
                                          batch_size=64,
                                          shuffle=True,
                                          num_workers=2)

valid_loader = torch.utils.data.DataLoader(valid_ds,
                                           batch_size=64,
                                           shuffle=False,
                                           num_workers=2)

test_ds = torchvision.datasets.CIFAR100(root='/content/', download=True, transform=transform, train=False)
test_loader = torch.utils.data.DataLoader(test_ds,                                          
                                          batch_size=64,
                                          shuffle=False,
                                          num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vit_b16 = vit_b16.to(device)

optim = torch.optim.SGD(vit_b16.parameters(),1e-3, momentum=0.9)
criterion = torch.nn.CrossEntropyLoss()
epochs = 30
for epoch in range(epochs):
    running_loss = 0.0
    for i, (x, y) in enumerate(train_loader, 0):
        if epoch == 0 and i == 0:
            break
        x = x.to(device)
        y = y.to(device)

        optim.zero_grad()
        y_pred = vit_b16(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optim.step()

        running_loss += loss.item()
        if i % 100 == 99:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

    with torch.no_grad():
        valid_loss = 0.0    
        total = 0
        correct = 0
        for x, y in valid_loader:
            x = x.to(device)
            y = y.to(device)

            y_pred = vit_b16(x)
            loss = criterion(y_pred, y)
            valid_loss += loss.item()
            _, predicted = torch.max(y_pred.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
        print("Acc : ", 100*correct//total, "Loss : ", valid_loss)