<a href="https://colab.research.google.com/github/enlupi/SNN-MUC/blob/main/Scripts/DTFastSim.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Data Simulation


In [9]:
import GenEvents as GE
import PlotEvents as PE
import EventData as ED
from Params import *

In [2]:
import numpy as np
import tqdm 

import snntorch as snn
import torch
import torch.nn as nn

import snntorch.spikeplot as splt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from snntorch import surrogate

from snntorch import spikegen

import matplotlib.pyplot as plt
import snntorch.spikeplot as splt

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# SNN

### Create datasets - preprocessing

In [3]:
num_events = 30000
num_events_test = 5000

events_arr, muon_list, max_n = GE.generate_noisy_evts(num_events, noise_frac=1, bkg_frac=0.5)
events_arr_test, muon_list_test, max_n_test = GE.generate_noisy_evts(num_events_test, noise_frac=1, bkg_frac=0.5)

100%|██████████| 30000/30000 [00:49<00:00, 603.53it/s]
100%|██████████| 5000/5000 [00:09<00:00, 535.80it/s]


In [28]:
def gen_target(muon_list):
    target = np.zeros(shape=(len(muon_list), NLAYERS, NWIRES), dtype=np.int16)
    for i, evt in tqdm.tqdm(enumerate(muon_list)):
        for hit in evt:
            layer, wire = hit['layer']-1, hit['wire_num']-1

            if hit['signal'] == True:
                target[i, layer, wire] = 1

    return torch.tensor(target, dtype=torch.int16)

In [29]:
target = gen_target(muon_list)
target_test = gen_target(muon_list_test)

30000it [00:00, 71569.25it/s]
5000it [00:00, 81952.97it/s]


In [36]:
muon_list[4]

