In [1]:
# import event simulation files
import GenEvents as ge
import PlotEvents as pe
import EventData as ed
import Params as ps

In [2]:
import numpy as np
import tqdm 
import math
import matplotlib.pyplot as plt
import csv

import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

import snntorch as snn
from snntorch import spikegen
import snntorch.spikeplot as splt
from snntorch import surrogate
import snntorch.functional as SF

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print('Selected device is',device)

Selected device is cpu


  return torch._C._cuda_getDeviceCount() > 0


In [3]:
n_evt = 100000
n_evt_test = 5000
n_evt_val = 10000

evt_arr,      muon_list,      max_n      = ge.generate_noisy_evts(n_evt,      noise_frac=0.5, bkg_frac=0.3,ineff=False)
evt_arr_test, muon_list_test, max_n_test = ge.generate_noisy_evts(n_evt_test, noise_frac=0.5, bkg_frac=0.3,ineff=False)
evt_arr_val,  muon_list_val,  max_n_val  = ge.generate_noisy_evts(n_evt_val,  noise_frac=0.5, bkg_frac=0.3,ineff=False)

100%|█████████████████████████████████| 100000/100000 [01:26<00:00, 1159.07it/s]
100%|█████████████████████████████████████| 5000/5000 [00:04<00:00, 1166.30it/s]
100%|███████████████████████████████████| 10000/10000 [00:08<00:00, 1200.53it/s]


In [4]:
# transform each event (muon_hits list) to np.array of fixed size
def convert_to(muon_list, size, target_dtype, features):
    padded_array = np.zeros(shape=(len(muon_list), size, len(features)), dtype=target_dtype)
    for i, muon_hits in enumerate(muon_list):
        for j, hit in enumerate(muon_hits):
            for k, f in enumerate(features):
                padded_array[i,j,k] = hit[f]  # BEWARE: implicit type conversions going on here

    return torch.tensor(padded_array)

feature_list = ['layer', 'wire_num', 'bx', 't0', 'signal']
mu_arr      = convert_to(muon_list,      size=max_n,      target_dtype=np.float32, features=feature_list)
mu_arr_test = convert_to(muon_list_test, size=max_n_test, target_dtype=np.float32, features=feature_list)
mu_arr_val  = convert_to(muon_list_val,  size=max_n_val,  target_dtype=np.float32, features=feature_list)

In [5]:
def custom_spikegen(data, num_step=40, batch_size=100, linearise=True):

    spike_data = torch.zeros(size=(num_step, batch_size, ps.NLAYERS,ps.NWIRES), dtype=data.dtype)
    for n_batch, evt in enumerate(data):
        for hit in evt:
            layer = int(hit[0])
            wire = int(hit[1])
            bx = int(hit[2])
            t0 = math.floor(hit[3])
            if bx != 0:
                spike_data[bx-t0+ps.bx_oot, n_batch, layer-1, wire-1] = 1
            else:
                break
    if linearise:
        spike_data = spike_data.view(num_step, batch_size, -1)
                
    return spike_data   

In [6]:
def gen_target_chamber(muon_arr,num_step=40,numMuons=4):
    target = torch.zeros(size=(len(muon_arr),num_step), dtype=muon_arr.dtype)
    for i, evt in tqdm.tqdm(enumerate(muon_arr)):
        true_hits=[]
        
        notZeroBx = evt[evt[:,2] > 0][:,2].numpy()
        
        if len(notZeroBx) == 0:
            break
        
        startTime = min(500,min(notZeroBx))
        
        for hit in evt:
            
            if hit[4] == 1.:
                true_hits.append(hit[2])
        true_hits.sort()
        if len(true_hits) < numMuons:
            continue
        
        muon_bx = true_hits[numMuons-1]
        
        target[i,int(muon_bx-startTime)] = 1
    return target

In [7]:
class CustomDataset(Dataset):
    def __init__(self, input_data, target, transform=None):
        self.data = list(zip(input_data, target))
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        
        sample = self.data[idx]
        
        if self.transform:
            sample = self.transform(sample)
            
        return sample


# Transformations

# converts to Torch tensor of desired type
def to_tensor_and_dtype(input_variable, target_dtype=torch.float32):
    
    # Convert to PyTorch tensor
    if type(input_variable) != torch.Tensor:
        input_variable = torch.tensor(input_variable)
    # Force the tensor to have the specified dtype
    tensor = input_variable.to(target_dtype)
    
    return tensor


transform = transforms.Compose([
    lambda x: (to_tensor_and_dtype(x[0], target_dtype=torch.float32), x[1])
])

