### Yin-Yang Dataset

In [None]:
import multiprocess as mp
import numpy as np
from numpy_ml.neural_nets.optimizers import Adam
import os
import matplotlib.pyplot as plt
import sys
import time
from torch.utils.data import DataLoader

sys.path.append(os.path.relpath("py/snn_gradients.py"))
sys.path.append(os.path.relpath("py/dataset.py"))

In [None]:
import py.snn_gradients as snn
import py.dataset as dataset

### Set Up Datasets

Code used to load and visualize the dataset was provided by https://github.com/lkriener/yin_yang_data_set

In [None]:
dataset_train = dataset.YinYangDataset(size=5000, seed=42)
dataset_validation = dataset.YinYangDataset(size=1000, seed=41)
dataset_test = dataset.YinYangDataset(size=1000, seed=40)

In [None]:
batchsize_train = 20
batchsize_eval = len(dataset_test)

train_loader = DataLoader(dataset_train, batch_size=batchsize_train, shuffle=True)
val_loader = DataLoader(dataset_validation, batch_size=batchsize_eval, shuffle=True)
test_loader = DataLoader(dataset_test, batch_size=batchsize_eval, shuffle=False)

In [None]:
fig, axes = plt.subplots(ncols=3, sharey=True, figsize=(15, 8))
titles = ['Training set', 'Validation set', 'Test set']
for i, loader in enumerate([train_loader, val_loader, test_loader]):
    axes[i].set_title(titles[i])
    axes[i].set_aspect('equal', adjustable='box')
    xs = []
    ys = []
    cs = []
    for batch, batch_labels in loader:
        for j, item in enumerate(batch):
            x1, y1, x2, y2 = item
            c = batch_labels[j]
            xs.append(x1)
            ys.append(y1)
            cs.append(c)
    xs = np.array(xs)
    ys = np.array(ys)
    cs = np.array(cs)
    axes[i].scatter(xs[cs == 0], ys[cs == 0], color='C0', edgecolor='k', alpha=0.7)
    axes[i].scatter(xs[cs == 1], ys[cs == 1], color='C1', edgecolor='k', alpha=0.7)
    axes[i].scatter(xs[cs == 2], ys[cs == 2], color='C2', edgecolor='k', alpha=0.7)
    axes[i].set_xlabel('x1')
    if i == 0:
        axes[i].set_ylabel('y1')

### Set Up Network

In [None]:
# Hyperparams
BETA = 0.05
ALPHA = 0.2
# BETA = 1.0
# ALPHA = 0.2
THETA = 1.0
T = 10

# loss hyper params
TAU_0 = 0.5
TAU_1 = 6.4
GAMMA = 0.003

# optimizer params
BETA_1 = 0.9
BETA_2 = 0.999
EPS = 1e-8
ETA = 0.005

In [None]:
adam = Adam(
    lr=ETA, 
    decay1=BETA_1, 
    decay2=BETA_2, 
    eps=EPS,
    lr_scheduler=None
)

In [None]:
ce_loss = snn.SnnCrossEntropy(
    num_classes=3,
    T=T,
    alpha=ALPHA,
    beta=BETA,
    theta=THETA,
    tau_0=TAU_0,
    tau_1=TAU_1,
    gamma=GAMMA,
)

In [None]:
net = snn.FeedForwardSNN(
    in_dim=4,
    beta=BETA, 
    alpha=ALPHA, 
    threshold=THETA, 
    layer_sizes=[50, 3],
    weights=[[np.random.normal(1.5, 0.78, 4) for i in range(50)],
             [np.random.normal(0.93, 0.1, 50) for i in range(3)]]
).build()

In [None]:
def train_single_example(data, labels, T, model, loss):
    out = model.forward(data, T)
    pred = np.argmin(np.asarray([np.min(out[1][-1][i], initial=T) for i in range(len(out[1][-1]))]))
    acc = np.sum(pred == labels)
    l = loss.forward(out[1][-1], labels)
    grad = loss.backward(
        out[0],
        out[1],
        out[2],
        labels,
        out[3],
        out[4]
    )
    n_i = out[5]
    return acc, l, grad, n_i

### Train the SNN

In [None]:
#with mp.Pool(np.min([mp.cpu_count(), batchsize_train])) as pool:
start_time = time.perf_counter()
losses=[]
for k in range(5):
    for i, (batch, batch_labels) in enumerate(train_loader):

        batch_loss = 0
        batch_acc = 0
        batch_grad = [np.zeros_like(w, dtype=np.float64) for w in net.weights]
        batch_ro = []
        bsz = len(batch)

        for j in range(bsz):

            data = batch[j].numpy() * T
            data = np.reshape(data, (4, 1))
            label = batch_labels[j].numpy()
            a, l, g, ro = train_single_example(data,
                                        label,
                                        T,
                                        net,
                                        ce_loss)
            batch_acc += a/bsz
            batch_loss += l/bsz
            batch_grad = [batch_grad[k] + 1/bsz * g[k] for k in range(len(g))]
            batch_ro.append(ro)

        losses.append(batch_loss)
        new_weights = [adam.update(net.weights[i], batch_grad[i], "w_layer_{i}".format(i=i)) for i in range(len(batch_grad))]
        #new_weights = [net.weights[i] - ETA * batch_grad[i] for i in range(len(batch_grad))]
        net.update(new_weights, T, batch_ro, -10)

        if i % 10 == 0:
            end_time = time.perf_counter()
            print("Time elapsed (sec)=", end_time - start_time)
            print("loss=", batch_loss)
            print("acc=", batch_acc)

### Inspect and visualize predictions

In [None]:
fig, axes = plt.subplots(ncols=2, sharey=True, figsize=(15, 8))
titles = ['Predictions', 'Test Set']
pred_cs = []
for i, loader in enumerate([test_loader, test_loader]):
    axes[i].set_title(titles[i])
    axes[i].set_aspect('equal', adjustable='box')
    xs = []
    ys = []
    cs = []
    for batch, batch_labels in loader:
        for j, item in enumerate(batch):
            x1, y1, x2, y2 = item
            if i == 1:
                c = batch_labels[j]
            else:
                data = (np.reshape(item.numpy() * T, (4, 1)))
                #print(data)
                output = net.forward(data, T)[1][-1]
                #print(output)
                pred = np.argmin(np.asarray([np.min(output[i], initial=T) for i in range(len(output))]))
                #print(pred)
                c = pred
            xs.append(x1)
            ys.append(y1)
            cs.append(c)
    xs = np.array(xs)
    ys = np.array(ys)
    cs = np.array(cs)
    if i == 0:
        pred_cs = cs
    axes[i].scatter(xs[cs == 0], ys[cs == 0], color='C0', edgecolor='k', alpha=0.7)
    axes[i].scatter(xs[cs == 1], ys[cs == 1], color='C1', edgecolor='k', alpha=0.7)
    axes[i].scatter(xs[cs == 2], ys[cs == 2], color='C2', edgecolor='k', alpha=0.7)
    axes[i].set_xlabel('x1')
    if i == 0:
        axes[i].set_ylabel('y1')

In [None]:
np.sum(pred_cs == cs)

In [None]:
plt.plot(losses)
plt.show()