In [None]:
%run "testing_LOW.ipynb"
%run "testing_MED.ipynb"
%run "testing_HOMEO.ipynb"
%run "testing_HIGH.ipynb"

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import imageio.v3 as iio

# Configuration
CLASS_NAMES = ("circle", "square", "triangle")
VAL_ROOT = Path("/Users/duaanaveed/Downloads/archive/shapes/valid")

# BDNF conditions in presentation order (NOT sorted by performance)
BDNF_CONDITIONS = ["low", "medium", "high_homeostatic", "high"]
CONDITION_LABELS = ["Low", "Medium", "High Homeostatic", "High"]

# File mapping (adjust paths if saved weights have different names)
WEIGHT_FILES = {
    "low": "trained_receptive_fields_low.npz",
    "medium": "trained_receptive_fields_medium.npz", 
    "high_homeostatic": "trained_receptive_fields_homeostatic.npz",
    "high": "trained_receptive_fields_high.npz" 
}

def normalize01(img):
    """Convert image to floats in [0,1] range (supports grayscale/RGB)."""
    img = img.astype(float, copy=False)
    if img.ndim == 3:
        img = img.mean(axis=-1)
    if img.max() > 1.0:
        img = img / 255.0
    return img

def list_images_with_labels(root: Path, class_names):
    """List image paths and integer labels."""
    paths, labels = [], []
    for cls in class_names:
        folder = root / cls
        for f in sorted(folder.glob("*.png")):
            paths.append(f)
            labels.append(class_names.index(cls))
    return paths, np.array(labels, int)

# Load validation dataset once
paths, y_true = list_images_with_labels(VAL_ROOT, CLASS_NAMES)
X = []
for p in paths:
    im = normalize01(iio.imread(p))
    X.append(im)
X = np.stack(X)
print(f"Loaded {len(X)} validation images")

max_rate_scalar = 10.0  # Matching training configuration

# Collect margins for each condition
all_margins = {cond: [] for cond in BDNF_CONDITIONS}

for condition in BDNF_CONDITIONS:
    weight_file = WEIGHT_FILES.get(condition)
    if not Path(weight_file).exists():
        print(f"Warning: {weight_file} not found, skipping {condition}")
        continue
    
    # Load trained weights
    bundle = np.load(weight_file)
    weights = bundle["W"]
    H = int(bundle["H"])
    W = int(bundle["WIMG"])
    
    # Validate image dimensions
    if X[0].shape != (H, W):
        print(f"Warning: Image shape mismatch for {condition}")
        continue
    
    # Encode validation images to rates
    val_rates = (1.0 - X.reshape(len(X), -1)) * max_rate_scalar
    
    # Normalize receptive field weights
    Wuse = weights / (np.linalg.norm(weights, axis=0, keepdims=True) + 1e-12)
    
    # Compute projection scores
    scores = val_rates @ Wuse
    
    # Calculate Top-1 margins (difference between top and runner-up scores)
    margins = scores.max(1) - np.partition(scores, -2, axis=1)[:, -2]
    
    all_margins[condition] = margins
    print(f"{condition}: {len(margins)} margins computed, mean={margins.mean():.3f}")

# Prepare data for violin plot
plot_data = [all_margins[cond] for cond in BDNF_CONDITIONS if len(all_margins[cond]) > 0]
plot_labels = [CONDITION_LABELS[i] for i, cond in enumerate(BDNF_CONDITIONS) if len(all_margins[cond]) > 0]

# Create violin plot
fig, ax = plt.subplots(figsize=(10, 6))

# Create violin plot with consistent styling
parts = ax.violinplot(plot_data, positions=range(len(plot_data)), 
                       showmeans=True, showmedians=True, widths=0.7)

# Style the violin plot
for pc in parts['bodies']:
    pc.set_facecolor('#8888ff')
    pc.set_alpha(0.7)
    pc.set_edgecolor('black')
    pc.set_linewidth(1.5)

# Customize mean/median markers
parts['cmeans'].set_color('red')
parts['cmeans'].set_linewidth(2)
parts['cmedians'].set_color('blue')
parts['cmedians'].set_linewidth(2)

# Set axis properties
ax.set_xticks(range(len(plot_labels)))
ax.set_xticklabels(plot_labels, fontsize=12)
ax.set_xlabel('BDNF Condition', fontsize=14, fontweight='bold')
ax.set_ylabel('Top-1 Margin', fontsize=14, fontweight='bold')
ax.set_title('Top-1 Margin Distribution Across BDNF Conditions', fontsize=16, fontweight='bold')

# Ensure Y-axis starts at 0 and uses same limits for all violins
y_min = 0
y_max = max([m.max() for m in plot_data if len(m) > 0]) * 1.1
ax.set_ylim(y_min, y_max)

# Add grid for readability
ax.grid(axis='y', alpha=0.3, linestyle='--')
ax.set_axisbelow(True)

# Add legend for mean/median
from matplotlib.lines import Line2D
legend_elements = [
    Line2D([0], [0], color='red', linewidth=2, label='Mean'),
    Line2D([0], [0], color='blue', linewidth=2, label='Median')
]
ax.legend(handles=legend_elements, loc='upper right')

plt.tight_layout()
plt.savefig('bdnf_margin_violin_plot.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nSummary Statistics:")
for i, cond in enumerate(BDNF_CONDITIONS):
    if len(all_margins[cond]) > 0:
        margins = all_margins[cond]
        print(f"{CONDITION_LABELS[i]:20s} - Mean: {margins.mean():.4f}, Median: {np.median(margins):.4f}, Std: {margins.std():.4f}")