In [1]:
import GenEvents as ge
import PlotEvents as pe
import EventData as ed
import Params as ps

In [2]:
import numpy as np
import tqdm 
import math

import snntorch as snn
import torch
import torch.nn as nn

import snntorch.spikeplot as splt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from snntorch import surrogate
import snntorch.functional as SF

from snntorch import spikegen

import matplotlib.pyplot as plt
import snntorch.spikeplot as splt

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [3]:
num_events = 30000
num_events_test = 8000

events_arr, muon_list, max_n = ge.generate_noisy_evts(num_events, noise_frac=0.4, bkg_frac=0.5)
events_arr_test, muon_list_test, max_n_test = ge.generate_noisy_evts(num_events_test, noise_frac=0.4, bkg_frac=0.5)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30000/30000 [00:29<00:00, 1028.03it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 8000/8000 [00:07<00:00, 1122.83it/s]


In [52]:
def target_train(muon_list):
    trueTarget = []
    for event in muon_list:
        trueEvent = []
        for hit in event:
            if hit["signal"] == True:
                trueEvent.append(hit)
    
        trueTarget.append(trueEvent)   
        
    return trueTarget

target = get_target_train(muon_list)
target_test = get_target_train(muon_list_test)

In [81]:
class CustomDataset(Dataset):
    def __init__(self, input_data, target, transform=None):
        self.input_data = input_data
        self.target = target
        
        self.transform = transform

    def __len__(self):
        return len(self.input_data)

    def __getitem__(self, idx):
        
        input_sample = self.input_data[idx]
        target_sample = self.target[idx]
        
        if self.transform:
            input_sample = self.transform(input_sample)
            target_sample = torch.tensor(target_sample,dtype=torch.float32)
            
        return input_sample, target_sample


# Transformations

# transform each event (muon_hits list) to np.array of fixed size
def convert_to(muon_hits, size, target_dtype=np.int16,
               features=['layer', 'wire_num', 'bx', 't0']):
    padded_array = np.zeros(shape=(size, len(features)), dtype=target_dtype)
    for i, hit in enumerate(muon_hits):
        for j, f in enumerate(features):
            padded_array[i,j] = hit[f]  # BEWARE: implicit type conversions going on here

    return padded_array

# converts to Torch tensor of desired type
def to_tensor_and_dtype(input_variable, target_dtype=torch.float32):
    
    # Convert to PyTorch tensor
    tensor = torch.tensor(input_variable)
    # Force the tensor to have the specified dtype
    tensor = tensor.to(target_dtype)
    
    return tensor

feature_list = ['layer', 'wire_num', 'bx', 't0']
transform = transforms.Compose([
    lambda x: convert_to(x, size=max_n, target_dtype=np.float32,
                         features=feature_list),
    lambda x: to_tensor_and_dtype(x, target_dtype=torch.float32)
])

train_dataset = CustomDataset(muon_list, target, transform=transform)
test_dataset = CustomDataset(muon_list_test, target_test, transform=transform)

In [82]:
def custom_spikegen(data_it, num_step=40):
    spike_data = torch.zeros(size=(num_step,batch_size,ps.NLAYERS,ps.NWIRES), dtype=data_it.dtype)
    for n_batch, evt in enumerate(data_it):
        for hit in evt:
            layer = int(hit[0])
            wire = int(hit[1])
            bx = int(hit[2])
            t0 = math.floor(hit[3])
            if bx != 0:
                spike_data[bx-t0+ps.bx_oot, n_batch, layer-1, wire-1] = 1
            else:
                break
                
    return spike_data  

def get_target_train(target,num_step=40,plain=True):

    targets=[]
    for event in target:
        
        if plain:
            spike_data = np.zeros(num_step)
        else:
            spike_data=np.zeros((num_step,ps.NLAYERS,ps.NWIRES))
        
        if event:
            
            startTime = min(500,min([hit["bx"] for hit in event]))
                        
            for hit in event:
                bx = hit["bx"]
                time = bx-startTime
                if plain:
                    spike_data[time] = 1
                
                else:

                    layer = hit["layer"]
                    wire = hit["wire_num"]
        
                    spike_data[time, layer-1, wire-1] = 1
        
        targets.append(spike_data)
    return targets

In [83]:
# Network Architecture
num_inputs = ps.NLAYERS*ps.NWIRES
num_hidden = 100
num_outputs = 1#num_inputs

# Temporal Dynamics
num_steps = 40
beta = 0.8

In [84]:
class Net(nn.Module):
    """Simple spiking neural network in snntorch."""

    def __init__(self, input_feat, hidden,out_feat,timesteps):
        super().__init__()
        
        self.input_feat = input_feat # number of input neurons 
        self.hidden = hidden # number of hidden neurons
        self.out_feat = out_feat # number of output neurons
        
        self.timesteps = timesteps # number of time steps to simulate the network
        spike_grad = surrogate.fast_sigmoid() # surrogate gradient function
        
        self.fc_in = nn.Linear(in_features=self.input_feat, out_features=self.hidden)
        self.lif_in = snn.Leaky(beta=beta,spike_grad=spike_grad,threshold=0.9)
        
        self.fc_out = nn.Linear(in_features=self.hidden, out_features=self.out_feat)
        self.lif_out = snn.Leaky(beta=beta,spike_grad=spike_grad,threshold=0.9)
    
    def forward(self, x):
        """Forward pass for several time steps."""

        # Initalize membrane potential
        mem1 = self.lif_in.init_leaky()
        mem2 = self.lif_out.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        # Loop over 
        for step in range(self.timesteps):
                
            cur1 = self.fc_in(x[step])
            spk1, mem1 = self.lif_in(cur1, mem1)
            cur2 = self.fc_out(spk1)
            spk2, mem2 = self.lif_out(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)
    
net = Net(num_inputs, num_hidden, num_outputs, num_steps).to(device)

In [85]:
def print_batch_accuracy(net, data, targets, batch_size, train=False):
    output, _ = net(data)
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())
    
    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer(net, data, targets, batch_size, epoch, counter, iter_counter,
                  loss_hist, test_loss_hist, test_data, test_targets):
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(net, data, targets, batch_size, train=True)
    print_batch_accuracy(net, test_data, test_targets, batch_size, train=False)
    print("\n")

