In [None]:
# Machine Learning Course Project - Binara Siriwardhana - Federico Pappani - Dalarna University

%pylab inline
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# Hyperparameters
dropout_rates = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
network_sizes = [1, 2, 4, 8]
num_epochs = 15

# CNN implementation
class CNN(nn.Module):
    def __init__(self, dropout_rate, network_size):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32 * network_size, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(32 * network_size * 16 * 16, 128 * network_size)
        self.fc2 = nn.Linear(128 * network_size, 10)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = x.view(-1, 32 * network_size * 16 * 16)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Load CIFAR-10 dataset and test data
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

losses = []
# Define training loop and evaluation function
def train_model(model, dropout_rate, network_size):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f'Dropout Rate {dropout_rate}, Network Size {network_size}, Epoch {epoch+1}, Loss: {running_loss / (i+1)}')

def evaluate_model(model):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

# Plot results
results = []

for dropout_rate in dropout_rates:
    for network_size in network_sizes:
        model = CNN(dropout_rate, network_size)
        train_model(model, dropout_rate, network_size)
        accuracy = evaluate_model(model)
        results.append((dropout_rate, network_size, accuracy))

# Print experiment results
for dropout_rate, network_size, accuracy in results:
    print(f'Dropout Rate: {dropout_rate}, Network Size: {network_size}, Accuracy: {accuracy}%')

accuracies = [[0] * len(network_sizes) for _ in range(len(dropout_rates))]

for result in results:
    dropout_rate, network_size, accuracy = result
    dropout_idx = dropout_rates.index(dropout_rate)
    size_idx = network_sizes.index(network_size)
    accuracies[dropout_idx][size_idx] = accuracy

plt.figure(figsize=(8, 6))
plt.imshow(accuracies, cmap='viridis', origin='lower', aspect='auto')
plt.colorbar(label='Accuracy (%)')
plt.xticks(range(len(network_sizes)), network_sizes)
plt.yticks(range(len(dropout_rates)), dropout_rates)
plt.xlabel('Network Size', fontsize=12)
plt.ylabel('Dropout Rate', fontsize=12)
plt.title('CNN Accuracy Heatmap', fontsize=14)

# Add annotations
for i in range(len(dropout_rates)):
    for j in range(len(network_sizes)):
        plt.text(j, i, f'{accuracies[i][j]:.2f}%', ha='center', va='center', color='k', fontsize=10)
plt.show()