In [None]:
!pip install torch torchvision matplotlib


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt
import time
import os


In [None]:
class LowRankCNN(nn.Module):
    def __init__(self, rank=20):
        super(LowRankCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.fc1_A = nn.Linear(32 * 7 * 7, rank, bias=False)
        self.fc1_B = nn.Linear(rank, 128, bias=True)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.reshape(-1, 32 * 7 * 7)
        x = self.fc1_A(x)
        x = self.fc1_B(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x


In [None]:
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

def evaluate(model, dataloader, device):
    model.eval()
    correct, total = 0, 0
    start = time.time()
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    end = time.time()
    return 100 * correct / total, end - start


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader_clean = DataLoader(test_dataset, batch_size=1000, shuffle=False)


In [None]:
def add_noise(dataset, noise_level=0.3):
    raw = dataset.data.float() / 255.0
    noise = torch.randn_like(raw) * noise_level
    noisy = torch.clamp(raw + noise, 0., 1.)
    noisy = (noisy - 0.1307) / 0.3081
    return TensorDataset(noisy.unsqueeze(1), dataset.targets)

test_loader_noisy = DataLoader(add_noise(test_dataset, 0.3), batch_size=1000)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LowRankCNN(rank=20).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(3):
    train(model, train_loader, criterion, optimizer, device)
    print(f"Epoch {epoch+1}/3 complete")

acc_clean, time_clean = evaluate(model, test_loader_clean, device)
acc_noisy, time_noisy = evaluate(model, test_loader_noisy, device)

torch.save(model.state_dict(), "low_rank_model.pth")
size_mb = os.path.getsize("low_rank_model.pth") / (1024 ** 2)

print(f"âœ… Clean Accuracy: {acc_clean:.2f}% | Time: {time_clean:.2f}s")
print(f"âœ… Noisy Accuracy: {acc_noisy:.2f}% | Time: {time_noisy:.2f}s")
print(f"ðŸ“¦ Model Size: {size_mb:.2f} MB")
