In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import time
import os
import matplotlib.pyplot as plt
from torchvision import transforms
from itertools import product

import torch.ao.quantization as quantization
from torch.ao.quantization import get_default_qconfig
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, prepare_qat_fx

In [None]:
class MNISTDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        data = pd.read_csv(csv_file)
        self.labels = data.iloc[:, 0].values
        self.pixels = data.iloc[:, 1:].values.astype('float32')
        self.pixels = self.pixels.reshape(-1, 28, 28)  # Reshape to 28x28 images

        # Normalize the pixel values
        self.pixels_mean = self.pixels.mean()
        self.pixels_std = self.pixels.std()
        self.pixels = (self.pixels - self.pixels_mean) / self.pixels_std

        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        image = self.pixels[idx]
        label = self.labels[idx]

        if self.transform:
            image = self.transform(torch.tensor(image).unsqueeze(0))

        return image.squeeze(0), torch.tensor(label)
    

class FFNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes, num_hidden_layers):
        super(FFNN, self).__init__()
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(input_size, hidden_size))
        for _ in range(num_hidden_layers - 1):
            self.layers.append(nn.Linear(hidden_size, hidden_size))
        self.layers.append(nn.Linear(hidden_size, num_classes))
        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        for layer in self.layers[:-1]:
            x = self.relu(layer(x))
        return self.layers[-1](x)

In [None]:
def create_dataloader(dataset_path, batch_size, is_train=True):
    # Create center crop transform
    transform = transforms.Compose([
        transforms.CenterCrop(20)  # Crop to 20x20 as specified
    ])
    
    # Create dataset and dataloader
    dataset = MNISTDataset(dataset_path, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=is_train)

def print_size_of_model(model, label=""):
    torch.save(model.state_dict(), "temp.p")
    size=os.path.getsize("temp.p")
    print("model: ",label,' \t','Size (MB):', size/1e6)
    os.remove('temp.p')
    return size

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


In [None]:
def train_model(model, train_loader, val_loader, epochs, learning_rate, device):
    print(f"Training normal precision model for {epochs} epochs")
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
        # Validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
        
        print(f'Epoch {epoch+1}, Accuracy: {100 * correct / total:.2f}%')
    
    return model

In [None]:
def train_model_mixed_precision(model, train_loader, val_loader, epochs, learning_rate, device):
    print(f"Training model with mixed precision for {epochs} epochs")
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Do mixed precision training with torch.autocast and GradScaler
    scaler = torch.cuda.amp.GradScaler()
    
    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            
            with torch.cuda.amp.autocast():
                output = model(data)
                loss = criterion(output, target)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
        # Validation
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
        
        print(f'Epoch {epoch+1}, Accuracy: {100 * correct / total:.2f}%')
    
    return model

In [None]:
def measure_inference_time(model, test_loader, batch_size, num_runs=5):
    model.eval()
    times = []
    
    with torch.no_grad():
        for _ in range(num_runs):
            data, _ = next(iter(test_loader))
            if batch_size == 1:
                data = data[0:1]
                
            start_time = time.time()
            _ = model(data)
            end_time = time.time()
            times.append(end_time - start_time)
    
    mean_time = np.mean(times)
    std_time = np.std(times)
    return mean_time, std_time

In [None]:
def evaluate_model(model, test_loader, mixed_precision=False):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            if mixed_precision:
                with torch.cuda.amp.autocast():
                    outputs = model(data)
            else:
                outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    return 100 * correct / total

In [None]:
def main():
    # Hyperparameters
    input_size = 20 * 20  # 20x20 pixels
    hidden_size = 1024
    num_classes = 10
    num_hidden_layers = 2
    batch_size = 64
    learning_rate = 0.001
    epochs = 2
    
    # Create model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = FFNN(input_size, hidden_size, num_classes, num_hidden_layers)
    
    # Create dataloaders
    train_loader = create_dataloader('data/mnist_train.csv', batch_size, True)
    test_loader = create_dataloader('data/mnist_test.csv', batch_size, False)
    
    # Train base model
    
    model = model.to(device)
    model = train_model(model, train_loader, test_loader, epochs, learning_rate, device)


    

In [None]:
main()