In [4]:
!pip install -q torch==1.10.1 torchvision 

In [26]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.nn.parameter import Parameter
import torchvision
import numpy as np
from SpykeTorch import snn
from SpykeTorch import functional as sf
from SpykeTorch import visualization as vis
from SpykeTorch import utils
from torchvision import transforms

#use_cuda = True
use_cuda = False

In [27]:
class FlyEye(nn.Module):
    def __init__(self):
        super(fly_eye, self).__init__()

        # photoreceptor cells
        self.R1_6 = snn.Convolutional(1,1,1)
        self.R7 = snn.Convolutional(1,1,1)
        self.R8 = snn.Convolutional(1,1,1)

        # Lamina
        self.L1 = snn.Convolutional(1,1,1)
        self.L2 = snn.Convolutional(1,1,1)
        self.L3 = snn.Convolutional(1,1,1)
        self.L4 = snn.Convolutional(1,1,1)
        self.L5 = snn.Convolutional(1,1,1)
        
        self.C2 = snn.Convolutional(1,1,1)
        self.C3 = snn.Convolutional(1,1,1)
        
        # Medulla
        self.Mi1 = snn.Convolutional(1,1,1)
        self.Mi4 = snn.Convolutional(1,1,1)
        self.Mi9 = snn.Convolutional(1,1,1)
        self.Mi15 = snn.Convolutional(1,1,1)
        
        self.Tm1 = snn.Convolutional(1,1,1)
        self.Tm2 = snn.Convolutional(1,1,1)
        self.Tm3 = snn.Convolutional(1,1,1)
        self.Tm4 = snn.Convolutional(1,1,1)
        self.Tm6 = snn.Convolutional(1,1,1)
        self.Tm9 = snn.Convolutional(1,1,1)
        self.Tm20 = snn.Convolutional(1,1,1)
        self.TmY5a = snn.Convolutional(1,1,1)
        
        self.T2 = snn.Convolutional(1,1,1)
        self.T2a = snn.Convolutional(1,1,1)
        self.T3 = snn.Convolutional(1,1,1)
        
        # Lobula
        self.LC4 = snn.Convolutional(1,1,1)
        self.LC17 = snn.Convolutional(1,1,1)
        
        # Cenral brain
        self.central_brain = snn.Convolutional(1,1,1)
        
        
        # STDP applying
        self.stdpR1_6 = snn.STDP(self.R1_6, (0.004, -0.003))
        self.stdpR7 = snn.STDP(self.R7, (0.004, -0.003))
        self.stdpR8 = snn.STDP(self.R8, (0.004, -0.003))        
        
        self.stdpL1 = snn.STDP(self.L1, (0.004, -0.003))
        self.stdpL2 = snn.STDP(self.L2, (0.004, -0.003))
        self.stdpL3 = snn.STDP(self.L3, (0.004, -0.003))
        self.stdpL4 = snn.STDP(self.L4, (0.004, -0.003))
        self.stdpL5 = snn.STDP(self.L5, (0.004, -0.003))
        
        self.stdpC2 = snn.STDP(self.C2, (0.004, -0.003))
        self.stdpC3 = snn.STDP(self.C3, (0.004, -0.003))
        
        self.stdpMi1 = snn.STDP(self.Mi1, (0.004, -0.003))
        self.stdpMi14 = snn.STDP(self.Mi4, (0.004, -0.003))
        self.stdpMi9 = snn.STDP(self.Mi9, (0.004, -0.003))
        self.stdpMi15 = snn.STDP(self.Mi15, (0.004, -0.003))
        
        self.stdpTm1 = snn.STDP(self.Tm1, (0.004, -0.003))
        self.stdpTm2 = snn.STDP(self.Tm2, (0.004, -0.003))
        self.stdpTm3 = snn.STDP(self.Tm3, (0.004, -0.003))
        self.stdpTm4 = snn.STDP(self.Tm4, (0.004, -0.003))
        self.stdpTm6 = snn.STDP(self.Tm6, (0.004, -0.003))
        self.stdpTm9 = snn.STDP(self.Tm9, (0.004, -0.003))
        self.stdpTm20 = snn.STDP(self.Tm20, (0.004, -0.003))
        self.stdpTmY5a = snn.STDP(self.TmY5a, (0.004, -0.003))
        
        self.stdpT2 = snn.STDP(self.T2, (0.004, -0.003))
        self.stdpT2a = snn.STDP(self.T2a, (0.004, -0.003))
        self.stdpT3 = snn.STDP(self.T3, (0.004, -0.003))
        
        self.stdpLC4 = snn.STDP(self.LC4, (0.004, -0.003))
        self.stdpLC17 = snn.STDP(self.LC17, (0.004, -0.003))
        
        self.anti_stdp_central_brain = snn.STDP(self.central_brain, (-0.004, 0.005), False, 0.2, 0.8)
        self.stdp_central_brain = snn.STDP(self.central_brain, (0.004, -0.003), False, 0.2, 0.8)
        
        
        
        
        self.ctx = {"input_spikes":None, "potentials":None, "output_spikes":None, "winners":None}
        self.spk_cnt1 = 0
        self.spk_cnt2 = 0

        
        
    def forward(self, input, max_layer):
        # input - данные после обработки
        
        
        
        
        # R1_6, R7, R8-------------------------------------------------------------------------------------------------
        R1_6 = input

        potR7 = self.convR7(input)
        spkR7, potR7 = sf.fire(potR7, self.convR7_t, True)

        if max_layer == 'R7':
            pot = sf.pointwise_inhibition(potR7)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = input
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        potR8 = self.convR8(input)
        spkR8, potR8 = sf.fire(potR8, self.convR8_t, True)
        
        if max_layer == 'R8':
            pot = sf.pointwise_inhibition(potR8)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = input
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        
        
        
        # L1, L2, L3, L4, L5------------------------------------------------------------------------------------------
        # L1: R1_6 + R8 as input
        potL1 = self.convL1(torch.cat([R1_6,spkR8], dim=1))
        spkL1, potL1 = sf.fire(potL1, self.convL1_t, True)
        if max_layer == 'L1':
            pot = sf.pointwise_inhibition(potL1)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([R1_6,spkR8], dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
                                                 
        # L2: R1_6 as input
        potL2 = self.convL2(R1_6)
        spkL2, potL2 = sf.fire(potL2, self.convL2_t, True)
        if max_layer == 'L2':
            pot = sf.pointwise_inhibition(potL2)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = R1_6
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
                                                 
        # L3: R1_6, R7, R8 as input
        potL3 = self.convL3(torch.cat([R1_6,spkR7,spkR8], dim=1))
        spkL3, potL3 = sf.fire(potL3, self.convL3_t, True)
        if max_layer == 'L3':
            pot = sf.pointwise_inhibition(potL3)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([R1_6,spkR7,spkR8], dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
                                                 
        # L4: R1_6 as input
        potL4 = self.convL4(R1_6)
        spkL4, potL4 = sf.fire(potL4, self.convL4_t, True)
        if max_layer == 'L4':
            pot = sf.pointwise_inhibition(potL4)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = R1_6
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        # L5: R1_6 + R8 as input
        potL5 = self.convL5(torch.cat([R1_6,spkR8], dim=1))
        spkL5, potL5 = sf.fire(potL5, self.convL5_t, True)
        if max_layer == 'L5':
            pot = sf.pointwise_inhibition(potL5)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([R1_6,spkR8], dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot                                         
    

        # C2, C3------------------------------------------------------------------------------------------------------
        # C2: L1, L5(5,5) as input
        potC2 = self.convC2(torch.cat([spkL1,
                                       sf.pooling(spkL5, 5, 1, 2)],
                                      dim=1))
        spkC2, potC2 = sf.fire(potC2, self.convC2_t, True)
        if max_layer == 'C2':
            pot = sf.pointwise_inhibition(potC2)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([spkL1,
                                       sf.pooling(spkL5, 5, 1, 2)],
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        # C3: L1, L2, L3, L5(3,3) as input
        potC3 = self.convC3(torch.cat([spkL1, spkL2, spkL3,
                                       sf.pooling(spkL5, 3, 1, 1)],
                                      dim=1))
        spkC3, potC3 = sf.fire(potC3, self.convC3_t, True)
        if max_layer == 'C3':
            pot = sf.pointwise_inhibition(potC3)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([spkL1, spkL2, spkL3,
                                       sf.pooling(spkL5, 3, 1, 1)],
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        
        # Mi1, Mi4, Mi9, Mi15-----------------------------------------------------------------------------------------
        # Mi1: R8, L1, L3, C2, C3, L5(3,3) as input
        potMi1 = self.convMi1(torch.cat([spkR8, spkL1, spkL3,spkC2,spkC3,
                                       sf.pooling(spkL5, 3, 1, 1)],
                                      dim=1))
        spkMi1, potMi1 = sf.fire(potMi1, self.convMi1_t, True)
        if max_layer == 'Mi1':
            pot = sf.pointwise_inhibition(potMi1)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([spkR8, spkL1, spkL3,spkC2,spkC3,
                                       sf.pooling(spkL5, 3, 1, 1)],
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        
        # Mi4: R8, L2, L3, C2, C3, L5(3,3) as input
        potMi4 = self.convMi4(torch.cat([spkR8, spkL2, spkL3,spkC2,spkC3, 
                                       sf.pooling(spkL5, 3, 1, 1)],
                                      dim=1))
        spkMi4, potMi4 = sf.fire(potMi4, self.convMi4_t, True)
        if max_layer == 'Mi4':
            pot = sf.pointwise_inhibition(potMi4)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([spkR8, spkL2, spkL3,spkC2,spkC3, 
                                       sf.pooling(spkL5, 3, 1, 1)],
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        # Mi9: R7, R8, L2, L3(3,3), L4(3,3) as input
        potMi9 = self.convMi9(torch.cat([spkR7, spkR8, spkL2, 
                                       sf.pooling(spkL3, 3, 1, 1),
                                       sf.pooling(spkL4, 3, 1, 1)],
                                      dim=1))
        spkMi9, potMi9 = sf.fire(potMi9, self.convMi9_t, True)
        if max_layer == 'Mi9':
            pot = sf.pointwise_inhibition(potMi9)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([spkR7, spkR8, spkL2, 
                                       sf.pooling(spkL3, 3, 1, 1),
                                       sf.pooling(spkL4, 3, 1, 1)],
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        # Mi15: R8(3,3), L5(3,3) as input
        potMi15 = self.convMi15(torch.cat([ 
                                       sf.pooling(spkR8, 3, 1, 1),
                                       sf.pooling(spkL5, 3, 1, 1)],
                                      dim=1))
        spkMi15, potMi15 = sf.fire(potMi15, self.convMi15_t, True)
        if max_layer == 'Mi15':
            pot = sf.pointwise_inhibition(potMi15)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([ 
                                       sf.pooling(spkR8, 3, 1, 1),
                                       sf.pooling(spkL5, 3, 1, 1)],
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        
        # Tm1, Tm2, Tm3, Tm4, Tm6, Tm9, Tm20, TmY5a------------------------------------------------------------------
        # Tm1: L2, L5, C2, C3, Mi1, Mi4, Mi9
        potTm1 = self.convTm1(torch.cat([spkL2, spkL5, spkC2, spkC3,
                                         spkMi1, spkMi4, spkMi9]
                                      dim=1))
        spkTm1, potTm1 = sf.fire(potTm1, self.convTm1_t, True)
        if max_layer == 'Tm1':
            pot = sf.pointwise_inhibition(potTm1)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([spkL2, spkL5, spkC2, spkC3,
                                         spkMi1, spkMi4, spkMi9]
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        
        # Tm2: L2, L4(3,3), C3, Mi1, Mi9
        potTm2 = self.convTm2(torch.cat([spkL2, spkC3, spkMi1, spkMi9,
                                         sf.pooling(spkL4, 3, 1, 1)]
                                      dim=1))
        spkTm2, potTm2 = sf.fire(potTm2, self.convTm2_t, True)
        if max_layer == 'Tm2':
            pot = sf.pointwise_inhibition(potTm2)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([spkL2, spkC3, spkMi1, spkMi9,
                                         sf.pooling(spkL4, 3, 1, 1)]
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        # Tm3: L3(3,3), C2(3,3), Mi4(3,3), Mi9(3,3), L1(5,5), L5(5,5), Mi1(5,5)
        potTm3 = self.convTm3(torch.cat([
                                         sf.pooling(spkL3, 3, 1, 1),
                                         sf.pooling(spkC2, 3, 1, 1),
                                         sf.pooling(spkMi4, 3, 1, 1),
                                         sf.pooling(spkMi9, 3, 1, 1),                
                                         sf.pooling(spkL1, 5, 1, 2),
                                         sf.pooling(spkL5, 5, 1, 2),
                                         sf.pooling(spkMi1, 5, 1, 2)]
                                      dim=1))
        spkTm3, potTm3 = sf.fire(potTm3, self.convTm3_t, True)
        if max_layer == 'Tm3':
            pot = sf.pointwise_inhibition(potTm3)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([
                                         sf.pooling(spkL3, 3, 1, 1),
                                         sf.pooling(spkC2, 3, 1, 1),
                                         sf.pooling(spkMi4, 3, 1, 1),
                                         sf.pooling(spkMi9, 3, 1, 1),                
                                         sf.pooling(spkL1, 5, 1, 2),
                                         sf.pooling(spkL5, 5, 1, 2),
                                         sf.pooling(spkMi1, 5, 1, 2)]
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        
        # Tm4: L4(3,3), Mi4(3), Mi9(3), L2(5), C3(5)
        potTm4 = self.convTm4(torch.cat([
                                         sf.pooling(spkL4, 3, 1, 1),
                                         sf.pooling(spkMi4, 3, 1, 1),
                                         sf.pooling(spkMi9, 3, 1, 1),                
                                         sf.pooling(spkL2, 5, 1, 2),
                                         sf.pooling(spkC3, 5, 1, 2)]
                                      dim=1))
        spkTm4, potTm4 = sf.fire(potTm4, self.convTm4_t, True)
        if max_layer == 'Tm4':
            pot = sf.pointwise_inhibition(potTm4)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([
                                         sf.pooling(spkL4, 3, 1, 1),
                                         sf.pooling(spkMi4, 3, 1, 1),
                                         sf.pooling(spkMi9, 3, 1, 1),                
                                         sf.pooling(spkL2, 5, 1, 2),
                                         sf.pooling(spkC3, 5, 1, 2)]
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        
        # Tm6: Mi1(3), Mi15(3), L5(5), Mi9(5)
        potTm6 = self.convTm6(torch.cat([sf.pooling(spkMi1, 3, 1, 1),
                                         sf.pooling(spkMi15, 3, 1, 1),               
                                         sf.pooling(spkL5, 5, 1, 2),
                                         sf.pooling(spkMi9, 5, 1, 2)]
                                      dim=1))
        spkTm6, potTm6 = sf.fire(potTm6, self.convTm6_t, True)
        if max_layer == 'Tm6':
            pot = sf.pointwise_inhibition(potTm6)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([sf.pooling(spkMi1, 3, 1, 1),
                                         sf.pooling(spkMi15, 3, 1, 1),               
                                         sf.pooling(spkL5, 5, 1, 2),
                                         sf.pooling(spkMi9, 5, 1, 2)]
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        # Tm9: L4(3), Mi4(3), L2, C2, C3
        potTm9 = self.convTm9(torch.cat([sf.pooling(spkL4, 3, 1, 1),
                                        sf.pooling(spkMi4, 3, 1, 1),
                                        spkL2, spkC2, spkC3]
                                      dim=1))
        spkTm9, potTm9 = sf.fire(potTm9, self.convTm9_t, True)
        if max_layer == 'Tm9':
            pot = sf.pointwise_inhibition(potTm9)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([sf.pooling(spkL4, 3, 1, 1),
                                        sf.pooling(spkMi4, 3, 1, 1),
                                        spkL2, spkC2, spkC3]
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        # Tm20: Mi4(3), R7, R8, L2, C3, Mi1
        potTm20 = self.convTm20(torch.cat([sf.pooling(spkMi4, 3, 1, 1),
                                        spkR7, spkR8, spkL2, spkC3, spkMi1]
                                      dim=1))
        spkTm20, potTm20 = sf.fire(potTm20, self.convTm20_t, True)
        if max_layer == 'Tm20':
            pot = sf.pointwise_inhibition(potTm20)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([sf.pooling(spkMi4, 3, 1, 1),
                                        spkR7, spkR8, spkL2, spkC3, spkMi1]
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        
        # TmY5a: L5(3), Mi4(3), Mi9(3)
        potTmY5a = self.convTmY5a(torch.cat([sf.pooling(spkL5, 3, 1, 1),
                                            sf.pooling(spkMi4, 3, 1, 1),
                                            sf.pooling(spkMi9, 3, 1, 1)]
                                      dim=1))
        spkTmY5a, potTmY5a = sf.fire(potTmY5a, self.convTmY5a_t, True)
        if max_layer == 'TmY5a':
            pot = sf.pointwise_inhibition(potTmY5a)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([sf.pooling(spkL5, 3, 1, 1),
                                            sf.pooling(spkMi4, 3, 1, 1),
                                            sf.pooling(spkMi9, 3, 1, 1)]
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        

        # T2, T2a, T3-----------------------------------------------------------------------------------------------
        # T2: Mi1(3), Tm1(3), Tm3(3), Tm4(3), TmY5a(3)
        potT2 = self.convT2(torch.cat([sf.pooling(spkMi1, 3, 1, 1),
                                            sf.pooling(spkTm1, 3, 1, 1),
                                            sf.pooling(spkTm3, 3, 1, 1),
                                            sf.pooling(spkTm4, 3, 1, 1),
                                            sf.pooling(spkTmY5a, 3, 1, 1)]
                                      dim=1))
        spkT2, potT2 = sf.fire(potT2, self.convT2_t, True)
        if max_layer == 'T2':
            pot = sf.pointwise_inhibition(potT2)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([sf.pooling(spkMi1, 3, 1, 1),
                                            sf.pooling(spkTm1, 3, 1, 1),
                                            sf.pooling(spkTm3, 3, 1, 1),
                                            sf.pooling(spkTm4, 3, 1, 1),
                                            sf.pooling(spkTmY5a, 3, 1, 1)]
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        
        # T2a: L2, L5, C2, C3, Mi4, Tm1, Tm2
        potT2a = self.convT2a(torch.cat([spkL2, spkL5, spkC2, spkC3, spkMi4, spkTm1, spkTm2], dim=1))
        spkT2a, potT2a = sf.fire(potT2a, self.convT2a_t, True)
        if max_layer == 'T2a':
            pot = sf.pointwise_inhibition(potT2a)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([spkL2, spkL5, spkC2, spkC3, spkMi4, spkTm1, spkTm2], dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        
        # T3: L2, L4, L5, C2, C3, Mi1, Mi9, Tm1, Tm2, Tm3, Tm6, TmY5a - 3 window
        potT3 = self.convT3(torch.cat([sf.pooling(spkL2, 3, 1, 1),
                                            sf.pooling(spkL4, 3, 1, 1),
                                            sf.pooling(spkL5, 3, 1, 1),
                                            sf.pooling(spkC2, 3, 1, 1),
                                            sf.pooling(spkC3, 3, 1, 1),
                                            sf.pooling(spkMi1, 3, 1, 1),
                                            sf.pooling(spkMi9, 3, 1, 1),
                                            sf.pooling(spkTm1, 3, 1, 1),
                                            sf.pooling(spkTm2, 3, 1, 1),
                                            sf.pooling(spkTm3, 3, 1, 1),
                                            sf.pooling(spkTm6, 3, 1, 1),
                                            sf.pooling(spkTmY5a, 3, 1, 1)]
                                      dim=1))
        spkT3, potT3 = sf.fire(potT3, self.convT3_t, True)
        if max_layer == 'T3':
            pot = sf.pointwise_inhibition(potT3)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([sf.pooling(spkL2, 3, 1, 1),
                                            sf.pooling(spkL4, 3, 1, 1),
                                            sf.pooling(spkL5, 3, 1, 1),
                                            sf.pooling(spkC2, 3, 1, 1),
                                            sf.pooling(spkC3, 3, 1, 1),
                                            sf.pooling(spkMi1, 3, 1, 1),
                                            sf.pooling(spkMi9, 3, 1, 1),
                                            sf.pooling(spkTm1, 3, 1, 1),
                                            sf.pooling(spkTm2, 3, 1, 1),
                                            sf.pooling(spkTm3, 3, 1, 1),
                                            sf.pooling(spkTm6, 3, 1, 1),
                                            sf.pooling(spkTmY5a, 3, 1, 1)]
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot

        # LC4, LC17------------------------------------------------------------------------------------------------
        # LC4: Tm1, Tm2, Tm3, Tm4, Tm6, Tm9, TmY5a, T2, T2a, T3 - 5 window
        potLC4 = self.convLC4(torch.cat([sf.pooling(spkTm1, 5, 1, 2),
                                            sf.pooling(spkTm2, 5, 1, 2),
                                            sf.pooling(spkTm3, 5, 1, 2),
                                            sf.pooling(spkTm4, 5, 1, 2),
                                            sf.pooling(spkTm6, 5, 1, 2),
                                            sf.pooling(spkTm9, 5, 1, 2),
                                            sf.pooling(spkTmY5a, 5, 1, 2),
                                            sf.pooling(spkT2, 5, 1, 2),
                                            sf.pooling(spkT2a, 5, 1, 2),
                                            sf.pooling(spkT3, 5, 1, 2)]
                                      dim=1))
        spkLC4, potLC4 = sf.fire(potLC4, self.convLC4_t, True)
        if max_layer == 'LC4':
            pot = sf.pointwise_inhibition(potLC4)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([sf.pooling(spkTm1, 5, 1, 2),
                                            sf.pooling(spkTm2, 5, 1, 2),
                                            sf.pooling(spkTm3, 5, 1, 2),
                                            sf.pooling(spkTm4, 5, 1, 2),
                                            sf.pooling(spkTm6, 5, 1, 2),
                                            sf.pooling(spkTm9, 5, 1, 2),
                                            sf.pooling(spkTmY5a, 5, 1, 2),
                                            sf.pooling(spkT2, 5, 1, 2),
                                            sf.pooling(spkT2a, 5, 1, 2),
                                            sf.pooling(spkT3, 5, 1, 2)]
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot
        
        
        # LC17: Tm2, Tm3, Tm4, Tm6, Tm9, Tm20, TmY5a, T2, T2a, T3 - 3 window
        potLC17 = self.convLC17(torch.cat([sf.pooling(spkTm2, 3, 1, 1),
                                            sf.pooling(spkTm3, 3, 1, 1),
                                            sf.pooling(spkTm4, 3, 1, 1),
                                            sf.pooling(spkTm6, 3, 1, 1),
                                            sf.pooling(spkTm9, 3, 1, 1),
                                            sf.pooling(spkTm20, 3, 1, 1),
                                            sf.pooling(spkTmY5a, 3, 1, 1),
                                            sf.pooling(spkT2, 3, 1, 1),
                                            sf.pooling(spkT2a, 3, 1, 1),
                                            sf.pooling(spkT3, 3, 1, 1)]
                                      dim=1))
        spkLC17, potLC17 = sf.fire(potLC17, self.convLC17_t, True)
        if max_layer == 'LC17':
            pot = sf.pointwise_inhibition(potLC17)
            spk = pot.sign()
            winners = sf.get_k_winners(pot, spikes=spk)
            self.ctx["input_spikes"] = torch.cat([sf.pooling(spkTm2, 3, 1, 1),
                                            sf.pooling(spkTm3, 3, 1, 1),
                                            sf.pooling(spkTm4, 3, 1, 1),
                                            sf.pooling(spkTm6, 3, 1, 1),
                                            sf.pooling(spkTm9, 3, 1, 1),
                                            sf.pooling(spkTm20, 3, 1, 1),
                                            sf.pooling(spkTmY5a, 3, 1, 1),
                                            sf.pooling(spkT2, 3, 1, 1),
                                            sf.pooling(spkT2a, 3, 1, 1),
                                            sf.pooling(spkT3, 3, 1, 1)]
                                      dim=1)
            self.ctx["potentials"] = pot
            self.ctx["output_spikes"] = spk
            self.ctx["winners"] = winners
            return spk, pot

        # central brain---------------------------------------------------------------------------------------------
        potCB = self.convCB(torch.cat([spkLC4, spkLC17]
                                      dim=1))
        spkCB, potCB = sf.fire(potCB)

        
        
        #------------------------------------------------------------------------------------------------------------
        if self.training:    
            # finish training
            winners = sf.get_k_winners(potCB, 1, 0, spkCB)
            self.ctx["input_spikes"] = torch.cat([spkLC4, spkLC17]
                                          dim=1)
            self.ctx["potentials"] = potCB
            self.ctx["output_spikes"] = spkCB
            self.ctx["winners"] = winners
            output = -1
            if len(winners) != 0:
                output = self.decision_map[winners[0][0]]
            return outputs
        else:
            winners = sf.get_k_winners(potCB, 1, 0, spkCB)
            output = -1
            if len(winners) != 0:
                output = self.decision_map[winners[0][0]]
            return output
        
        
        
        
        
        
        
    # STDP for layer layer_idx
    def stdp(self, layer_idx):
        
        if layer_idx == 'R7':
            self.stdpR7(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'R8':
            self.stdpR8(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
       
        if layer_idx == 'L1':
            self.stdpL1(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'L2':
            self.stdpL2(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'L3':
            self.stdpL3(sezf.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'L4':
            self.stdpL4(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'L5':
            self.stdpL5(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        
        if layer_idx == 'C2':
            self.stdpC2(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'C3':
            self.stdpC3(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        
        if layer_idx == 'Mi1':
            self.stdpMi1(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'Mi4':
            self.stdpMi4(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'Mi9':
            self.stdpMi9(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'Mi15':
            self.stdpMi15(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        
        if layer_idx == 'Tm1':
            self.stdpTm1(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'Tm2':
            self.stdpTm2(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'Tm3':
            self.stdpTm3(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'Tm4':
            self.stdpTm4(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'Tm6':
            self.stdpTm6(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'Tm9':
            self.stdpTm9(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'Tm20':
            self.stdpTm20(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'TmY5a':
            self.stdpTmY5a(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
            
        if layer_idx == 'T2':
            self.stdpT2(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'T2a':
            self.stdpT2a(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'T3':
            self.stdpT3(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        
        if layer_idx == 'LC4':
            self.stdpLC4(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
        if layer_idx == 'LC17':
            self.stdpLC17(self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
       

    # learning rates updating
    def update_learning_rates(self, stdp_ap, stdp_an, anti_stdp_ap, anti_stdp_an):
        self.stdp_central_brain.update_all_learning_rate(stdp_ap, stdp_an)
        self.anti_stdp_central_brain.update_all_learning_rate(anti_stdp_an, anti_stdp_ap)

    # reward signal for ultimate layer
    def reward(self):
        self.stdp_central_brain(
            self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])
    
    # punishment signal for ultimate layer
    def punish(self):
        self.anti_stdp_central_brain(
            self.ctx["input_spikes"], self.ctx["potentials"], self.ctx["output_spikes"], self.ctx["winners"])

In [28]:
# training process for the layer layer_idx (R1_6, R7, R8, L1, L2, L3, .....)
def train_unsupervise(network, data, layer_idx):
    network.train()
    for i in range(len(data)):
        data_in = data[i]
        if use_cuda:
            data_in = data_in.cuda()
        network(data_in, layer_idx)
        network.stdp(layer_idx)

In [29]:
def train_rl(network, data, target):
    network.train()
    perf = np.array([0,0,0]) # correct, wrong, silence
    for i in range(len(data)):
        data_in = data[i]
        target_in = target[i]
        if use_cuda:
            data_in = data_in.cuda()
            target_in = target_in.cuda()
        # До последнего слоя
        d = network(data_in, 'central_brain')
        if d != -1:
            if d == target_in:
                perf[0]+=1
                network.reward()
            else:
                perf[1]+=1
                network.punish()
        else:
            perf[2]+=1
    return perf/len(data)

In [30]:
def test(network, data, target):
    network.eval()
    perf = np.array([0,0,0]) # correct, wrong, silence
    for i in range(len(data)):
        data_in = data[i]
        target_in = target[i]
        if use_cuda:
            data_in = data_in.cuda()
            target_in = target_in.cuda()
        d = network(data_in, 'central_brain')
        if d != -1:
            if d == target_in:
                perf[0]+=1
            else:
                perf[1]+=1
        else:
            perf[2]+=1
    return perf/len(data)

In [31]:
# ??????????????????
class Intensity2LatencyTransform:
    def __init__(self, filter, timesteps = 15):
        self.to_tensor = transforms.ToTensor()
        self.filter = filter
        self.temporal_transform = utils.Intensity2Latency(timesteps)
        self.cnt = 0
    def __call__(self, image):
        if self.cnt % 1000 == 0:
            print(self.cnt)
        self.cnt+=1
        image = self.to_tensor(image) * 255
        image.unsqueeze_(0)
        image = self.filter(image)
        image = sf.local_normalization(image, 8)
        temporal_image = self.temporal_transform(image)
        return temporal_image.sign().byte()


In [1]:
kernels = [ utils.DoGKernel(3,3/9,6/9),
            utils.DoGKernel(3,6/9,3/9),
            utils.DoGKernel(7,7/9,14/9),
            utils.DoGKernel(7,14/9,7/9),
            utils.DoGKernel(13,13/9,26/9),
            utils.DoGKernel(13,26/9,13/9)]
filter = utils.Filter(kernels, padding = 6, thresholds = 50)
pre_input = Intensity2LatencyTransform(filter)

#data_root = "data"
#MNIST_train = utils.CacheDataset(torchvision.datasets.MNIST(root=data_root, train=True, download=True, transform = s1c1))
#MNIST_test = utils.CacheDataset(torchvision.datasets.MNIST(root=data_root, train=False, download=True, transform = s1c1))
#MNIST_loader = DataLoader(MNIST_train, batch_size=1000, shuffle=False)
#MNIST_testLoader = DataLoader(MNIST_test, batch_size=len(MNIST_test), shuffle=False)

flyeye = FlyEye()
if use_cuda:
    flyeye.cuda()

# Training R7
print("Training R7")
if os.path.isfile("saved_R7.net"):
    mozafari.load_state_dict(torch.load("saved_R7.net"))
else:
    for epoch in range(2):
        print("Epoch", epoch)
        iter = 0
        # ?????????????????????????????????
        for data,targets in MNIST_loader:
            print("Iteration", iter)
            train_unsupervise(flyeye, data, 'R7')
            print("Done!")
            iter+=1
    torch.save(flyeye.state_dict(), "saved_R7.net")
    

# Training R8
# Training L1
# Training L2
# Training L3
# ...........
# Training LC17

# initial adaptive learning rates

apr = mozafari.stdp_central_brain.learning_rate[0][0].item()
anr = mozafari.stdp_central_brain.learning_rate[0][1].item()
app = mozafari.anti_stdp_central_brain.learning_rate[0][1].item()
anp = mozafari.anti_stdp_central_brain.learning_rate[0][0].item()

adaptive_min = 0
adaptive_int = 1
apr_adapt = ((1.0 - 1.0 / 10) * adaptive_int + adaptive_min) * apr
anr_adapt = ((1.0 - 1.0 / 10) * adaptive_int + adaptive_min) * anr
app_adapt = ((1.0 / 10) * adaptive_int + adaptive_min) * app
anp_adapt = ((1.0 / 10) * adaptive_int + adaptive_min) * anp

# perf
best_train = np.array([0.0,0.0,0.0,0.0]) # correct, wrong, silence, epoch
best_test = np.array([0.0,0.0,0.0,0.0]) # correct, wrong, silence, epoch

# Training Central Brain
print("Training Central Brain")
for epoch in range(680):
    print("Epoch #:", epoch)
    perf_train = np.array([0.0,0.0,0.0])
    # ????????????????????????????????????????????????
    for data,targets in MNIST_loader:
        perf_train_batch = train_rl(flyeye, data, targets)
        print(perf_train_batch)
        #update adaptive learning rates
        apr_adapt = apr * (perf_train_batch[1] * adaptive_int + adaptive_min)
        anr_adapt = anr * (perf_train_batch[1] * adaptive_int + adaptive_min)
        app_adapt = app * (perf_train_batch[0] * adaptive_int + adaptive_min)
        anp_adapt = anp * (perf_train_batch[0] * adaptive_int + adaptive_min)
        flyeye.update_learning_rates(apr_adapt, anr_adapt, app_adapt, anp_adapt)
        perf_train += perf_train_batch
    # ??????????????
    perf_train /= len(MNIST_loader)
    if best_train[0] <= perf_train[0]:
        best_train = np.append(perf_train, epoch)
    print("Current Train:", perf_train)
    print("   Best Train:", best_train)

    # ????????????????????
    for data,targets in MNIST_testLoader:
        perf_test = test(flyeye, data, targets)
        if best_test[0] <= perf_test[0]:
            best_test = np.append(perf_test, epoch)
            torch.save(flyeye.state_dict(), "saved.net")
        print(" Current Test:", perf_test)
        print("    Best Test:", best_test)

NameError: name 'utils' is not defined