In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import json
from PIL import Image
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms
from tqdm import tqdm
import psutil
import threading
import time
import torchvision

In [None]:
"""
# Biased MNIST CNN Training - CrossEntropy Version

## Model Architecture:
- SimpleCNN with 2 convolutional layers followed by a fully connected layer
- Input: RGB images (3 channels, 160x160)
- Conv1: 3 -> 16 channels, 3x3 kernel, stride 2, padding 1, followed by ReLU and MaxPool
- Conv2: 16 -> 32 channels, 3x3 kernel, stride 2, padding 1, followed by ReLU and MaxPool
- Final feature map size: 32 x 10 x 10 
- Fully connected layer: 3200 -> 10 (one for each digit)

## Loss Function:
- CrossEntropyLoss
- Works directly with class indices (0-9)
- Good for multi-class classification problems

## Training Details:
- Optimizer: Adam with learning rate 0.001
- Batch size: 32
- Epochs: 20 for subset training (10% of training data)
- Dataset: Biased MNIST with correlation level 0.5
- Uses data normalization based on dataset statistics

## How to Run:
1. Place your biased MNIST dataset in the correct paths:
   - Default paths assume: 'biased_mnist/full_0.5/trainval'
2. Run the subset training function for faster iteration:
   - run_biased_mnist_cnn_subset()

## Expected Outputs:
- Training/validation accuracies and loss will be displayed during training
- Model will be saved as 'simple_biased_mnist_cnn_subset_05.pth'

## Notes:
- Memory monitoring is included to track RAM usage
- You can adjust the correlation level by changing paths and filenames
- The test on Standard MNIST probably isn't a good idea I am not sure yet. 
"""

In [None]:
# Memory monitoring class
class MemoryMonitor:
    def __init__(self, interval=1.0):
        self.interval = interval
        self.running = False
        self.thread = None
        self.max_memory = 0
        self.current_memory = 0
    
    def memory_monitor_func(self):
        while self.running:
            # Get memory info
            process = psutil.Process(os.getpid())
            memory_info = process.memory_info()
            memory_mb = memory_info.rss / (1024 * 1024)  # converting to mb
            
            self.current_memory = memory_mb
            self.max_memory = max(self.max_memory, memory_mb)
            time.sleep(self.interval)
    
    def start(self):
        self.running = True
        self.thread = threading.Thread(target=self.memory_monitor_func)
        self.thread.daemon = True
        self.thread.start()
    
    def stop(self):
        self.running = False
        if self.thread:
            self.thread.join(timeout=2.0)
    
    def get_memory_usage(self):
        return {
            'current': self.current_memory,
            'max': self.max_memory
        }

In [None]:
# Define the BiasedMNISTDataset class with JSON integration
class BiasedMNISTDataset(Dataset):
    def __init__(self, root_dir, transform=None, json_path=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        
        # load labels from JSON
        label_dict = {}
        if json_path and os.path.exists(json_path):
            try:
                with open(json_path, 'r') as f:
                    json_data = json.load(f)

                if isinstance(json_data, list):
                    for item in json_data:
                        if isinstance(item, dict) and 'index' in item and 'digit' in item:
                            label_dict[item['index']] = item['digit']
                print(f"Loaded {len(label_dict)} labels from JSON file")
            except Exception as e:
                print(f"Error loading JSON: {e}")
        
        # Load the images
        if os.path.exists(root_dir):
            for filename in os.listdir(root_dir):
                if filename.endswith('.jpg'):
                    self.image_paths.append(os.path.join(root_dir, filename))

                    try:
                        index = int(os.path.basename(filename).split('.')[0])
                        
                        # get label from JSON
                        if index in label_dict:
                            label = label_dict[index]
                        else:
                            label = index % 10
                        
                    except (ValueError, IndexError) as e:
                        print(f"Error parsing filename {filename}: {e}")
                        label = 0
                    
                    self.labels.append(label)
        else:
            print(f"Directory not found: {root_dir}")
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path)
        label = self.labels[idx]
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

In [None]:
# Define a simplified CNN architecture
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        # two convolutional layers
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        
        # Activation and pooling
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        
        # After conv1 (stride 2): 160x160 -> 80x80
        # After pool: 80x80 -> 40x40
        # After conv2 (stride 2): 40x40 -> 20x20
        # After pool: 20x20 -> 10x10
        # So final feature map size is 32 x 10 x 10
        self.fc = nn.Linear(32 * 10 * 10, num_classes)
        
    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [None]:
