In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F
!pip install timm
import timm
import time

Collecting timm
  Downloading timm-0.9.16-py3-none-any.whl (2.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->timm)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->timm)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->timm)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->timm)
  Using cache

In [None]:
def low_rank_linear(in_features, out_features, rank, bias=True):
    """ Low-rank approximation of a linear layer """
    return nn.Sequential(
        nn.Linear(in_features, rank, bias=False),
        nn.Linear(rank, out_features, bias=bias)
    )

class LowRankSelfAttention(nn.Module):
    def __init__(self, dim, num_heads, rank, qkv_bias=False, qk_scale=None):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # Low-rank QKV
        self.qkv = low_rank_linear(dim, dim * 3, rank=rank * 3, bias=qkv_bias)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        return self.proj(x)

class LowRankTransformerBlock(nn.Module):
    def __init__(self, dim, num_heads, rank, qkv_bias=False, qk_scale=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = LowRankSelfAttention(dim, num_heads, rank, qkv_bias, qk_scale)
        self.norm2 = norm_layer(dim)
        self.mlp = nn.Sequential(
            low_rank_linear(dim, dim * 4, rank=dim * 4),
            act_layer(),
            low_rank_linear(dim * 4, dim, rank=dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


In [None]:
transform = transforms.Compose([
    transforms.Resize(size=(16, 16)),  # Resize to handle input size for ViT
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])  # CIFAR-10 normalization
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers = 2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 60784776.72it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
class ViT_FLORA(nn.Module):
    def __init__(self, img_size=16, patch_size=4, token_len=256, num_classes=10, num_heads=8, rank_factor=0.5):
        super().__init__()
        num_patches = (img_size // patch_size) ** 2  # This calculates the number of patches
        self.token_len = token_len
        self.cls_token = nn.Parameter(torch.zeros(1, 1, token_len))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, token_len))  # Ensure this matches num_patches + 1 for cls token
        timm.layers.trunc_normal_(self.pos_embed, std=.02)
        timm.layers.trunc_normal_(self.cls_token, std=.02)

        self.patch_emb = nn.Conv2d(3, token_len, kernel_size=patch_size, stride=patch_size)
        self.transformer = nn.Sequential(*[
            LowRankTransformerBlock(dim=token_len, num_heads=num_heads, rank=int(rank_factor * token_len))
            for _ in range(12)
        ])
        self.head = nn.Linear(token_len, num_classes)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_emb(x).flatten(2).transpose(1, 2)
        x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
        x += self.pos_embed  # Add positional embeddings
        x = self.transformer(x)
        x = self.head(x[:, 0])
        return x


In [None]:
import sys

def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
device = get_device()
print(device)

# Setup the training components
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViT_FLORA()
model.to(device)
if next(model.parameters()).is_cuda:
  print("model moved to cuda")

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
loss_fn = nn.CrossEntropyLoss()


for epoch in range(10):  # loop over the dataset multiple times
    start_time = time.time()
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)
        #print(inputs.shape)
        #print(labels.shape)
        #sys.exit()

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        #normalized_loss = loss.mean() / len(inputs)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Print statistics


        running_loss += loss.item()
        print(f'Epoch [{epoch + 1}/{10}], Batch [{i + 1}/{len(train_loader)}], Loss: {loss.item():.3f}')
        if i == len(train_loader) - 1:  # If it's the last mini-batch
            end_time = time.time()
            time_taken = end_time - start_time
            start_time = 0
            end_time = 0
            # Print the average loss over the epoch
            print(f'Epoch [{epoch + 1}/{10}], Average Loss: {running_loss / (i + 1):.3f}, Time Taken: {time_taken} seconds.\n')
            time_taken =0
            running_loss = 0.0



cpu


  self.pid = os.fork()


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch [4/10], Batch [489/782], Loss: 1.139
Epoch [4/10], Batch [490/782], Loss: 1.149
Epoch [4/10], Batch [491/782], Loss: 1.077
Epoch [4/10], Batch [492/782], Loss: 1.095
Epoch [4/10], Batch [493/782], Loss: 0.985
Epoch [4/10], Batch [494/782], Loss: 1.104
Epoch [4/10], Batch [495/782], Loss: 1.141
Epoch [4/10], Batch [496/782], Loss: 1.056
Epoch [4/10], Batch [497/782], Loss: 1.046
Epoch [4/10], Batch [498/782], Loss: 1.001
Epoch [4/10], Batch [499/782], Loss: 1.374
Epoch [4/10], Batch [500/782], Loss: 1.158
Epoch [4/10], Batch [501/782], Loss: 1.049
Epoch [4/10], Batch [502/782], Loss: 1.071
Epoch [4/10], Batch [503/782], Loss: 1.122
Epoch [4/10], Batch [504/782], Loss: 1.032
Epoch [4/10], Batch [505/782], Loss: 0.890
Epoch [4/10], Batch [506/782], Loss: 1.196
Epoch [4/10], Batch [507/782], Loss: 1.096
Epoch [4/10], Batch [508/782], Loss: 0.913
Epoch [4/10], Batch [509/782], Loss: 1.083
Epoch [4/10], Batch [510/782], L

In [None]:
correct = 0
total = 0
with torch.no_grad():
  for data in test_loader:
    images,labels = data
    outputs = model(images)
    _, predicted = torch.max(outputs.data,1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print(f'Acccuracy: {100*correct/total}%')

Acccuracy: 60.97%
