# Imports and Setup

In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, Image
import random
import torch

# Add src to path to import modules
sys.path.append(os.path.abspath(os.path.join('..', 'src')))

from src import config
from src.train_phase1 import run_phase1
from src.phase1_utils import compute_and_save_importance, create_consensus_mask, get_param_names
from src.visualization import plot_matrix
from src.train_phase2 import run_phase2_lodo
from src.models import FeedForwardHead

# --- Reproducibility ---
random.seed(config.SEED)
np.random.seed(config.SEED)
torch.manual_seed(config.SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(config.SEED)

print("Setup complete. Using config:")
print(f"  - Device: {config.DEVICE}")
print(f"  - Domains: {config.DOMAINS}")
print(f"  - Phase 1 Domains for Importance: {config.PH1_DOMAINS}")

# Running Phase 1

In [None]:
# This will take some time to run. It trains 3 separate models.
run_phase1()

# Compute Importance and Create Consensus Masks

In [None]:
# Instantiate a temporary model to get parameter names
ffn_temp = FeedForwardHead(config.EMBEDDING_DIM)
param_names = get_param_names(ffn_temp)

# 1. Compute per-domain importance
for domain in config.PH1_DOMAINS:
    print(f"\nComputing importance for domain: {domain}")
    for name in param_names:
        compute_and_save_importance(domain, name)

# 2. Create consensus masks
print("\nCreating consensus masks...")
for name in param_names:
    create_consensus_mask(name, domains=config.PH1_DOMAINS)

print("\nImportance and consensus masks generated successfully.")

# Visualize Weights and Importance Masks

In [None]:
def visualize_all(param_name):
    filename_base = param_name.replace('.', '_')
    
    # --- Visualize Initial and Final Weights ---
    initial_w = np.load(os.path.join(config.PHASE1_WEIGHTS_DIR, "initial", f"{filename_base}.npy"))
    plot_matrix(initial_w, 
                os.path.join(config.PLOTS_DIR, "initial", f"{filename_base}.png"), 
                f"Initial Shared - {param_name}")
    
    for domain in config.PH1_DOMAINS:
        final_w = np.load(os.path.join(config.PHASE1_WEIGHTS_DIR, "final", f"domain_{domain}", f"{filename_base}.npy"))
        plot_matrix(final_w, 
                    os.path.join(config.PLOTS_DIR, f"final_{domain}", f"{filename_base}.png"), 
                    f"Final {domain} - {param_name}")

    # --- Visualize Importance Masks (Norm and Binary) ---
    for domain in config.PH1_DOMAINS:
        norm_mask = np.load(os.path.join(config.PHASE1_IMPORTANCE_DIR, f"domain_{domain}", f"{filename_base}_norm.npy"))
        binary_mask = np.load(os.path.join(config.PHASE1_IMPORTANCE_DIR, f"domain_{domain}", f"{filename_base}_binary.npy"))
        
        plot_matrix(norm_mask, os.path.join(config.PLOTS_DIR, f"importance_{domain}", f"{filename_base}_norm.png"), 
                    f"Importance Norm {domain} - {param_name}", vmin=0, vmax=1)
        plot_matrix(binary_mask, os.path.join(config.PLOTS_DIR, f"importance_{domain}", f"{filename_base}_binary.png"),
                    f"Importance Binary {domain} - {param_name}", cmap='gray_r', vmin=0, vmax=1)

    # --- Visualize Consensus Masks ---
    consensus_norm = np.load(os.path.join(config.PHASE1_IMPORTANCE_DIR, "consensus", f"{filename_base}_norm.npy"))
    consensus_binary = np.load(os.path.join(config.PHASE1_IMPORTANCE_DIR, "consensus", f"{filename_base}_binary.npy"))

    plot_matrix(consensus_norm, os.path.join(config.PLOTS_DIR, "consensus", f"{filename_base}_norm.png"),
                f"Consensus Norm - {param_name}", vmin=0, vmax=1)
    plot_matrix(consensus_binary, os.path.join(config.PLOTS_DIR, "consensus", f"{filename_base}_binary.png"),
                f"Consensus Binary - {param_name}", cmap='gray_r', vmin=0, vmax=1)

# Generate plots for the first hidden layer weight as an example
example_param = 'layers.0.weight'
visualize_all(example_param)

# Display the final consensus mask for inspection
print(f"Displaying Consensus Binary Mask for '{example_param}':")
display(Image(filename=os.path.join(config.PLOTS_DIR, "consensus", f"{example_param.replace('.', '_')}_binary.png")))

# Run Phase 2 LODO Training

In [None]:
lodo_results = {}

for test_domain in config.DOMAINS:
    train_domains = [d for d in config.DOMAINS if d != test_domain]
    
    test_accuracy = run_phase2_lodo(train_domains, test_domain)
    lodo_results[test_domain] = test_accuracy

# Final Results Summary

In [None]:
print("--- LODO Final Results ---")
total_acc = 0
for domain, acc in lodo_results.items():
    print(f"  - Accuracy on held-out '{domain}': {acc:.4f}")
    total_acc += acc

avg_acc = total_acc / len(lodo_results)
print(f"\nAverage LODO Accuracy: {avg_acc:.4f}")

# You can also create a plot here if you wish
plt.figure(figsize=(8, 5))
plt.bar(lodo_results.keys(), lodo_results.values())
plt.ylabel("Accuracy")
plt.title("Leave-One-Domain-Out Test Accuracy")
plt.ylim(0, 1)
for i, (domain, acc) in enumerate(lodo_results.items()):
    plt.text(i, acc + 0.01, f"{acc:.2%}", ha='center')
plt.show()