In [7]:
### data generation

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
from tqdm import tqdm
import random
import os # Import the os module for path operations

# --- Configuration ---
NUM_MODELS = 10
BATCH_SIZE = 64
NUM_EPOCHS = 3  # A small number of epochs for quick training of many models
LEARNING_RATE = 0.001

# Define the classes for each subset
CLASS1_LABELS = [0, 1, 2, 3]
CLASS2_LABELS = [2, 3, 4, 5]

# Set device to GPU if available, otherwise CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
print(f"Using device: {DEVICE}")

# Set random seeds for reproducibility (for dataset splitting, model init, etc.)
# Note: Training 1000 models means there will still be inherent randomness
# due to different initializations and data shuffling.
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
elif torch.backends.mps.is_available(): # For Apple Silicon
    # MPS does not have a manual_seed_all equivalent for all operations.
    # Deterministic behavior for MPS is generally handled by setting seeds for CPU/GPU.
    # For full reproducibility, you might need to disable MPS for some operations,
    # or accept minor variations.
    pass

# --- Data Loading and Preprocessing ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # MNIST mean and std
])

# Download and load the MNIST dataset
print("Loading MNIST dataset...")
full_train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
full_test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
print("MNIST dataset loaded.")

# --- Dataset Subset Creation Function ---
def create_subset(dataset, labels_to_include):
    """
    Filters a dataset to include only samples with specified labels.
    """
    indices = []
    for i, (_, label) in enumerate(dataset):
        if label in labels_to_include:
            indices.append(i)
    return Subset(dataset, indices)

# Create the specific subsets
print("Creating dataset subsets...")
class1_train_dataset = create_subset(full_train_dataset, CLASS1_LABELS)
class1_test_dataset = create_subset(full_test_dataset, CLASS1_LABELS) # Test set for Class 1 labels

class2_train_dataset = create_subset(full_train_dataset, CLASS2_LABELS)
class2_test_dataset = create_subset(full_test_dataset, CLASS2_LABELS) # Test set for Class 2 labels
print("Dataset subsets created.")

# --- Neural Network Definition (3 LAYERS) ---
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        # Input layer: 28*28 = 784 pixels
        # Hidden layer 1: 64 neurons
        # Hidden layer 2: 32 neurons
        # Output layer: 10 neurons (for digits 0-9, even if only a subset is trained)
        self.fc1 = nn.Linear(28 * 28, 64)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(64, 32)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(32, 10) # Output 10 classes for MNIST

    def forward(self, x):
        x = x.view(-1, 28 * 28) # Flatten the 28x28 image
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x

# --- Training Function ---
def train_model(model, train_loader, criterion, optimizer, num_epochs):
    """Trains a single neural network model."""
    model.train()
    for epoch in range(num_epochs):
        for data, target in train_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

# --- Evaluation Function ---
def evaluate_model(model, data_loader, description=""):
    """Evaluates a single neural network model and returns accuracy."""
    model.eval()
    correct = 0
    total = 0
    if len(data_loader.dataset) == 0:
        return float('nan')
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    accuracy = 100 * correct / total
    return accuracy

