In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

from models import MLP, LSTM
from train import train_model
from utils import shuffle_labels


In [None]:
# Set random seed for reproducibility
torch.manual_seed(42)

# Load and preprocess MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', 
                                         train=True,
                                         transform=transform,
                                         download=True)

test_dataset = torchvision.datasets.MNIST(root='./data',
                                        train=False, 
                                        transform=transform)

# Create shuffled data loaders
train_loader = DataLoader(dataset=train_dataset,
                         batch_size=128,
                         shuffle=True)

test_loader = DataLoader(dataset=test_dataset,
                        batch_size=128,
                        shuffle=False)

In [None]:
# Define hyperparameters
input_size = 28 * 28  # MNIST images are 28x28
hidden_size = 25
num_classes = 10
num_epochs = 50
batch_size = 6000 #None # 128
learning_rate = 0.01
dropout_rate = 0.0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
mlp_model = MLP(
    input_size=input_size, 
    hidden_size=hidden_size, 
    num_classes=num_classes, 
    dropout_rate=dropout_rate
).to(device)
mlp_model_criterion = nn.CrossEntropyLoss()
mlp_model_optimizer = torch.optim.Adam(mlp_model.parameters(), lr=learning_rate)

lstm_model = LSTM(
    input_size=28, # each sequence has 28 features
    hidden_size=hidden_size, 
    num_layers=1, 
    num_classes=num_classes, 
    dropout_rate=dropout_rate
).to(device)
lstm_model_criterion = nn.CrossEntropyLoss()
lstm_model_optimizer = torch.optim.Adam(lstm_model.parameters(), lr=learning_rate)

models = []

models.append(('mlp', mlp_model, mlp_model_criterion, mlp_model_optimizer))
models.append(('lstm', lstm_model, lstm_model_criterion, lstm_model_optimizer))


In [None]:
num_shuffles=30
train_logs = {}

for model_name, model, criterion, optimizer in models:
    # Initialize lists to store results for all runs
    all_train_losses = []
    all_test_accuracies = []
    all_label_mappings = []
    

    # Repeat training with different label shufflings
    for run in range(num_shuffles):
        print(f"\nStarting Run {run + 1}/{num_shuffles}")
        
        # Shuffle the labels
        shuffled_train, label_mapping = shuffle_labels(train_dataset)
        shuffled_test, _ = shuffle_labels(test_dataset, label_mapping)  # Use same mapping for test set
        
        # Print the label mapping for this run
        print("Label mapping for this run:")
        print("Original:  ", " ".join(str(i) for i in range(10)))
        print("Mapped to: ", " ".join(str(label_mapping[i]) for i in range(10)))
        
        # Train the model
        train_losses, test_accuracies = train_model(
            model=model, 
            train_data=shuffled_train, 
            test_data=shuffled_test, 
            num_epochs=num_epochs, 
            device=device, 
            criterion=criterion, 
            optimizer=optimizer,
            batch_size=batch_size
        )
        
        # Store results
        all_train_losses.append(train_losses)
        all_test_accuracies.append(test_accuracies)
        all_label_mappings.append(label_mapping)
    
    train_logs[model_name] = {
        'all_train_losses': all_train_losses,
        'all_test_accuracies': all_test_accuracies,
        'all_label_mappings': all_label_mappings
    }


In [None]:
for model_name, train_log in train_logs.items():
    all_train_losses = train_log['all_train_losses']
    all_test_accuracies = train_log['all_test_accuracies']
    all_label_mappings = train_log['all_label_mappings']

    # Create separate figures for loss and accuracy
    # Loss figure
    fig_loss = plt.figure(figsize=(15, 5))
    combined_losses = [loss for run_losses in all_train_losses for loss in run_losses]
    plt.plot(combined_losses)
    plt.title('Training Loss', fontsize=14)
    plt.xlabel('Training Steps', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    for i in range(1, num_shuffles):
        plt.axvline(x=i*len(combined_losses)//num_shuffles, color='red', linestyle='--', alpha=0.5)
    plt.savefig(f'mnist_reshuffle_{model_name}_loss.png')
    plt.close(fig_loss)

    # Accuracy figure
    fig_acc = plt.figure(figsize=(15, 5))
    combined_accuracies = [acc for run_accuracies in all_test_accuracies for acc in run_accuracies]
    plt.plot(combined_accuracies)
    plt.title('Test Accuracy', fontsize=14)
    plt.xlabel('Epochs', fontsize=12)
    plt.ylabel('Accuracy (%)', fontsize=12)
    for i in range(1, num_shuffles):
        plt.axvline(x=i*len(combined_accuracies)//num_shuffles, color='red', linestyle='--', alpha=0.5)
    plt.savefig(f'mnist_reshuffle_{model_name}_accuracy.png')
    plt.close(fig_acc)

    # If you also want to display them together
    fig_combined, (ax1, ax2) = plt.subplots(2, 1, figsize=(15, 10))
    
    ax1.plot(combined_losses)
    ax1.set_title('Training Loss', fontsize=14)
    ax1.set_xlabel('Training Steps', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    for i in range(1, num_shuffles):
        ax1.axvline(x=i*len(combined_losses)//num_shuffles, color='red', linestyle='--', alpha=0.5)

    ax2.plot(combined_accuracies)
    ax2.set_title('Test Accuracy', fontsize=14)
    ax2.set_xlabel('Epochs', fontsize=12)
    ax2.set_ylabel('Accuracy (%)', fontsize=12)
    for i in range(1, num_shuffles):
        ax2.axvline(x=i*len(combined_accuracies)//num_shuffles, color='red', linestyle='--', alpha=0.5)

    plt.tight_layout()
    plt.savefig(f'mnist_reshuffle_{model_name}_combined.png')
    plt.show()