In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from numpy.linalg import eig
import pandas as pd
from minisom import MiniSom  
import math
import ipynb
from sklearn.metrics import classification_report
import cv2
from itertools import chain
# Import custom functions from seperate workbook
from ipynb.fs.full.helper_functions import *

%matplotlib inline


In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

In [3]:
def classify(som, x_test, x_train, y_train):
    
    # For each image we assign the label to the winning nuron
    winmap = som.labels_map(x_train, y_train)
    # Here we fiding the default class to assign something to the output when thhere is no winmap value for a given value
    default_class = np.sum(list(winmap.values())).most_common()[0][0]
    result = []
    for d in x_test:
        # Find the winning node for the given sample
        win_position = som.winner(d)
        # Check if that sample is in the list of posible results
        if win_position in winmap:
            # it is! great, now append the most common result for that sample for that node and its given image
            result.append(winmap[win_position].most_common()[0][0])
        else:
            result.append(default_class)
    return result

In [4]:
def create_train_som(x_train, n_features):
    
    # Create SOM dimensions
    som_nurons = int((math.sqrt(5*math.sqrt(n_features))))
#     print(som_nurons)
    x = som_nurons
    y = som_nurons
    
    #Create and train SOM
    som = MiniSom(x, y, n_features, sigma=0.3, learning_rate=0.5) # initialization of 6x6 SOM
    som.random_weights_init(x_train)
    print("Training...")
    som.train_random(x_train,100, verbose=False) # training with 100 iterations
    print("...ready!")
    return som

In [5]:
def flatten_and_reshape(data):
    res = np.reshape(data, (data.shape[0], data.shape[1]*data.shape[2]))
    return(res)

In [10]:
x_test_flat = flatten_and_reshape(x_test)
x_train_flat = flatten_and_reshape(x_train)

In [11]:
first_som = create_train_som(x_train_flat, 784)

Training...
...ready!


In [20]:
winmap = first_som.labels_map(x_train_flat, y_train)
# default_class = np.sum(list(winmap.values())).most_common()[0][0]
# result = []

In [18]:
default_class = np.sum(list(winmap.values())).most_common()[0][0]

In [19]:
default_class

1

In [17]:
winmap

defaultdict(list,
            {(4,
              1): Counter({5: 663,
                      3: 1551,
                      8: 470,
                      9: 38,
                      2: 81,
                      0: 9,
                      6: 1,
                      1: 2,
                      7: 1}),
             (8,
              2): Counter({0: 2031,
                      5: 80,
                      2: 47,
                      9: 11,
                      3: 56,
                      6: 63,
                      8: 11,
                      7: 1,
                      4: 1}),
             (1,
              9): Counter({4: 1186,
                      9: 847,
                      8: 74,
                      0: 16,
                      2: 113,
                      6: 66,
                      5: 93,
                      7: 86,
                      3: 37}),
             (4,
              5): Counter({1: 1095,
                      7: 17,
                      2: 65,
            

In [21]:
list(winmap.values())

[Counter({5: 663, 3: 1551, 8: 470, 9: 38, 2: 81, 0: 9, 6: 1, 1: 2, 7: 1}),
 Counter({0: 2031, 5: 80, 2: 47, 9: 11, 3: 56, 6: 63, 8: 11, 7: 1, 4: 1}),
 Counter({4: 1186, 9: 847, 8: 74, 0: 16, 2: 113, 6: 66, 5: 93, 7: 86, 3: 37}),
 Counter({1: 1095, 7: 17, 2: 65, 4: 14, 8: 14, 3: 1, 6: 3, 9: 1}),
 Counter({9: 1366, 4: 649, 1: 1, 8: 64, 2: 4, 7: 133, 5: 234, 3: 12}),
 Counter({2: 455,
          8: 1205,
          3: 108,
          5: 21,
          6: 5,
          4: 7,
          9: 10,
          1: 5,
          7: 17,
          0: 3}),
 Counter({1: 944, 7: 32, 8: 25, 2: 2, 4: 29, 3: 25, 5: 12, 9: 12}),
 Counter({4: 599,
          9: 555,
          7: 553,
          3: 13,
          8: 89,
          1: 4,
          2: 42,
          5: 26,
          0: 6,
          6: 2}),
 Counter({5: 933,
          4: 189,
          8: 186,
          3: 56,
          2: 93,
          6: 76,
          0: 35,
          7: 19,
          1: 9,
          9: 16}),
 Counter({3: 414, 8: 21, 9: 2}),
 Counter({6: 1

In [22]:
np.sum(list(winmap.values()))

Counter({5: 5421,
         3: 6131,
         8: 5851,
         9: 5949,
         2: 5958,
         0: 5923,
         6: 5918,
         1: 6742,
         7: 6265,
         4: 5842})