<a href="https://colab.research.google.com/github/josemoti1999/sysdl_project/blob/master/GA_Spiking_Conv.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import os
import matplotlib.pyplot as plt
import torchvision.datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import copy
from torch.autograd import Variable

In [0]:
POPULATION_SIZE=100
MUTATION_POWER=0.02

In [0]:
class GeneticAlgorithm():
    def __init__(self, device, model, target, data):
        self.device=device
        self.model=model.to("cpu")
        self.target=target.to(self.device)
        self.data=data.to(self.device)

    def mutate(self, population_size=20, mutation_power=0.02):
        models_list=[]
        models_list.append(self.model.to("cpu"))
        for _ in range(population_size):
            child=copy.deepcopy(self.model).to(self.device)
            for value in child.parameters():
                tensor_shape=value.shape
                noise=torch.randn(tensor_shape).to(self.device)
                value+=mutation_power*noise
            models_list.append(child.to("cpu"))
        return models_list

    def find_best_model(self, population_size=20, mutation_power=0.02):
        models_list=self.mutate(population_size, mutation_power)
        loss_model_dict={}
        for model in models_list:
            model.to(self.device)
            output=model(self.data)
            loss=F.nll_loss(output, self.target)
            loss_model_dict[model]=loss
            for (model,loss) in sorted(loss_model_dict.items(), key=lambda x: x[1], reverse=False):
                model_return=model
                loss_return=loss
                break
            model_return.to("cpu")
        return model_return, loss_return

In [0]:
def train(model, device, train_set_loader, epoch, logging_interval=100):

    model.eval()
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(train_set_loader):
            ga=GeneticAlgorithm(device,model, target, data)
            model, loss = ga.find_best_model(POPULATION_SIZE, MUTATION_POWER)
            model.to(device)
            output=model(data.to(device))
            pred = output.max(1, keepdim=True)[1]  # get the index of the max log-probability
            correct = pred.eq(target.cuda().view_as(pred)).float().mean().item()
            print('Train Epoch: {} [{}/{} ({:.0f}%)] Loss: {:.6f} Accuracy: {:.2f}%'.format(
                epoch, (batch_idx+1) * len(data), len(train_set_loader.dataset),
                100. * (batch_idx+1) / len(train_set_loader), loss.item(),
                100. * correct))
        return model, 100.*correct

def train_many_epochs(model, tot_epochs=100):
    accuracy=0
    for epoch in range(tot_epochs):
        model, accuracy=train(model, device, train_set_loader, epoch, logging_interval=100)
        test(model, device, test_set_loader)
    return model

def test(model, device, test_set_loader):

    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for data, target in test_set_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduce=True).item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_set_loader.dataset)
    print("")
    print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
        test_loss,
        correct, len(test_set_loader.dataset),
        100. * correct / len(test_set_loader.dataset)))
    print("")

def download_mnist(data_path):
    if not os.path.exists(data_path):
        os.mkdir(data_path)
    transformation = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    training_set = torchvision.datasets.MNIST(data_path, train=True, transform=transformation, download=True)
    testing_set = torchvision.datasets.MNIST(data_path, train=False, transform=transformation, download=True)
    return training_set, testing_set

In [0]:
batch_size = 1000
DATA_PATH = './data'

training_set, testing_set = download_mnist(DATA_PATH)
train_set_loader = torch.utils.data.DataLoader(
    dataset=training_set,
    batch_size=batch_size,
    shuffle=True)
test_set_loader = torch.utils.data.DataLoader(
    dataset=testing_set,
    batch_size=batch_size,
    shuffle=False)

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Spiking Conv2d RNN

In [0]:
import torch
import torch.nn as nn

class SpikingConvLayerRNN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, decay_multiplier=0.9, threshold=2.0, penalty_threshold=2.5):
        super(SpikingConvLayerRNN, self).__init__()
        self.device = device
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.threshold = threshold
        self.decay_multiplier = decay_multiplier
        self.penalty_threshold = penalty_threshold
                
        self.conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size)

        self.init_parameters()
        self.reset_state()

    def init_parameters(self):
        for param in self.parameters():
            if param.dim() >= 2:
                nn.init.xavier_uniform_(param)
    
    def reset_state(self):
        self.prev_inner = None
        self.prev_outer = None
    
    def forward(self, x):
        input_excitation = self.conv(x)

        if self.prev_inner == None or self.prev_outer == None:
            self.prev_inner = torch.zeros_like(input_excitation)
            self.prev_outer = torch.zeros_like(input_excitation)
        
        inner_excitation = input_excitation + self.prev_inner * self.decay_multiplier
        outer_excitation = (inner_excitation > self.threshold).float()
        
        decrease = (self.penalty_threshold/self.threshold * inner_excitation) 
        decrease *= outer_excitation
        inner_excitation = inner_excitation - decrease

        delayed_return_state = self.prev_inner
        delayed_return_output = self.prev_outer
        self.prev_inner = inner_excitation
        self.prev_outer = outer_excitation

        return delayed_return_state, delayed_return_output