# --- Main Experiment Logic ---
def run_experiment(train_dataset, self_test_dataset, cross_test_dataset,
                   train_description, self_test_description, cross_test_description):
    """
    Runs the training and evaluation for a given set of datasets.
    Returns a list of trained model state_dicts.
    """
    train_accuracies = []
    self_test_accuracies = []
    cross_test_accuracies = []
    trained_model_weights = [] # List to store model weights

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    self_test_loader = DataLoader(self_test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    cross_test_loader = DataLoader(cross_test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    print(f"\n--- Training {NUM_MODELS} models for {train_description} ---")
    for i in tqdm(range(NUM_MODELS), desc=f"Training models for {train_description}"):
        model = SimpleNN().to(DEVICE)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

        # Train the model
        train_model(model, train_loader, criterion, optimizer, NUM_EPOCHS)

        # Save the model's state_dict (weights)
        trained_model_weights.append(model.state_dict())

        # Evaluate accuracies
        train_acc = evaluate_model(model, train_loader, f"{train_description} Training")
        self_test_acc = evaluate_model(model, self_test_loader, f"{self_test_description} Testing")
        cross_test_acc = evaluate_model(model, cross_test_loader, f"{cross_test_description} Testing")

        train_accuracies.append(train_acc)
        self_test_accuracies.append(self_test_acc)
        cross_test_accuracies.append(cross_test_acc)

        # Print accuracy after each run
        print(f"  Model {i+1}/{NUM_MODELS}:")
        print(f"    Training Acc ({train_description}): {train_acc:.2f}%")
        print(f"    Testing Acc ({self_test_description}): {self_test_acc:.2f}%")
        print(f"    Testing Acc ({cross_test_description}): {cross_test_acc:.2f}%")

    # Calculate and print average results
    avg_train_acc = np.mean(train_accuracies)
    std_train_acc = np.std(train_accuracies)
    avg_self_test_acc = np.mean(self_test_accuracies)
    std_self_test_acc = np.std(self_test_accuracies)
    avg_cross_test_acc = np.mean(cross_test_accuracies)
    std_cross_test_acc = np.std(cross_test_accuracies)

    print(f"\n--- Results for models trained on {train_description} ---")
    print(f"Average Training Accuracy ({train_description}): {avg_train_acc:.2f}% (Std Dev: {std_train_acc:.2f})")
    print(f"Average Testing Accuracy ({self_test_description}): {avg_self_test_acc:.2f}% (Std Dev: {std_self_test_acc:.2f})")
    print(f"Average Testing Accuracy ({cross_test_description}): {avg_cross_test_acc:.2f}% (Std Dev: {std_cross_test_acc:.2f})")
    print("-" * 50)

    return trained_model_weights # Return the list of weights

# Define a directory to save the models
SAVE_DIR = './trained_models'
os.makedirs(SAVE_DIR, exist_ok=True) # Create the directory if it doesn't exist

# Run the experiment for Class 1 models
class1_model_weights = run_experiment(class1_train_dataset, class1_test_dataset, class2_test_dataset,
                                      f"Class 1 (Digits {CLASS1_LABELS})",
                                      f"Class 1 Test (Digits {CLASS1_LABELS})",
                                      f"Class 2 Test (Digits {CLASS2_LABELS})")

# Save the weights for Class 1 models
class1_save_path = os.path.join(SAVE_DIR, 'class1_models_weights.pt')
torch.save(class1_model_weights, class1_save_path)
print(f"Saved Class 1 model weights to {class1_save_path}")

# Run the experiment for Class 2 models
class2_model_weights = run_experiment(class2_train_dataset, class2_test_dataset, class1_test_dataset,
                                      f"Class 2 (Digits {CLASS2_LABELS})",
                                      f"Class 2 Test (Digits {CLASS2_LABELS})",
                                      f"Class 1 Test (Digits {CLASS1_LABELS})")

# Save the weights for Class 2 models
class2_save_path = os.path.join(SAVE_DIR, 'class2_models_weights.pt')
torch.save(class2_model_weights, class2_save_path)
print(f"Saved Class 2 model weights to {class2_save_path}")

print("\nExperiment complete!")
print(f"Class 1 model weights stored in 'class1_model_weights' (list of {len(class1_model_weights)} state_dicts).")
print(f"Class 2 model weights stored in 'class2_model_weights' (list of {len(class2_model_weights)} state_dicts).")


Using device: mps
Loading MNIST dataset...
MNIST dataset loaded.
Creating dataset subsets...
Dataset subsets created.

--- Training 10 models for Class 1 (Digits [0, 1, 2, 3]) ---


Training models for Class 1 (Digits [0, 1, 2, 3]):  10%|█         | 1/10 [00:07<01:03,  7.09s/it]

  Model 1/10:
    Training Acc (Class 1 (Digits [0, 1, 2, 3])): 99.30%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 98.80%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 51.15%


Training models for Class 1 (Digits [0, 1, 2, 3]):  20%|██        | 2/10 [00:14<00:57,  7.13s/it]

  Model 2/10:
    Training Acc (Class 1 (Digits [0, 1, 2, 3])): 99.23%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 98.92%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 51.40%


Training models for Class 1 (Digits [0, 1, 2, 3]):  30%|███       | 3/10 [00:21<00:48,  6.97s/it]

  Model 3/10:
    Training Acc (Class 1 (Digits [0, 1, 2, 3])): 99.05%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 98.77%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 51.10%


Training models for Class 1 (Digits [0, 1, 2, 3]):  40%|████      | 4/10 [00:27<00:41,  6.91s/it]

  Model 4/10:
    Training Acc (Class 1 (Digits [0, 1, 2, 3])): 99.20%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 98.97%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 51.33%


Training models for Class 1 (Digits [0, 1, 2, 3]):  50%|█████     | 5/10 [00:34<00:34,  6.99s/it]

  Model 5/10:
    Training Acc (Class 1 (Digits [0, 1, 2, 3])): 99.30%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 99.06%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 51.48%


Training models for Class 1 (Digits [0, 1, 2, 3]):  60%|██████    | 6/10 [00:42<00:28,  7.01s/it]

  Model 6/10:
    Training Acc (Class 1 (Digits [0, 1, 2, 3])): 99.21%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 99.09%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 51.51%


Training models for Class 1 (Digits [0, 1, 2, 3]):  70%|███████   | 7/10 [00:49<00:21,  7.09s/it]

  Model 7/10:
    Training Acc (Class 1 (Digits [0, 1, 2, 3])): 99.20%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 98.94%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 51.25%


Training models for Class 1 (Digits [0, 1, 2, 3]):  80%|████████  | 8/10 [00:56<00:14,  7.22s/it]

  Model 8/10:
    Training Acc (Class 1 (Digits [0, 1, 2, 3])): 99.23%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 99.11%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 51.48%


Training models for Class 1 (Digits [0, 1, 2, 3]):  90%|█████████ | 9/10 [01:04<00:07,  7.25s/it]

  Model 9/10:
    Training Acc (Class 1 (Digits [0, 1, 2, 3])): 99.00%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 98.99%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 51.35%


Training models for Class 1 (Digits [0, 1, 2, 3]): 100%|██████████| 10/10 [01:11<00:00,  7.12s/it]


  Model 10/10:
    Training Acc (Class 1 (Digits [0, 1, 2, 3])): 99.02%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 98.97%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 51.30%

--- Results for models trained on Class 1 (Digits [0, 1, 2, 3]) ---
Average Training Accuracy (Class 1 (Digits [0, 1, 2, 3])): 99.17% (Std Dev: 0.10)
Average Testing Accuracy (Class 1 Test (Digits [0, 1, 2, 3])): 98.96% (Std Dev: 0.11)
Average Testing Accuracy (Class 2 Test (Digits [2, 3, 4, 5])): 51.34% (Std Dev: 0.13)
--------------------------------------------------
Saved Class 1 model weights to ./trained_models/class1_models_weights.pt

--- Training 10 models for Class 2 (Digits [2, 3, 4, 5]) ---


Training models for Class 2 (Digits [2, 3, 4, 5]):  10%|█         | 1/10 [00:06<01:00,  6.77s/it]

  Model 1/10:
    Training Acc (Class 2 (Digits [2, 3, 4, 5])): 98.36%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 97.88%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 47.63%


Training models for Class 2 (Digits [2, 3, 4, 5]):  20%|██        | 2/10 [00:13<00:55,  6.92s/it]

  Model 2/10:
    Training Acc (Class 2 (Digits [2, 3, 4, 5])): 98.80%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 98.24%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 48.16%


Training models for Class 2 (Digits [2, 3, 4, 5]):  30%|███       | 3/10 [00:21<00:49,  7.07s/it]

  Model 3/10:
    Training Acc (Class 2 (Digits [2, 3, 4, 5])): 98.71%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 98.19%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 48.28%


Training models for Class 2 (Digits [2, 3, 4, 5]):  40%|████      | 4/10 [00:28<00:42,  7.04s/it]

  Model 4/10:
    Training Acc (Class 2 (Digits [2, 3, 4, 5])): 98.48%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 98.31%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 48.23%


Training models for Class 2 (Digits [2, 3, 4, 5]):  50%|█████     | 5/10 [00:34<00:34,  6.95s/it]

  Model 5/10:
    Training Acc (Class 2 (Digits [2, 3, 4, 5])): 98.93%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 98.57%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 48.21%


Training models for Class 2 (Digits [2, 3, 4, 5]):  60%|██████    | 6/10 [00:41<00:27,  6.94s/it]

  Model 6/10:
    Training Acc (Class 2 (Digits [2, 3, 4, 5])): 98.61%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 98.16%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 47.82%


Training models for Class 2 (Digits [2, 3, 4, 5]):  70%|███████   | 7/10 [00:48<00:20,  6.82s/it]

  Model 7/10:
    Training Acc (Class 2 (Digits [2, 3, 4, 5])): 98.72%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 98.34%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 47.92%


Training models for Class 2 (Digits [2, 3, 4, 5]):  80%|████████  | 8/10 [00:55<00:13,  6.87s/it]

  Model 8/10:
    Training Acc (Class 2 (Digits [2, 3, 4, 5])): 98.59%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 98.47%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 48.38%


Training models for Class 2 (Digits [2, 3, 4, 5]):  90%|█████████ | 9/10 [01:02<00:06,  6.83s/it]

  Model 9/10:
    Training Acc (Class 2 (Digits [2, 3, 4, 5])): 98.45%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 97.96%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 47.82%


Training models for Class 2 (Digits [2, 3, 4, 5]): 100%|██████████| 10/10 [01:09<00:00,  6.91s/it]

  Model 10/10:
    Training Acc (Class 2 (Digits [2, 3, 4, 5])): 98.74%
    Testing Acc (Class 2 Test (Digits [2, 3, 4, 5])): 98.16%
    Testing Acc (Class 1 Test (Digits [0, 1, 2, 3])): 48.11%

--- Results for models trained on Class 2 (Digits [2, 3, 4, 5]) ---
Average Training Accuracy (Class 2 (Digits [2, 3, 4, 5])): 98.64% (Std Dev: 0.17)
Average Testing Accuracy (Class 2 Test (Digits [2, 3, 4, 5])): 98.23% (Std Dev: 0.20)
Average Testing Accuracy (Class 1 Test (Digits [0, 1, 2, 3])): 48.06% (Std Dev: 0.23)
--------------------------------------------------
Saved Class 2 model weights to ./trained_models/class2_models_weights.pt

Experiment complete!
Class 1 model weights stored in 'class1_model_weights' (list of 10 state_dicts).
Class 2 model weights stored in 'class2_model_weights' (list of 10 state_dicts).





In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import numpy as np
import os
import random
from tqdm import tqdm
from scipy.linalg import orthogonal_procrustes # For Step 3

# --- Configuration ---
SAVE_DIR = './trained_models'
PLOTS_DIR = './surgery_plots' # Directory for new plots related to surgery
os.makedirs(PLOTS_DIR, exist_ok=True)

CLASS1_WEIGHTS_PATH = os.path.join(SAVE_DIR, 'class1_models_weights.pt')
CLASS2_WEIGHTS_PATH = os.path.join(SAVE_DIR, 'class2_models_weights.pt')

# Data loading and model training parameters (should match your experiment setup)
BATCH_SIZE = 64
NUM_EPOCHS_PROBE = 5 # More epochs for probe training for better accuracy
LEARNING_RATE_PROBE = 0.001

# Define the classes for each subset
CLASS1_LABELS = [0, 1, 2, 3] # Digits Model A was originally trained on
CLASS2_LABELS = [2, 3, 4, 5] # Digits Model B was originally trained on
SHARED_LABELS_FOR_ALIGNMENT = [2, 3] # Digits common to both Class 1 and Class 2 for Procrustes alignment
TARGET_DIGIT_TO_TRANSFER = 4 # The digit whose knowledge we want to transfer from Model B to Model A
OOC_DIGIT_TO_MONITOR = 5 # Digit that Model A should ideally remain ignorant of

# Surgical parameters
ALPHA = 0.8 # Scaling factor for the transferred probe (Step 6) - controls strength of injection
K_ROWS_TO_PRUNE = 1 # Number of rows (output neurons) in A's classifier to modify (Step 5)

# Set device
DEVICE = torch.device("cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu"))
print(f"Using device: {DEVICE}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
elif torch.backends.mps.is_available():
    pass

# --- Data Loading and Preprocessing (MNIST) ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # MNIST mean and std
])

