# core

> SOM class implementation

In [None]:
#| default_exp core

In [None]:
#| export
import numpy as np
from sklearn.decomposition import PCA
from fastcore.all import *
import matplotlib.pyplot as plt

In [None]:
#| exports
dist_fn = lambda x, ws: np.linalg.norm(ws-x, axis=-1)

In [None]:
#| exports
class SOM:
    def __init__(self, 
                 grid_sz:tuple, # grid size
                 input_dim:int, # input dimension
                 init:str='random', # initialization method
                #  dist_fn=np.linalg.norm # distance function
                 dist_fn=dist_fn # distance function
                 )->None: # initialize the SOM
        "Initialize a Self-Organizing Map with given grid size and input dimension"
        store_attr()
        self.weights = None  # Will be initialized when fitting

In [None]:
#| exports
@patch
def _initialize_weights_pca(self:SOM, X:np.ndarray)->np.ndarray:
    "Initialize weights using PCA of the input data"
    pca = PCA(n_components=2)
    pca.fit(X)
    
    # Create grid coordinates and scale by eigenvalues
    n = self.grid_sz[0]
    alpha = np.linspace(-1, 1, n) * np.sqrt(pca.explained_variance_[0])
    beta = np.linspace(-1, 1, n) * np.sqrt(pca.explained_variance_[1])
    
    # Create the grid
    alpha_grid, beta_grid = np.meshgrid(alpha, beta)
    
    # Initialize weights as linear combination of first two PCs
    return (alpha_grid[..., np.newaxis] * pca.components_[0] + 
            beta_grid[..., np.newaxis] * pca.components_[1])

In [None]:
#| exports
@patch
def _initialize_weights(self:SOM, 
                        X:np.ndarray=None, # data matrix
                        method:str='random' # initialization method
                        )->np.ndarray: # weights matrix
    "Initialize weights using either random or PCA initialization"
    if method == 'random':
        return np.random.randn(*self.grid_sz, self.input_dim)
    elif method == 'pca':
        if X is None: 
            raise ValueError("Data matrix X required for PCA initialization")
        return self._initialize_weights_pca(X)

In [None]:
#| exports
@patch
def _find_bmu(self:SOM, 
              x:np.ndarray, # input vector
              )->tuple: # coordinates of the best matching unit
    "Find coordinates of Best Matching Unit for input x"
    distances = self.dist_fn(x, self.weights)
    return np.unravel_index(np.argmin(distances), self.grid_sz)

In [None]:
#| exports
@patch
def _grid_distances(self:SOM, 
                    bmu_pos:tuple # coordinates of the best matching unit
                    )->np.ndarray: # grid distances
    "Calculate grid distances from BMU position"
    rows_idx, cols_idx = np.ogrid[0:self.grid_sz[0], 0:self.grid_sz[1]]
    return (bmu_pos[0] - rows_idx)**2 + (bmu_pos[1] - cols_idx)**2

In [None]:
#| exports
@patch
def _neighborhood_function(self:SOM, 
                           grid_dist:np.ndarray, # grid distances
                           sigma:float # neighborhood radius
                           )->np.ndarray: # neighborhood function values
    "Calculate neighborhood function values"
    return np.exp(-grid_dist/(2*sigma**2))

In [None]:
#| exports
@patch
def _update_weights(self:SOM, 
                    x:np.ndarray, # input vector
                    learning_rate:float, # learning rate
                    sigma:float # neighborhood radius
                    )->None: # update weights
    "Update weights for a single input vector"
    bmu_pos = self._find_bmu(x)
    grid_dist = self._grid_distances(bmu_pos)
    neighborhood = self._neighborhood_function(grid_dist, sigma)
    self.weights += learning_rate * neighborhood[..., np.newaxis] * (x - self.weights)

In [None]:
#| exports
def exp_sched(start_val, end_val, i, n_steps):
    decay = -np.log(end_val/start_val)/n_steps
    return start_val * np.exp(-decay * i)

In [None]:
#| exports
class Scheduler:
    def __init__(self, start_val, end_val, step_size, 
                 n_samples, n_epochs, decay_fn=exp_sched):
        store_attr()
        self.current_step = 0
        self.current_value = start_val
        self.total_steps = (n_samples * n_epochs) // step_size
    
    def step(self, total_samples):
        if total_samples % self.step_size == 0:
            self.current_value = self.decay_fn(
                self.start_val, self.end_val, 
                self.current_step, self.total_steps
            )
            self.current_step += 1
        return self.current_value

In [None]:
#| exports
@patch
def fit(self:SOM, 
        X:np.ndarray, # data matrix
        n_epochs:int=20, # number of epochs
        lr_scheduler:Scheduler=None, # learning rate scheduler
        sigma_scheduler:Scheduler=None, # neighborhood radius scheduler
        shuffle:bool=True, # shuffle data
        verbose:bool=True # verbose
        )->tuple: # tuple of weights, quantization errors, and topographic errors
    "Train the SOM on input data X"
    # Initialize weights if not already done
    if self.weights is None:
        self.weights = self._initialize_weights(X, self.init)
    
    # Setup default schedulers if none provided
    if lr_scheduler is None:
        lr_scheduler = Scheduler(1.0, 0.01, 100, len(X), n_epochs)
    if sigma_scheduler is None:
        sigma_scheduler = Scheduler(max(self.grid_sz)/2, 1.0, 100, len(X), n_epochs)
    
    qe_errors, te_errors = [], []
    for epoch in range(n_epochs):
        X_ = np.random.permutation(X) if shuffle else X.copy()
        
        # Train on each input vector
        for i, x in enumerate(X_):
            total_samples = epoch * len(X) + i
            lr = lr_scheduler.step(total_samples)
            sigma = sigma_scheduler.step(total_samples)
            self._update_weights(x, lr, sigma)
        
        # Calculate errors
        qe = self.quantization_error(X)
        te = self.topographic_error(X)
        qe_errors.append(qe)
        te_errors.append(te)
        
        if verbose:
            print(f'Epoch: {epoch+1} | QE: {qe:.4f}, TE: {te:.4f}')
    
    return self.weights, qe_errors, te_errors

