In [None]:
!pip3 install torchvision

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = datasets.CIFAR10(root='./data', train=True,
                                      download=True, transform=transform)
subset_indices = torch.randperm(len(trainset))[:20000]  # i have negative patience
trainset = Subset(trainset, subset_indices)

train_loader = DataLoader(trainset, batch_size=1024, shuffle=True, num_workers=2)

testset = datasets.CIFAR10(root='./data', train=False,
                                     download=True, transform=transform)
subset_indices = torch.randperm(len(testset))[:5000]
testset = Subset(testset, subset_indices)
test_loader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [11]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=96):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2  # (32 // 4) ** 2 = 64 patches
        #we use convolution as an epic hack to replace patching
        self.proj = nn.Conv2d(
            in_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size,
        )

    def forward(self, x):
        x = self.proj(x)  # (B, embed_dim, 8, 8)  with patch_size=4
        x = x.flatten(2)   # (B, embed_dim, 64)
        x = x.transpose(1, 2)  # (B, 64, embed_dim)  <- Correct order for MultiheadAttention
        return x


class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, drop):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


class Block(nn.Module):
    def __init__(self, dim, n_heads, mlp_ratio, drop, attn_drop):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, n_heads, dropout=attn_drop, batch_first=True)
        self.norm2 = nn.LayerNorm(dim)
        hidden_features = int(dim * mlp_ratio)
        self.mlp = MLP(in_features=dim, hidden_features=hidden_features, out_features=dim, drop=drop)

    def forward(self, x):
        residual = x
        x = self.norm1(x)
        x, _ = self.attn(x, x, x)
        x = x + residual

        residual = x
        x = self.norm2(x)
        x = self.mlp(x)
        x = x + residual
        return x


class ViT(nn.Module):
    def __init__(self, img_size=32, patch_size=4, in_channels=3, n_classes=10, embed_dim=96,  # Adjusted embed_dim
                 depth=6, n_heads=8, mlp_ratio=4., drop_rate=0.4, attn_drop_rate=0.2):  # Adjusted depth, n_heads
        super().__init__()

        self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.n_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        self.blocks = nn.Sequential(*[
            Block(dim=embed_dim, n_heads=n_heads, mlp_ratio=mlp_ratio, drop=drop_rate, attn_drop=attn_drop_rate)
            for _ in range(depth)
        ])

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, n_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x) # (B, num_patches, embed_dim)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1) # (B, num_patches + 1, embed_dim)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        x = self.blocks(x)

        x = self.norm(x)
        x = x[:, 0]  # CLS token
        x = self.head(x)
        return x

In [12]:
model = ViT()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.03)
criterion = nn.CrossEntropyLoss()

num_epochs = 20

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0

    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted == labels).sum().item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Training Accuracy: {100 * correct_train / total_train:.2f}%')
    running_loss = 0.0


model.eval()
correct_test = 0
total_test = 0

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total_test += labels.size(0)
        correct_test += (predicted == labels).sum().item()

test_accuracy = 100 * correct_test / total_test
print(f'Test Accuracy: {test_accuracy:.2f}%')

Epoch [1/20], Loss: 2.0679, Training Accuracy: 22.14%
Epoch [2/20], Loss: 1.9384, Training Accuracy: 27.36%
Epoch [3/20], Loss: 1.8322, Training Accuracy: 31.09%
Epoch [4/20], Loss: 1.7521, Training Accuracy: 33.84%
Epoch [5/20], Loss: 1.6940, Training Accuracy: 36.17%
Epoch [6/20], Loss: 1.6300, Training Accuracy: 38.91%
Epoch [7/20], Loss: 1.6062, Training Accuracy: 39.95%
Epoch [8/20], Loss: 1.5485, Training Accuracy: 42.96%
Epoch [9/20], Loss: 1.5009, Training Accuracy: 44.38%
Epoch [10/20], Loss: 1.4515, Training Accuracy: 46.02%
Epoch [11/20], Loss: 1.4142, Training Accuracy: 47.98%
Epoch [12/20], Loss: 1.3829, Training Accuracy: 49.69%
Epoch [13/20], Loss: 1.3432, Training Accuracy: 50.73%
Epoch [14/20], Loss: 1.3096, Training Accuracy: 51.95%
Epoch [15/20], Loss: 1.2992, Training Accuracy: 52.28%
Epoch [16/20], Loss: 1.2796, Training Accuracy: 53.34%
Epoch [17/20], Loss: 1.2579, Training Accuracy: 54.03%
Epoch [18/20], Loss: 1.2351, Training Accuracy: 54.58%
Epoch [19/20], Loss