print("Loading MNIST dataset...")
full_train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
full_test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)
print("MNIST dataset loaded.")

# --- Dataset Subset Creation Function ---
def create_subset(dataset, labels_to_include):
    """Filters a dataset to include only samples with specified labels."""
    indices = []
    for i, (_, label) in enumerate(dataset):
        if label in labels_to_include:
            indices.append(i)
    return Subset(dataset, indices)

# Create specific datasets for the pipeline
# These are used for evaluation and for collecting data for the probe and alignment
eval_set_A_labels_test_dataset = create_subset(full_test_dataset, CLASS1_LABELS) # Test set for 0,1,2,3
target_digit_test_dataset = create_subset(full_test_dataset, [TARGET_DIGIT_TO_TRANSFER]) # Test set for digit 4
ooc_digit_test_dataset = create_subset(full_test_dataset, [OOC_DIGIT_TO_MONITOR]) # Test set for digit 5

# Dataset for probe training on Model B's knowledge.
# It includes the target digit (4) and other digits Model B knows (2,3,5)
# The probe will learn to distinguish 4 from (2,3,5).
probe_train_dataset_B = create_subset(full_train_dataset, CLASS2_LABELS)
probe_test_dataset_B = create_subset(full_test_dataset, CLASS2_LABELS)

# Dataset for Orthogonal Procrustes alignment.
# Uses digits that both Model A and Model B are familiar with.
shared_data_for_alignment = create_subset(full_test_dataset, SHARED_LABELS_FOR_ALIGNMENT)