#Spiking RNN

In [0]:
class SpikingNeuronLayerRNN(nn.Module):

    def __init__(self, n_inputs=28*28, n_hidden=100, decay_multiplier=0.9, threshold=2.0, penalty_threshold=2.5):
        super(SpikingNeuronLayerRNN, self).__init__()
        self.device = device
        self.n_inputs = n_inputs
        self.n_hidden = n_hidden
        self.decay_multiplier = decay_multiplier
        self.threshold = threshold
        self.penalty_threshold = penalty_threshold

        self.fc = nn.Linear(n_inputs, n_hidden)

        self.init_parameters()
        self.reset_state()

    def init_parameters(self):
        for param in self.parameters():
            if param.dim() >= 2:
                nn.init.xavier_uniform_(param)

    def reset_state(self):
        self.prev_inner = torch.zeros([self.n_hidden]).to(self.device)
        self.prev_outer = torch.zeros([self.n_hidden]).to(self.device)

    def forward(self, x):

        if self.prev_inner.dim() == 1:
            batch_size = x.shape[0]
            self.prev_inner = torch.stack(batch_size * [self.prev_inner])
            self.prev_outer = torch.stack(batch_size * [self.prev_outer])

        input_excitation = self.fc(x)
        inner_excitation = input_excitation + self.prev_inner * self.decay_multiplier
        outer_excitation = (inner_excitation > self.threshold).float()
        inner_excitation = inner_excitation - (self.penalty_threshold/self.threshold * inner_excitation) * outer_excitation

        delayed_return_state = self.prev_inner
        delayed_return_output = self.prev_outer
        self.prev_inner = inner_excitation
        self.prev_outer = outer_excitation
        return delayed_return_state, delayed_return_output

#Bridge

In [0]:

class InputDataToSpikingPerceptronLayer(nn.Module):

    def __init__(self, device):
        super(InputDataToSpikingPerceptronLayer, self).__init__()
        self.device = device

        self.reset_state()
        self.to(self.device)

    def reset_state(self):
        pass

    def forward(self, x, is_2D=True):
        x = x.view(x.size(0), -1)
        return x


class OutputDataToSpikingPerceptronLayer(nn.Module):

    def __init__(self, average_output=True):

        super(OutputDataToSpikingPerceptronLayer, self).__init__()
        if average_output:
            self.reducer = lambda x, dim: x.mean(dim=dim)
        else:
            self.reducer = lambda x, dim: x.sum(dim=dim)

    def forward(self, x):
        if type(x) == list:
            x = torch.stack(x)
        return self.reducer(x, 0)

#Network