In [None]:
#| exports
@patch
def quantization_error(self:SOM, 
                       X:np.ndarray # data matrix
                       )->float: # quantization error
    "Calculate average distance between each input vector and its BMU"
    return np.array([
        self.dist_fn(x, self.weights).min() 
        for x in X
    ]).mean()

In [None]:
#| exports
@patch
def topographic_error(self:SOM, 
                      X:np.ndarray # data matrix
                      )->float: # topographic error
    "Calculate proportion of data vectors where 1st and 2nd BMUs are not adjacent"
    def _check_bmu_adjacency(x):
        # Get indices of two best matching units
        distances = self.dist_fn(x, self.weights)
        flat_indices = np.argpartition(distances.flatten(), 2)[:2]
        indices = np.unravel_index(flat_indices, self.grid_sz)
        # Check if any coordinate differs by more than 1
        return any(np.abs(x-y) > 1 for x,y in indices)
    
    n_errors = sum(_check_bmu_adjacency(x) for x in X)
    return 100 * n_errors / len(X)  # Return percentage

In [None]:
#| exports
@patch
def transform(self:SOM, 
              X:np.ndarray # data matrix
              )->np.ndarray: # bmu coordinates
    "Find Best Matching Unit (BMU) coordinates for each input vector"
    bmu_coords = np.zeros((len(X), 2), dtype=int)
    for i, x in enumerate(X):
        bmu_coords[i] = self._find_bmu(x)
    return bmu_coords

In [None]:
#| exports
@patch  
def predict(self:SOM, 
            X:np.ndarray # data matrix
            )->np.ndarray: # bmu coordinates
    "Alias for transform method to follow sklearn convention"
    return self.transform(X)

In [None]:
#| exports
@patch  
def _calculate_umatrix(self:SOM)->np.ndarray:
    "Calculate U-Matrix values for current weights"
    def _neighbor_distances(pos):
        # Offsets for 8 neighbors
        nbr_offsets = [
            (-1,-1), (-1,0), (-1,1),  # top-left, top, top-right
            (0,-1),          (0,1),   # left, right
            (1,-1),  (1,0),  (1,1)    # bottom-left, bottom, bottom-right
        ]
    
        distances = []
        weights = []
        for dr, dc in nbr_offsets:
            r, c = pos
            nbr_r, nbr_c = r+dr, c+dc
            if (nbr_r >= 0 and nbr_r < self.grid_sz[0] and 
                nbr_c >= 0 and nbr_c < self.grid_sz[1]):
                w = 1/np.sqrt(dr**2 + dc**2)  # weight by distance
                weights.append(w)
                d = self.dist_fn(self.weights[r,c], self.weights[nbr_r,nbr_c])
                distances.append(d)
        return np.average(distances, weights=weights)
        
    umatrix = np.zeros(self.grid_sz)
    for i, j in np.ndindex(self.grid_sz):
        umatrix[i,j] = _neighbor_distances((i,j))
    return umatrix

In [None]:
#| exports
@patch
def umatrix(self:SOM)->np.ndarray:
    "Return the U-Matrix for current weights"
    return self._calculate_umatrix()

In [None]:
#| exports
@patch
def plot_umatrix(self:SOM, 
                 figsize=(8,6), 
                 cmap='viridis_r'):
    "Plot U-Matrix visualization"
    plt.figure(figsize=figsize)
    plt.imshow(self.umatrix(), cmap=cmap, interpolation='nearest')
    plt.colorbar(label='Average distance to neighbors')
    plt.title('U-Matrix')
    plt.tight_layout();

In [None]:
#| eval: false
# Load and normalize MNIST data
from sklearn.datasets import load_digits

X, y = load_digits(return_X_y=True)
X_norm = (X - np.mean(X, axis=-1, keepdims=True))/X.max()

# Create and train SOM
som = SOM(grid_sz=(20,20), input_dim=64, init='pca')
som.fit(X_norm, n_epochs=20);

Epoch: 1 | QE: 2.1917, TE: 1.8920
Epoch: 2 | QE: 2.0051, TE: 2.9494
Epoch: 3 | QE: 1.9656, TE: 1.6694
Epoch: 4 | QE: 1.8192, TE: 1.1686
Epoch: 5 | QE: 1.7921, TE: 1.0017
Epoch: 6 | QE: 1.7082, TE: 0.8904
Epoch: 7 | QE: 1.6437, TE: 0.5565
Epoch: 8 | QE: 1.5971, TE: 1.0017
Epoch: 9 | QE: 1.5459, TE: 1.0017
Epoch: 10 | QE: 1.4812, TE: 0.5008
Epoch: 11 | QE: 1.4391, TE: 0.3339
Epoch: 12 | QE: 1.3891, TE: 0.5008
Epoch: 13 | QE: 1.3530, TE: 0.3339
Epoch: 14 | QE: 1.3136, TE: 0.5008
Epoch: 15 | QE: 1.2779, TE: 0.8904
Epoch: 16 | QE: 1.2454, TE: 0.7791
Epoch: 17 | QE: 1.2159, TE: 0.6121
Epoch: 18 | QE: 1.1905, TE: 0.7234
Epoch: 19 | QE: 1.1688, TE: 0.7791
Epoch: 20 | QE: 1.1502, TE: 0.7791


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()