# --- SimpleNN Model Definition (MUST MATCH YOUR TRAINING SCRIPT) ---
# This architecture is critical for loading saved weights correctly.
# It should be identical to the SimpleNN in 'mnist-subset-nn-experiment' Canvas.
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        # Input layer: 28*28 = 784 pixels
        # Hidden layer 1: 64 neurons
        # Hidden layer 2: 32 neurons
        # Output layer: 10 neurons (for digits 0-9)
        self.fc1 = nn.Linear(28 * 28, 64)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(64, 32)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(32, 10) # Output 10 classes for MNIST

    def forward(self, x):
        x = x.view(-1, 28 * 28) # Flatten the 28x28 image
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x

    # Method to extract hidden features from the penultimate hidden layer (Step 1)
    # This is the output of the 'fc2' layer after ReLU.
    def get_hidden_features(self, x):
        x = x.view(-1, 28 * 28)
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x) # This is the penultimate hidden layer
        x = self.relu2(x) # Apply ReLU to get the features
        return x

# --- Probe Network Definition (for Step 2) ---
class ProbeNN(nn.Module):
    def __init__(self, feature_dim):
        super(ProbeNN, self).__init__()
        # A simple linear layer that takes the hidden features and outputs a single logit
        # for binary classification (digit 4 vs. not-4).
        self.linear = nn.Linear(feature_dim, 1)

    def forward(self, x):
        return self.linear(x)

# --- Helper Functions ---
def load_model_weights(path):
    """Loads a list of state_dicts from a .pt file."""
    if not os.path.exists(path):
        raise FileNotFoundError(f"Model weights file not found: {path}\n"
                                "Please ensure you have run the 'MNIST Subset Neural Network Experiment' Canvas "
                                "to generate and save these model weights first, and that SAVE_DIR is correct.")
    print(f"Loading weights from: {path}")
    # Use map_location='cpu' to load to CPU regardless of where they were saved, then move to DEVICE
    # weights_only=True is recommended for security and best practice
    return torch.load(path, map_location=DEVICE, weights_only=True)

def train_model(model, train_loader, criterion, optimizer, num_epochs):
    """Trains a single neural network model."""
    model.train()
    for epoch in range(num_epochs):
        for data, target in train_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

def evaluate_model_accuracy(model, data_loader):
    """Evaluates a single neural network model and returns accuracy."""
    model.eval()
    correct = 0
    total = 0
    if len(data_loader.dataset) == 0:
        return float('nan') # Handle empty datasets gracefully
    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    return 100 * correct / total

def get_hidden_activations(model, data_loader):
    """Collects hidden layer activations for a given model and data loader."""
    model.eval() # Set model to evaluation mode
    hidden_features = []
    labels = []
    with torch.no_grad(): # No gradients needed for feature extraction
        for data, target in data_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            # Use the specific get_hidden_features method of SimpleNN
            hidden_feats = model.get_hidden_features(data)
            hidden_features.append(hidden_feats.cpu()) # Move to CPU to concatenate
            labels.append(target.cpu())
    return torch.cat(hidden_features), torch.cat(labels)

