<a href="https://colab.research.google.com/github/hirokame/BPBG/blob/main/BackProp_withBG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import cupy as cp
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm

transform = transforms.ToTensor()
        

# download MNIST dataset
train_data = datasets.MNIST('../data', train=True, download=True, transform=transform)
test_data = datasets.MNIST('../data', train=False, transform=transform)

# set DataLoader
batch_size = 4
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)

In [None]:
'''Set the layer number and size'''
input_size = 784
hidden_size = [100,50]
output_size = 10


'''Set the device'''
cuda0 = torch.device('cuda:0')


'''Initialize the weight'''
w12 = (torch.rand(input_size, hidden_size[0])*2-1).cuda()
w23 = (torch.rand(hidden_size[0], hidden_size[1])*2-1).cuda()
w34 = (torch.rand(hidden_size[1], output_size)*2-1).cuda()



'''Lateral inhibition inside the layer'''
lateral_inhibition = False

w11 = -torch.div(torch.rand(input_size, input_size), input_size).cuda()
for i in range(input_size):
  w11[i,i]=0
w22 = -torch.div(torch.rand(hidden_size[0], hidden_size[0]), hidden_size[0]).cuda()
for i in range(hidden_size[0]):
  w22[i,i]=0
w33 = -torch.div(torch.rand(hidden_size[1], hidden_size[1]), hidden_size[1]).cuda()
for i in range(hidden_size[1]):
  w33[i,i]=0



'''Call Torch.nn functions'''
identity = nn.Identity()
# act = nn.Sigmoid()
act = nn.ReLU()
softmax = nn.Softmax()


''' Hyper parameters '''

eta1 = 3.0e-2  # learning rate
eta2 = 3.0e-2
eta3 = 3.0e-2

alpha1 = 0  # Leaky dynamics of weights
alpha2 = 0

beta = 1.0 # synapse weight modification of BG loop

epoch = 200  # training epochs

eps = 1.0e-7

'''Training and Test log'''
history = {
    'train_loss':[],
    'test_loss':[],
    'test_acc':[],
    'pred_hist': np.empty((0,10))
}

