In [None]:
from brian2 import *
from pathlib import Path
import numpy as np
import imageio.v3 as iio
import matplotlib.pyplot as plt
import gc

# Shared config and IO helpers (used by training + validation cells)
CLASS_NAMES = ("circle", "square", "triangle")
TRAIN_ROOT = Path("/Users/duaanaveed/Downloads/archive/training/shapes/train")
VAL_ROOT = Path("/Users/duaanaveed/Downloads/archive/shapes/valid")

def normalise01(img):
    """Convert image to floats in [0,1] range (supports grayscale/RGB)."""
    img = img.astype(float, copy=False)
    if img.ndim == 3:
        img = img.mean(axis=-1)
    if img.max() > 1.0:
        img = img / 255.0
    return img

def list_images_with_labels(root: Path, class_names):
    """List image paths and integer labels given a root folder and class list."""
    paths, labels = [], []
    for cls in class_names:
        folder = root / cls
        for f in sorted(folder.glob("*.png")):
            paths.append(f)
            labels.append(class_names.index(cls))
    return paths, np.array(labels, int)

def load_images(root: Path, class_names):
    """Load and normalise images; return numpy array and labels."""
    paths, labels = list_images_with_labels(root, class_names)
    imgs = []
    for p in paths:
        try:
            img = normalise01(iio.imread(p))
            imgs.append(img)
        except Exception as e:
            print(f"Skipped {p.name}: {e}")
    return np.stack(imgs), labels, paths

# Setup output directory
directory = "TESTSTDP_medium_epoh_colour_test"
Path(directory).mkdir(parents=True, exist_ok=True)

# Load training images (shared loader)
imgs, labels, train_paths = load_images(TRAIN_ROOT, CLASS_NAMES)
print(f"Loaded {len(imgs)} training images of shape {imgs[0].shape}")

H, W = imgs[0].shape
Npix = H * W

# Simulation parameters
defaultclock.dt = 1*ms
Ton = 300*ms      # Stimulus presentation duration
Tisi = 50*ms      # Inter-stimulus interval
epochs = 5
max_rate = 10*Hz  # Maximum input firing rate
max_rate_scalar = float(max_rate/Hz)

# Precompute Poisson rates for all training images (keeps Brian2 units)
train_rates = (1.0 - imgs.reshape(len(imgs), -1)) * max_rate

def img_to_rates_np(img):
    """Convert image to firing-rate vector (unitless, for numpy-only ops)."""
    return (1.0 - img.reshape(-1)) * max_rate_scalar

# Input layer: Poisson neurons encoding pixel intensities
IN = PoissonGroup(Npix, rates=0*Hz)

# Output layer parameters
tau_m = 20*ms     # Membrane time constant
v_rest = 0*mV     # Resting potential
v_thr = 4*mV      # Spike threshold
v_reset = 0*mV    # Reset potential
tau_e = 5*ms      # Excitatory current decay
tau_i = 10*ms     # Inhibitory current decay

# Output layer: 3 LIF neurons with teacher signal
eqs_out = '''
dv/dt  = (-(v - v_rest) + ge - gi + Iteach) / tau_m : volt
dge/dt = -ge/tau_e : volt
dgi/dt = -gi/tau_i : volt
Iteach : volt
s : integer
'''

OUT = NeuronGroup(
    N=3, model=eqs_out,
    threshold='v>v_thr',
    reset='v=v_reset; s += 1',
    refractory=6*ms,
    method='euler'
)

OUT.v = v_rest
if hasattr(OUT, 'ge'): OUT.ge = 0
if hasattr(OUT, 'gi'): OUT.gi = 0

# BDNF profiles for modulating plasticity
BDNF_PROFILES = {
    "low": {
        "eta_scale": 0.5,
        "ltp_scale": 0.8,
        "ltd_scale": 1.1,
        "norm_every": 1,
    },
    "medium": {
        "eta_scale": 1.0,
        "ltp_scale": 1.0,
        "ltd_scale": 1.0,
        "norm_every": 2,
    },
    "high": {
        "eta_scale": 1.5,
        "ltp_scale": 1.2,
        "ltd_scale": 0.9,
        "norm_every": 3,
    },
    "high_homeostatic": {
        "eta_scale": 1.3,
        "ltp_scale": 1.1,
        "ltd_scale": 0.95,
        "norm_every": 3,
    },
}

BDNF_LEVEL = "medium"

# Base STDP parameters
eta_base = 0.02
Apre_base = 0.015
Apost_base = -0.016