# --- Main Surgical Pipeline ---
def run_model_surgery_pipeline():
    print("\n--- Starting Model Surgery Pipeline ---")

    # Load pre-trained Class 1 (Model A) and Class 2 (Model B) models
    try:
        class1_models_sd = load_model_weights(CLASS1_WEIGHTS_PATH)
        class2_models_sd = load_model_weights(CLASS2_WEIGHTS_PATH)
    except FileNotFoundError as e:
        print(f"Error: {e}")
        print("Please ensure you have run the 'MNIST Subset Neural Network Experiment' Canvas to generate these models.")
        return
    except RuntimeError as e:
        print(f"Error loading model weights: {e}")
        print("This often means the SimpleNN architecture in this script does not match the saved weights.")
        print("Please ensure the SimpleNN in 'mnist-subset-nn-experiment' Canvas is identical to this script's SimpleNN,")
        print("and then re-run 'mnist-subset-nn-experiment' to regenerate compatible weights.")
        return

    # Pick one model from each class for the surgery
    model_A = SimpleNN().to(DEVICE)
    # random.choice returns a state_dict, load it into the model instance
    model_A.load_state_dict(random.choice(class1_models_sd))
    print("Loaded Model A (trained on 0,1,2,3).")

    model_B = SimpleNN().to(DEVICE)
    model_B.load_state_dict(random.choice(class2_models_sd))
    print("Loaded Model B (trained on 2,3,4,5).")

    # --- Step 2: Train a linear probe on B for “digit 4 vs not-4” ---
    print("\nStep 2: Training a linear probe on Model B for 'digit 4 vs not-4'")
    
    # Prepare data for probe training:
    # Labels for probe: 1 for TARGET_DIGIT_TO_TRANSFER (4), 0 for other digits in CLASS2_LABELS (2,3,5)
    probe_data_B_raw = []
    probe_labels_B_binary = []
    
    for img, label in tqdm(probe_train_dataset_B, desc="Collecting probe training data from Model B's domain"):
        probe_data_B_raw.append(img)
        # Assign binary label: 1 if it's the target digit, 0 otherwise
        probe_labels_B_binary.append(1 if label == TARGET_DIGIT_TO_TRANSFER else 0)

    # Convert lists to tensors
    probe_data_B_raw = torch.stack(probe_data_B_raw)
    probe_labels_B_binary = torch.tensor(probe_labels_B_binary, dtype=torch.float32).unsqueeze(1) # Ensure shape (N, 1)

    # Create a DataLoader for the raw image data to pass through Model B
    probe_train_loader_raw = DataLoader(
        torch.utils.data.TensorDataset(probe_data_B_raw, probe_labels_B_binary),
        batch_size=BATCH_SIZE, shuffle=True
    )

    # Instantiate the probe network. Feature_dim is the output size of SimpleNN's penultimate layer (fc2).
    feature_dim = model_B.fc2.out_features # This is 32 for our current SimpleNN
    probe_net = ProbeNN(feature_dim).to(DEVICE)
    probe_criterion = nn.BCEWithLogitsLoss() # Use BCEWithLogitsLoss for binary classification with logits
    probe_optimizer = optim.Adam(probe_net.parameters(), lr=LEARNING_RATE_PROBE)

    # Freeze Model B's weights and collect hidden features for probe training
    model_B.eval() # Ensure Model B is in eval mode for feature extraction, no gradients needed for Model B
    probe_hidden_features = []
    probe_targets = []
    with torch.no_grad(): # No gradients needed for Model B's forward pass
        for data, target in probe_train_loader_raw:
            data = data.to(DEVICE)
            hidden_feats = model_B.get_hidden_features(data) # Extract features
            probe_hidden_features.append(hidden_feats)
            probe_targets.append(target.to(DEVICE)) # Target is already (batch_size, 1) here
    
    # Create a DataLoader for the extracted hidden features to train the probe
    probe_train_loader_features = DataLoader(
        torch.utils.data.TensorDataset(torch.cat(probe_hidden_features), torch.cat(probe_targets)),
        batch_size=BATCH_SIZE, shuffle=True
    )

    # Train the probe network
    print("  Training the probe network...")
    train_model(probe_net, probe_train_loader_features, probe_criterion, probe_optimizer, NUM_EPOCHS_PROBE)

    # Extract W4 (the weight of the linear probe layer). It's a 1x_feature_dim tensor.
    W4 = probe_net.linear.weight.data.clone().detach().squeeze(0) # Shape (feature_dim,)
    print(f"  Probe W4 extracted. Shape: {W4.shape}")

    # --- Step 3: Align B’s hidden basis to A’s hidden basis with a bridging map ---
    print("\nStep 3: Aligning hidden bases using Orthogonal Procrustes")
    
    # Collect hidden vectors for shared digits (2 and 3) from the test set for alignment
    shared_loader = DataLoader(shared_data_for_alignment, batch_size=BATCH_SIZE, shuffle=False)
    
    # Get hidden activations from Model A and Model B for the shared data
    H_A_shared, _ = get_hidden_activations(model_A, shared_loader)
    H_B_shared, _ = get_hidden_activations(model_B, shared_loader)

    if H_A_shared.shape[0] == 0 or H_B_shared.shape[0] == 0:
        print("  Warning: No shared data found for alignment. Skipping Procrustes. This may lead to poor transfer.")
        R = torch.eye(feature_dim, device=DEVICE) # Use identity matrix if no shared data
    else:
        # Orthogonal Procrustes: find R such that R @ H_B_shared.T approx H_A_shared.T
        # scipy's orthogonal_procrustes expects (num_samples, num_features)
        # It returns R_np such that B @ R_np is closest to A. So, H_B_shared @ R_np approx H_A_shared.
        # We want R such that R @ H_B_shared.T to be H_A_shared.T, then we need to transpose inputs for orthogonal_procrustes
        # or use it as orthogonal_procrustes(A, B) finds R such that A @ R is closest to B.
        # So, if we want R @ H_B_shared.T to be H_A_shared.T, then we need to transpose inputs for orthogonal_procrustes
        # or use it as orthogonal_procrustes(H_A_shared, H_B_shared) to get R such that H_A_shared @ R is closest to H_B_shared
        # The common usage is to find R such that X @ R is closest to Y. So, we want H_B_shared @ R to be H_A_shared.
        # Therefore, we pass (H_B_shared, H_A_shared) to orthogonal_procrustes.
        R_np, _ = orthogonal_procrustes(H_B_shared.numpy(), H_A_shared.numpy())
        R = torch.tensor(R_np, dtype=torch.float32, device=DEVICE)
        print(f"  Orthogonal Procrustes matrix R computed. Shape: {R.shape}")

    # --- Step 4: Transport the probe ---
    print("\nStep 4: Transporting the probe")
    # W_tilde_4 = R @ W4 (R is feature_dim x feature_dim, W4 is feature_dim)
    # W4 needs to be a column vector for matrix multiplication, then squeeze back to 1D
    W_tilde_4 = (R @ W4.unsqueeze(1)).squeeze(1) # R is (dim, dim), W4 is (dim,), result (dim,)
    print(f"  Transported probe W_tilde_4 computed. Shape: {W_tilde_4.shape}")

    # --- Step 5: Locate the behaviour region in A ---
    print("\nStep 5: Locating the behaviour region in Model A's classifier layer")
    # Get Model A's classifier weights (fc3.weight)
    # This layer takes input from the second hidden layer (32 neurons) and outputs 10 logits.
    W_clf_A = model_A.fc3.weight.data.clone().detach() # Shape (10, 32)
    
    # Calculate cosine similarity for each row (output neuron's weight vector) of W_clf_A with W_tilde_4
    # Normalize W_clf_A rows (each output neuron's weight vector)
    W_clf_A_norm = torch.nn.functional.normalize(W_clf_A, p=2, dim=1) # Normalize along the feature dimension
    # Normalize W_tilde_4 (the transported probe)
    W_tilde_4_norm = torch.nn.functional.normalize(W_tilde_4, p=2, dim=0) # Normalize as a vector

    # Cosine similarity: (num_classes, feature_dim) @ (feature_dim, 1) -> (num_classes, 1)
    cosine_similarities = (W_clf_A_norm @ W_tilde_4_norm.unsqueeze(1)).squeeze(1) # Result shape (10,)
    
    # Choose the bottom-k rows (largest negative similarity, meaning most "opposite" direction)
    # Sort in ascending order (smallest values first), take the first k indices
    sorted_indices = torch.argsort(cosine_similarities, descending=False)
    selected_rows_indices = sorted_indices[:K_ROWS_TO_PRUNE]
    print(f"  Selected rows for surgical edit (indices): {selected_rows_indices.tolist()}")
    print(f"  Corresponding cosine similarities: {cosine_similarities[selected_rows_indices].tolist()}")

    # --- Step 6: Surgical edit ---
    print("\nStep 6: Performing surgical edit on Model A's classifier weights")
    # Create a copy of model_A's state_dict to modify
    modified_model_A_sd = model_A.state_dict()
    
    # Get the weight tensor for fc3.weight (the classifier layer)
    fc3_weight_tensor = modified_model_A_sd['fc3.weight'].clone().detach()

    for idx in selected_rows_indices:
        # Apply the edit: v_i <- v_i + alpha * W_tilde_4
        # W_tilde_4 is (feature_dim,), fc3_weight_tensor[idx] is also (feature_dim,)
        fc3_weight_tensor[idx] = fc3_weight_tensor[idx] + ALPHA * W_tilde_4.to(fc3_weight_tensor.device)
    
    # Update the state dict with the modified weight tensor
    modified_model_A_sd['fc3.weight'] = fc3_weight_tensor
    print(f"  Surgical edit applied to {K_ROWS_TO_PRUNE} row(s) of Model A's fc3.weight.")

    # --- Step 7: Add an output weight for class 4 ---
    print("\nStep 7: Copying output weight and bias for class 4 from Model B to Model A")
    # Get class 4 weights and bias from Model B's fc3 layer
    class4_weight_B = model_B.fc3.weight.data[TARGET_DIGIT_TO_TRANSFER].clone().detach() # Shape (feature_dim,)
    class4_bias_B = model_B.fc3.bias.data[TARGET_DIGIT_TO_TRANSFER].clone().detach() # Shape ()

    # Model A already has 10 output neurons. We overwrite the weights/bias
    # corresponding to the TARGET_DIGIT_TO_TRANSFER (which is 4, index 4).
    modified_model_A_sd['fc3.weight'][TARGET_DIGIT_TO_TRANSFER] = class4_weight_B.to(DEVICE)
    modified_model_A_sd['fc3.bias'][TARGET_DIGIT_TO_TRANSFER] = class4_bias_B.to(DEVICE)
    print(f"  Class {TARGET_DIGIT_TO_TRANSFER} weight and bias copied from Model B to Model A's output layer.")

    # Load the modified state_dict into a new Model A instance
    surgically_modified_model_A = SimpleNN().to(DEVICE)
    surgically_modified_model_A.load_state_dict(modified_model_A_sd)
    print("Surgically modified Model A created and loaded with new weights.")

    # --- Step 8: Sanity tests ---
    print("\nStep 8: Performing Sanity Tests on Surgically Modified Model A")

    # Evaluate on original A's digits (0,1,2,3)
    acc_A_orig_digits = evaluate_model_accuracy(surgically_modified_model_A, DataLoader(eval_set_A_labels_test_dataset, BATCH_SIZE))
    print(f"  Surgically Modified Model A Accuracy on original digits ({CLASS1_LABELS}): {acc_A_orig_digits:.2f}%")

    # Evaluate on target digit (4)
    acc_target_digit = evaluate_model_accuracy(surgically_modified_model_A, DataLoader(target_digit_test_dataset, BATCH_SIZE))
    print(f"  Surgically Modified Model A Accuracy on target digit ({TARGET_DIGIT_TO_TRANSFER}): {acc_target_digit:.2f}%")

    # Evaluate on out-of-class digit (5)
    acc_ooc_digit = evaluate_model_accuracy(surgically_modified_model_A, DataLoader(ooc_digit_test_dataset, BATCH_SIZE))
    print(f"  Surgically Modified Model A Accuracy on out-of-class digit ({OOC_DIGIT_TO_MONITOR}): {acc_ooc_digit:.2f}%")

    # For comparison, evaluate original Model A's performance
    print("\n--- Original Model A Performance (for comparison) ---")
    orig_A_acc_orig_digits = evaluate_model_accuracy(model_A, DataLoader(eval_set_A_labels_test_dataset, BATCH_SIZE))
    orig_A_acc_target_digit = evaluate_model_accuracy(model_A, DataLoader(target_digit_test_dataset, BATCH_SIZE))
    orig_A_acc_ooc_digit = evaluate_model_accuracy(model_A, DataLoader(ooc_digit_test_dataset, BATCH_SIZE))
    print(f"  Original Model A Accuracy on original digits ({CLASS1_LABELS}): {orig_A_acc_orig_digits:.2f}%")
    print(f"  Original Model A Accuracy on target digit ({TARGET_DIGIT_TO_TRANSFER}): {orig_A_acc_target_digit:.2f}%")
    print(f"  Original Model A Accuracy on out-of-class digit ({OOC_DIGIT_TO_MONITOR}): {orig_A_acc_ooc_digit:.2f}%")

    print("\nModel surgery pipeline complete!")

