## Preliminaries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn
dtype = torch.float
device = torch.device("cpu")
device = torch.device("cuda:0") # Uncomment this to run on GPU

import importlib
import time
from copy import deepcopy
from IPython.display import display
from ipywidgets import FloatProgress

import random
import copy

import mnist_loader
importlib.reload(mnist_loader)

import pickle

import NeuralNetwork as NN
import Layer

importlib.reload(NN)
importlib.reload(Layer);

## Load MNIST DATA

In [None]:
MNIST_padding = torch.nn.ConstantPad2d(2, 0.0)

def DrawDigit(x, dim):
    plt.imshow(np.reshape(x.cpu(), (dim,dim)), cmap='gray')

In [None]:
train, validate, test = mnist_loader.load_data_wrapper()
train = [MNIST_padding(torch.reshape(torch.tensor(train[0]).float().to(device), (50000, 28, 28))), torch.tensor(train[1]).float().to(device)]
test = [MNIST_padding(torch.reshape(torch.tensor(test[0]).float().to(device), (10000, 28, 28))), torch.tensor(test[1]).float().to(device)]

In [None]:
howmany = 1000 #1000 #50
train, validate, test = mnist_loader.load_data_wrapper()
train = [MNIST_padding(torch.reshape(torch.tensor(train[0][:howmany]).float().to(device), (howmany, 28, 28))), torch.tensor(train[1][:howmany]).float().to(device)]
test = [MNIST_padding(torch.reshape(torch.tensor(test[0][:howmany]).float().to(device), (howmany, 28, 28))), torch.tensor(test[1][:howmany]).float().to(device)]

In [None]:
#1000 training samples but full test set
howmany = 1000
train, validate, test = mnist_loader.load_data_wrapper()
train = [MNIST_padding(torch.reshape(torch.tensor(train[0][:howmany]).float().to(device), (howmany, 28, 28))), torch.tensor(train[1][:howmany]).float().to(device)]
test = [MNIST_padding(torch.reshape(torch.tensor(test[0]).float().to(device), (10000, 28, 28))), torch.tensor(test[1]).float().to(device)]

In [None]:
#Retrieve one of each digit
idx = []
for i in range (0, 10):
    while True:
        j = np.random.randint(0, len(train[0]))
        if np.argmax(train[1][j].cpu()) == i:
            idx.append(j)
            break

print(idx)

images = train[0][idx]
classes = train[1][idx]

print(classes)

# No Padding / 1D top layer

In [None]:
train, validate, test = mnist_loader.load_data_wrapper()
train = [torch.tensor(train[0]).float().to(device), torch.tensor(train[1]).float().to(device)]
test = [torch.tensor(test[0]).float().to(device), torch.tensor(test[1]).float().to(device)]

In [None]:
howmany = 1000
train, validate, test = mnist_loader.load_data_wrapper()
train = [torch.tensor(train[0][:howmany]).float().to(device), torch.tensor(train[1][:howmany]).float().to(device)]
test = [torch.tensor(test[0]).float().to(device), torch.tensor(test[1]).float().to(device)]

In [None]:
train[0] = torch.unsqueeze(train[0], dim=1)
test[0] = torch.unsqueeze(test[0], dim=1)

### Transform input a la Whittington & Bogcz

In [None]:
def inv_logistic_raw(y):
    z = torch.log( y / (1.0 - y) )
    z[z>5.] = 5
    z[z<-5.] = -5.
    return z

def inv_logistic(dataset):
    for idx in range(len(dataset)):
        y = inv_logistic_raw(dataset[idx])
        dataset[idx] = y

In [None]:
# Transform the input a la Whittington & Bogacz
train[1] = train[1]*0.94 + 0.03
test[1] = test[1]*0.94 + 0.03

In [None]:
inv_logistic(train[0])
inv_logistic(test[0])

## Create Network

In [None]:
importlib.reload(NN)
importlib.reload(Layer);

In [None]:
#Run this line of code to free memory before initializing a new network
net.Release()
del net

In [None]:
# RF n=[#channels, width, height]
net = NN.NeuralNetwork()

'''
Image - RF - Dense - Class Network
'''

#classification layer
net.AddLayer(Layer.InputPELayer(n=10)) #n=[10]

#fully connected layer
net.AddLayer(Layer.PELayer(n=600))

#first RF layer
net.AddLayer(Layer.RetinotopicPELayer(imsize=(28, 28), channels=8, receptive_field=5, receptive_field_spacing=1)) #n=[4, 28, 28]; total n=3136

