In [2]:
data = [
    [2.3976, 1.5328, 1.9044, 1.1937, 2.4184, 1.8649, 1],
    [2.3936, 1.4804, 1.9907, 1.2732, 2.2719, 1.8110, 1],
    [2.2880, 1.4585, 1.9867, 1.2451, 2.3389, 1.8099, 1],
    [2.2904, 1.4766, 1.8876, 1.2706, 2.2966, 1.7744, 1],
    [1.1201, 0.0587, 1.3154, 5.3783, 3.1849, 2.4276, 2],
    [0.9913, 0.1524, 1.2700, 5.3808, 3.0714, 2.3331, 2],
    [1.0915, 0.1881, 1.1387, 5.3703, 3.3561, 2.3833, 2],
    [1.0535, 0.1229, 1.2743, 5.3226, 3.0952, 2.3193, 2],
    [1.4871, 2.3448, 0.9918, 2.3160, 1.6783, 5.0850, 3],
    [1.3312, 2.2553, 0.9618, 2.4702, 1.7272, 5.0645, 3],
    [1.3646, 2.2945, 1.0562, 2.4763, 1.8051, 5.1470, 3],
    [1.4392, 2.2296, 1.1278, 2.4230, 1.7259, 5.0876, 3],
    [2.9364, 1.5323, 4.6109, 1.3160, 4.2000, 6.8749, 4],
    [2.9034, 1.4640, 4.6061, 1.4598, 4.2912, 6.9142, 4],
    [3.0181, 1.4918, 4.7051, 1.3521, 4.2623, 6.7966, 4],
    [2.9374, 1.4896, 4.7219, 1.3977, 4.1863, 6.8336, 4]
]


In [17]:
import numpy as np

data = np.array(data)

x = data[:, :-1]
y = data[:, -1]

In [19]:
y

array([1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3., 4., 4., 4., 4.])

In [28]:
import io
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

