In [26]:
import numpy as np
from abc import ABC, abstractmethod

class Layer(ABC):
    
    def __init__(self):
        self.__prevIn__ = []
        self.__prevOut__ = []
        
    def setPrevIn(self, dataIn):
        self.__prevIn = dataIn
        
    def setPrevOut(self, out):
        self.__prevOut = out
        
    def getPrevIn(self):
        return self.__prevIn
    
    def getPrevOut(self):
        return self.__prevOut
    
    def backward(self, gradIn):
        return (gradIn @ self.gradient())
    
    @abstractmethod
    def forward(self, dataIn):
        pass
    
    @abstractmethod
    def gradient(self):
        pass

class HardSigLayer(Layer):
    def __init__(self):
        super().__init__()
    
    def forward(self, dataIn):
        self.setPrevIn(dataIn)      
        z = torch.clip(self.getPrevIn(), -1, 1)
        self.setPrevOut(z)
        return self.getPrevOut()
        
    def gradient(self): 
        z = (self.getPrevOut() > -1) & (self.getPrevOut() < 1)
        return z

In [2]:
# Mini MNIST MLP to Match Equilibrium Propagation

import numpy as np
from sklearn import datasets

#Load mini MNIST data
digits = datasets.load_digits()
data = digits.data
targets = digits.target

#Standardize data
inputs = data - np.mean(data)
inputs = inputs/(np.std(data))

In [3]:
print(inputs.shape)
print(targets.shape)

(1797, 64)
(1797,)


In [4]:
train_x = np.array(inputs[:1497])
valid_x = np.array(inputs[1497:])
train_y = np.array(targets[:1497])
valid_y = np.array(targets[1497:])

In [5]:
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
from PIL import Image
import torch
#from torchvision import datasets, models, transforms
from torchvision import models, transforms
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import time
import copy

In [6]:
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, x, y):
        super(MyDataset, self).__init__()
        assert x.shape[0] == y.shape[0]
        self.x = x
        self.y = y
    
    def __len__(self):
        return self.y.shape[0]
    
    def __getitem__(self, index):
        return self.x[index], self.y[index]

In [7]:
traindata = MyDataset(train_x, train_y)
validation = MyDataset(valid_x, valid_y)

In [8]:
image_datasets = {
    'train': 
    traindata,
    'validation': 
    validation
}

dataloaders = {
    'train':
    torch.utils.data.DataLoader(traindata,
                                batch_size=1,
                                shuffle=True, num_workers=0),
    'validation':
    torch.utils.data.DataLoader(validation,
                                batch_size=1,
                                shuffle=False, num_workers=0)
}

In [9]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
print(device)

cuda:0


In [27]:
class LinReg(nn.Module):
    def __init__(self):
        super().__init__()
        self.d = 8
        self.flatten = nn.Flatten()
        self.fc0 = nn.Linear(self.d**2,50)
        self.hs0 = HardSigLayer()
        #self.relu0 = nn.ReLU()
        self.do0 = nn.Dropout(0.25)
        self.fc1 = nn.Linear(50,10)
        self.hs1 = HardSigLayer()
        #self.relu1 = nn.ReLU()
        #self.do1 = nn.Dropout(0.1)
        #self.fc2 = nn.Linear(64,2)
        #self.do2 = nn.Dropout(0.2)
        
    def forward(self, x):
        x = x.float()
        x = self.flatten(x)
        x = self.fc0(x)
        x = self.hs0.forward(x)
        #x = self.relu0(x)
        x = self.do0(x)
        x = self.fc1(x)
        x = self.hs1.forward(x)
        #x = self.relu1(x)
        #x = self.do1(x)
        #x = self.fc2(x)
        #x = self.do2(x)
        return x
    
model = LinReg().to(device)

In [28]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())
#optimizer = optim.SGD(model.fc.parameters(),lr=0.1,momentum=0.9)

In [29]:
def train_model(model, criterion, optimizer, num_epochs=50):
    best_acc = 0
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch+1, num_epochs))
        print('-' * 10)

        for phase in ['train', 'validation']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.type(torch.LongTensor)
                labels = labels.to(device)

                if phase == 'validation':
                    with torch.no_grad():
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)
                else:
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                if phase == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                _, preds = torch.max(outputs, 1)
                running_loss += loss.detach() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(image_datasets[phase])
            epoch_acc = running_corrects.float() / len(image_datasets[phase])
            if phase == 'validation' and epoch_acc > best_acc:
                print('saving best model...')
                torch.save(model.state_dict(), 'models/pytorch/weights.h5')
                best_acc = epoch_acc

            print('{} loss: {:.4f}, acc: {:.4f}'.format(phase,
                                                        epoch_loss.item(),
                                                        epoch_acc.item()))
    return model, best_acc

