In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#device = 'cpu'

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize([0.5], [0.5])])

batch_size = 8

trainset = torchvision.datasets.MNIST(root='./mnist_data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=0)

testset = torchvision.datasets.MNIST(root='./mnist_data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False, num_workers=0)

0.3%

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./mnist_data\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting ./mnist_data\MNIST\raw\train-images-idx3-ubyte.gz to ./mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz


100.0%
2.0%

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz
Extracting ./mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz to ./mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%
100.0%


Extracting ./mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting ./mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./mnist_data\MNIST\raw



In [2]:
neurons = 25
classes = 10

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, neurons, 5)
        for param in self.conv1.parameters():
            param.requires_grad = False # freeze the kernel weights
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(neurons*12*12, classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = self.fc1(x)
        return x

net = Net()
net.to(device)

criterion = nn.CrossEntropyLoss()
#optimizer = optim.Adam(net.parameters())
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

def train():
    print("start training [", end="")
    for epoch in range(10):  # loop over the dataset multiple times
        for inputs, labels in trainloader:
            # get the inputs; data is a list of [inputs, labels]
            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        print('*', end='')
    print("] finished.")

def test():
    net.eval()
    correct = 0
    with torch.no_grad():
        for input, label in testloader:
            input = input.to(device)
            label = label.to(device)
            output = net(input)
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(label.data.view_as(pred)).sum()
    
    print('{:.6f}'.format(correct / len(testloader.dataset)))

In [20]:
# First pass training
net.apply(init_weights)

import numpy as np
import pickle
import sys
sys.path.append('./src')

work_path = './snapshots_02292024_1014/'

# Load the synapses from the trained network
fe_net = pickle.load(open(work_path+'net_6000.pkl', 'rb'))
synapse = fe_net.synapses

with torch.no_grad():
    net.conv1.weight.data = torch.from_numpy(synapse.astype(np.float32).reshape(neurons, 1, 5, 5)).to(device)
    net.conv1.bias.data.fill_(0)

train()
test()

torch.save(net.state_dict(), work_path+'ann_6000')

start training [**********] finished.
0.970100


In [21]:
net.load_state_dict(torch.load(work_path+'ann_6000'))

for i in range(60):
    fe_net = pickle.load(open(work_path+f'net_{i+1}00.pkl', 'rb'))
    synapse = fe_net.synapses
    with torch.no_grad():
        net.conv1.weight.data = torch.from_numpy(synapse.astype(np.float32).reshape(neurons, 1, 5, 5)).to(device)
    test()

0.644200
0.759400
0.817700
0.790500
0.825700
0.907100
0.945600
0.957600
0.960300
0.964900
0.966700
0.966300
0.964700
0.966300
0.966400
0.967900
0.967800
0.967700
0.969100
0.968400
0.968000
0.968200
0.968500
0.969100
0.968800
0.968800
0.968900
0.968600
0.968600
0.968500
0.968900
0.968800
0.968600
0.968700
0.968800
0.969100
0.969000
0.969400
0.969300
0.969600
0.969100
0.969200
0.969600
0.969600
0.969400
0.968700
0.969700
0.969400
0.969600
0.969500
0.970400
0.969900
0.969600
0.970100
0.969700
0.969600
0.969800
0.969900
0.969500
0.970100


In [10]:
import matplotlib.pyplot as plt

def visualize(array, fname, w_min=0, w_max=1, size=None, arange=None):
    num = array.shape[0]

    if arange is None:
        cols = int(np.ceil(np.sqrt(num)))
        rows = int((num-1)/cols)+1
    else:
        cols, rows = arange

    if size is None:
        if len(array.shape) == 2:
            pixels = array.shape[1]
            y = int(np.ceil(np.sqrt(pixels)))
            x = int((pixels-1)/y)+1
            size = [x, y]
        elif len(array.shape) == 3:
            size = array.shape[1:3]

    scale = np.max(size)/10

    plt.clf()
    fig = plt.figure(figsize=(scale*cols, scale*rows))
    plt.subplots_adjust(bottom=.01, left=.01, right=.99, top=.99)

    for i in range(0, num):
        data = array[i].reshape(size)

        axs = fig.add_subplot(rows, cols, i+1)
        axs.set_xticks([])
        axs.set_yticks([])

        img = axs.imshow(data, vmin=w_min, vmax=w_max, cmap='gray')
    # colorbar
    #fig.colorbar(img, cax=fig.add_axes([0.2, 0.93, 0.6, 0.03]), orientation='horizontal')

    plt.savefig(fname+'.png')
    plt.close()


In [12]:
visualize(synapse, '1')

<Figure size 640x480 with 0 Axes>