In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.backends.cudnn as cudnn
from tensorboardX import SummaryWriter
import model.simplenet as simplenet
import matplotlib.pyplot as plt
import torch.utils.data as Data
import imageio
import numpy as np
import random
import pandas as pd

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    #torch.backends.cudnn.benchmark = True #for accelerating the running
    return

def savefig(name):
    plt.savefig(name,dpi=600, bbox_inches='tight')
    return

def change_weight(weight):
    '''
    Change the weight matrix in quadratic neural network.
    '''
    model_dict = model.state_dict()
    model_dict['classifier.0.weight_a'] = weight
    model.load_state_dict(model_dict)
    
    return


from sklearn.datasets import load_digits


digits = load_digits()['data']
target = torch.from_numpy(load_digits()['target'])

for i in range(digits.shape[0]):
    digits[i,:] = digits[i,:]/max(digits[i,:])

device = torch.device('cpu')
writer = SummaryWriter()

train_dataset = torch.from_numpy(digits)
valid_dataset = torch.from_numpy(digits)


EI_distribution = torch.bernoulli(torch.ones(64)*0.75)
print('number of excitatory neuron:{}'.format(EI_distribution.sum()))
kappa_matrix = -torch.ones(64, 64)
kappa_matrix[EI_distribution==0,:] = 1
kappa_matrix[:,EI_distribution==0] = 1

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=1797, shuffle=False)
valid_loader = torch.utils.data.DataLoader(dataset=valid_dataset, batch_size=1797)


model = simplenet.SimpleNet_0(num_eigens=5,EI_distribution=EI_distribution)

print(model)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
test_accuracy = []
train_accuracy = [0]
all_train_loss = []
image_list = []
best_prec = 0
min_loss = 1

for epoch in range(0, 10000):

    for param_group in optimizer.param_groups:
        param_group['lr'] = 0.01

    model.train()
    # train for one epoch
    for i, input in enumerate(train_loader):
        train_total = 0
        train_correct = 0
        train_loss = 0
                    
        
        # compute output
        output = model(input.to(torch.float32))
        loss = criterion(output, target)

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward() 
        optimizer.step()

        weight_a = model.state_dict()['classifier.0.weight_a']
        new_weight = weight_a.data
        for j in range(10):
            new_weight[j,(new_weight[j,:,:]*kappa_matrix)<0] = 0

        change_weight(new_weight)  

        _, predicted = torch.max(output.data, 1)
        train_total += target.size(0)
        train_correct += (predicted == target).sum().item()
        prec = train_correct / train_total

        train_accuracy.append(prec)
        all_train_loss.append(loss)

        print('Epoch [{}/{}], Step [{}/{}], Loss: {:.5f}, Train_Acc:{:.2f}%'.format(epoch+1, 10000, i, len(train_loader), loss, prec*100))


        model.eval()
        valid_correct = 0
        valid_total = 0
        with torch.no_grad():
            total_loss = 0
            for j, input in enumerate(valid_loader):
                output = model(input.to(torch.float32))
                
                _, predicted = torch.max(output.data, 1)
                valid_total = output.shape[0]
                valid_correct = (predicted == target).sum().item()
                loss = criterion(output, target)
            
                prec = valid_correct / valid_total
                print('Accuary on test images:{:.2f}%, loss:{:.2f}'.format(prec*100, loss))
                test_accuracy.append(prec)
                best_prec = max(prec, best_prec)
                min_loss = min(min_loss, loss)
        
print('Best accuracy is: {:.2f}%, Minimum loss is: {:.4f}'.format(best_prec*100, min_loss))

number of excitatory neuron:48.0
SimpleNet_0(
  (classifier): Sequential(
    (0): Dales_General_quadratic(in_features=64, out_features=10, bias=False)
  )
)
Epoch [1/10000], Step [0/1], Loss: 2.44621, Train_Acc:12.85%
Accuary on test images:13.58%, loss:2.39
Epoch [2/10000], Step [0/1], Loss: 2.39482, Train_Acc:13.58%
Accuary on test images:14.91%, loss:2.35
Epoch [3/10000], Step [0/1], Loss: 2.34586, Train_Acc:14.91%
Accuary on test images:16.69%, loss:2.30
Epoch [4/10000], Step [0/1], Loss: 2.29905, Train_Acc:16.69%
Accuary on test images:18.48%, loss:2.25
Epoch [5/10000], Step [0/1], Loss: 2.25403, Train_Acc:18.48%
Accuary on test images:20.98%, loss:2.21
Epoch [6/10000], Step [0/1], Loss: 2.21067, Train_Acc:20.98%
Accuary on test images:23.82%, loss:2.17
Epoch [7/10000], Step [0/1], Loss: 2.16881, Train_Acc:23.82%
Accuary on test images:27.66%, loss:2.13
Epoch [8/10000], Step [0/1], Loss: 2.12826, Train_Acc:27.66%
Accuary on test images:30.94%, loss:2.09
Epoch [9/10000], Step [0/1