In [66]:
import numpy as np
import pandas as pd
import os
from help_functions import *

In [56]:
class SOM:
    def __init__(self, map_width, map_height, input_dim, neighbourhood_function,
                 distance_k, learning_rate=0.5, sigma=None, decay_type='exponential', beta=0.999):
        self.map_width = map_width
        self.map_height = map_height
        self.input_dim = input_dim
        self.learning_rate = learning_rate
        self.distance_k = distance_k
        self.beta = beta
        self.time = 1
        self.label_map_db = {}
        self.tmp = {}
        
        if sigma is None:
            self.sigma = max(map_width, map_height) / 2.0
        else:
            self.sigma = sigma
            
        if self.distance_k == np.inf:
            self.calculate_distance_func = chebyshev_distance
        elif self.distance_k == 1:
            self.calculate_distance_func = manhattan_distance
        elif self.distance_k == 2:
            self.calculate_distance_func = euclidean_distance
        elif self.distance_k < 1:
            raise ValueError('Distance must have positive non-zero k value')
        else:
            self.calculate_distance_func = lambda a, b, axis: generic_distance(a, b, axis, self.distance_k)
            
        # TODO add other weight initialization options
        self.weights = np.random.rand(self.map_width, self.map_height, self.input_dim)
        
        if neighbourhood_function == 'gaussian':
            self.neighbourhood_func = gaussian_neighbourhood
        elif neighbourhood_function == 'rectangular':
            self.neighbourhood_func = rectangular_neighbourhood
        elif neighbourhood_function == 'triangular':
            self.neighbourhood_func = triangular_neighbourhood
        elif neighbourhood_function == 'cosine':
            self.neighbourhood_func = cosine_down_to_zero_neighbourhood
        else:
            raise ValueError(f'Unknown neighbourhood function {neighbourhood_function}')
        
        
        if decay_type == 'exponential' and 0 < self.beta < 1:
            self.calculate_decay = decay_exponential
        elif decay_type == 'power' and self.beta < 0:
            self.calculate_decay = decay_power
        else:
            raise ValueError(f'Unknown decay type or invalid beta')
            
    def get_weights(self):
        return self.weights
    
    def get_weight_of_node(self, node_idx):
        return self.weights[node_idx[0]][node_idx[1]]
    
    def update_time(self):
        self.time += 1
        
    def find_BMU(self, input_vector):
        dists = self.calculate_distance_func(self.weights, input_vector, 2)

        min_index = np.argmin(dists)
        bmu_idx = np.unravel_index(min_index, dists.shape)
        return bmu_idx
    
    def calculate_grid_distances(self, bmu_idx):
        x_coords, y_coords = np.meshgrid(np.arange(self.map_width), 
                                         np.arange(self.map_height), indexing='ij')
        dist_sq = (x_coords - bmu_idx[0])**2 + (y_coords - bmu_idx[1])**2
        return np.sqrt(dist_sq)
            
    def calculate_neighbourhood_influence(self, bmu_idx, sigma_t):
        grid_dists = self.calculate_grid_distances(bmu_idx)
        return self.neighbourhood_func(grid_dists, sigma_t)
    
    def update_weights(self, input_vector, bmu_idx):
        eta_t = self.calculate_decay(self.learning_rate, self.beta, self.time)
        sigma_t = self.calculate_decay(self.sigma, self.beta, self.time)
        
        # shape (map_width, map_height)
        influence = self.calculate_neighbourhood_influence(bmu_idx, sigma_t)
        
        # shape (map_width, map_height, input_dim)
        diff = input_vector - self.weights
        
        # reshaping to (map_width, map_height, 1) to broadcast over diff
        influence_new = influence[:, :, np.newaxis]
        
        self.weights += eta_t * influence_new * diff
        
    def calculate_QE(self, data):
        diff_total = 0
        for sample in data:
            bmu_idx = self.find_BMU(sample)      
            weight = self.get_weight_of_node(bmu_idx)
            
            diff_total += self.calculate_distance_func(weight, sample, 0)
        
        return diff_total / data.shape[0]
    
    def get_label_map(self, data, label, epoch):
    # https://medium.com/data-science/understanding-self-organising-map-neural-network-with-python-code-7a77f501e985
        map = np.empty(shape=(self.map_width, self.map_height), dtype=object)

        for row in range(self.map_width):
          for col in range(self.map_height):
            map[row][col] = []
    
        for i, sample in enumerate(data):
            bmu_idx = self.find_BMU(sample)
            map[bmu_idx[0]][bmu_idx[1]].append(label[i])
        
        self.tmp[epoch] = np.copy(map)
        
        for row in range(self.map_width):
            for col in range(self.map_height):
                label_list = map[row][col]
                if len(label_list)==0:
                  label = None
                else:
                  label = max(label_list, key=label_list.count)
                map[row][col] = label
                
        return map
    
    def train_online(self, data, label, num_epochs):
        num_samples = data.shape[0]
        
        for epoch in range(num_epochs):
            for sample in range(num_samples):
                input_vector = data[sample]
                
                bmu_idx = self.find_BMU(input_vector)
                
                self.update_weights(input_vector, bmu_idx)
                
            self.update_time()
            
            if epoch % 10 == 0:
                print(f"Epoch {epoch+1}/{num_epochs} complete. Current sigma: {self.calculate_decay(self.sigma, self.beta, self.time):.4f}, learning rate: {self.calculate_decay(self.learning_rate, self.beta, self.time):.4f}, QE: {self.calculate_QE(data):.4f}")
                
            if epoch % 50 == 0:    
                self.label_map_db[epoch] = self.get_label_map(data, label, epoch)
    
    