# Apply BDNF profile scaling
_bd = BDNF_PROFILES[BDNF_LEVEL]
eta = eta_base * _bd["eta_scale"]
Apre = Apre_base * _bd["ltp_scale"]
Apost = Apost_base * _bd["ltd_scale"]
norm_every = _bd["norm_every"]

print(f"[BDNF={BDNF_LEVEL}] eta={eta:.4f}, Apre={Apre:.4f}, Apost={Apost:.4f}, norm_every={norm_every}")

# STDP parameters
taupre = 10*ms
taupost = 10*ms
wmax = 0.05
wscale = 5*mV

# Plasticity switch (1.0 during training, 0.0 during testing)
plasticity = 1.0

# Synapses with STDP
S = Synapses(IN, OUT,
    model='''
        w : 1
        dapre/dt  = -apre/taupre  : 1 (event-driven)
        dapost/dt = -apost/taupost : 1 (event-driven)
    ''',
    on_pre='''
        ge_post += w * wscale
        apre += Apre
        w = clip(w + plasticity*eta*apost, 0, wmax)
    ''',
    on_post='''
        apost += Apost
        w = clip(w + plasticity*eta*apre, 0, wmax)
    ''',
    method='euler'
)

# Full connectivity with random initialisation
S.connect(True)
S.w = '0.002*rand()'
S.apre = 0
S.apost = 0

# Lateral inhibition between output neurons
ginh = 0.5*mV
Sinhib = Synapses(OUT, OUT, on_pre='gi_post += ginh')
Sinhib.connect(condition='i != j')
Sinhib.delay = 1*ms

# Spike monitors
spike_inp = SpikeMonitor(IN)
spike_out = SpikeMonitor(OUT)

# Training loop
net = Network(collect())
teacher_amp = 6*mV

correct = 0
seen = 0

for ep in range(epochs):
    order = np.random.permutation(len(imgs))
    for k in order:
        label_scalar = int(labels[k])
        
        # Set input firing rates (precomputed)
        IN.rates = train_rates[k]
        
        # Apply teacher signal to correct output neuron
        OUT.Iteach = 0*volt
        OUT.Iteach[label_scalar] = teacher_amp
        OUT.s = 0
        
        # Present stimulus
        net.run(Ton)
        
        # Make prediction based on spike counts
        counts = OUT.s[:]
        pred = int(np.argmax(counts))
        is_ok = (pred == label_scalar)
        correct += int(is_ok)
        seen += 1
        print(f"img {k:02d} label={labels[k]} counts={counts} → pred={pred} {'✓' if is_ok else '✗'}")
        
        # Inter-stimulus interval: reset activity
        IN.rates = 0*Hz
        OUT.Iteach[:] = 0*mV
        OUT.v = v_rest
        OUT.ge = 0*mV
        OUT.gi = 0*mV
        S.apre = 0
        S.apost = 0
        net.run(Tisi)
    
    # Epoch summary
    print(f"[epoch {ep+1}] running accuracy: {correct}/{seen} = {100*correct/seen:.1f}%")
    
    # Periodic weight normalisation
    if (ep + 1) % norm_every == 0:
        i = np.array(S.i[:])
        j = np.array(S.j[:])
        for out_idx in np.unique(j):
            idx = np.where(j == out_idx)[0]
            total = float(np.sum(S.w[idx]))
            if total > 0:
                S.w[idx] *= (1.0 / total)**0.5
        print(f"[BDNF={BDNF_LEVEL}] soft-normalised incoming weights")

# Disable plasticity after training
plasticity = 0.0

# Build weight matrix for visualisation
n_post = len(np.unique(S.j[:]))
n_pre = len(np.unique(S.i[:]))
print(f"n_pre = {n_pre}, n_post = {n_post}, total weights = {len(S.w[:])}")

Wmat = np.zeros((n_pre, n_post))
for i, j, w in zip(S.i[:], S.j[:], S.w[:]):
    Wmat[i, j] = w

# Save receptive field visualisations
for k in range(n_post):
    fig, ax = plt.subplots()
    im = ax.imshow(Wmat[:, k].reshape(H, W))
    ax.set_title(f'OUT{k} receptive field')
    fig.colorbar(im, ax=ax)
    fig.tight_layout()
    
    out_png = Path(directory) / f"rf_OUT{k}.png"
    fig.savefig(out_png, dpi=150, bbox_inches="tight")
    print(f"saved: {out_png}")
    
    plt.show()
    plt.close(fig)