[{'layer': 3,
  'wire_num': 2,
  'bx': 516,
  'tdc': 27,
  'label': 0,
  't0': 500.3666666666667,
  'psi': 0.3154152764339655,
  'x0': 4.176478920061371,
  'signal': False},
 {'layer': 3,
  'wire_num': 2,
  'bx': 504,
  'tdc': 27,
  'label': 1,
  't0': 500.3666666666667,
  'psi': 0.3154152764339655,
  'x0': 4.176478920061371,
  'signal': True},
 {'layer': 3,
  'wire_num': 3,
  'bx': 493,
  'tdc': 21,
  'label': 0,
  't0': 500.3666666666667,
  'psi': 0.3154152764339655,
  'x0': 4.176478920061371,
  'signal': False},
 {'layer': 2,
  'wire_num': 3,
  'bx': 514,
  'tdc': 9,
  'label': -1,
  't0': 500.3666666666667,
  'psi': 0.3154152764339655,
  'x0': 4.176478920061371,
  'signal': True},
 {'layer': 1,
  'wire_num': 2,
  'bx': 502,
  'tdc': 0,
  'label': -1,
  't0': 500.3666666666667,
  'psi': 0.3154152764339655,
  'x0': 4.176478920061371,
  'signal': True},
 {'layer': 2,
  'wire_num': 1,
  'bx': 517,
  'tdc': 13,
  'label': 0,
  't0': 500.3666666666667,
  'psi': 0.3154152764339655,
  'x0'

In [45]:
class CustomDataset(Dataset):
    def __init__(self, input_data, target, transform=None):
        self.input_data = input_data
        self.target = target
        
        self.transform = transform

    def __len__(self):
        return len(self.input_data)

    def __getitem__(self, idx):
        
        input_sample = self.input_data[idx]
        target_sample = self.target[idx]
        
        if self.transform:
            input_sample = self.transform(input_sample)
            
        return input_sample, target_sample
    
def pad(muon_hits, max_n_hit):
    padded_array = np.zeros(max_n_hit, dtype=ED.hit_dtype)
    for i, hit in enumerate(muon_hits):
        padded_array[i]['bx']     = hit['bx']
        padded_array[i]['tdc']    = hit['tdc']
        padded_array[i]['label']  = hit['label']
        padded_array[i]['signal'] = hit['signal']

    return padded_array

def convert_to(muon_hits, size, target_dtype=np.int16,
               features=['layer', 'wire_num', 'bx', 'tdc', 'signal']):
    padded_array = np.zeros(shape=(size, len(features)), dtype=target_dtype)
    for i, hit in enumerate(muon_hits):
        for j, f in enumerate(features):
            padded_array[i,j] = hit[f]

    return padded_array



def to_tensor_and_dtype(input_variable, target_dtype=torch.float32):
    
    # Convert to PyTorch tensor
    tensor = torch.tensor(input_variable)
    # Force the tensor to have the specified dtype
    tensor = tensor.to(target_dtype)
    
    return tensor

# Define a transform for 4x4 matrices
pad_transform = transforms.Compose([
    lambda x: convert_to(x, size=max_n, target_dtype=np.int16,
                         features=['layer', 'wire_num', 'bx', 'tdc', 'signal'])
])

In [46]:
train_dataset = CustomDataset(muon_list, target, transform=pad_transform)
test_dataset = CustomDataset(muon_list_test, target_test, transform=pad_transform)

In [47]:
batch_size = 60
nw=0

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw)

In [48]:
data = iter(train_loader)
data_it, targets_it = next(data)

print(type(data))

# Spiking Data
#spike_data = spikegen.rate(data_it, num_steps=num_steps)

<class 'torch.utils.data.dataloader._SingleProcessDataLoaderIter'>


In [None]:
def event_to_matrix(event,rec=True):
    tmp_matrix = event['tdc']*(25/30)+event['bx']*25
    tmp_matrix[tmp_matrix <0] = 0
    
    min_value = np.min(tmp_matrix)
    max_value = np.max(tmp_matrix)
    
    tmp_matrix = (tmp_matrix - min_value) / (max_value - min_value)
    
    if rec:
        tmp_matrix[tmp_matrix > 0] = 1/tmp_matrix[tmp_matrix > 0]
        
        min_value = np.min(tmp_matrix)
        max_value = np.max(tmp_matrix)

        tmp_matrix = (tmp_matrix - min_value) / (max_value - min_value)
        
    return tmp_matrix

In [None]:
def print_batch_accuracy(net, data, targets, batch_size, train=False):
    output, _ = net(data.view(batch_size, -1))
    _, idx = output.sum(dim=0).max(1)
    acc = np.mean((targets == idx).detach().cpu().numpy())

    if train:
        print(f"Train set accuracy for a single minibatch: {acc*100:.2f}%")
    else:
        print(f"Test set accuracy for a single minibatch: {acc*100:.2f}%")

def train_printer(net, data, targets, batch_size, epoch,counter, iter_counter,
                  loss_hist, test_loss_hist, test_data, test_targets):
    print(f"Epoch {epoch}, Iteration {iter_counter}")
    print(f"Train Set Loss: {loss_hist[counter]:.2f}")
    print(f"Test Set Loss: {test_loss_hist[counter]:.2f}")
    print_batch_accuracy(net, data, targets, batch_size, train=True)
    print_batch_accuracy(net, test_data, test_targets, batch_size, train=False)
    print("\n")

In [None]:
#train set
input_data_r=np.zeros((num_events,NLAYERS,NWIRES))
input_data_t=np.zeros((num_events,NLAYERS,NWIRES))

for i in range(num_events):
    input_data_r[i]=event_to_matrix(events_arr['mc'][i],True)
    input_data_t[i]=event_to_matrix(events_arr['mc'][i],False)
    
target_data = np.ones(num_events)
for i in range(num_events):
    if events_arr[i]['n_true_hits'] == 0:
        target_data[i] = 0
        
#test set
test_data_r=np.zeros((num_events_test,NLAYERS,NWIRES))
test_data_t=np.zeros((num_events_test,NLAYERS,NWIRES))

for i in range(num_events_test):
    test_data_r[i]=event_to_matrix(events_arr_test['mc'][i],True)
    test_data_t[i]=event_to_matrix(events_arr['mc'][i],False)
    
target_test = np.ones(num_events_test)
for i in range(num_events_test):
    if events_arr_test[i]['n_true_hits'] == 0:
        target_test[i] = 0

### SNN architecture

In [None]:
# Network Architecture
num_inputs = NLAYERS*NWIRES
num_hidden = 100
num_outputs = 2

# Temporal Dynamics
num_steps = 25
beta = 0.7

batch_size = 60

In [None]:
class Net(nn.Module):
    """Simple spiking neural network in snntorch."""

    def __init__(self, input_feat, hidden,out_feat,timesteps):
        super().__init__()
        
        self.input_feat = input_feat # number of input neurons 
        self.hidden = hidden # number of hidden neurons
        self.out_feat = out_feat # number of output neurons
        
        self.timesteps = timesteps # number of time steps to simulate the network
        spike_grad = surrogate.fast_sigmoid() # surrogate gradient function
        
        self.fc_in = nn.Linear(in_features=self.input_feat, out_features=self.hidden)
        self.lif_in = snn.Leaky(beta=beta,spike_grad=spike_grad)
        
        self.fc_out = nn.Linear(in_features=self.hidden, out_features=self.out_feat)
        self.lif_out = snn.Leaky(beta=beta,spike_grad=spike_grad)
        
    def forward(self, x):
        """Forward pass for several time steps."""

        # Initalize membrane potential
        mem1 = self.lif_in.init_leaky()
        mem2 = self.lif_out.init_leaky()

        # Record the final layer
        spk2_rec = []
        mem2_rec = []

        # Loop over 
        for step in range(self.timesteps):
                
            cur1 = self.fc_in(x)
            spk1, mem1 = self.lif_in(cur1, mem1)
            cur2 = self.fc_out(spk1)
            spk2, mem2 = self.lif_out(cur2, mem2)
            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        return torch.stack(spk2_rec, dim=0), torch.stack(mem2_rec, dim=0)
    
net = Net(NLAYERS*NWIRES,20,2,25).to(device)

### SNN training

In [None]:
num_epochs = 1
loss_hist = []
test_loss_hist = []
counter = 0

loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=5e-4, betas=(0.9, 0.999))

