In [137]:
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
import numpy as np
import random
from tqdm import tqdm

In [138]:
iris_set = load_iris()
n_classes = iris_set.target_names.__len__()

In [139]:
class Dataset:
  def __init__(self, dataset, indices, transform=None, encoder=None):
    self.dataset, self.indices = dataset, indices
    self.transform, self.encoder = transform, encoder
  def __getitem__(self, item: int):
    idx = self.indices[item]
    feature, label = self.dataset.data[idx], self.dataset.target[idx]
    if self.transform: feature = self.transform(feature)
    if self.encoder: label = self.encoder(label)
    return feature, label
  def __len__(self): return len(self.indices)

In [140]:
indices = random.sample(range(iris_set.data.__len__()), 100)
encoder = lambda index: np.eye(n_classes)[index]

# init Datasets
support_set = Dataset(iris_set, indices[:50], encoder=encoder)
query_set = Dataset(iris_set, indices[50:], encoder=encoder)

In [142]:
def sigmoid(x): return 1 / (1 + np.exp(-1 * x))
def relu(x): return np.maximum(0, x)

In [143]:
class ADALINE:
  def __init__(self, n_inpt, n_ouput):
    self.n_ouput, self.n_inpt = n_ouput, n_inpt
    self.weight = np.zeros(shape=(self.n_inpt, self.n_ouput))
  # __init__

  def forward(self, x): return relu(np.dot(x.T, self.weight))
# ADALINE

In [155]:
def GDR(model, lr):
  def _GDR(x, y):
    pred = model.forward(x)
    error = pred - y
    grads = np.dot(x.reshape(1, -1).T, error.reshape(1, -1))
    model.weight -= grads * lr
  return _GDR

In [158]:
progress_bar = tqdm(range(1000))

# init and train a model
model = ADALINE(4, 3)
optimizer = GDR(model, 0.001)
for _ in progress_bar:
  for feature, label in support_set:
    optimizer(feature, label)

100%|██████████| 1000/1000 [00:01<00:00, 591.85it/s]


In [159]:
count, n_samples = 0, len(query_set)
for feature, label in support_set:
  pred = model.forward(feature)
  if np.argmax(pred) == np.argmax(label): count += 1
print(f"accuracy: {count / n_samples:.2f}({count}/{n_samples})")

accuracy: 0.92(46/50)