# Save trained weights
def save_receptive_fields(S, H, W, path="trained_receptive_fields_medium.npz"):
    """Save trained synaptic weights to disk."""
    n_post = len(np.unique(S.j[:]))
    n_pre = H * W
    
    Wdense = np.zeros((n_pre, n_post), dtype=np.float32)
    Wdense[np.asarray(S.i[:]), np.asarray(S.j[:])] = np.asarray(S.w[:], dtype=np.float32)
    
    np.savez_compressed(
        path,
        W=Wdense,
        H=np.int32(H),
        WIMG=np.int32(W),
        n_pre=np.int32(n_pre),
        n_post=np.int32(n_post),
        bdnf_level=np.array(BDNF_LEVEL, dtype=object),
    )
    print(f"[saved] {path}  shape={Wdense.shape}")

save_receptive_fields(S, H, W, "trained_receptive_fields_medium.npz")

# Cleanup
del imgs, train_paths
gc.collect()

In [None]:
from pathlib import Path
import numpy as np
import imageio.v3 as iio

# Configuration (reuse shared helpers from first cell)
CLASS_NAMES = ("circle", "square", "triangle")
VAL_ROOT = Path("/Users/duaanaveed/Downloads/archive/shapes/valid")
use_norm = True

# Load trained receptive fields
bundle = np.load("trained_receptive_fields_medium.npz")
weights = bundle["W"]
H = int(bundle["H"])
W = int(bundle["WIMG"])

# Load validation dataset with shared loader
paths, y_true = list_images_with_labels(VAL_ROOT, list(CLASS_NAMES))
X = []
for p in paths:
    im = normalise01(iio.imread(p))
    if im.shape != (H, W):
        raise ValueError(f"{p.name} has {im.shape}, expected {(H, W)}")
    X.append(im)
X = np.stack(X)
print(f"Loaded {len(X)} validation images of shape {X[0].shape}")

# Precompute rate-encoded validation vectors (unitless for numpy ops)
val_rates = (1.0 - X.reshape(len(X), -1)) * max_rate_scalar

# Normalise receptive field weights if enabled
Wuse = weights / (np.linalg.norm(weights, axis=0, keepdims=True) + 1e-12) if use_norm else weights

# Compute projection scores and predictions
scores = val_rates @ Wuse
y_pred = scores.argmax(1)
margins = scores.max(1) - np.partition(scores, -2, axis=1)[:, -2]

# Evaluation metrics
acc = (y_pred == y_true).mean()
cm = np.zeros((len(CLASS_NAMES), len(CLASS_NAMES)), int)
for t, p in zip(y_true, y_pred):
    cm[t, p] += 1

print(f"Template accuracy (use_norm={use_norm}): {100*acc:.1f}%")
print("Confusion Matrix (rows=true, cols=pred):")
print(cm)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def contribution_map(idx, class_idx):
    """Compute pixel-wise contribution to a specific class prediction."""
    rates = val_rates[idx]
    w = Wuse[:, class_idx]
    return (rates * w).reshape(H, W)

def cosine_scores(idx):
    """Compute cosine similarity between one image and each receptive field."""
    v = val_rates[idx]
    v = v / (np.linalg.norm(v) + 1e-12)
    Wc = Wuse / (np.linalg.norm(Wuse, axis=0, keepdims=True) + 1e-12)
    return v @ Wc

# Visualise single example
k = 40  # Index to examine different images
im = X[k]
true_lbl = CLASS_NAMES[y_true[k]]
pred_lbl = CLASS_NAMES[y_pred[k]]
cos = cosine_scores(k)

fig, axes = plt.subplots(1, 5, figsize=(15, 3))
axes[0].imshow(im, cmap='gray')
axes[0].set_title(f"Image\ntrue={true_lbl}\npred={pred_lbl}")

for c in range(len(CLASS_NAMES)):
    axes[1+c].imshow(contribution_map(k, c), cmap='magma')
    axes[1+c].set_title(f"{CLASS_NAMES[c]}\ncos={cos[c]:.3f}\nscore={scores[k,c]:.1f}")

for ax in axes:
    ax.axis('off')

plt.tight_layout()
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Find hardest and easiest examples by margin
idx_sorted = np.argsort(margins)
hard = idx_sorted[:6]  # Lowest margins (hardest)
easy = idx_sorted[-6:]  # Highest margins (easiest)