if __name__ == "__main__":
    run_model_surgery_pipeline()


Using device: mps
Loading MNIST dataset...
MNIST dataset loaded.

--- Starting Model Surgery Pipeline ---
Loading weights from: ./trained_models/class1_models_weights.pt
Loading weights from: ./trained_models/class2_models_weights.pt
Loaded Model A (trained on 0,1,2,3).
Loaded Model B (trained on 2,3,4,5).

Step 2: Training a linear probe on Model B for 'digit 4 vs not-4'


Collecting probe training data from Model B's domain: 100%|██████████| 23352/23352 [00:00<00:00, 26822.04it/s]


  Training the probe network...
  Probe W4 extracted. Shape: torch.Size([32])

Step 3: Aligning hidden bases using Orthogonal Procrustes
  Orthogonal Procrustes matrix R computed. Shape: torch.Size([32, 32])

Step 4: Transporting the probe
  Transported probe W_tilde_4 computed. Shape: torch.Size([32])

Step 5: Locating the behaviour region in Model A's classifier layer
  Selected rows for surgical edit (indices): [5]
  Corresponding cosine similarities: [-0.18029682338237762]

Step 6: Performing surgical edit on Model A's classifier weights
  Surgical edit applied to 1 row(s) of Model A's fc3.weight.

