In [11]:
import numpy as np
import tensorflow as tf
import time
import pandas as pd
from tqdm import trange
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping
import seaborn as sns
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist, fashion_mnist, cifar10


In [12]:
# ---------------------------
seed = 42
encoder_epochs = 100
batch_size_encoder = 128
pretrain_encoder = True   # set False to skip
feature_dim = 256         # encoder output size
Nhid = 300                # hidden spiking neurons
Nout = 2                 # CIFAR-10 classes
T = 25                    # timesteps for spikes

# Spiking neuron parameters
lam = 0.9              # leak (membrane decay)
theta_h_base = 0.5     # hidden threshold base
theta_o = 0.5          # output threshold
eta_out = 5e-4         # LR for W2 (supervised Hebbian)
eta_in = 2e-4          # LR for W1 (κ-based SADP)
decay = 0.9995         # mild decay (multiplicative) per update
norm_eps = 1e-6
clip_w2 = 5.0


In [13]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import numpy as np
import os

def load_dataset(root='Data', img_size=(28, 28), batch_size=64):
    """
    Loads and preprocesses custom Lung Image dataset (3 classes).
    Automatically splits into training and testing sets (80/20 split).
    Each subfolder inside `root` should correspond to one class.
    """

    datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

    def load_split(split):
        gen = datagen.flow_from_directory(
            root,
            target_size=img_size,
            color_mode='rgb',     # assuming colored images
            batch_size=batch_size,
            class_mode='sparse',
            subset=split,          # 'training' or 'validation'
            shuffle=False
        )

        num_samples = gen.samples
        x = np.zeros((num_samples, *img_size, 3), dtype=np.float32)
        y = np.zeros((num_samples,), dtype=np.int32)

        idx = 0
        for bx, by in gen:
            bsize = bx.shape[0]
            x[idx:idx+bsize] = bx
            y[idx:idx+bsize] = by
            idx += bsize
            if idx >= num_samples:
                break
        return x, y

    # ---- Load train and test sets ----
    (x_train, y_train) = load_split('training')
    (x_test, y_test) = load_split('validation')

    input_shape = x_train.shape[1:]
    print(f"✅ Loaded LUNG_IMAGE_SETS dataset successfully")
    print(f"Input shape: {input_shape}")
    print(f"Train: {x_train.shape}, Test: {x_test.shape}")
    print(f"Classes found: {len(np.unique(y_train))}")

    return (x_train, y_train), (x_test, y_test), input_shape


In [4]:
# ---------------------------
# Poisson Encoder
# ---------------------------
def poisson_encode_features(batch_feats, T):
    B, Nin_local = batch_feats.shape
    rnd = np.random.rand(B, T, Nin_local).astype(np.float32)
    spikes = (rnd < batch_feats[:, None, :]).astype(np.float32)
    return spikes

In [5]:
from tensorflow.keras import layers, models, regularizers