In [30]:
model_trained, accuracy = train_model(model, criterion, optimizer, num_epochs=100)
print('\nBest test accuracy: %f'%accuracy)

Epoch 1/100
----------
train loss: 1.2775, acc: 0.8257
saving best model...
validation loss: 1.0940, acc: 0.8733
Epoch 2/100
----------
train loss: 0.9889, acc: 0.9345
saving best model...
validation loss: 1.0347, acc: 0.8900
Epoch 3/100
----------
train loss: 0.9414, acc: 0.9439
saving best model...
validation loss: 1.0241, acc: 0.9000
Epoch 4/100
----------
train loss: 0.9141, acc: 0.9586
validation loss: 0.9960, acc: 0.8967
Epoch 5/100
----------
train loss: 0.8918, acc: 0.9646
validation loss: 1.0006, acc: 0.9000
Epoch 6/100
----------
train loss: 0.8881, acc: 0.9599
saving best model...
validation loss: 0.9725, acc: 0.9067
Epoch 7/100
----------
train loss: 0.8886, acc: 0.9613
saving best model...
validation loss: 0.9720, acc: 0.9100
Epoch 8/100
----------
train loss: 0.8761, acc: 0.9666
validation loss: 0.9837, acc: 0.9033
Epoch 9/100
----------
train loss: 0.8673, acc: 0.9746
saving best model...
validation loss: 0.9721, acc: 0.9167
Epoch 10/100
----------
train loss: 0.8640, ac

train loss: 0.8113, acc: 0.9940
validation loss: 0.9570, acc: 0.9233
Epoch 88/100
----------
train loss: 0.8151, acc: 0.9893
validation loss: 0.9462, acc: 0.9133
Epoch 89/100
----------
train loss: 0.8124, acc: 0.9947
validation loss: 0.9426, acc: 0.9233
Epoch 90/100
----------
train loss: 0.8121, acc: 0.9940
validation loss: 0.9616, acc: 0.9133
Epoch 91/100
----------
train loss: 0.8115, acc: 0.9967
validation loss: 0.9486, acc: 0.9167
Epoch 92/100
----------
train loss: 0.8139, acc: 0.9920
validation loss: 0.9814, acc: 0.9067
Epoch 93/100
----------
train loss: 0.8148, acc: 0.9933
validation loss: 0.9803, acc: 0.9033
Epoch 94/100
----------
train loss: 0.8188, acc: 0.9940
validation loss: 0.9575, acc: 0.9133
Epoch 95/100
----------
train loss: 0.8130, acc: 0.9940
validation loss: 0.9498, acc: 0.9233
Epoch 96/100
----------
train loss: 0.8124, acc: 0.9960
validation loss: 0.9635, acc: 0.9167
Epoch 97/100
----------
train loss: 0.8116, acc: 0.9953
validation loss: 0.9540, acc: 0.9167
E

In [31]:
model.load_state_dict(torch.load('models/pytorch/hsig_9333_acc.h5'))

<All keys matched successfully>

In [32]:
for name, param in model.named_parameters():
    print('name: ', name)
    print(type(param))
    print('param.shape: ', param.shape)
    print('param.requires_grad: ', param.requires_grad)
    print('=====')

name:  fc0.weight
<class 'torch.nn.parameter.Parameter'>
param.shape:  torch.Size([50, 64])
param.requires_grad:  True
=====
name:  fc0.bias
<class 'torch.nn.parameter.Parameter'>
param.shape:  torch.Size([50])
param.requires_grad:  True
=====
name:  fc1.weight
<class 'torch.nn.parameter.Parameter'>
param.shape:  torch.Size([10, 50])
param.requires_grad:  True
=====
name:  fc1.bias
<class 'torch.nn.parameter.Parameter'>
param.shape:  torch.Size([10])
param.requires_grad:  True
=====


In [37]:
W1 = model.fc0.weight.cpu().detach().numpy().T
print(W1.shape)
np.save('w1.npy',W1)

(64, 50)


In [38]:
W2 = model.fc1.weight.cpu().detach().numpy().T
print(W2.shape)
np.save('w2.npy',W2)

(50, 10)


In [39]:
bh = model.fc0.bias.cpu().detach().numpy()
print(bh.shape)
np.save('bh.npy',bh)

(50,)


In [40]:
by = model.fc1.bias.cpu().detach().numpy()
print(by.shape)
np.save('by.npy',by)

(10,)