In [8]:
target      = gen_target_chamber(mu_arr,num_step=40,numMuons=4)
target_test = gen_target_chamber(mu_arr_test,num_step=40,numMuons=4)
target_val  = gen_target_chamber(mu_arr_val,num_step=40,numMuons=4)

train_dataset = CustomDataset(mu_arr,      target,      transform=transform)
test_dataset  = CustomDataset(mu_arr_test, target_test, transform=transform)
val_dataset   = CustomDataset(mu_arr_val,  target_val,  transform=transform)

batch_size = 100
nw=0

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw)
test_loader  = DataLoader(test_dataset,  batch_size=batch_size, shuffle=False, num_workers=nw)
val_loader   = DataLoader(val_dataset,   batch_size=batch_size, shuffle=False, num_workers=nw)

100000it [00:29, 3410.97it/s]
5000it [00:01, 3360.72it/s]
10000it [00:02, 3414.60it/s]


In [9]:
num_inputs = ps.NLAYERS*ps.NWIRES
num_hidden = 100
num_outputs_chamber = 2

num_steps = 40
window = 16
beta = 0.8
alpha =0.8
threshold=0.9

In [10]:
class Net1layer2ndOrder(nn.Module):
    """Simple spiking neural network in snntorch."""

    def __init__(self, input_feat, hidden,out_feat,window,learnable):
        super().__init__()
        
        self.input_feat = input_feat # number of input neurons 
        self.hidden = hidden # number of hidden neurons
        self.out_feat = out_feat # number of output neurons
        
        self.window = window # number of time steps to simulate the network
        spike_grad = surrogate.fast_sigmoid() # surrogate gradient function
        
        self.fc_in = nn.Linear(in_features=self.input_feat, out_features=self.hidden)
        self.lif_in = snn.Synaptic(beta=beta,alpha=alpha,threshold=threshold,
                                   spike_grad=spike_grad,
                                   learn_beta=learnable,learn_threshold=learnable,learn_alpha=learnable)
        
        self.fc_out = nn.Linear(in_features=self.hidden, out_features=self.out_feat)
        self.lif_out = snn.Synaptic(beta=beta,alpha=alpha,threshold=threshold,
                                   spike_grad=spike_grad,
                                   learn_beta=learnable,learn_threshold=learnable,learn_alpha=learnable)
    
    def forward(self, x):
        """Forward pass for several time steps."""
        
        # Initalize membrane potential
        syn1, mem1 = self.lif_in.init_synaptic()
        syn2, mem2 = self.lif_out.init_synaptic()
        
        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        # Loop over 
        for step in range(self.window):

            cur1 = self.fc_in(x[step])
            spk1, syn1, mem1 = self.lif_in(cur1, syn1, mem1)
            
            cur2 = self.fc_out(spk1)
            spk2, syn2, mem2 = self.lif_out(cur2, syn2, mem2)
            
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)

In [11]:
def comp_accuracy(output, targets, type):
    total = 0
    correct = 0

    # whole chamber classification
    if type == 'chamber':
        _, predicted = output.sum(dim=0).max(1)
        total = targets.size(0)
        correct = (predicted == targets).sum().item()

    # cell-by-cell classification
    elif type == 'cell':
        batch_size = targets.size(0)
        _, predicted = output.sum(dim=0).view(batch_size, ps.NLAYERS*ps.NWIRES, -1).max(2)
        total = targets.numel()
        correct = (predicted == targets).sum().item()

    # accuracy on the whole event for cell-by-cell classification
    elif type == 'cell_per_evt':
        batch_size = targets.size(0)
        total = batch_size
        _, predicted = output.sum(dim=0).view(batch_size, ps.NLAYERS*ps.NWIRES, -1).max(2)
        for i in range(batch_size):
            correct += torch.sum((predicted[i] == targets[i]).all())

    # accuracy cell-by-cell moment-bymoment for precise timing recontruction
    elif type == 'mse_timing':
        num_steps = targets.size(0)
        batch_size = targets.size(1)
        total = targets.numel()
        correct = (output == targets).sum().item()

    # accuracy on the whole event for timing reconstruction
    elif type == 'mse_timing_per_evt':
        num_steps = targets.size(0)
        batch_size = targets.size(1)
        total = batch_size
        for i in range(batch_size):
            correct += torch.sum((output[:, i] == targets[:, i]).all())

    elif type == "ce_timing":
        num_steps = targets.size(0)
        batch_size = targets.size(1)
        total = targets.numel()
        _, predicted = output.view(num_steps, batch_size, ps.NLAYERS*ps.NWIRES, -1).max(3)
        correct = (predicted == targets).sum().item()

    elif type == "bce_timing":
        total = targets.numel()
        sigmoid = nn.Sigmoid()
        predicted = sigmoid(output)
        predicted[predicted <  0.5] = 0
        predicted[predicted >= 0.5] = 1
        correct = (predicted == targets).sum().item()


    return total, correct