In [74]:
df = pd.read_csv("../data/iris.csv")
label = df['Iris-setosa'].to_numpy()
X = df.drop(['Iris-setosa'], axis=1).to_numpy()

In [81]:
som = SOM(7,7,4,"gaussian",2)
som.get_weights()

array([[[0.57112988, 0.69578727, 0.21343256, 0.69410226],
        [0.78945795, 0.70255532, 0.85626917, 0.46847736],
        [0.36391888, 0.25021033, 0.39334212, 0.11045078],
        [0.06983724, 0.20433889, 0.8938741 , 0.76739274],
        [0.62102194, 0.26027367, 0.56952434, 0.18247861],
        [0.51776608, 0.66052052, 0.67995302, 0.719808  ],
        [0.30629007, 0.20603815, 0.88967521, 0.28117964]],

       [[0.05207909, 0.66826667, 0.45606206, 0.42213594],
        [0.47093477, 0.73089919, 0.42779269, 0.0091991 ],
        [0.52656972, 0.51328612, 0.33383474, 0.81161017],
        [0.15413106, 0.23719762, 0.22588592, 0.31140156],
        [0.31017119, 0.86928476, 0.27603176, 0.80381623],
        [0.36511074, 0.65186342, 0.13018653, 0.59952739],
        [0.30366702, 0.20442162, 0.88578551, 0.65235018]],

       [[0.21485873, 0.31997245, 0.5149197 , 0.57572044],
        [0.70197738, 0.00376095, 0.32311219, 0.75009531],
        [0.10905173, 0.61203394, 0.2368448 , 0.71547202],
        [0

In [82]:
#data = np.array([[1, 1,1], [3, 4,3], [1, 8,4], [7, 8,7]])
#label = np.array([1, 3, 3, 2])
som.train_online(X, label, 1000)

Epoch 1/1000 complete. Current sigma: 3.4930, learning rate: 0.4990, QE: 2.0104
Epoch 11/1000 complete. Current sigma: 3.4582, learning rate: 0.4940, QE: 2.0248
Epoch 21/1000 complete. Current sigma: 3.4238, learning rate: 0.4891, QE: 2.0087
Epoch 31/1000 complete. Current sigma: 3.3897, learning rate: 0.4842, QE: 2.0212
Epoch 41/1000 complete. Current sigma: 3.3560, learning rate: 0.4794, QE: 2.0181
Epoch 51/1000 complete. Current sigma: 3.3226, learning rate: 0.4747, QE: 2.0039
Epoch 61/1000 complete. Current sigma: 3.2895, learning rate: 0.4699, QE: 2.0017
Epoch 71/1000 complete. Current sigma: 3.2567, learning rate: 0.4652, QE: 1.9997
Epoch 81/1000 complete. Current sigma: 3.2243, learning rate: 0.4606, QE: 1.9980
Epoch 91/1000 complete. Current sigma: 3.1922, learning rate: 0.4560, QE: 1.9961
Epoch 101/1000 complete. Current sigma: 3.1604, learning rate: 0.4515, QE: 1.9874
Epoch 111/1000 complete. Current sigma: 3.1290, learning rate: 0.4470, QE: 1.9925
Epoch 121/1000 complete. Cu

In [83]:
som.label_map_db[900]

array([['Iris-versicolor', 'Iris-versicolor', 'Iris-versicolor',
        'Iris-versicolor', 'Iris-versicolor', None, 'Iris-setosa'],
       ['Iris-virginica', None, 'Iris-versicolor', 'Iris-versicolor',
        'Iris-versicolor', None, 'Iris-setosa'],
       ['Iris-virginica', 'Iris-virginica', 'Iris-virginica',
        'Iris-virginica', 'Iris-versicolor', None, 'Iris-setosa'],
       ['Iris-virginica', None, None, None, 'Iris-versicolor', None,
        None],
       ['Iris-virginica', None, 'Iris-virginica', 'Iris-virginica', None,
        'Iris-versicolor', None],
       ['Iris-virginica', 'Iris-virginica', 'Iris-virginica', None,
        'Iris-virginica', 'Iris-versicolor', None],
       ['Iris-virginica', 'Iris-virginica', 'Iris-virginica',
        'Iris-virginica', 'Iris-versicolor', 'Iris-versicolor',
        'Iris-versicolor']], dtype=object)

In [80]:
som.get_weights()

array([[[7.10487001, 3.14074567, 5.93746063, 2.14259747],
        [6.92265889, 3.15343418, 5.77271283, 2.1796904 ],
        [6.77640235, 3.14932557, 5.6061204 , 2.19452977],
        [6.68974589, 3.12982147, 5.45045823, 2.17332037],
        [6.64469266, 3.09927724, 5.29102359, 2.09644204],
        [6.61812682, 3.06136445, 5.08703998, 1.93379267],
        [6.60028824, 3.02832426, 4.83164329, 1.70587915],
        [6.57703321, 3.03046506, 4.56310094, 1.4986614 ],
        [6.50975084, 3.09295083, 4.25105176, 1.33091814]],

       [[6.91649071, 3.11185031, 5.82304745, 2.14721025],
        [6.759113  , 3.14441207, 5.68096147, 2.19023165],
        [6.64171116, 3.15398622, 5.5424811 , 2.19815152],
        [6.57528222, 3.1398134 , 5.41656735, 2.16665691],
        [6.54148538, 3.10791383, 5.28425282, 2.08887912],
        [6.51603629, 3.06800847, 5.08810886, 1.93161092],
        [6.49358324, 3.0425838 , 4.78711949, 1.69020092],
        [6.44356067, 3.08137479, 4.36953797, 1.42454262],
        [6.2