In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import time

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# Model 1: Basic CNN with 5 fully connected layers
class BasicCNN(nn.Module):
    def __init__(self):
        super(BasicCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        # Fully connected layers
        self.fc1 = nn.Linear(64 * 3 * 3, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, 32)
        self.fc5 = nn.Linear(32, 16)
        self.output = nn.Linear(16, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        x = x.view(-1, 64 * 3 * 3)  # Flatten
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = torch.relu(self.fc4(x))
        x = torch.relu(self.fc5(x))
        x = self.output(x)
        return x

# Model 2: CNN with Random Dense Connections (50% chance)
class RandomLinear(nn.Module):
    def __init__(self, in_features, out_features, p=0.5):
        super(RandomLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.fc = nn.Linear(in_features, out_features)
        self.mask = self.create_random_mask(p)

    def create_random_mask(self, p):
        return torch.bernoulli(torch.full((self.out_features, self.in_features), p))

    def forward(self, x):
        masked_weight = self.fc.weight * self.mask
        return torch.nn.functional.linear(x, masked_weight, self.fc.bias)

class RandomDenseCNN(nn.Module):
    def __init__(self):
        super(RandomDenseCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        # Dense layers with random connections
        self.fc1 = RandomLinear(64 * 3 * 3, 256, p=0.5)
        self.fc2 = RandomLinear(256, 128, p=0.5)
        self.fc3 = RandomLinear(128, 64, p=0.5)
        self.fc4 = RandomLinear(64, 32, p=0.5)
        self.fc5 = RandomLinear(32, 16, p=0.5)
        self.output = nn.Linear(16, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        x = x.view(-1, 64 * 3 * 3)  # Flatten
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.relu(self.fc3(x))
        x = torch.relu(self.fc4(x))
        x = torch.relu(self.fc5(x))
        x = self.output(x)
        return x

# Function to train and time the model
def train_model(model, train_loader, device, epochs=5):
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    start_time = time.time()
    for epoch in range(epochs):
        running_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}', leave=False)
        for images, labels in progress_bar:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            progress_bar.set_postfix(loss=running_loss / len(train_loader))

    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f'Time taken for {model.__class__.__name__}: {elapsed_time:.2f} seconds')

# Function to evaluate the model
def evaluate_model(model, test_loader, device):
    model.eval()  # Set to evaluation mode
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of {model.__class__.__name__}: {accuracy:.2f}%')

# Training and comparing both models
basic_model = BasicCNN()
random_model = RandomDenseCNN()

print("Training BasicCNN:")
train_model(basic_model, train_loader, device)
print("Evaluating BasicCNN:")
evaluate_model(basic_model, test_loader, device)

print("\nTraining RandomDenseCNN:")
train_model(random_model, train_loader, device)
print("Evaluating RandomDenseCNN:")
evaluate_model(random_model, test_loader, device)


Using device: cpu
Training BasicCNN:


                                                                                                                       

Time taken for BasicCNN: 146.57 seconds
Evaluating BasicCNN:
Accuracy of BasicCNN: 99.11%

Training RandomDenseCNN:


                                                                                                                       

Time taken for RandomDenseCNN: 149.69 seconds
Evaluating RandomDenseCNN:
Accuracy of RandomDenseCNN: 98.69%