def build_encoder(input_shape=(28, 28, 3), output_dim=256, dropout_rate=0.4, l2_reg=1e-4):
    """
    Robust CNN encoder optimized for small 28x28 color retinal images.
    Outputs a feature vector of size `output_dim`.
    """

    inp = layers.Input(shape=input_shape)

    # --- Block 1 (Keep spatial info early) ---
    x = layers.Conv2D(64, 3, padding='same', activation='relu',
                      kernel_regularizer=regularizers.l2(l2_reg))(inp)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(64, 3, padding='same', activation='relu',
                      kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(2)(x)
    x = layers.Dropout(0.25)(x)

    # --- Block 2 ---
    x = layers.Conv2D(128, 3, padding='same', activation='relu',
                      kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv2D(128, 3, padding='same', activation='relu',
                      kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(2)(x)
    x = layers.Dropout(0.3)(x)

    # --- Block 3 (compact high-level features) ---
    x = layers.Conv2D(256, 3, padding='same', activation='relu',
                      kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.GlobalAveragePooling2D()(x)

    # --- Dense Projection Head ---
    x = layers.Dense(512, activation='relu', kernel_regularizer=regularizers.l2(l2_reg))(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout_rate)(x)
    out = layers.Dense(output_dim, activation='relu')(x)

    model = models.Model(inp, out, name='encoder_small_retina')
    return model


In [6]:
# ---------------------------
# LIF forward pass (1SADP / 2SADP)
def forward_lif_features(x_batch_feats, y_targets=None, teacher_force=False,
                         architecture='1SADP', T_local=25):
    global W1, W1_2, W2, theta_h, theta_h2, Nin, Nout, Nhid_global, lam, theta_o

    B = x_batch_feats.shape[0]

    # --------------------------
    # Membrane potentials
    # --------------------------
    Vh  = np.zeros((B, Nhid_global), dtype=np.float32)
    Vh2 = np.zeros((B, Nhid_global), dtype=np.float32) if architecture == '2SADP' else None
    Vo  = np.zeros((B, Nout), dtype=np.float32)

    # --------------------------
    # Spike storage
    # --------------------------
    spikes_h  = np.zeros((B, T_local, Nhid_global), dtype=np.float32)
    spikes_h2 = np.zeros((B, T_local, Nhid_global), dtype=np.float32) if architecture=='2SADP' else None
    spikes_o  = np.zeros((B, T_local, Nout), dtype=np.float32)

    # --------------------------
    # Input Poisson spikes
    # --------------------------
    S_in = poisson_encode_features(x_batch_feats, T_local)

    # --------------------------
    # Time loop
    # --------------------------
    for t in range(T_local):

        # ===== Hidden Layer 1 =====
        I_h = np.einsum('bi,ij->bj', S_in[:, t], W1)
        Vh = lam * Vh + I_h

        spk_h = (Vh > theta_h).astype(np.float32)
        Vh[spk_h == 1] = 0.0
        spikes_h[:, t, :] = spk_h

        # ===== Hidden Layer 2 (2SADP) =====
        if architecture == '2SADP':
            I_h2 = np.einsum('bi,ij->bj', spk_h, W1_2)
            Vh2 = lam * Vh2 + I_h2

            spk_h2 = (Vh2 > theta_h2).astype(np.float32)
            Vh2[spk_h2 == 1] = 0.0
            spikes_h2[:, t, :] = spk_h2

            spk_to_output = spk_h2
        else:
            spk_to_output = spk_h

        # ===== Output Layer =====
        I_o = np.einsum('bi,ij->bj', spk_to_output, W2)
        Vo = lam * Vo + I_o

        spk_o = (Vo > theta_o).astype(np.float32)
        Vo[spk_o == 1] = 0.0
        spikes_o[:, t, :] = spk_o

    # Return spikes
    return (spikes_h, spikes_h2, spikes_o) if architecture=='2SADP' else (spikes_h, spikes_o)



# ---------------------------
# Weight update per batch
# ---------------------------
# def update_weights_batch(x_batch_feats, y_batch, spikes_h, spikes_o, architecture='1SADP'):
#     global W1, W1_2, W2

#     B = x_batch_feats.shape[0]
#     targets = np.zeros((B, Nout), dtype=np.float32)
#     targets[np.arange(B), y_batch] = 1.0

#     # output spike counts
#     out_counts = spikes_o.sum(axis=1)
#     preds = np.argmax(out_counts, axis=1)

#     # Errors
#     errors = targets[:, None, :] - spikes_o
#     dW2 = np.einsum('bti,btj->ij', spikes_h if architecture=='1SADP' else spikes_h, errors) / float(B)
#     W2 += eta_out * dW2

#     # Kappa-based SADP for W1
#     batch_idx = np.arange(B)[:, None]
#     time_idx = np.arange(spikes_o.shape[1])[None, :]
#     class_idx = y_batch[:, None]
#     target_spikes = spikes_o[batch_idx, time_idx, class_idx].astype(np.float32)

#     agree = np.mean(spikes_h == target_spikes[..., None], axis=1).astype(np.float32)
#     pa = np.mean(spikes_h, axis=1).astype(np.float32)
#     pb = np.mean(target_spikes, axis=1).astype(np.float32)[:, None]
#     pe = pa * pb + (1 - pa) * (1 - pb)
#     kappa_vals = (agree - pe) / (1.0 - pe + 1e-9)

#     dW1 = np.einsum('bi,bj->ij', x_batch_feats, kappa_vals) / float(B)
#     W1 += eta_in * dW1

#     # decay & normalize
#     W1 *= decay
#     W2 *= decay
#     W1 /= (np.linalg.norm(W1, axis=0, keepdims=True) + norm_eps)
#     np.clip(W2, -clip_w2, clip_w2, out=W2)

#     return preds, float(np.mean(kappa_vals))



def update_weights_batch(x_batch_feats, y_batch,
                         spikes_h, spikes_h2, spikes_o,
                         architecture='1SADP'):
    global W1, W1_2, W2

    B, T, _ = spikes_o.shape

    # ---------------------------
    # One-hot targets
    # ---------------------------
    targets = np.zeros((B, Nout), dtype=np.float32)
    targets[np.arange(B), y_batch] = 1.0

    # ---------------------------
    # Prediction
    # ---------------------------
    out_counts = spikes_o.sum(axis=1)
    preds = np.argmax(out_counts, axis=1)

    # ======================================================
    # OUTPUT LAYER UPDATE (supervised Hebbian)
    # ======================================================
    errors = targets[:, None, :] - spikes_o
    pre_out = spikes_h2 if architecture == '2SADP' else spikes_h

    dW2 = np.einsum('bti,btj->ij', pre_out, errors) / B
    W2 += eta_out * dW2

    # ======================================================
    # TARGET SPIKES (correct-class output neuron) 
    # ======================================================
    target_spikes = spikes_o[np.arange(B), :, y_batch].astype(np.float32)

    # ======================================================
    # W1 SADP UPDATE (Input → Hidden-1)
    # ======================================================
    agree1 = np.mean(spikes_h == target_spikes[:, :, None], axis=1)

    pa1 = np.mean(spikes_h, axis=1)
    pb  = np.mean(target_spikes, axis=1)[:, None]

    pe1 = pa1 * pb + (1.0 - pa1) * (1.0 - pb)
    kappa1 = (agree1 - pe1) / (1.0 - pe1 + 1e-9)

    dW1 = np.einsum('bi,bj->ij', x_batch_feats, kappa1) / B
    W1 += eta_in * dW1

    # ======================================================
    # W1_2 SADP UPDATE (Hidden-1 → Hidden-2)
    # ======================================================
    if architecture == '2SADP':
        agree2 = np.mean(spikes_h2 == target_spikes[:, :, None], axis=1)
        pa2 = np.mean(spikes_h2, axis=1)

        pe2 = pa2 * pb + (1.0 - pa2) * (1.0 - pb)
        kappa2 = (agree2 - pe2) / (1.0 - pe2 + 1e-9)

        dW1_2 = np.einsum(
            'bi,bj->ij',
            np.mean(spikes_h, axis=1),
            kappa2
        ) / B

        W1_2 += eta_in * dW1_2
        W1_2 *= decay
        W1_2 /= (np.linalg.norm(W1_2, axis=0, keepdims=True) + norm_eps)

    # ---------------------------
    # Decay & normalize
    # ---------------------------
    W1 *= decay
    W2 *= decay

    W1 /= (np.linalg.norm(W1, axis=0, keepdims=True) + norm_eps)
    np.clip(W2, -clip_w2, clip_w2, out=W2)

    return preds, float(np.mean(kappa1))




In [7]:
def train_snn(n_epochs=25, batch_size=128, n_samples=None):
    global W1, W2   # ensure weight updates are applied
    
    n_train = len(train_feats_norm) if n_samples is None else min(n_samples, len(train_feats_norm))
    n_epochs_run = 0
    epoch_accuracies = []

    for epoch in range(n_epochs):
        n_epochs_run += 1
        t0 = time.time()

        idx = np.random.permutation(len(train_feats_norm))[:n_train]
        num_batches = n_train // batch_size

        correct = 0
        kappa_log = []

        for bi in trange(num_batches, desc=f"SNN Epoch {epoch+1}/{n_epochs}"):
            s = bi * batch_size
            e = s + batch_size
            batch_idx = idx[s:e]

            Xb = train_feats_norm[batch_idx]
            yb = y_train[batch_idx]

            # forward
            sh, so = forward_lif_features(Xb, teacher_force=False)

           
            preds, kappa_val = update_weights_batch(
                Xb, yb,
                spikes_h=sh,
                spikes_h2=None,
                spikes_o=so,
                architecture='1SADP'
            )

            correct += np.sum(preds == yb)
            kappa_log.append(kappa_val)

        epoch_time = time.time() - t0
        acc = correct / float(n_train)
        avg_k = float(np.mean(kappa_log)) if kappa_log else 0.0
        epoch_accuracies.append(acc)

        print(
            f"Epoch {epoch+1}/{n_epochs} | "
            f"Train acc = {acc:.4f} | "
            f"avg κ = {avg_k:.6f} | "
            f"Time = {epoch_time:.2f}s"
        )

    return None, n_epochs_run, epoch_accuracies


In [8]:
# -----------------------------------------------------
# Evaluate SNN Function
# -----------------------------------------------------
def evaluate_snn(n_samples=None, batch_size=256):
    n_test = len(test_feats_norm) if n_samples is None else min(n_samples, len(test_feats_norm))
    idx = np.random.permutation(len(test_feats_norm))[:n_test]
    num_batches = int(np.ceil(n_test / batch_size))
    correct = 0

    for bi in trange(num_batches, desc="SNN Testing"):
        s = bi * batch_size
        e = min(s + batch_size, n_test)
        batch_idx = idx[s:e]

        Xb = test_feats_norm[batch_idx]
        yb = y_test[batch_idx]

        sh, so = forward_lif_features(Xb, teacher_force=False)
        preds = np.argmax(so.sum(axis=1), axis=1)
        correct += np.sum(preds == yb)

    acc = correct / float(n_test)
    print(f"SNN Test Acc = {acc:.4f}")
    return acc

In [9]:
def run_experiment(dataset_name=None, feature_dim=128, encoder_epochs=10,
                   batch_size_encoder=128, pretrain_encoder=True,
                   Nhid=256, theta_h_base=0.5, seed=42,
                   architecture='1SADP', T=25,
                   encoding_type='cnn+poisson'):
    """
    encoding_type: 'poisson_only' or 'cnn+poisson'
    architecture: '1SADP' or '2SADP'
    T: temporal dimension (timesteps)
    """

    global x_train, y_train, x_test, y_test
    global train_feats_norm, test_feats_norm
    global W1, W1_2, W2, theta_h, theta_h2, Nin, Nout, Nhid_global, lam, theta_o

    print("\n" + "="*60)
    print(f"Running experiment | {architecture} | T={T} | Encoding: {encoding_type}")
    print("="*60)

    tf.random.set_seed(seed)
    np.random.seed(seed)

    # ---------------------------
    # Use loaded dataset if available
    # ---------------------------
    if 'x_train' not in globals() or 'x_test' not in globals():
        if dataset_name is None:
            raise ValueError("No dataset loaded and no dataset_name provided.")
        else:
            print(f"Loading dataset: {dataset_name}")
            (x_train, y_train), (x_test, y_test), input_shape = load_dataset(dataset_name)
    else:
        print("Using already loaded dataset...")
        input_shape = x_train.shape[1:]

    # ---------------------------
    # Encoder + feature extraction
    # ---------------------------
    if encoding_type == 'cnn+poisson':
        encoder = build_encoder(input_shape=input_shape, output_dim=feature_dim)

        if pretrain_encoder:
            inp = encoder.input
            feat = encoder.output
            out = layers.Dense(len(np.unique(y_train)), activation='softmax')(feat)
            clf = models.Model(inp, out)
            clf.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
            print("Pretraining encoder on provided dataset...")
            clf.fit(x_train, y_train, validation_split=0.1,
                    epochs=encoder_epochs, batch_size=batch_size_encoder,
                    verbose=2)

        print("Extracting features from CNN...")
        train_feats = encoder.predict(x_train, batch_size=256, verbose=1)
        test_feats  = encoder.predict(x_test, batch_size=256, verbose=1)

        minf = train_feats.min(axis=0, keepdims=True)
        maxf = train_feats.max(axis=0, keepdims=True)
        rangef = (maxf - minf) + 1e-9
        train_feats_norm = (train_feats - minf) / rangef
        test_feats_norm  = (test_feats  - minf) / rangef

    elif encoding_type == 'poisson_only':
        train_feats_norm = x_train.reshape(len(x_train), -1).astype(np.float32)
        test_feats_norm = x_test.reshape(len(x_test), -1).astype(np.float32)

        train_feats_norm /= train_feats_norm.max()
        test_feats_norm /= test_feats_norm.max()

    else:
        raise ValueError("encoding_type must be 'poisson_only' or 'cnn+poisson'")

    # ---------------------------
    # Initialize network
    # ---------------------------
    Nin = train_feats_norm.shape[1]
    Nout = len(np.unique(y_train))
    Nhid_global = Nhid
    lam = 0.9
    theta_o = 0.5

    np.random.seed(seed)
    W1 = np.random.normal(0, 0.1, (Nin, Nhid_global)).astype(np.float32)
    W2 = np.random.normal(0, 0.1, (Nhid_global, Nout)).astype(np.float32)
    theta_h = (theta_h_base + 0.05 * np.random.randn(Nhid_global)).astype(np.float32)

    if architecture == '2SADP':
        W1_2 = np.random.normal(0, 0.1, (Nhid_global, Nhid_global)).astype(np.float32)
        theta_h2 = (theta_h_base + 0.05 * np.random.randn(Nhid_global)).astype(np.float32)
        print(f"Initialized W1_2: {W1_2.shape}, theta_h2: {theta_h2.shape}")
    else:
        W1_2, theta_h2 = None, None

    print(f"Feature dim (Nin): {Nin}, Nhid: {Nhid_global}, Nout: {Nout}")
    if architecture == '1SADP':
        print(f"W1: {W1.shape}, W2: {W2.shape}, θ_h: {theta_h.shape}")
    else:
        print(f"W1: {W1.shape}, W1_2: {W1_2.shape}, W2: {W2.shape}, θ_h: {theta_h.shape}, θ_h2: {theta_h2.shape}")