In [0]:
class SpikingConvNet(nn.Module):

    def __init__(self, n_time_steps, begin_eval):
        super(SpikingConvNet, self).__init__()
        assert(0 <=begin_eval and begin_eval < n_time_steps)
        self.n_time_steps = n_time_steps
        self.begin_eval = begin_eval

        self.conv1 = SpikingConvLayerRNN(1, 8, 3,
                                         decay_multiplier=0.9, threshold=1, penalty_threshold=1.5)
        

        self.conv2 = SpikingConvLayerRNN(8, 16, 3,
                                         decay_multiplier=0.9, threshold=1, penalty_threshold=1.5)


        self.fc1 = SpikingNeuronLayerRNN(n_inputs=256, n_hidden=16,
                                         decay_multiplier=0.9, threshold=1, penalty_threshold=1.5)

        self.fc2 = SpikingNeuronLayerRNN(n_inputs=16, n_hidden=10,
                                         decay_multiplier=0.9, threshold=1, penalty_threshold=1.5)

        self.pool = nn.MaxPool2d(3, stride=2)

        self.output_conversion = OutputDataToSpikingPerceptronLayer(average_output=False)


    def forward_through_time(self, x):
        self.conv1.reset_state()
        self.conv2.reset_state()
        self.fc1.reset_state()
        self.fc2.reset_state()

        out = []

        all_conv1_states = []
        all_conv2_states = []
        all_fc1_states = []
        all_fc2_states = []

        all_conv1_outputs = []
        all_conv2_outputs = []
        all_fc1_outputs = []
        all_fc2_outputs = []
        
        for _ in range(self.n_time_steps):
            xi = x # Input conv req.
            
            conv1_state, conv1_output = self.conv1(xi)
            
            conv2_input = self.pool(conv1_output)
            conv2_state, conv2_output = self.conv2(conv2_input)

            flat_input = self.pool(conv2_output)

            
            flattened_stuff = flat_input.view(batch_size, -1) #CHANGE

            fc1_state, fc1_output = self.fc1(flattened_stuff)
            fc2_state, fc2_output = self.fc2(fc1_output)

            #all_conv1_states.append(conv1_state)
            #all_conv2_states.append(conv2_state)
            all_fc1_states.append(fc1_state)
            all_fc2_states.append(fc2_state)

            #all_conv1_outputs.append(conv1_output)
            #all_conv2_outputs.append(conv2_output)
            all_fc1_outputs.append(fc1_output)
            all_fc2_outputs.append(fc2_output)

            out.append(fc2_state)
        
        out = self.output_conversion(out[self.begin_eval:])
        
        return out,\
               [[all_conv1_states, all_conv1_outputs],\
                [all_conv2_states, all_conv2_outputs],\
                [all_fc1_states, all_fc1_outputs],\
                [all_fc2_states, all_fc2_outputs]]

    def forward(self, x):
        out, _ = self.forward_through_time(x)
        return F.log_softmax(out, dim = -1)



    def visualize_all_neurons(self, x):
        """
        WILL NOT WORK
        
        """
        assert x.shape[0] == 1 and len(x.shape) == 4, (
            "Pass only 1 example to SpikingNet.visualize(x) with outer dimension shape of 1.")
        _, layers_state = self.forward_through_time(x)

        for i, (all_layer_states, all_layer_outputs) in enumerate(layers_state):
            layer_state  =  torch.stack(all_layer_states).data.cpu().numpy().squeeze().transpose()
            layer_output = torch.stack(all_layer_outputs).data.cpu().numpy().squeeze().transpose()

            self.plot_layer(layer_state, title="Inner state values of neurons for layer {}".format(i))
            self.plot_layer(layer_output, title="Output spikes (activation) values of neurons for layer {}".format(i))

    def visualize_neuron(self, x, layer_idx, neuron_idx):
        """
        WILL NOT WORK
        
        """
        assert x.shape[0] == 1 and len(x.shape) == 4, (
            "Pass only 1 example to SpikingNet.visualize(x) with outer dimension shape of 1.")
        _, layers_state = self.forward_through_time(x)

        all_layer_states, all_layer_outputs = layers_state[layer_idx]
        layer_state  =  torch.stack(all_layer_states).data.cpu().numpy().squeeze().transpose()
        layer_output = torch.stack(all_layer_outputs).data.cpu().numpy().squeeze().transpose()

        self.plot_neuron(layer_state[neuron_idx], title="Inner state values neuron {} of layer {}".format(neuron_idx, layer_idx))
        self.plot_neuron(layer_output[neuron_idx], title="Output spikes (activation) values of neuron {} of layer {}".format(neuron_idx, layer_idx))

    def plot_layer(self, layer_values, title):
        
        width = max(16, layer_values.shape[0] / 8)
        height = max(4, layer_values.shape[1] / 8)
        plt.figure(figsize=(width, height))
        plt.imshow(
            layer_values,
            interpolation="nearest",
            cmap=plt.cm.rainbow
        )
        plt.title(title)
        plt.colorbar()
        plt.xlabel("Time")
        plt.ylabel("Neurons of layer")
        plt.show()

    def plot_neuron(self, neuron_through_time, title):
        width = max(16, len(neuron_through_time) / 8)
        height = 4
        plt.figure(figsize=(width, height))
        plt.title(title)
        plt.plot(neuron_through_time)
        plt.xlabel("Time")
        plt.ylabel("Neuron's activation")
        plt.show()

#Training Stuffs

In [0]:
spiking_model = SpikingConvNet(n_time_steps=128, begin_eval=0)

In [0]:
spiking_model = train_many_epochs(spiking_model, tot_epochs=20)






Test set: Average loss: 0.0057, Accuracy: 2159/10000 (21.59%)


Test set: Average loss: 0.0056, Accuracy: 1800/10000 (18.00%)


Test set: Average loss: 0.0049, Accuracy: 1870/10000 (18.70%)


Test set: Average loss: 0.0038, Accuracy: 2086/10000 (20.86%)


Test set: Average loss: 0.0032, Accuracy: 2343/10000 (23.43%)


Test set: Average loss: 0.0034, Accuracy: 2015/10000 (20.15%)


Test set: Average loss: 0.0032, Accuracy: 2228/10000 (22.28%)


Test set: Average loss: 0.0030, Accuracy: 1944/10000 (19.44%)


Test set: Average loss: 0.0030, Accuracy: 1994/10000 (19.94%)


Test set: Average loss: 0.0028, Accuracy: 1994/10000 (19.94%)


Test set: Average loss: 0.0027, Accuracy: 2018/10000 (20.18%)

