In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

import numpy as np
#import utils 
from utils import draw_weights, selectivity_metric

In [None]:
class HebbNet:

    def __init__(self, input_dim: int, n_units: int = 64, rule: str = "oja", norm_weights=False, mode="normal"):
        assert rule in ("oja", "hebb", "BCM"), "rule must be 'oja' or 'hebb' or 'BCM' " 
        assert mode in ("normal", "WTA"), "rule must be 'normal' or 'WTA'"
        self.rule = rule
        self.mode = mode
        self.norm_weights = norm_weights
        self.rng  = np.random.default_rng()
        self.input_dim = input_dim
        self.W    = self.rng.normal(0, 0.1, size=(n_units, input_dim)).astype(np.float32)
        self.W   /= np.linalg.norm(self.W, axis=1, keepdims=True) + 1e-9
        self.threshold = np.zeros(n_units) # self.rng.normal(0, 0.1, size=(n_units)).astype(np.float32)

        # BCM hps
        self.gamma = 0.5
        self.tau_th= 0.09
        self.tau_w = 0.001

    # Your job: Forward + learning
    def forward(self, X: np.ndarray) -> np.ndarray:
        # implement the forward pass for the hebbain activations.

        return X @ self.W.T

    def update(self, X: np.ndarray, eta: float):
        # implement the hebbian learning rules (and with winner take all). Do this for a whole batch at a time!
        # you can switch between different rules based on self.rule
        Y = self.forward(X)        

        if self.mode == "normal":
            # iterate over batch
            for i in range(batch_size):
                    
                # pick corresponding input and output vectors, otherwise dims when x * y wont work out 
                x_vec = X[i,:] 
                y_vec = Y[i,:]

                if self.rule == "hebb":
                     self.W += eta * np.outer(y_vec, x_vec)

                if self.rule == "oja":          
                    y_vec = y_vec.reshape(-1)               
                    delta_w = np.outer(y_vec , ( x_vec - np.dot(y_vec, self.W)))
                    self.W += eta * delta_w

            self.W /= batch_size # after iterating over batch div with batchsize 

            if self.norm_weights:
                self.W /= np.linalg.norm(self.W) + 1e-9

            
        if self.mode == "WTA": 
            # pick winners per input vec (img)
            winners = Y.argmax(axis=1)
            for i, w in enumerate(winners):
                x_vec = X[i]    # pick input that the neuron "won"
                y = Y[i, w]     # pick output (scalar) of winning neuron    

                if self.rule == "hebb": 
                        self.W[w] += eta * y * x_vec 

                if self.rule == "oja":
                    delta_w = y * ( x_vec - y * self.W[w])
                    self.W[w] += eta * delta_w

                if self.rule == "BCM":
                    # calc current threshold 
                    cur_threshold =  np.mean(Y**2, axis=0)[w] # avg of neuron over batch 
                    print(cur_threshold)

                    # update threshold using moving avg 
                    self.threshold[w] = ( self.gamma*self.threshold[w] + (1-self.gamma) * cur_threshold ) * self.tau_th

                    # calc delta W
                    t  = cur_threshold #self.threshold[w]
                    delta_w = ( (y * ( y - t ) )  * x_vec) / t
                    delta_w /= X.shape[0] # normalize the weights update according to the number of samples

                    # update weights
                    self.W[w] += delta_w * self.tau_w  
                    #print("w", self.weights)

                if self.norm_weights:
                    self.W[w] /= np.linalg.norm(self.W[w]) + 1e-9

    # ----- helpers ------
    
    # Supervised linear read‑out (ridge regression)
    def train_linear_classifier(self, X: np.ndarray, y: np.ndarray, reg: float = 1e-3):
        """Fit `W_out` via linear regression. This is technically a "delta rule".
        """
        A = self.forward(X)                 # (N, H)
        N, H = A.shape
        classes = int(y.max()) + 1
        # One‑hot encode labels → Y (N, C)
        Y = np.zeros((N, classes), dtype=np.float32)
        Y[np.arange(N), y] = 1.0
        # Closed‑form ridge solution
        I = np.eye(H, dtype=np.float32)
        self.W_out = np.linalg.solve(A.T @ A + reg * I, A.T @ Y)  # (H, C), safe in the object

    def predict(self, X: np.ndarray) -> np.ndarray:
        """Return predicted class indices using the trained linear read‑out."""
        if self.W_out is None:
            raise RuntimeError("Linear classifier not trained. Call train_linear_classifier first.")
        logits = self.forward(X) @ self.W_out  # (B, C)
        return logits.argmax(axis=1)

    def linear_accuracy(self, X: np.ndarray, y: np.ndarray) -> float:
        """Accuracy using the supervised read‑out (must be trained)."""
        preds = self.predict(X)
        return float((preds == y).mean())

    # Evaluation with majority votes.
    def majority_labels(self, X: np.ndarray, y: np.ndarray) -> list[int]:
        winners = self.forward(X).argmax(axis=1)
        bag = [[] for _ in range(self.W.shape[0])]
        for w, lbl in zip(winners, y):
            bag[w].append(lbl)
        return [max(set(b), key=b.count) if b else -1 for b in bag]

    def accuracy(self, X: np.ndarray, y: np.ndarray, unit_labels: Sequence[int]) -> float:
        preds = np.take(unit_labels, self.forward(X).argmax(axis=1))
        return float((preds == y).mean())




In [None]:


# load data
# data_np = np.loadtxt("gratings.csv", delimiter=",")
# data = torch.from_numpy(data_np)
# n_samples, n_in = data.shape

from sklearn.datasets import fetch_openml
data, y = fetch_openml(name='mnist_784', version=1, data_id=None, return_X_y=True)
data = torch.tensor(data.to_numpy(), dtype=torch.float64)
data /= 255.
data -= torch.mean(data, dim=0) # balance data around mean 
# for mnist eps=0.001 !


# preprocess


print(data.shape)
n_samples, n_in = data.shape

# hps #
n_epochs = 5
n_units = 10
batch_size = 100
tau_th = 0.05
tau_w = 0.001

model = BCM_Model_Curti(n_in, n_units)

# learning loop
for epoch in range(n_epochs):

    data = data[torch.randperm(n_samples)]  # Shuffle the input data

    # Iterate over all minibatches
    for i in range(n_samples // batch_size):
        minibatch = data[i * batch_size:(i + 1) * batch_size].T # transform to shape (n_in, batch_size)
        model.update(minibatch, tau_th=tau_th, tau_w=tau_w)
    draw_weights(model.weights, epoch)