latest_loss = 2.0
loss_list = []
for e in range(epoch):
  print('epoch {} start'.format(e+1))
  for batch_idx, (data, target) in enumerate(train_loader):

    '''set the device '''
    data  = (data).cuda()
    target = target.cuda()

    ''' Training Part '''

    '''forward propagation'''
    x1_in = torch.flatten(data,1) # (batch,28,28)->(batch,784)
    if lateral_inhibition:
      x1_in = act(x1_in + torch.matmul(x1_in, w11))
    x1_out = identity(x1_in)


    x2_in = act(torch.matmul(x1_out, w12)) # (batch,784)->(batch,100), relu
    if lateral_inhibition:
      x2_in = act(x2_in + torch.matmul(x2_in, w22))
    x2_in = torch.div(x2_in, torch.mean(x2_in)+eps) # normalization
    x2_out = identity(x2_in)


    x3_in = act(torch.matmul(x2_out, w23)) # (batch,100)->(batch,50), relu
    if lateral_inhibition:
      x3_in = act(x3_in + torch.matmul(x3_in, w33))
    x3_in = torch.div(x3_in, torch.mean(x3_in)+eps) # normalization
    x3_out = identity(x3_in)
    

    x4_in = act(torch.matmul(x3_out, w34)) # (batch,50)->(batch,10), relu
    x4_in = torch.div(x4_in, torch.mean(x4_in)+eps) # normalization
    x4_out = identity(x4_in)
    

    output = softmax(x4_out)
    

    ''' label encoding -> one-hot encoding '''
    target_oh = F.one_hot(target, num_classes = 10)


    '''loss: subtraction between output and target'''
    loss = target_oh - output
    

    '''
    sim34(i,j): similarity matrix between layer3 node(i) and layer4 node(j)
    sim24(i,j): similarity matrix between layer2 node(i) and layer4 node(j)
    similarity = absolute value of subtraction between two nodes
    '''
    grid3 = torch.tile(x3_out, (1,output_size)).reshape(batch_size, output_size, hidden_size[1]).transpose(1,2) # (32,50,10)
    grid4 = torch.tile(x4_out, (1,hidden_size[1])).reshape(batch_size, hidden_size[1], output_size) # (32,50,10)
    sim34 = torch.reciprocal(torch.abs(grid4-grid3)+1.0e-7) # similarity matrix = (32,50,10)
  
    grid2 = torch.tile(x2_out, (1,output_size)).reshape(batch_size, output_size, hidden_size[0]).transpose(1,2) #(32,100,10)
    grid4 = torch.tile(x4_out, (1,hidden_size[0])).reshape(batch_size, hidden_size[0], output_size) # (32,100,10)
    sim24 = torch.reciprocal(torch.abs(grid4-grid2)+1.0e-7) # similarity matrix = (32,100,10)


    # Calculate the gradient like value (delta * similarity * trace)
    x4_in = beta*loss # (32,10)
    x3_in = beta*torch.einsum('bn,bnm->bm',loss,sim34.transpose(1,2)) # (32,50) + (32,10)@(32,10,50) = (32,50)
    x2_in = beta*torch.einsum('bn,bnm->bm',loss,sim24.transpose(1,2)) # (32,100) + (32,10)@(32,10,100) = (32,100)  

    # x4_in = act(x4_in + beta*loss) # (32,10)
    # x3_in = act(x3_in + beta*torch.einsum('bn,bnm->bm',loss,sim34.transpose(1,2))) # (32,50) + (32,10)@(32,10,50) = (32,50)
    # x2_in = act(x2_in + beta*torch.einsum('bn,bnm->bm',loss,sim24.transpose(1,2))) # (32,100) + (32,10)@(32,10,100) = (32,100)  
    

    # Dopamine modulation (if dopamine released, increase lr)
    current_loss = torch.sum(torch.abs(loss))/batch_size
    eta_dop1 = max(eta1, eta1*(current_loss/latest_loss)**4)
    eta_dop2 = max(eta2, eta2*(current_loss/latest_loss)**4)
    eta_dop3 = max(eta3, eta3*(current_loss/latest_loss)**4)

    # Hebbinan Plasticity (with leaky dynamics)
    w12 +=  -alpha1*w12 + eta_dop1*torch.mean(torch.einsum('bn,bm->bnm',x1_out, x2_in), 0) # hebbian plasticity -> take an average through the batch -> multiply the learning rate
    w23 +=  -alpha2*w23 + eta_dop2*torch.mean(torch.einsum('bn,bm->bnm',x2_out, x3_in), 0)
    w34 +=  eta_dop3*torch.mean(torch.einsum('bn,bm->bnm',x3_out, x4_in), 0)
    # print("w12", w12.shape, w12)
    # print("w23", w23.shape, w23)
    # print("w34", w34.shape, w34)


    # Learning rate update (simulated annealing)
    if batch_idx%2 == 0:
      loss_list.append(current_loss.cpu())
      if batch_idx%100 == 0:
        loss_ave = np.mean(loss_list)
        if loss_ave < 0.5*latest_loss:
          latest_loss = loss_ave
          eta1 *= 0.8
          eta2 *= 0.8
          eta3 *= 0.8
        if batch_idx%1000 == 0:
          print(loss_ave)
        # if batch_idx % 3000 == 0:
        #   print('Epoch{}, {}/60000, Train loss:{}'.format((e+1), batch_size*(batch_idx+1), torch.sum(torch.abs(loss))/batch_size))
      loss_list = []

  train_loss = torch.sum(torch.abs(loss))/batch_size
  
  # print('Train loss: {}'.format(train_loss))
  history['train_loss'].append(train_loss.cpu())

  ''' Test Part '''
  test_loss = 0
  correct = 0
  pred_hist = np.empty(0)
  for data, target in test_loader:
    data = data.cuda()
    target = target.cuda()

    # forward propagation
    x1_in = torch.flatten(data,1) # (batch,28,28)->(batch,784)
    if lateral_inhibition:
      x1_in = act(x1_in + torch.matmul(x1_in, w11))
    x1_out = identity(x1_in)


    x2_in = act(torch.matmul(x1_out, w12)) # (batch,784)->(batch,100), relu
    if lateral_inhibition:
      x2_in = act(x2_in + torch.matmul(x2_in, w22))
    x2_in = torch.div(x2_in, torch.mean(x2_in)+eps) # normalization
    x2_out = identity(x2_in)


    x3_in = act(torch.matmul(x2_out, w23)) # (batch,100)->(batch,50), relu
    if lateral_inhibition:
      x3_in = act(x3_in + torch.matmul(x3_in, w33))
    x3_in = torch.div(x3_in, torch.mean(x3_in)+eps) # normalization
    x3_out = identity(x3_in)


    x4_in = act(torch.matmul(x3_out, w34)) # (batch,50)->(batch,10), relu
    x4_in = torch.div(x4_in, torch.mean(x4_in)+eps) # normalization
    x4_out = identity(x4_in)


    output = softmax(x4_out)
    target_oh = F.one_hot(target, num_classes = 10)
    loss = target_oh - output


    test_loss += torch.sum(torch.abs(loss))
    pred = output.argmax(dim=1, keepdim=True)
    correct += pred.eq(target.view_as(pred)).sum()
    pred_hist = np.append(pred_hist, pred.cpu())
  
  # print('pred', pred.reshape(1,-1))
  # print('target', target)
  test_loss /= 10000
  acc = correct/10000
  histo = np.bincount(pred_hist.astype('int64').ravel(), minlength=10)
  print('Test loss: {}, Test acc: {}'.format(test_loss, acc))

  history['test_loss'].append(test_loss.cpu())
  history['test_acc'].append(correct.cpu()/10000)
  history['pred_hist'] = np.vstack((history['pred_hist'],histo))
  print('histogram', histo)
  print('latest_loss', latest_loss,'eta', eta1,eta2,eta3)

  # print('w12',w12[0][:10])
  # print('w23',w23[0][:10])
  # print('w34',w34[0])

