In [None]:
!pip install einops



In [None]:
import torch
import torch.nn as nn
import math
from sklearn.decomposition import PCA
from torch.nn.parameter import Parameter
from einops import rearrange,repeat
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from einops.layers.torch import Rearrange
import torch.nn.functional as F
from torch import optim

In [None]:
# helpers
def pair(t):
    return t if isinstance(t, tuple) else (t, t)

In [None]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

In [None]:
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, PCAAttention(dim, heads=heads, dim_head=dim_head, dropout=dropout, pca_components=16)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

In [None]:
class LRLinearSuper(nn.Module):
    def __init__(self, in_channel, out_channel, bias=True, fused=False, sample_ratio=1.0):
        super().__init__()
        self.bias = bias
        self.fused = fused
        self.sample_ratio = sample_ratio
        self.num_components = min(in_channel, out_channel)
        self.VT = nn.Linear(in_channel, int(round(self.num_components * sample_ratio)), bias=False)
        self.U = nn.Linear(int(round(self.num_components * sample_ratio)), out_channel, bias=bias)

    def forward(self, x):
        if self.fused:
            weight = self.U.weight @ self.VT.weight
            if self.bias:
                return F.linear(x, weight, self.U.bias)
            else:
                return F.linear(x, weight)
        else:
            x = self.VT(x)
            return self.U(x)

In [None]:
# PCA-based Low Rank Linear Layer
class PCALinear(nn.Module):
    def __init__(self, in_features, out_features, rank):
        super(PCALinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.rank = rank

        # Initialize weights and bias
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        self.bias = Parameter(torch.Tensor(out_features))
        self.reset_parameters()

        # PCA components (initialized as None)
        self.components = None
        self.singular_values = None

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def perform_pca(self):
        # Perform PCA on the detached weight matrix
        pca = PCA(n_components=self.rank)
        flattened_weights = self.weight.detach().cpu().numpy().reshape(self.out_features, -1)
        pca.fit(flattened_weights)
        self.components = torch.tensor(pca.components_, dtype=torch.float32, device=self.weight.device)
        self.singular_values = torch.tensor(pca.singular_values_, dtype=torch.float32, device=self.weight.device)

    def forward(self, input):
        if self.components is None or self.singular_values is None:
            self.perform_pca()

        # Project the input using PCA components
        transformed_input = input.matmul(self.components.T) * self.singular_values
        return torch.nn.functional.linear(transformed_input, self.components, self.bias)

    def extra_repr(self):
        return 'in_features={}, out_features={}, rank={}, bias={}'.format(
            self.in_features, self.out_features, self.rank, self.bias is not None
        )

In [None]:
# Pre-norm layer
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

In [None]:
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0., ratio=0.5):
        super().__init__()
        self.net = nn.Sequential(
            LRLinearSuper(dim, hidden_dim, fused=True, sample_ratio=ratio),
            nn.GELU(),
            nn.Dropout(dropout),
            LRLinearSuper(hidden_dim, dim, fused=True, sample_ratio=ratio),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

In [None]:
class PCAAttention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0., ratio=0.5, pca_components=32):
        super().__init__()
        self.heads = heads
        self.dim_head = dim_head
        inner_dim = dim_head * heads
        self.scale = dim_head ** -0.5
        self.attend = nn.Softmax(dim=-1)

        self.to_q = LRLinearSuper(dim, inner_dim, fused=True, sample_ratio=ratio)
        self.to_k = LRLinearSuper(dim, inner_dim, fused=True, sample_ratio=ratio)
        self.to_v = LRLinearSuper(dim, inner_dim, fused=True, sample_ratio=ratio)

        self.compress_k = nn.Linear(inner_dim, pca_components, bias=False)
        self.expand_k = nn.Linear(pca_components, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            LRLinearSuper(inner_dim, dim, fused=True, sample_ratio=ratio),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(x)

        # Make sure the rearrangement matches the inner_dim
        #print("Shape of q, k, v before rearrange:", q.shape, k.shape, v.shape)
        q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads, d=self.dim_head)
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads, d=self.dim_head)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads, d=self.dim_head)
        #print("Shape of q, k, v after rearrange:", q.shape, k.shape, v.shape)

        # Make sure compression is done correctly
        k_reshaped = rearrange(k, 'b h n d -> b n (h d)')
        #print("Shape of k before compress:", k_reshaped.shape)
        k_compressed = self.compress_k(k_reshaped)
        k_expanded = self.expand_k(k_compressed)
        #print("Shape of k after expansion:", k_expanded.shape)

        # Rearrange to fit q
        k_expanded = rearrange(k_expanded, 'b n (h d) -> b h n d', h=self.heads, d=self.dim_head)
        dots = torch.matmul(q, k_expanded.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        #print("Shape of output after attention:", out.shape)

        return self.to_out(out)


In [None]:
class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        self.pool = pool  # Ensure pool is defined

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.Linear(patch_dim, dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        if self.pool == 'mean':
            x = x.mean(dim=1)
        elif self.pool == 'cls':
            x = x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [None]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010])
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)  # Adjusted number of workers
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=2)




Files already downloaded and verified
Files already downloaded and verified


In [None]:
def compute_accuracy(outputs, labels):
    _, predictions = torch.max(outputs, 1)
    correct = (predictions == labels).type(torch.float).sum().item()
    return correct / labels.size(0)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ViT(
    image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dim_head = 64,
    dropout = 0.1,
    emb_dropout = 0.1
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
# Training loop with accuracy
for epoch in range(10):
    model.train()
    train_loss = 0
    train_accuracy = 0
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        train_accuracy += compute_accuracy(outputs, labels)

    avg_train_loss = train_loss / len(train_loader)
    avg_train_accuracy = train_accuracy / len(train_loader)
    print(f'Epoch {epoch+1}, Loss: {avg_train_loss:.4f}, Accuracy: {avg_train_accuracy:.4f}')

# Testing loop for accuracy after training
model.eval()
test_accuracy = 0
with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        test_accuracy += compute_accuracy(outputs, labels)

avg_test_accuracy = test_accuracy / len(test_loader)
print(f'Test Accuracy: {avg_test_accuracy:.4f}')

  self.pid = os.fork()


Epoch 1, Loss: 2.3330, Accuracy: 0.1230
Epoch 2, Loss: 2.2178, Accuracy: 0.1630
Epoch 3, Loss: 2.2069, Accuracy: 0.1661
Epoch 4, Loss: 2.1975, Accuracy: 0.1662
Epoch 5, Loss: 2.1605, Accuracy: 0.1801
Epoch 6, Loss: 2.2352, Accuracy: 0.1480
Epoch 7, Loss: 2.2818, Accuracy: 0.1306
Epoch 8, Loss: 2.2685, Accuracy: 0.1348
Epoch 9, Loss: 2.2546, Accuracy: 0.1421
Epoch 10, Loss: 2.2622, Accuracy: 0.1367
Test Accuracy: 0.1436