Step 7: Copying output weight and bias for class 4 from Model B to Model A
  Class 4 weight and bias copied from Model B to Model A's output layer.
Surgically modified Model A created and loaded with new weights.

Step 8: Performing Sanity Tests on Surgically Modified Model A
  Surgically Modified Model A Accuracy on original digits ([0, 1, 2, 3]): 98.92%
  Surgically Modified Model A Ac

In [None]:
# Try a different approach - let's improve the surgical method itself
print("--- Improved Model Surgery Attempt ---")

# Let's try a more aggressive approach and also check the alignment quality
def improved_model_surgery():
    # Check alignment quality first
    print("Checking alignment quality...")
    sample_features_A = H_A_shared[:100]  # Take first 100 samples
    sample_features_B = H_B_shared[:100]
    
    # Apply alignment and check reconstruction error
    aligned_B_features = sample_features_B @ R.cpu()
    reconstruction_error = torch.norm(sample_features_A - aligned_B_features) / torch.norm(sample_features_A)
    print(f"Reconstruction error after alignment: {reconstruction_error:.4f}")
    
    # Try multiple approaches for better surgery
    print("\nTrying improved surgical approaches...")
    
    # Approach 1: More aggressive alpha and more rows
    ALPHA_IMPROVED = 1.5
    K_ROWS_IMPROVED = 3
    
    # Get fresh copy of model A
    improved_model_A = SimpleNN().to(DEVICE) 
    improved_model_A.load_state_dict(model_A.state_dict())
    
    # Apply more aggressive surgery
    with torch.no_grad():
        # Get current classifier weights
        classifier_weights = improved_model_A.fc3.weight.data
        
        # Find rows with most negative cosine similarity (top 3 this time)
        cosine_sims = torch.nn.functional.cosine_similarity(
            classifier_weights, W_tilde_4.unsqueeze(0), dim=1
        )
        bottom_k_indices = torch.argsort(cosine_sims)[:K_ROWS_IMPROVED]
        
        print(f"Modifying rows: {bottom_k_indices.tolist()}")
        print(f"Their cosine similarities: {cosine_sims[bottom_k_indices].tolist()}")
        
        # Apply surgical edit with higher alpha
        for idx in bottom_k_indices:
            classifier_weights[idx] += ALPHA_IMPROVED * W_tilde_4
        
        # Also directly set the row for digit 4 to be the transported probe + original row
        classifier_weights[4] = model_B.fc3.weight.data[4] + 0.3 * W_tilde_4
        improved_model_A.fc3.bias.data[4] = model_B.fc3.bias.data[4]
    
    # Test improved model
    print("\nTesting improved surgical model...")
    acc_orig = evaluate_model_accuracy(improved_model_A, DataLoader(eval_set_A_labels_test_dataset, BATCH_SIZE))
    acc_target = evaluate_model_accuracy(improved_model_A, DataLoader(target_digit_test_dataset, BATCH_SIZE))
    acc_ooc = evaluate_model_accuracy(improved_model_A, DataLoader(ooc_digit_test_dataset, BATCH_SIZE))
    
    print(f"Improved Model - Original digits accuracy: {acc_orig:.2f}%")
    print(f"Improved Model - Target digit 4 accuracy: {acc_target:.2f}%")
    print(f"Improved Model - OOC digit 5 accuracy: {acc_ooc:.2f}%")
    
    return improved_model_A