def show_panel(idxs, title):
    """Display a panel of images with their contribution maps."""
    n = len(idxs)
    fig, axes = plt.subplots(n, 1+len(CLASS_NAMES), figsize=(12, 2.5*n))
    if n == 1:
        axes = np.expand_dims(axes, 0)
    
    for row, k in enumerate(idxs):
        im = X[k]
        t = y_true[k]
        p = y_pred[k]
        cos = cosine_scores(k)
        
        # Show original image
        axes[row, 0].imshow(im, cmap='gray')
        axes[row, 0].set_title(
            f"{title}\n{paths[k].name}\n"
            f"true={CLASS_NAMES[t]}  pred={CLASS_NAMES[p]}\n"
            f"margin={margins[k]:.2f}"
        )
        axes[row, 0].axis('off')
        
        # Show contribution maps for each class
        for c in range(len(CLASS_NAMES)):
            axes[row, 1+c].imshow(contribution_map(k, c), cmap='magma')
            axes[row, 1+c].set_title(
                f"{CLASS_NAMES[c]}\n"
                f"cos={cos[c]:.2f}\n"
                f"score={scores[k,c]:.1f}"
            )
            axes[row, 1+c].axis('off')
    
    plt.tight_layout()
    plt.show()

show_panel(hard, "Hardest")
show_panel(easy, "Easiest")

In [None]:
import matplotlib.pyplot as plt

def avg_contrib_for_subset(mask, class_idx):
    """Compute average contribution heatmap for a subset of images."""
    idxs = np.where(mask)[0]
    C = np.stack([contribution_map(i, class_idx) for i in idxs], axis=0)
    return C.mean(axis=0) if len(C) else np.zeros((H, W))

# Show average contributions for each true/predicted class combination
for t_idx, t_name in enumerate(CLASS_NAMES):
    for p_idx, p_name in enumerate(CLASS_NAMES):
        mask = (y_true == t_idx) & (y_pred == p_idx)
        if mask.sum() == 0:
            continue
        
        fig, axes = plt.subplots(1, len(CLASS_NAMES)+1, figsize=(12, 3))
        
        # Average input image
        axes[0].imshow(np.mean(X[mask], axis=0), cmap='gray')
        axes[0].set_title(f"Avg IMG\ntrue={t_name}, pred={p_name}\nN={mask.sum()}")
        axes[0].axis('off')
        
        # Average contribution to each class
        for c in range(len(CLASS_NAMES)):
            axes[1+c].imshow(avg_contrib_for_subset(mask, c), cmap='magma')
            axes[1+c].set_title(f"Avg contrib → {CLASS_NAMES[c]}")
            axes[1+c].axis('off')
        
        plt.tight_layout()
        plt.show()

In [None]:
import matplotlib.pyplot as plt

# Score distributions for each class
for c, name in enumerate(CLASS_NAMES):
    plt.figure(figsize=(5, 3))
    plt.hist(scores[y_true==c, c], bins=20, alpha=0.8)
    plt.title(f"Score distribution for true {name} → class {name}")
    plt.xlabel("Raw score")
    plt.ylabel("Count")
    plt.show()

# Margin distributions by true class
plt.figure(figsize=(5, 3))
for c, name in enumerate(CLASS_NAMES):
    plt.hist(margins[y_true==c], bins=20, alpha=0.5, label=name)
plt.title("Top-1 margin by true class")
plt.xlabel("margin (top - runner-up)")
plt.legend()
plt.show()

In [None]:
import numpy as np

# Install pandas if needed
!pip install pandas -q
import pandas as pd

# Get top-2 predictions for each image
top2 = np.argpartition(-scores, 2, axis=1)[:, :2]
top2_scores = np.take_along_axis(scores, top2, axis=1)

# Build detailed results table
rows = []
for i, (a, b) in enumerate(top2):
    s1, s2 = top2_scores[i, np.argsort(-top2_scores[i])]
    a, b = top2[i, np.argsort(-top2_scores[i])]
    rows.append({
        "file": paths[i].name,
        "true": CLASS_NAMES[y_true[i]],
        "pred": CLASS_NAMES[y_pred[i]],
        "runner_up": CLASS_NAMES[b],
        "pred_score": scores[i, a],
        "runner_score": scores[i, b],
        "margin": scores[i, a] - scores[i, b]
    })

df = pd.DataFrame(rows).sort_values("margin")

# Show most and least confident predictions
display(df.head(10))  # Lowest margins (least confident)
display(df.tail(10))  # Highest margins (most confident)