class LVQ_1:
    def __init__(self, input_size, weight_size, hyper_params, weight_init_type='center'):
        self.input_size = input_size
        self.init_sigma = hyper_params['sigma']
        self.hyper_params = hyper_params
        self.fig_arr = []
        self.weight_init_type = weight_init_type

        # 初始化 dummy weight points
        self.weight = np.zeros((weight_size, input_size))
        self.labels = np.zeros(weight_size)

        # weight points init 為 -1 ~ 1 的 uniform distribution
        # self.weight = np.random.uniform(-1, 1, self.weight.shape)

        # 產生一個 (h, w, 2) 的 index 矩陣，後續 find_winner 計算較為方便
        self.position_mat = np.array([[(i, j) for j in range(weight_size[1])]\
                                      for i in range(weight_size[0])])
    
    def weight_init(self, weight, data, labels, type='uniform'):
        if type == 'uniform':
            # 对每个类别随机选择一些数据点作为初始原型向量
            unique_labels = np.unique(labels)
            prototypes_per_class = self.weight.shape[0] // len(unique_labels)
            
            idx = 0
            for label in unique_labels:
                class_data = data[labels == label]
                random_indices = np.random.choice(len(class_data), 
                                                prototypes_per_class, 
                                                replace=False)
                
                self.weight[idx:idx+prototypes_per_class] = class_data[random_indices]
                self.labels[idx:idx+prototypes_per_class] = label
                idx += prototypes_per_class
                
        elif type == 'center':
            # 使用每个类的中心点作为原型向量
            unique_labels = np.unique(labels)
            prototypes_per_class = self.weight.shape[0] // len(unique_labels)
            
            idx = 0
            for label in unique_labels:
                class_data = data[labels == label]
                class_center = np.mean(class_data, axis=0)
                
                self.weight[idx:idx+prototypes_per_class] = class_center
                self.labels[idx:idx+prototypes_per_class] = label
                idx += prototypes_per_class
                
        return weight
    
    def sigma_decaying(self, sigma, max_iter):
        return [sigma * np.exp(-it / max_iter) for it in range(max_iter)]
    
    def cal_alpha(self, winner_pos, sigma):
        distance = np.linalg.norm(winner_pos - self.position_mat, axis=2)
        # 用 gaussian function 計算每個 weight point 
        # 被更新程度的比例，距離 winner point 愈近愈大
        # alpha -> (h, w) -reshape-> (h, w, 1)
        return np.exp(-(distance**2) / (2 * sigma**2))\
            .reshape(self.weight.shape[0], self.weight.shape[1], 1)
    
    def cal_input_weight_distance(self, x):
        # 將所有 weight point 個別減去同個 x 後取歐式距離 (L2 norm)
        x_reshaped = x.reshape(1, 1, -1)
        distances = np.linalg.norm(self.weight - x_reshaped, axis=2)
        return distances

    def find_winner(self, x):
        distances = self.cal_input_weight_distance(x)

        # 找到距離當前 x 最近的 weight point 的 index
        winner_idx = np.unravel_index(distances.argmin(), distances.shape)
        return np.array(winner_idx)
    
    def train(self, data, save_gif=False, file_name='result.gif'):
        self.weight = self.weight_init(self.weight, data[:, :-1], self.weight_init_type)
        n_samples = len(data)
        np.random.shuffle(data)
        total_iter = self.hyper_params['epochs'] * n_samples
        err_arr = []
        data = list(zip(data[:, :-1], data[:, -1]))
        # self.sigma_arr = self.sigma_decaying(self.init_sigma, total_iter)
     
        for epoch in range(self.hyper_params['epochs']):    
            for i, (x, y) in enumerate(data):

                print(x, y)

                cur_iter = epoch * n_samples + i
                # sigma = self.sigma_arr[cur_iter]
                winner_pos = self.find_winner(x)
                # alpha = self.cal_alpha(winner_pos, sigma)
                # self.weight += (self.hyper_params['lr'] * alpha * (x - self.weight))
                # if i % 10 == 0:
                #     if save_gif:
                #         self.plot_mesh(data, cur_data=x, save_gif=True, title=[epoch, i])
                #     err_arr.append(self.cal_error(data))

        # if save_gif:
        #     self.fig_arr[0].save(file_name, save_all=True, 
        #                          append_images=self.fig_arr[1:], 
        #                          loop=0, duration=100)

        return err_arr

    def plot_convergence(self, err_arr):
        plt.figure(figsize=(6, 5))
        plt.plot(err_arr, 'b-')
        plt.xlabel('Epoch')
        plt.ylabel('Error')
        plt.grid(True)
        plt.show()

    def cal_error(self, data):
        total_err = 0
        # 算出所有 data 與最近的權重的距離後取平均
        for x in data:
            # 找到最近的權重向量
            min_distance = np.min(self.cal_input_weight_distance(x))
            total_err += min_distance
        return total_err / len(data)
    
    def plot_mesh(self, data, cur_data=None, show_plot=False, save_gif=False, title=None):
        fig = plt.figure(figsize=(5, 5))

        if cur_data is not None:
            plt.scatter(cur_data[0], cur_data[1], c='r', marker='x')
        
        plt.scatter(data[:, 0], data[:, 1], c='g', alpha=0.5, marker='x')
        
        # weight -> (h, w, input_size) -> (h*w, input_size)
        weight_2d = self.weight.reshape(-1, self.input_size)
        
        plt.scatter(weight_2d[:, 0], weight_2d[:, 1], marker='o', facecolors='none', edgecolors='black')
        
        # 水平方向的 weight point 連接
        for i in range(self.weight.shape[0]):
            for j in range(self.weight.shape[1] - 1):
                w1 = self.weight[i, j]
                w2 = self.weight[i, j + 1]
                plt.plot([w1[0], w2[0]], [w1[1], w2[1]], 'b-', linewidth=1)
        
        # 垂直方向的 weight point 連接
        for i in range(self.weight.shape[0] - 1):
            for j in range(self.weight.shape[1]):
                w1 = self.weight[i, j]
                w2 = self.weight[i + 1, j]
                plt.plot([w1[0], w2[0]], [w1[1], w2[1]], 'r-', linewidth=1)
        
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.xlabel('x')
        plt.ylabel('y')
        if title:
            plt.title(f'Epoch: {title[0]}, Iter: {title[1]}')
        if show_plot:
            plt.show()
        if save_gif:
            buf = io.BytesIO()
            plt.savefig(buf, format='png')
            buf.seek(0)
            image = Image.open(buf)
            self.fig_arr.append(image.copy())  
            plt.close(fig)
            buf.close()
        

In [None]:
model = LVQ_1(input_size=6, weight_size=(4, 4), hyper_params={'lr': 0.1, 'sigma': 1, 'epochs': 1000})

model.train(data)

In [30]:
model.weight

array([[[1.9402125 , 1.34826875, 2.2218375 , 2.60285625, 2.86935625,
         4.03293125],
        [1.9402125 , 1.34826875, 2.2218375 , 2.60285625, 2.86935625,
         4.03293125],
        [1.9402125 , 1.34826875, 2.2218375 , 2.60285625, 2.86935625,
         4.03293125],
        [1.9402125 , 1.34826875, 2.2218375 , 2.60285625, 2.86935625,
         4.03293125]],

       [[1.9402125 , 1.34826875, 2.2218375 , 2.60285625, 2.86935625,
         4.03293125],
        [1.9402125 , 1.34826875, 2.2218375 , 2.60285625, 2.86935625,
         4.03293125],
        [1.9402125 , 1.34826875, 2.2218375 , 2.60285625, 2.86935625,
         4.03293125],
        [1.9402125 , 1.34826875, 2.2218375 , 2.60285625, 2.86935625,
         4.03293125]],

       [[1.9402125 , 1.34826875, 2.2218375 , 2.60285625, 2.86935625,
         4.03293125],
        [1.9402125 , 1.34826875, 2.2218375 , 2.60285625, 2.86935625,
         4.03293125],
        [1.9402125 , 1.34826875, 2.2218375 , 2.60285625, 2.86935625,
         4.032931