In [12]:
def accuracy_set(net, data_loader, loss_fn, accuracy_type, batch_size, linearise,num_steps=40):
    net.eval()
    with torch.no_grad():
        total = 0
        correct = 0
        temp_loss = []
        for data, targets in data_loader:
            data = data.to(device)
            targets = targets.to(device)

            # create spike train
            spike_in = custom_spikegen(data, num_steps, batch_size, linearise)
            spike_in = spike_in.to(device)
            
            padding = (0,0,0,0,15, 15)
            padded_input = torch.nn.functional.pad(spike_in, padding).to(device)
            
            loss_val = torch.tensor(0.)
            for i in range(num_steps):
                window_input = padded_input[i:window+i]
                
                # forward pass
                spk_rec, mem_rec = net(window_input)
                output = spk_rec
                
                # compute loss
                loss_val += loss_fn(output, targets[:,i].type(torch.long))
                
                # calculate total accuracy
                tot, corr = comp_accuracy(output, targets[:,i], accuracy_type)
                total += tot
                correct += corr
                
            temp_loss.append(loss_val.item())

            

        mean_loss = np.mean(temp_loss)
        acc = correct/total
        return mean_loss, acc

In [13]:
def train_net(net, train_loader, val_loader, num_epochs, loss_fn, optimizer, accuracy_type, 
              batch_size, num_steps = 40,linearise=True):

    net.to(device)
    
    loss_hist = []
    loss_val_hist = []
    acc_val_hist = []

    iter_counter = 0

    for epoch in range(num_epochs):
        net.train()
        batch_counter = 0
        # Minibatch training loop
        for data, targets in train_loader:
            data = data.to(device)
            targets = targets.to(device)

            # create spike train
            spike_in = custom_spikegen(data, num_steps, batch_size, linearise)
            spike_in = spike_in.to(device)
            
            padding = (0,0,0,0,15, 15)
            padded_input = nn.functional.pad(spike_in, padding).to(device)
            
            loss_val = torch.tensor(0.)
            for i in range(num_steps):
                window_input = padded_input[i:window+i]
                
                # forward pass
                spk_rec, mem_rec = net(window_input)
                output = spk_rec
                
                # compute loss
                loss_val += loss_fn(output, targets[:,i].type(torch.long))
        
            # Gradient calculation + weight update
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            # Store loss history for future plotting
            loss_hist.append(loss_val.item())
            if iter_counter % 50 == 0:
                print("Epoch:", epoch)
                print("Batch:", batch_counter)
                print("Iteration:", iter_counter)
                print("Loss:", loss_val.item(),"\n")
        
            batch_counter += 1
            iter_counter += 1

        # Validation
        mean_loss_val, acc_val = accuracy_set(net, val_loader, loss_fn, accuracy_type,batch_size, linearise)

        loss_val_hist.append(mean_loss_val)
        acc_val_hist.append(acc_val)
        print(f"Validation Set Loss: {mean_loss_val}")
        print(f"Validation Set Accuracy: {100 *acc_val:.2f}%")
        print("\n--------------------------------------------------\n")

    return loss_hist, loss_val_hist, acc_val_hist

In [14]:
net = Net1layer2ndOrder(num_inputs, num_hidden, num_outputs_chamber, window,True).to(device)

loss_fn = SF.ce_count_loss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

num_epochs = 5
accuracy_type = 'chamber'

loss_hist, loss_val_hist, acc_val_hist = train_net(net,
                                                   train_loader, val_loader,
                                                   num_epochs, loss_fn, optimizer,
                                                   accuracy_type, batch_size)

Epoch: 0
Batch: 0
Iteration: 0
Loss: 59.376651763916016 

Epoch: 0
Batch: 50
Iteration: 50
Loss: 3.565741777420044 

Epoch: 0
Batch: 100
Iteration: 100
Loss: 3.013195037841797 

Epoch: 0
Batch: 150
Iteration: 150
Loss: 3.114096164703369 

Epoch: 0
Batch: 200
Iteration: 200
Loss: 2.5198585987091064 

Epoch: 0
Batch: 250
Iteration: 250
Loss: 2.8296031951904297 

Epoch: 0
Batch: 300
Iteration: 300
Loss: 3.3197968006134033 

Epoch: 0
Batch: 350
Iteration: 350
Loss: 2.7453818321228027 

Epoch: 0
Batch: 400
Iteration: 400
Loss: 2.8191676139831543 

Epoch: 0
Batch: 450
Iteration: 450
Loss: 2.7915916442871094 



KeyboardInterrupt: 