train_loader_rate = DataLoader(train_rate, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader_rate = DataLoader(test_rate, batch_size=batch_size, shuffle=False, num_workers=0)

# Outer training loop
for epoch in range(num_epochs):
    iter_counter = 0
    train_batch = iter(train_loader_rate)

    # Minibatch training loop
    for data, targets in train_batch:
        data = data.to(device)
        targets = targets.to(device)

        # forward pass
        net.train()
        spk_rec, mem_rec = net(data.view(batch_size, -1))

        # initialize the loss & sum over time
        loss_val = torch.zeros((1), dtype=torch.float, device=device)
        for step in range(num_steps):
            loss_val += loss(mem_rec[step], targets)

        # Gradient calculation + weight update
        optimizer.zero_grad()
        loss_val.backward()
        optimizer.step()

        # Store loss history for future plotting
        loss_hist.append(loss_val.item())

        # Test set
        with torch.no_grad():
            net.eval()
            test_data, test_targets = next(iter(test_loader_rate))
            test_data = test_data.to(device)
            test_targets = test_targets.to(device)

            # Test set forward pass
            test_spk, test_mem = net(test_data.view(batch_size, -1))

            # Test set loss
            test_loss = torch.zeros((1), dtype=torch.float, device=device)
            for step in range(num_steps):
                test_loss += loss(test_mem[step], test_targets)
            test_loss_hist.append(test_loss.item())

            # Print train/test loss/accuracy
            if counter % 50 == 0:
                train_printer(
                    net, data, targets, batch_size, epoch,
                    counter, iter_counter,
                    loss_hist, test_loss_hist,
                    test_data, test_targets)
            counter += 1
            iter_counter +=1

In [None]:
# Plot Loss
fig = plt.figure(facecolor="w", figsize=(10, 5))
plt.plot(loss_hist)
plt.plot(test_loss_hist)
plt.title("Loss Curves")
plt.legend(["Train Loss", "Test Loss"])
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.show()

In [None]:
total = 0
correct = 0

# drop_last switched to False to keep all samples
test_loader = DataLoader(test_rate, batch_size=batch_size, shuffle=True, drop_last=False)

with torch.no_grad():
    net.eval()
    
    for data, targets in test_loader:
        data = data.to(device)
        targets = targets.to(device)
    
        # forward pass
        test_spk, _ = net(data.view(data.size(0), -1))

        # calculate total accuracy
        _, predicted = test_spk.sum(dim=0).max(1)
        total += targets.size(0)
        correct += (predicted == targets).sum().item()

print(f"Total correctly classified test set images: {correct}/{total}")
print(f"Test Set Accuracy: {100 * correct / total:.2f}%")