In [86]:
batch_size = 60
nw=0

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw)

In [94]:
num_epochs = 5
loss_hist = []
iter_counter = 0

#loss = nn.MSELoss()
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))


# Outer training loop
for epoch in range(num_epochs):
    batch_counter = 0
    # Minibatch training loop
    for data_it, targets_it in train_loader:
        data_it = data_it.to(device)
        targets_it = targets_it.to(device)

        # create spike train
        spike_in = custom_spikegen(data_it, num_steps)
        spike_in = spike_in.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(spike_in.view(num_steps, batch_size, -1))

        loss_val = torch.tensor(0.)
        # initialize the loss & sum over time
        for i in range(batch_size):
        
            loss_val += loss_fn(spk_rec[:,i], targets_it[i,:])
        
        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())
        
        if iter_counter % 50 == 0:
            print("Epoch:",epoch)
            print("Batch:",batch_counter)
            print("Iteration:",iter_counter)
            print("Loss:",loss_val.item(),"\n")
        
        batch_counter += 1
        iter_counter += 1

Epoch: 0
Batch: 0
Iteration: 0
Loss: 2.6000003814697266 

Epoch: 0
Batch: 50
Iteration: 50
Loss: 2.9512503147125244 

Epoch: 0
Batch: 100
Iteration: 100
Loss: 2.7400004863739014 

Epoch: 0
Batch: 150
Iteration: 150
Loss: 3.2400004863739014 

Epoch: 0
Batch: 200
Iteration: 200
Loss: 3.9225010871887207 

Epoch: 0
Batch: 250
Iteration: 250
Loss: 3.9700005054473877 

Epoch: 0
Batch: 300
Iteration: 300
Loss: 4.2962493896484375 

Epoch: 0
Batch: 350
Iteration: 350
Loss: 3.923750400543213 

Epoch: 0
Batch: 400
Iteration: 400
Loss: 3.081249713897705 

Epoch: 0
Batch: 450
Iteration: 450
Loss: 2.9337499141693115 

Epoch: 1
Batch: 0
Iteration: 500
Loss: 3.970001220703125 

Epoch: 1
Batch: 50
Iteration: 550
Loss: 3.2637503147125244 

Epoch: 1
Batch: 100
Iteration: 600
Loss: 3.3187499046325684 

Epoch: 1
Batch: 150
Iteration: 650
Loss: 4.159999847412109 

Epoch: 1
Batch: 200
Iteration: 700
Loss: 2.4150004386901855 

Epoch: 1
Batch: 250
Iteration: 750
Loss: 3.976250171661377 

Epoch: 1
Batch: 300
It