# plot figures
plt.figure()
plt.plot(range(1, epoch+1), history['train_loss'], label='train_loss')
plt.plot(range(1, epoch+1), history['test_loss'], label='test_loss')
plt.xlabel('epoch')
plt.legend()

plt.figure()
plt.plot(range(1, epoch+1), history['test_acc'])
plt.title('test accuracy')
plt.xlabel('epoch')
plt.show()

In [None]:
np.bincount([0,2,3,4,5,6,7,8])

In [None]:
plt.figure(figsize = (10,4))
plt.title('Train/Test loss')
plt.plot(range(1, 201), history['train_loss'], label='train_loss')
plt.plot(range(1, 201), history['test_loss'], label='test_loss')
plt.ylabel('Loss')
plt.xlabel('epoch')
plt.legend()

plt.figure(figsize = (10,4))
plt.plot(range(1, 201), history['test_acc'], label='test_acc')
plt.title('test accuracy')
plt.xlabel('epoch')
plt.legend()

pred_hist = history['pred_hist'].reshape(10,200)
plt.figure(figsize = (10,4))
for i in range(2):
  plt.plot(range(1, 201), pred_hist[i], label='digit {}'.format(i))
plt.title('digit histogram')
plt.xlabel('epoch')
plt.legend()


plt.show()

In [None]:
print('w12',w12.mean(dim=0))
print('w23',w23.mean(dim=0))
print('w34',w34.mean(dim=0))

In [None]:
class MyNet(torch.nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc1 = torch.nn.Linear(784, 100)
        self.fc2 = torch.nn.Linear(100, 50)
        self.fc3 = torch.nn.Linear(50, 10)
 
    def forward(self, x):
        x = self.fc1(x)
        x = torch.sigmoid(x)
        x = self.fc2(x)
        x = torch.sigmoid(x)
        x = self.fc3(x)
 
        return F.log_softmax(x, dim=1)

In [None]:
if __name__ == '__main__':
  # set the training epoch
  epoch = 5
  
  # save the log
  history = {
      'train_loss':[],
      'test_loss':[],
      'test_acc':[]
  }

  # initiate the network
  net: torch.nn.Module = MyNet()
  
  # set the optimizer
  optimizer = torch.optim.Adam(params=net.parameters(), lr=1.0e-3)

  for e in range(epoch):

    #####     Training Part     #####

    loss = None

    net.train(True)

    for i, (data, target) in enumerate(train_loader):

      data = data.view(-1, 784)


      optimizer.zero_grad()
      output = net(data)
      loss = F.nll_loss(output, target)
      loss.backward()
      optimizer.step()

      if i%100 == 0:
        print("Training log: {} epoch ({}/60000 train data) Loss: {})".format(e+1, (i+1)*64, loss.item()))
      
    history['train_loss'].append(loss.detach().numpy())


    #####     Test Part     #####

    net.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
      for data, target in test_loader:
        data = data.view(-1,784)
        output = net(data)
        test_loss += F.nll_loss(output, target, reduction="sum").item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= 10000

    print('Test loss (avg): {}, Accuracy: {}'.format(test_loss, correct/10000))

    history['test_loss'].append(test_loss)
    history['test_acc'].append(correct/10000)

  # plot figure
  plt.figure()
  plt.plot(range(1, epoch+1), history['train_loss'], label='train_loss')
  plt.plot(range(1, epoch+1), history['test_loss'], label='test_loss')
  plt.xlabel('epoch')
  plt.legend()
  plt.savefig('loss.png')

  plt.figure()
  plt.plot(range(1, epoch+1), history['test_acc'])
  plt.title('test accuracy')
  plt.xlabel('epoch')
  plt.savefig('test_acc.png')