In [1]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import os

# --- Configuration ---
# Define the directory where your trained models are saved.
# This should match the SAVE_DIR in your meta-learning script.
SAVE_DIR = './trained_models'
PLOTS_DIR = './weight_plots' # Directory to save the generated plots
os.makedirs(PLOTS_DIR, exist_ok=True)

# Define paths for saved model weights
CLASS1_WEIGHTS_PATH = os.path.join(SAVE_DIR, 'class1_models_weights.pt')
CLASS2_WEIGHTS_PATH = os.path.join(SAVE_DIR, 'class2_models_weights.pt')

# Set device (CPU is fine for plotting, no heavy computation)
DEVICE = torch.device("cpu")

# --- SimpleNN Model Definition (MUST MATCH YOUR TRAINING SCRIPT) ---
# This class definition must be IDENTICAL to the SimpleNN used to train
# the models whose weights you are loading.
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        # Input layer: 28*28 = 784 pixels
        # Single Hidden layer: 8 neurons
        # Output layer: 10 neurons (for digits 0-9)
        self.fc1 = nn.Linear(28 * 28, 8) # Input to the single hidden layer
        self.relu1 = nn.ReLU()
        self.fc3 = nn.Linear(8, 10) # Direct connection from hidden layer to output

    def forward(self, x):
        x = x.view(-1, 28 * 28) # Flatten the 28x28 image
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc3(x) # Direct connection
        return x

# --- Helper Function to Load Weights ---
def load_model_weights(path):
    """Loads a list of state_dicts from a .pt file."""
    if not os.path.exists(path):
        raise FileNotFoundError(f"Model weights file not found: {path}\n"
                                "Please ensure you have run the 'MNIST Subset Neural Network Experiment' Canvas "
                                "to generate and save these model weights first, and that SAVE_DIR is correct.")
    print(f"Loading weights from: {path}")
    # Use map_location='cpu' to load to CPU regardless of where they were saved
    # weights_only=True is recommended for security and best practice
    return torch.load(path, map_location=DEVICE, weights_only=True)

# --- Helper Function to Plot Differences for a Pair of Models ---
def _plot_weight_set_differences(model1_sd, model2_sd, comparison_prefix, plots_dir):
    """
    Calculates and plots the differences between two model state dictionaries.
    Args:
        model1_sd (dict): State dictionary of the first model.
        model2_sd (dict): State dictionary of the second model.
        comparison_prefix (str): Prefix for plot titles and filenames (e.g., "Class1_Intra").
        plots_dir (str): Directory to save the plots.
    """
    print(f"\n--- Plotting differences for: {comparison_prefix} ---")

    # --- Plotting fc1.weight differences ---
    if 'fc1.weight' in model1_sd and 'fc1.weight' in model2_sd:
        fc1_weight_m1 = model1_sd['fc1.weight'].cpu().numpy()
        fc1_weight_m2 = model2_sd['fc1.weight'].cpu().numpy()
        fc1_weight_diff = fc1_weight_m2 - fc1_weight_m1
        fc1_weight_abs_diff = np.abs(fc1_weight_diff)

        print(f"  fc1.weight shape: {fc1_weight_m1.shape}")
        print(f"  fc1.weight difference (mean abs): {np.mean(fc1_weight_abs_diff):.4f}")

        # Plot 1: Histogram of fc1.weight differences
        plt.figure(figsize=(10, 6))
        plt.hist(fc1_weight_diff.flatten(), bins=50, color='skyblue', edgecolor='black')
        plt.title(f'Histogram of fc1.weight Differences ({comparison_prefix})')
        plt.xlabel('Weight Difference Value')
        plt.ylabel('Frequency')
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, f'{comparison_prefix}_fc1_weight_diff_histogram.png'))
        plt.close()
        print(f"  Saved fc1 weight difference histogram to {plots_dir}/{comparison_prefix}_fc1_weight_diff_histogram.png")

        # Plot 2: Image representation of fc1.weight (Model 1 of the pair)
        plt.figure(figsize=(12, 4))
        plt.imshow(fc1_weight_m1, cmap='viridis', aspect='auto')
        plt.colorbar(label='Weight Value')
        plt.title(f'fc1.weight Matrix (First Model in {comparison_prefix})')
        plt.xlabel('Input Features (Flattened 28x28 pixels)')
        plt.ylabel('Output Neurons (Hidden Layer)')
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, f'{comparison_prefix}_fc1_weight_model1_image.png'))
        plt.close()
        print(f"  Saved fc1 weight image (First Model in {comparison_prefix}) to {plots_dir}/{comparison_prefix}_fc1_weight_model1_image.png")


        # Plot 3: Image representation of fc1.weight differences
        plt.figure(figsize=(12, 4))
        plt.imshow(fc1_weight_diff, cmap='coolwarm', aspect='auto')
        plt.colorbar(label='Difference Value')
        plt.title(f'fc1.weight Differences ({comparison_prefix})')
        plt.xlabel('Input Features (Flattened 28x28 pixels)')
        plt.ylabel('Output Neurons (Hidden Layer)')
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, f'{comparison_prefix}_fc1_weight_diff_image.png'))
        plt.close()
        print(f"  Saved fc1 weight difference image to {plots_dir}/{comparison_prefix}_fc1_weight_diff_image.png")

    # --- Plotting fc3.weight differences ---
    if 'fc3.weight' in model1_sd and 'fc3.weight' in model2_sd:
        fc3_weight_m1 = model1_sd['fc3.weight'].cpu().numpy()
        fc3_weight_m2 = model2_sd['fc3.weight'].cpu().numpy()
        fc3_weight_diff = fc3_weight_m2 - fc3_weight_m1
        fc3_weight_abs_diff = np.abs(fc3_weight_diff)

        print(f"  fc3.weight shape: {fc3_weight_m1.shape}")
        print(f"  fc3.weight difference (mean abs): {np.mean(fc3_weight_abs_diff):.4f}")

        # Plot 4: Histogram of fc3.weight differences
        plt.figure(figsize=(10, 6))
        plt.hist(fc3_weight_diff.flatten(), bins=20, color='lightgreen', edgecolor='black')
        plt.title(f'Histogram of fc3.weight Differences ({comparison_prefix})')
        plt.xlabel('Weight Difference Value')
        plt.ylabel('Frequency')
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, f'{comparison_prefix}_fc3_weight_diff_histogram.png'))
        plt.close()
        print(f"  Saved fc3 weight difference histogram to {plots_dir}/{comparison_prefix}_fc3_weight_diff_histogram.png")

        # Plot 5: Image representation of fc3.weight differences
        plt.figure(figsize=(8, 6))
        plt.imshow(fc3_weight_diff, cmap='coolwarm', aspect='auto')
        plt.colorbar(label='Difference Value')
        plt.title(f'fc3.weight Differences ({comparison_prefix})')
        plt.xlabel('Input Neurons (Hidden Layer)')
        plt.ylabel('Output Neurons (Digit Classes)')
        plt.xticks(np.arange(fc3_weight_diff.shape[1])) # Show all input neuron labels
        plt.yticks(np.arange(fc3_weight_diff.shape[0])) # Show all output neuron labels
        plt.tight_layout()
        plt.savefig(os.path.join(plots_dir, f'{comparison_prefix}_fc3_weight_diff_image.png'))
        plt.close()
        print(f"  Saved fc3 weight difference image to {plots_dir}/{comparison_prefix}_fc3_weight_diff_image.png")