#input MNIST image
net.AddLayer(Layer.RetinotopicPELayer(imsize=(32, 32), channels=1, receptive_field=1, receptive_field_spacing=1)) #n=[1, 32, 32]; total n=1024

In [None]:
'''
Image - RF - Dense - Dense - Class Network
'''

net = NN.NeuralNetwork()
net.AddLayer(Layer.InputPELayer(n=10))
net.AddLayer(Layer.PELayer(n=576))
net.AddLayer(Layer.PELayer(n=576))
net.AddLayer(Layer.RetinotopicPELayer(imsize=(32, 32), channels=6, receptive_field=5, receptive_field_spacing=3)) 
net.AddLayer(Layer.RetinotopicPELayer(imsize=(32, 32), channels=1, receptive_field=1, receptive_field_spacing=1))

In [None]:
'''
Image - Dense - Dense - Class Network
'''

net = NN.NeuralNetwork()
net.AddLayer(Layer.InputPELayer(n=10))
net.AddLayer(Layer.PELayer(n=600))
net.AddLayer(Layer.PELayer(n=600))
net.AddLayer(Layer.TopPELayer(n=784))

In [None]:
#Create connections between layers

af = 'tanh'
sym = False
shared = False #Whether to use shared convolutions for RF connections

net.Connect(0, 1, act=af, symmetric=sym)
net.Connect(1, 2, act=af, symmetric=sym)
net.Connect(2, 3, act=af, symmetric=sym, shared=shared)

net.SetTau(0.2)
net.learning_tau = 0.8
net.connections[-1].normalize_feedback=False

In [None]:
for c in net.connections:    
    print(c.M.shape)
    print(c.W.shape)
    print(c.below.b.shape)

In [None]:
for c in net.connections:    
    print(c.M)
    print(c.W)
    print(c.below.b)

## Train the Network

In [None]:
#Regular Learn procedure

net.Reset()
net.SetWeightDecay(0.01)
net.SetvDecay(0.0)
net.Learn(train[1][:100], train[0][:100], T=25., dt=0.001, epochs=10, batch_size=10)

In [None]:
#FastLearn

net.Reset()
net.l_rate = 0.001
net.SetWeightDecay(0.01)
net.SetvDecay(0.0)
net.FastLearn(train[1], train[0], test=None, T=80, beta_time=0.02, epochs=1, Beta_one=0.9, Beta_two=0.999, ep=0.00000001, batch_size=10, shuffle=True)

In [None]:
#Prime the network with feedforward states, then generate from those latent states

net.Allocate(10)
net.Reset()

net.SetBidirectional()
net.layers[0].SetFF()

net.SetExpectation(torch.unsqueeze(test[0][9990:10000], dim=1))
z = net.Generate(6., test[0][9990:10000], dt=0.002)

#remove expectation
net.layers[-1].SetFF()
net.SetExpectation(torch.zeros_like( net.layers[-1].e ).float().to(device))
net.layers[-1].e = torch.zeros_like( net.layers[-1].e ).float().to(device)
net.layers[-1].v = torch.zeros_like( net.layers[-1].v ).float().to(device)

w = net.Predict(6., z, dt=0.001)
plt.figure(figsize=[18,4])
for n,zz in enumerate(NN.logistic(w)):
    plt.subplot(2,10,n+1); DrawDigit(test[0][9990+n], 32)
    plt.subplot(2,10,n+11); DrawDigit(zz, 32)

In [None]:
#Generate from classes alone with Predict

net.Reset()
net.Allocate(10)

net.SetvDecay(0.01)

net.SetBidirectional()
net.layers[0].SetFF()
net.layers[-1].SetFF()

z = net.Predict(6., test[1][9990:10000], dt=0.002)

plt.figure(figsize=[18,4])
for n,zz in enumerate(z):
    plt.subplot(2,10,n+1); DrawDigit(test[0][n+9990], 32)
    plt.subplot(2,10,n+11); DrawDigit(zz, 32)

In [None]:
#Generate from classes alone with FastPredict

net.Reset()
net.Allocate(10)

net.SetvDecay(0.01)

net.SetBidirectional()
net.layers[0].SetFF()
net.layers[-1].SetFF()

z = net.FastPredict(test[1][9990:10000], T=255)

plt.figure(figsize=[18,4])
for n,zz in enumerate(z):
    plt.subplot(2,10,n+1); DrawDigit(test[0][n+9990], 32)
    plt.subplot(2,10,n+11); DrawDigit(zz, 32)