# Calculate dataset statistics with a small sample
def calculate_stats_fast(dataset, sample_size=1000):
    indices = torch.randperm(len(dataset))[:sample_size]
    
    # Create a temporary dataloader
    mini_loader = DataLoader(
        Subset(dataset, indices),
        batch_size=100,
        shuffle=False,
        num_workers=0
    )
    
    print(f"Calculating statistics using {sample_size} random samples (fast mode)...")
    
    channels_sum = torch.zeros(3)
    channels_squared_sum = torch.zeros(3)
    num_batches = 0
    
    # memory monitoring
    memory_monitor = MemoryMonitor()
    memory_monitor.start()
    
    try:
        # progress bar
        progress_bar = tqdm(mini_loader, desc="Calculating stats")
        
        for data, _ in progress_bar:
            # [batch_size, 3, height, width]
            channels_sum += torch.mean(data, dim=[0, 2, 3])
            channels_squared_sum += torch.mean(data**2, dim=[0, 2, 3])
            num_batches += 1
            
            mem_usage = memory_monitor.get_memory_usage()
            progress_bar.set_postfix({'RAM': f"{mem_usage['current']:.1f}MB"})
    
    finally:
        memory_monitor.stop()

        mem_usage = memory_monitor.get_memory_usage()
        print(f"Statistics calculation - Maximum RAM usage: {mem_usage['max']:.1f}MB")
    
    mean = channels_sum / num_batches
    std = (channels_squared_sum / num_batches - mean**2)**0.5
    
    print(f"Fast statistics calculation complete. Using sample of {sample_size} images.")
    return mean, std

def create_subset_loader(dataset, fraction=0.1, batch_size=32):
    subset_size = int(len(dataset) * fraction)
    indices = torch.randperm(len(dataset))[:subset_size]
    
    subset_dataset = Subset(dataset, indices)
    
    # DataLoader
    subset_loader = DataLoader(
        subset_dataset, 
        batch_size=batch_size,
        shuffle=True,
        num_workers=0
    )
    
    return subset_loader, subset_size

def check_label_range(dataset):
    try:
        sample_size = min(1000, len(dataset))
        indices = torch.randperm(len(dataset))[:sample_size]
        
        labels = [dataset[i.item()][1] for i in indices]
        min_label = min(labels)
        max_label = max(labels)
        
        print(f"Label range (from sample of {sample_size}): {min_label} to {max_label}")

        if max_label >= 10:
            print(f"WARNING: Found labels outside expected range (0-9): max={max_label}")
            return False
        return True
    except Exception as e:
        print(f"Error checking label range: {e}")
        return False

In [None]:
# CHANGES MADE: 
# - Using CrossEntropyLoss instead of MSE
# - Removed one-hot encoding as it's not needed for CrossEntropyLoss