# --- Main Plotting Logic ---
def plot_weight_differences():
    """
    Loads Class 1 and Class 2 model weights, calculates their differences,
    and generates plots to visualize these differences for various comparisons.
    """
    try:
        class1_models = load_model_weights(CLASS1_WEIGHTS_PATH)
        class2_models = load_model_weights(CLASS2_WEIGHTS_PATH)
    except FileNotFoundError as e:
        print(f"Error: {e}")
        return
    except Exception as e:
        print(f"An unexpected error occurred during weight loading: {e}")
        return

    if not class1_models or not class2_models:
        print("No models loaded. Please ensure the training script generated models.")
        return

    # Scenario 1: Between Class 1 and Class 2 models (first models)
    print("\nComparing weights: First Class 1 Model vs. First Class 2 Model")
    _plot_weight_set_differences(class1_models[0], class2_models[0], "Class1_vs_Class2", PLOTS_DIR)

    # Scenario 2: Within Class 1 models (first vs. second)
    if len(class1_models) >= 2:
        print("\nComparing weights: First Class 1 Model vs. Second Class 1 Model (Intra-Class)")
        _plot_weight_set_differences(class1_models[0], class1_models[1], "Class1_Intra", PLOTS_DIR)
    else:
        print("\nSkipping intra-class comparison for Class 1: Not enough Class 1 models loaded (need at least 2).")

    # Scenario 3: Within Class 2 models (first vs. second)
    if len(class2_models) >= 2:
        print("\nComparing weights: First Class 2 Model vs. Second Class 2 Model (Intra-Class)")
        _plot_weight_set_differences(class2_models[0], class2_models[1], "Class2_Intra", PLOTS_DIR)
    else:
        print("\nSkipping intra-class comparison for Class 2: Not enough Class 2 models loaded (need at least 2).")

    print(f"\nWeight difference plotting complete. Check the '{PLOTS_DIR}' directory for plots.")

if __name__ == "__main__":
    plot_weight_differences()


Loading weights from: ./trained_models/class1_models_weights.pt
Loading weights from: ./trained_models/class2_models_weights.pt

Comparing weights: First Class 1 Model vs. First Class 2 Model

--- Plotting differences for: Class1_vs_Class2 ---
  fc1.weight shape: (16, 784)
  fc1.weight difference (mean abs): 0.0381
  Saved fc1 weight difference histogram to ./weight_plots/Class1_vs_Class2_fc1_weight_diff_histogram.png
  Saved fc1 weight image (First Model in Class1_vs_Class2) to ./weight_plots/Class1_vs_Class2_fc1_weight_model1_image.png
  Saved fc1 weight difference image to ./weight_plots/Class1_vs_Class2_fc1_weight_diff_image.png
  fc3.weight shape: (10, 16)
  fc3.weight difference (mean abs): 0.1869
  Saved fc3 weight difference histogram to ./weight_plots/Class1_vs_Class2_fc3_weight_diff_histogram.png
  Saved fc3 weight difference image to ./weight_plots/Class1_vs_Class2_fc3_weight_diff_image.png

Comparing weights: First Class 1 Model vs. Second Class 1 Model (Intra-Class)

--- P