In [None]:
import torch
import torch.nn as nn
import numpy as np
import torchvision
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F
import time
import torch.optim as optim
import sys
!pip install timm
import timm




In [None]:
transform = transforms.Compose([
    transforms.Resize(size=(16, 16)),  # Resize to handle input size for ViT
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))  # CIFAR-10 normalization
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers = 2)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)

Files already downloaded and verified
Files already downloaded and verified


In [None]:

class SVDLinear(nn.Module):
    def __init__(self, in_features, out_features, rank=None):
        super(SVDLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank or min(in_features, out_features)

        # Initialize weight and perform SVD
        init_weight = torch.randn(out_features, in_features)
        U, S, V = torch.svd(init_weight)

        # Reduce to desired rank by truncating SVD components
        self.U = nn.Parameter(U[:, :self.rank])
        self.S = nn.Parameter(torch.diag(S[:self.rank]))
        self.V = nn.Parameter(V[:, :self.rank])

    def forward(self, x):
        weight = self.U @ self.S @ self.V.t()
        return x @ weight.t()



class LowRankLinear(nn.Module):
    """ Low-rank linear layer using two smaller dense layers. """
    def __init__(self, in_features, out_features, rank):
        super().__init__()
        self.rank = rank
        self.linear1 = nn.Linear(in_features, rank)
        self.linear2 = nn.Linear(rank, out_features)

    def forward(self, x):
        return self.linear2(self.linear1(x))

class EfficientAttention(nn.Module):
    """ Efficient Attention with low-rank keys and values. """
    def __init__(self, dim, rank, num_heads=1, qkv_bias=False):
        super().__init__()
        self.num_heads = num_heads
        self.rank = rank
        self.head_dim = rank // num_heads
        self.scale = self.head_dim ** -0.5

        self.query = nn.Linear(dim, rank, bias=qkv_bias)
        self.key = SVDLinear(dim, rank, rank//2)
        self.value = SVDLinear(dim, rank, rank//2)
        self.proj = nn.Linear(rank, dim)

    def forward(self, x):
        B, N, _ = x.shape
        q = self.query(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = self.key(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = self.value(x).reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)

        attn = (q * self.scale) @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        x = (attn @ v).transpose(1, 2).reshape(B, N, self.rank)
        x = self.proj(x)
        return x

class EncodingFLORA(nn.Module):
    """ Encoding block with low-rank attention and feed-forward network. """
    def __init__(self, dim, num_heads=1, rank=128, hidden_mul=4, qkv_bias=False, act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = EfficientAttention(dim, rank, num_heads, qkv_bias)
        self.norm2 = norm_layer(dim)
        self.mlp = nn.Sequential(
            SVDLinear(dim, int(dim * hidden_mul), rank),
            act_layer(),
            nn.Linear(int(dim * hidden_mul), dim)
        )

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


In [None]:
def get_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    else:
        return torch.device('cpu')
device = get_device()
print(device)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


cpu


In [None]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)




In [None]:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.svd_linear1 = SVDLinear(768, 1024, rank=512)
        self.relu = nn.ReLU()
        self.svd_linear2 = SVDLinear(1024, 10, rank=512)  # Output for 10 classes

    def forward(self, x):
        x = x.view(-1, 16*16*3)  # Flatten the image
        x = self.svd_linear1(x)
        x = self.relu(x)
        x = self.svd_linear2(x)
        return x

model = SimpleNet()
model.apply(init_weights)
model.to(device)
if next(model.parameters()).is_cuda:
  print("model moved to cuda")

In [None]:
# Example check for NaN or Inf in inputs and labels
for inputs, labels in train_loader:
    assert not torch.isnan(inputs).any(), "Input has NaN values"
    assert not torch.isinf(inputs).any(), "Input has Inf values"
    assert not torch.isnan(labels).any(), "Labels have NaN values"
    assert not torch.isinf(labels).any(), "Labels have Inf values"


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001,  betas=(0.95, 0.99))

# Training loop
for epoch in range(100):  # number of epochs
    start_time = time.time()
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        # print(f'Epoch [{epoch +1}/{100}], Batch[{i+1}/{len(train_loader)}], Loss:{loss.item(): .3f}')
        if i == len(train_loader) - 1:
            end_time = time.time()
            time_taken = end_time - start_time
            print(f'Epoch[{epoch +1}/{100}], Average Loss: {running_loss/(i+1): .3f},Time Taken: {time_taken} seconds.\n')
            running_loss = 0.0
print('Finished Training')


Epoch[1/100], Average Loss:  3.062,Time Taken: 73.56437516212463 seconds.

Epoch[2/100], Average Loss:  2.410,Time Taken: 70.67940974235535 seconds.

Epoch[3/100], Average Loss:  1.996,Time Taken: 71.55696415901184 seconds.

Epoch[4/100], Average Loss:  1.699,Time Taken: 72.88640666007996 seconds.

Epoch[5/100], Average Loss:  1.411,Time Taken: 71.46317791938782 seconds.

Epoch[6/100], Average Loss:  1.263,Time Taken: 70.38732886314392 seconds.

Epoch[7/100], Average Loss:  1.073,Time Taken: 72.63445496559143 seconds.

Epoch[8/100], Average Loss:  0.975,Time Taken: 70.15279912948608 seconds.

Epoch[9/100], Average Loss:  0.859,Time Taken: 71.39410018920898 seconds.

Epoch[10/100], Average Loss:  0.783,Time Taken: 71.56018471717834 seconds.

Epoch[11/100], Average Loss:  0.721,Time Taken: 70.8158278465271 seconds.

Epoch[12/100], Average Loss:  0.673,Time Taken: 70.0799458026886 seconds.

Epoch[13/100], Average Loss:  0.636,Time Taken: 72.45615673065186 seconds.

Epoch[14/100], Average 

In [None]:
correct = 0
total = 0
with torch.no_grad():
  for data in test_loader:
    images,labels = data
    outputs = model(images)
    _, predicted = torch.max(outputs.data,1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

print(f'Acccuracy: {100*correct/total}%')

Acccuracy: 45.77%