def train_model(model, train_loader, test_loader, num_epochs=3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model = model.to(device)

    # Using CrossEntropyLoss
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    memory_monitor = MemoryMonitor()
    memory_monitor.start()
    
    try:
        for epoch in range(num_epochs):
            model.train()
            running_loss = 0.0
            correct = 0
            total = 0

            print(f"Epoch {epoch+1}/{num_epochs}")
            progress_bar = tqdm(train_loader, desc="Training")
            
            for inputs, labels in progress_bar:
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                
                # Calculate CrossEntropy loss directly with integer labels
                loss = criterion(outputs, labels)

                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()

                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

                mem_usage = memory_monitor.get_memory_usage()

                progress_bar.set_postfix({
                    'loss': f"{loss.item():.4f}",
                    'acc': f"{100 * correct / total:.2f}%",
                    'RAM': f"{mem_usage['current']:.1f}MB"
                })

            epoch_loss = running_loss / len(train_loader)
            epoch_acc = 100 * correct / total
            mem_usage = memory_monitor.get_memory_usage()
            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%, Max RAM: {mem_usage["max"]:.1f}MB')

            test_accuracy = evaluate_model(model, test_loader, device, memory_monitor)
            print(f'Test Accuracy: {test_accuracy:.2f}%')
    
    finally:
        memory_monitor.stop()
        mem_usage = memory_monitor.get_memory_usage()
        print(f"Maximum RAM usage: {mem_usage['max']:.1f}MB")
    
    return model

In [None]:
def evaluate_model(model, test_loader, device=None, memory_monitor=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.eval()
    correct = 0
    total = 0
    progress_bar = tqdm(test_loader, desc="Evaluating")
    
    with torch.no_grad():
        for inputs, labels in progress_bar:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            postfix = {'acc': f"{100 * correct / total:.2f}%"}
            if memory_monitor:
                mem_usage = memory_monitor.get_memory_usage()
                postfix['RAM'] = f"{mem_usage['current']:.1f}MB"
            
            progress_bar.set_postfix(postfix)
    
    accuracy = 100 * correct / total
    return accuracy

In [None]:
# CHANGES MADE:
# - Removed "mse" from the model filename

# Run the complete workflow with full dataset
def run_biased_mnist_cnn():
    # CHANGE THIS PATH AND RENAME MODEL FILE FOR EACH BIAS LEVEL
    base_dir = 'biased_mnist'
    train_folder = f"{base_dir}/full_0.5/trainval"  # Using full_0.5 for now
    test_folder = f"{base_dir}/full/test"
    train_json_path = f"{base_dir}/full_0.5/trainval.json" # Using full_0.5 for now
    test_json_path = f"{base_dir}/full/test.json"
    
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    
    print("Creating datasets...")
    train_dataset = BiasedMNISTDataset(train_folder, transform=transform, json_path=train_json_path)
    test_dataset = BiasedMNISTDataset(test_folder, transform=transform, json_path=test_json_path)

    print(f"Training dataset size: {len(train_dataset)}")
    print(f"Test dataset size: {len(test_dataset)}")
    
    if len(train_dataset) == 0 or len(test_dataset) == 0:
        print("Error: Empty dataset found. Please check the file paths.")
        return

    print("Checking label ranges...")
    train_labels_ok = check_label_range(train_dataset)
    test_labels_ok = check_label_range(test_dataset)
    
    if not (train_labels_ok and test_labels_ok):
        print("WARNING: Label range check failed. Please check the dataset.")

    mean, std = calculate_stats_fast(train_dataset, sample_size=1000)
    print(f"Dataset mean: {mean}")
    print(f"Dataset std: {std}")
    
    normalized_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    
    train_dataset_normalized = BiasedMNISTDataset(train_folder, transform=normalized_transform, json_path=train_json_path)
    test_dataset_normalized = BiasedMNISTDataset(test_folder, transform=normalized_transform, json_path=test_json_path)
    
    # dataloaders
    train_loader = DataLoader(train_dataset_normalized, batch_size=32, shuffle=True, num_workers=0)
    test_loader = DataLoader(test_dataset_normalized, batch_size=32, shuffle=False, num_workers=0)
    
    print(f"Using FULL training dataset with {len(train_dataset_normalized)} images")
    print(f"Using FULL test dataset with {len(test_dataset_normalized)} images")
    
    # the simplified CNN model
    model = SimpleCNN(num_classes=10)
    print("Model architecture:")
    print(model)
    
    # Train the model
    print("\nStarting training...")
    trained_model = train_model(model, train_loader, test_loader, num_epochs=3)
    
    # Final EVAL
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    final_accuracy = evaluate_model(trained_model, test_loader, device, None)
    print(f"\nFinal Test Accuracy: {final_accuracy:.2f}%")
    
    # Save model
    torch.save(trained_model.state_dict(), 'simple_biased_mnist_cnn_full.pth')
    print("Model saved as 'simple_biased_mnist_cnn_full.pth'")

In [None]:
# CHANGES MADE:
# - Removed "mse" from the model filename

# Run the complete workflow with subset of data
def run_biased_mnist_cnn_subset():
    # SET PATHS CORRECTLY FOR CORRECT BIAS LEVEL
    base_dir = 'biased_mnist'
    train_folder = f"{base_dir}/full_0.5/trainval"  # Using full_0.5 for now
    test_folder = f"{base_dir}/full/test"
    train_json_path = f"{base_dir}/full_0.5/trainval.json" # Change when changing correlation levels
    test_json_path = f"{base_dir}/full/test.json"
    
    transform = transforms.Compose([
        transforms.ToTensor()
    ])
    
    print("Creating datasets...")
    train_dataset = BiasedMNISTDataset(train_folder, transform=transform, json_path=train_json_path)
    test_dataset = BiasedMNISTDataset(test_folder, transform=transform, json_path=test_json_path)

    print(f"Training dataset size: {len(train_dataset)}")
    print(f"Test dataset size: {len(test_dataset)}")
    
    if len(train_dataset) == 0 or len(test_dataset) == 0:
        print("Error: Empty dataset found. Please check the file paths.")
        return

    print("Checking label ranges...")
    train_labels_ok = check_label_range(train_dataset)
    test_labels_ok = check_label_range(test_dataset)
    
    if not (train_labels_ok and test_labels_ok):
        print("WARNING: Label range check failed. Please check the dataset.")

    mean, std = calculate_stats_fast(train_dataset, sample_size=1000)
    print(f"Dataset mean: {mean}")
    print(f"Dataset std: {std}")

    normalized_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    
    train_dataset_normalized = BiasedMNISTDataset(train_folder, transform=normalized_transform, json_path=train_json_path)
    test_dataset_normalized = BiasedMNISTDataset(test_folder, transform=normalized_transform, json_path=test_json_path)

    train_fraction = 0.1  # Use 10% of training data (10,000 IMAGES)
    test_fraction = 0.2   # Use 20% of test data (2,000 IMAGES)
    
    train_loader, train_subset_size = create_subset_loader(
        train_dataset_normalized, 
        fraction=train_fraction, 
        batch_size=32
    )
    
    test_loader, test_subset_size = create_subset_loader(
        test_dataset_normalized, 
        fraction=test_fraction, 
        batch_size=32
    )
    
    print(f"Using {train_subset_size} training images ({train_fraction*100:.1f}% of dataset)")
    print(f"Using {test_subset_size} test images ({test_fraction*100:.1f}% of dataset)")
    
    # the simplified CNN model
    model = SimpleCNN(num_classes=10)
    print("Model architecture:")
    print(model)
    
    # Train the model
    print("\nStarting training...")
    trained_model = train_model(model, train_loader, test_loader, num_epochs=20)
    
    # Final EVAL
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    final_accuracy = evaluate_model(trained_model, test_loader, device, None)
    print(f"\nFinal Test Accuracy: {final_accuracy:.2f}%")
    
    # Save model
    torch.save(trained_model.state_dict(), 'simple_biased_mnist_cnn_subset_05.pth')
    print("Model saved as 'simple_biased_mnist_cnn_subset_05.pth'")

In [None]:
def load_trained_model(model_path):
    model = SimpleCNN(num_classes=10)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    return model

def load_standard_mnist():
    # Standard normalization for MNIST
    transform = transforms.Compose([
        transforms.Resize((160, 160)),  # Resize to match our biased MNIST images
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    mnist_test = torchvision.datasets.MNIST(
        root='./data', 
        train=False, 
        download=True, 
        transform=transform
    )
    
    class MNISTtoRGB(torch.utils.data.Dataset):
        def __init__(self, mnist_dataset):
            self.mnist_dataset = mnist_dataset
            
        def __len__(self):
            return len(self.mnist_dataset)
            
        def __getitem__(self, idx):
            img, label = self.mnist_dataset[idx]
            rgb_img = torch.cat([img, img, img], dim=0)
            return rgb_img, label
    
    rgb_mnist_test = MNISTtoRGB(mnist_test)
    return rgb_mnist_test

In [None]:
# CHANGES MADE:
# - Removed "mse" from model filenames

def test_on_standard_mnist():
    model_01_path = 'simple_biased_mnist_cnn_subset_01.pth'
    model_05_path = 'simple_biased_mnist_cnn_subset_05.pth'
    try:
        model_01 = load_trained_model(model_01_path)
        print(f"Successfully loaded model trained on correlation level 0.1")
    except:
        print(f"Could not load model from {model_01_path}")
        model_01 = None
        
    try:
        model_05 = load_trained_model(model_05_path)
        print(f"Successfully loaded model trained on correlation level 0.5")
    except:
        print(f"Could not load model from {model_05_path}")
        model_05 = None

    print("Loading standard MNIST test set...")
    mnist_dataset = load_standard_mnist()
    mnist_loader = DataLoader(mnist_dataset, batch_size=64, shuffle=False, num_workers=0)
    print(f"Loaded {len(mnist_dataset)} standard MNIST test images")

    print("\nEvaluating models on standard (unbiased) MNIST:")
    
    results = {}
    
    if model_01 is not None:
        print("\nEvaluating model trained on correlation level 0.1:")
        accuracy_01 = evaluate_model(model_01, mnist_loader, None, None)
        results["0.1"] = accuracy_01
        print(f"Accuracy on standard MNIST: {accuracy_01:.2f}%")
        
    if model_05 is not None:
        print("\nEvaluating model trained on correlation level 0.5:")
        accuracy_05 = evaluate_model(model_05, mnist_loader, None, None)
        results["0.5"] = accuracy_05
        print(f"Accuracy on standard MNIST: {accuracy_05:.2f}%")
    
    # Print summary
    print("\nSummary of model performance on standard MNIST:")
    print("-" * 50)
    print("| Correlation Level | Biased Test Acc | Standard MNIST Acc |")
    print("|-------------------|-----------------|-------------------|")
    
    if "0.1" in results:
        print(f"| 0.1               | 37.55%          | {results['0.1']:.2f}%             |")
    
    if "0.5" in results:
        print(f"| 0.5               | 18.00%          | {results['0.5']:.2f}%             |")
    
    print("-" * 50)
    print("Note: A model that relies heavily on bias features will perform poorly on standard MNIST")

In [None]:
# Execute the function for full dataset training
# run_biased_mnist_cnn()

# Execute the function for subset training
run_biased_mnist_cnn_subset()

# Run the test on standard mnist as counterfactual set
# test_on_standard_mnist()