# Run improved surgery
improved_model = improved_model_surgery()

In [None]:
# Debug: Let's check the probe training and model predictions more carefully
import matplotlib.pyplot as plt

def debug_probe_performance():
    """Debug the probe training to see if it's actually learning the digit-4 pattern"""
    print("--- Debugging Probe Performance ---")
    
    # Re-create the probe training data for evaluation
    probe_data_B_raw = []
    probe_labels_B_binary = []
    
    for img, label in probe_train_dataset_B:
        probe_data_B_raw.append(img)
        probe_labels_B_binary.append(1 if label == TARGET_DIGIT_TO_TRANSFER else 0)

    probe_data_B_raw = torch.stack(probe_data_B_raw)
    probe_labels_B_binary = torch.tensor(probe_labels_B_binary, dtype=torch.float32)

    # Extract features using Model B
    model_B.eval()
    with torch.no_grad():
        probe_features = model_B.get_hidden_features(probe_data_B_raw.to(DEVICE))
    
    # Test probe predictions
    probe_net.eval()
    with torch.no_grad():
        probe_logits = probe_net(probe_features).squeeze()
        probe_probs = torch.sigmoid(probe_logits)
        probe_preds = (probe_probs > 0.5).float()
    
    # Calculate probe accuracy
    accuracy = (probe_preds.cpu() == probe_labels_B_binary).float().mean()
    print(f"Probe accuracy on training data: {accuracy:.3f}")
    
    # Check distribution of predictions for each digit
    digit_4_mask = probe_labels_B_binary == 1
    digit_not4_mask = probe_labels_B_binary == 0
    
    print(f"Probe predictions for digit 4 (should be high): {probe_probs[digit_4_mask].mean():.3f}")
    print(f"Probe predictions for non-4 digits (should be low): {probe_probs[digit_not4_mask].mean():.3f}")
    
    return probe_features, probe_labels_B_binary, probe_probs

def debug_model_predictions():
    """Debug what the surgically modified model is actually predicting"""
    print("\n--- Debugging Model Predictions ---")
    
    # Test on a few examples of digit 4
    digit_4_loader = DataLoader(target_digit_test_dataset, batch_size=10, shuffle=False)
    
    surgically_modified_model_A.eval()
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(digit_4_loader):
            if batch_idx > 0:  # Only check first batch
                break
                
            data, target = data.to(DEVICE), target.to(DEVICE)
            
            # Get predictions
            logits = surgically_modified_model_A(data)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(logits, dim=1)
            
            print(f"Target labels: {target[:5].cpu()}")
            print(f"Predicted labels: {preds[:5].cpu()}")
            print(f"Max probabilities: {probs.max(dim=1)[0][:5].cpu()}")
            print(f"Prob for digit 4: {probs[:5, 4].cpu()}")
            
            # Also check what original Model A would predict
            original_logits = model_A(data)
            original_preds = torch.argmax(original_logits, dim=1)
            print(f"Original Model A predictions: {original_preds[:5].cpu()}")

# Run debugging
probe_features, probe_labels, probe_probs = debug_probe_performance()
debug_model_predictions()