In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets

In [2]:
#One-hot encoding
def create_target(t):
        target_vector = np.zeros(10)
        for i in range(10):
            if i == t:
                target_vector[i] = 1
        return target_vector

#Hard sigmoid [-1, 1]
def hsig(x):
    z = np.clip(x, -1, 1)
    return np.copy(z)

def d_hsig(x):
    z = (x > -1) & (x < 1)
    return np.copy(z)

In [3]:
#Load mini MNIST data
digits = datasets.load_digits()
data = digits.data
targets = digits.target

#Standardize data
inputs = data - np.mean(data)
inputs = inputs/(np.std(data))

In [5]:
#Define network hyperparameters
n_x = 64
n_h = 50
n_y = 10

alpha1 = 0.01
alpha2 = 0.005
beta = 1
epsilon = 0.1

#Weight intialization
W1 = np.random.uniform(0, (4/(n_x + n_y)), (n_x, n_h))
W2 = np.random.uniform(0, (4/(n_h)), (n_h, n_y))

#Bias initialization
bh = np.random.uniform(0, 4/(n_x + n_y), n_h)
by = np.random.uniform(0, 4/(n_h), n_y)

for ex in range(5000):
    #Randomly sample from data
    rnd = np.random.randint(0, 1497)
    x = inputs[rnd]# + 0.1 * np.random.rand(64)
    t = create_target(targets[rnd])
    
    #Random activation initialization
    h = np.random.uniform(-1, 1, n_h)
    y = np.random.uniform(-1, 1, n_y)
    
    #Free Phase
    for itr in range(100):
        #Calculate free gradient steps
        dh = d_hsig(h) * (np.dot(x, W1) + np.dot(y, W2.T) + bh) - h
        dy = d_hsig(y) * (np.dot(h, W2) + by) - y
        
        #Update activations
        h = hsig(h + epsilon * dh)
        y = hsig(y + epsilon * dy)
        
    #Store free equilibrium states
    h_free = np.copy(h)
    y_free = np.copy(y)
    
    #Weakly Clamped Phase
    for itr in range(20):
        #Calculate weakly clamped gradient steps
        dy = d_hsig(y) * (np.dot(h, W2) + by) - y + beta * (t - y)
        dh = d_hsig(h) * (np.dot(x, W1) + np.dot(y, W2.T) + bh) - h
        
        #Update activations
        h = hsig(h + epsilon * dh)
        y = hsig(y + epsilon * dy)
        
    #Store weakly clamped activations
    h_clamped = np.copy(h)
    y_clamped = np.copy(y)
    
    #Update weights
    W1 += alpha1 * (1/beta) * (np.outer(x, h_clamped) - np.outer(x, h_free))
    W2 += alpha2 * (1/beta) * (np.outer(h_clamped, y_clamped) - np.outer(h_free, y_free))
    
    #Print Mean Squared Error
    if ex % 100 == 0:
        print(np.dot(t - y_free, t - y_free))
        
    #Learning rate schedule
    if ex % 2500 == 2499:
        alpha1 /= 10
        alpha2 /= 10

1.1612889050268784
1.001700025380486
0.774104817181942
0.9093371567822642
0.677803654435267
0.8732613662927812
0.6242772298381177
0.5832786636047895
0.7428377875771537
0.4817724614350415
0.5767140151931699
0.6559420251002694
0.6068828221280779
0.23284466387543792
1.098541055451883
0.33033568073324515
0.11065713720255016
0.7810601546973898
0.12587057973827565
0.1900078491134798
0.24159488073016672
0.638292455606521
0.5449044426178697
0.5025879874404235
0.11906591523436101
0.7528702050913768
0.2857744099225166
0.5706575204362844
0.3246933610877407
0.7059262394079027
0.18579000580669458
0.11091981553030163
0.5018088822353577
0.14690292842555402
0.38173122384629515
0.2633553594525888
0.3085222690949818
0.35138575597522226
0.4816320432986371
0.3109104919385766
0.39603232240742325
0.3555251740371096
0.09659973131700954
0.4536248906036674
0.7349519320222381
0.14282091536618424
0.27127050816676
0.2158156495817466
0.32717198246865775
0.2640209801368415


In [9]:
#Test Accuracy
score = 0
for test in range(200):
    rnd = np.random.randint(1497, 1797)
    x = inputs[rnd]
    t = create_target(targets[rnd])
    h = np.random.uniform(-1, 1, n_h)
    y = np.random.uniform(-1, 1, n_y)
    
    #Free Phase
    for itr in range(100):
        dh = d_hsig(h) * (np.dot(x, W1) + np.dot(y, W2.T) + bh) - h
        dy = d_hsig(y) * (np.dot(h, W2) + by) - y
        
        h = hsig(h + epsilon * dh)
        y = hsig(y + epsilon * dy)
        
    h_free = np.copy(h)
    y_free = np.copy(y)
    
    if np.argmax(y_free) == targets[rnd]:
        score += 1
    
print(score/200)

0.895
