# NetworkPSA Analysis Using Recurrent Neural Network + Attention
- Run on Cori PyTorch GPU kernels or LEGEND Kernel
- You can also run it on Cori PyTorch CPU kernel, but the speed will be extremely slow.

In [3]:
import numpy as np
import os
import argparse
import time
import math
import random
import torch.nn as nn
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn import init
import torch.nn.functional as F
import torch
import torch.utils.data as data_utils
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import gzip
import pickle
import numpy as np
from torch.autograd import Variable
from scipy import sparse
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.preprocessing import StandardScaler
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
import torchsnooper

from tqdm import tqdm

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
'''
These 2 parameters define the range of waveform we'd like to analyze
'''
LSPAN = 100 # number of time samples prior to t0
HSPAN = 200 # number of time samples after t0

In [5]:
#This function gets the false positive rate, true positive rate, cutting threshold and area under curve using the given signal and background array
def get_roc(sig,bkg):
    testY = np.array([1]*len(sig) + [0]*len(bkg))
    predY = np.array(sig+bkg)
    auc = roc_auc_score(testY, predY)
    fpr, tpr, thr = roc_curve(testY, predY)
    return fpr,tpr,thr,auc

## Dataset Object
- Created by extracting Majorana data and save in `.pickle` format.
- The content of the pickle file can be found in `WaveformExtraction.ipynb`

In [6]:
class DetectorDataset(Dataset):

    def __init__(self, dep="DEP_P42575A_10percent.pickle", sep = "SEP_P42575A_10percent.pickle",dsize=-1):
        
        DEP_dict = self.event_loader(dep)
        SEP_dict = self.event_loader(sep)

        if dsize == -1:
            dsize = min(len(DEP_dict), len(SEP_dict))
        
        #Shuffle dataset and select #dsize event from DEP and SEP
        np.random.shuffle(DEP_dict)
        np.random.shuffle(SEP_dict)
        DEP_dict = DEP_dict[:dsize]
        SEP_dict = SEP_dict[:dsize]
        self.event_dict = DEP_dict + SEP_dict
        self.label = ([1]*len(DEP_dict)) + ([0] * len(SEP_dict))
        
        self.size = len(self.event_dict)
        print(self.size)
        
        #Get offset values:
        self.max_offset = np.max(self.get_field_from_dict(DEP_dict,"tstart") + self.get_field_from_dict(SEP_dict,"tstart"))
        
        #Get all unique detector name:
        self.detector_name = np.unique(self.get_field_from_dict(DEP_dict,"detector") + self.get_field_from_dict(SEP_dict,"detector"))
        
    def __len__(self):
        return self.size
    
    def build_scaler(self):
        '''
        '''
        wf_array = []
        for i in range(self.size):
            wf_array.append(self.get_wf(i).reshape(1,-1))
        wf_array = np.concatenate(wf_array,axis=0)
        scaler = StandardScaler()
        scaler.fit(wf_array)
        return scaler
    
    def get_scaler(self):
        return self.scaler
    
    def set_scaler(self,scaler):
        self.scaler = scaler
    
    def get_wf(self,idx):
        event = self.event_dict[idx]
        wf = np.array(event["wf"]).flatten()
        midindex = event["t0"]

        #baseline subtraction
        wf -= np.average(wf[:(midindex-50)])
        
        #Extract waveform from its t0
        wfbegin = midindex - LSPAN
        wfend = midindex + HSPAN
        
        wf = wf[wfbegin:wfend]
        wf = (wf - np.min(wf)) / (np.max(wf) - np.min(wf))#rescale wf to between 0 and 1

        
        return wf

    def __getitem__(self, idx):
        event = self.event_dict[idx]
        wf = np.array(event["wf"]).flatten()
        midindex = event["t0"]

        wf = self.get_wf(idx)
        
        avse = event["avse"]
        tdrift = event["tDrift"]
        
        return wf, self.label[idx], avse
        
    def return_label(self):
        return self.trainY

    def return_detector_array(self):
        return self.detector_name
    
    #Load event from .pickle file
    def event_loader(self, address):
        wf_list = []
        with (open(address, "rb")) as openfile:
            while True:
#                 if len(wf_list) > 2000:
#                     break
                try:
                   wf_list.append(pickle.load(openfile, encoding='latin1'))
                except EOFError:
                    break
        return wf_list
    
    def get_field_from_dict(self, input_dict, fieldname):
        field_list = []
        for event in input_dict:
            field_list.append(event[fieldname])
        return field_list
    
    def plot_offset_correction(self):
        plt.subplot(211)
        
        plt.subplot(212)
# next(iter(DetectorDataset()))

In [None]:
#Load dataset
def load_data(batch_size):

    dataset = DetectorDataset()
    test_dataset = DetectorDataset(dep="DEP_P42575A_Co56.pickle")
    validation_split = .3 #Split data set into training & testing with 7:3 ratio
    shuffle_dataset = True
    random_seed= 42222

    #make sure we have the same amount of signal/bkg in the training/test dataset
    division = 2
    dataset_size = int(len(dataset)/division)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    if shuffle_dataset :
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    train_indices += list(division*dataset_size - 1-np.array(train_indices))
    val_indices += list(division*dataset_size- 1-np.array(val_indices))

    np.random.shuffle(train_indices)
    np.random.shuffle(val_indices)
    
    test_dataset_size = int(len(test_dataset))
    test_indices = list(range(test_dataset_size))
    if shuffle_dataset :
        np.random.seed(random_seed)
        np.random.shuffle(test_indices)

    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(test_indices)

    train_loader = data_utils.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, drop_last=True)
    test_loader = data_utils.DataLoader(test_dataset, batch_size=batch_size,sampler=valid_sampler,  drop_last=True)

    return train_loader,test_loader, dataset.return_detector_array()

In [None]:
#The fully connected part of neural network
class FCNet(nn.Module):
    def __init__(self, first_unit, last_unit):
        super(FCNet, self).__init__()
        
        #Number of channels in each fully connected layers
        fc1, fc2, fc3, fc4 = (first_unit, int(first_unit*0.5), int(first_unit*0.25), int(first_unit*0.1))
        do = 0.2
        self.fcnet = nn.Sequential(
            torch.nn.Linear(fc1, fc2),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(do),
            torch.nn.Linear(fc2, fc3),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(do),
            torch.nn.Linear(fc3, fc4),
            torch.nn.LeakyReLU(),
            torch.nn.Dropout(do),
            torch.nn.Linear(fc4, last_unit),
        )
    def forward(self, x):
        return self.fcnet(x)

## The Recurrent Neural Network Model

In [7]:
#The RNN based model:
class RNN(nn.Module):
    def __init__(self,get_attention = False):
        super(RNN, self).__init__()
        
        bidirec = True    #Whether to use a bidirectional RNN
        self.bidirec =bidirec
        feed_in_dim = 512
        self.seg = 1      #Segment waveform to reduce its length. If the original waveform is (2000,1), then segment it with self.seg=5 can reduce its length to (400,5)
        self.emb_dim = 128
        self.emb_tick = 1/500.0
        self.embedding = nn.Embedding(int(1/self.emb_tick)+1,self.emb_dim)
        self.seq_len = (HSPAN + LSPAN)//self.seg
        if bidirec:
            self.RNNLayer = torch.nn.GRU(input_size = self.emb_dim, hidden_size = feed_in_dim//2,num_layers=3, batch_first=True,bidirectional=True,dropout=0.5)
            feed_in_dim *= 2
        else:
            self.RNNLayer = torch.nn.GRU(input_size = self.emb_dim, hidden_size = feed_in_dim//2,num_layers=3, batch_first=True,bidirectional=False,dropout=0.5)
        self.fcnet = FCNet(feed_in_dim,1)
        self.attention_weight = nn.Linear(feed_in_dim//2, feed_in_dim//2, bias=False) # When turning off the bias, an nn.Linear is pretty much a matrix multiplication
        self.norm = torch.nn.BatchNorm1d(feed_in_dim//2)
        self.get_attention = get_attention

    # @torchsnooper.snoop()
    def forward(self, x):
        # x = x.view(-1,self.seq_len)
        x = (x - x.min(dim=-1,keepdim=True)[0])/(x.max(dim=-1,keepdim=True)[0] - x.min(dim=-1,keepdim=True)[0])
        x = (x/self.emb_tick).long()
        x = self.embedding(x)
        bsize = x.size(0)
        output, hidden = self.RNNLayer(x)
        if self.bidirec:
            hidden =  hidden[-2:]
            hidden = hidden.transpose(0,1).reshape(bsize,-1)
        else:
            hidden =  hidden[-1]
        
        
        #Attention Mechanism
        hidden_attention = hidden.unsqueeze(-1) #[batch, channel] -> [batch, channel, 1]
        w_attention = self.attention_weight(output) # [batch, seq_len, channel] * [channel, channel] -> [batch, seq_len, channel]
        w_attention = torch.einsum("ijl,ilm->ijm",w_attention,hidden_attention).squeeze(-1)   # [batch, seq_len, channel] * [batch, channel, 1] -> [batch, seq_len, 1]
        attention_score = torch.softmax(w_attention,dim=-1) #Softmax over seq_len dimension
        if self.get_attention:
            return attention_score
        
        context = torch.sum(attention_score.unsqueeze(-1).expand(*output.size()) * output,dim=1) #Sum over seq_len dimension with attention score multiplied to output
        x = self.fcnet(torch.cat([context,hidden],dim=-1)) #concatenate context vector with last hidden state output
        return x

In [None]:
#Load data
BATCH_SIZE = 32
train_loader, test_loader, det_array = load_data(BATCH_SIZE)

In [None]:
#This feeds the waveform into classifier and get sigmoid output for signal and background events
def get_sigmoid(waveform_in, labels_in ,classifier_in):
    waveform_in = waveform_in.to(DEVICE)
    labels_in = labels_in.to(DEVICE).float()
    outputs_in  = classifier_in(waveform_in)

    lb_data_in = labels_in.cpu().data.numpy().flatten()
    outpt_data_in = outputs_in.cpu().data.numpy().flatten()

    signal_in = np.argwhere(lb_data_in == 1.0)
    bkg_in = np.argwhere(lb_data_in == 0.0)

    return list(outpt_data_in[signal_in].flatten()), list(outpt_data_in[bkg_in].flatten())

## Training RNN

In [None]:
NUM_EPOCHS = 50
LEARNING_RATE =0.1

#Define RNN network
RNNclassifier = RNN()

RNNclassifier.to(DEVICE)



print("#params", sum(x.numel() for x in RNNclassifier.parameters()))

RNNcriterion = torch.nn.BCEWithLogitsLoss() #BCEWithLogitsLoss does not require the last layer to be sigmoid
RNNcriterion = RNNcriterion.to(DEVICE)

# Warmup training scheme
# This allows the attention mechanism to learn general features of data in first few epochs
warmup_size = 4000 # Warm up step used in transformer paper
print("Warmup Size: %d"%(warmup_size))
lmbda = lambda epoch: min((epoch+1)**-0.5, (epoch+1)*warmup_size**-1.5)
RNNoptimizer = torch.optim.AdamW(RNNclassifier.parameters(),lr=LEARNING_RATE, betas=(0.9, 0.98),eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(RNNoptimizer, lr_lambda=lmbda)

for epoch in range(NUM_EPOCHS):
    for i, (waveform, labels, avse) in enumerate(train_loader):
        RNNclassifier.train()
        waveform = waveform.to(DEVICE)
        labels = labels.to(DEVICE).float()
        labels = labels.view(-1,1)
        
        #Train RNN
        RNNoutputs  = RNNclassifier(waveform)
        RNNloss = RNNcriterion(RNNoutputs, labels)
        
        RNNloss.backward()
        RNNoptimizer.step()        # update parameters of net
        RNNoptimizer.zero_grad()   # reset gradient
        scheduler.step()

    print('\rEpoch [{0}/{1}], Iter [{2}/{3}] Loss: {4:.4f}'.format(
        epoch+1, NUM_EPOCHS, i+1, len(train_loader),
        RNNloss.item(), end=""),end="")
    sigmoid_s_RNN = []
    sigmoid_b_RNN = []
    avse_s = []
    avse_b = []

    for waveform,labels,avse in tqdm(test_loader):

        RNNclassifier.eval()

        with torch.no_grad():
            sig_RNN, bkg_RNN = get_sigmoid(waveform, labels, RNNclassifier)

            lb_data = labels.cpu().data.numpy().flatten()
            avse_data = avse.cpu().data.numpy().flatten()
            
            signal = np.argwhere(lb_data == 1.0)
            bkg = np.argwhere(lb_data == 0.0)
            
            sigmoid_s_RNN += sig_RNN
            sigmoid_b_RNN += bkg_RNN
            
            avse_s += list(avse_data[signal].flatten())
            avse_b += list(avse_data[bkg].flatten())

    #Set the range of scatter plot from 5% to 95% quantile of sigmoid output
    xlow = np.quantile(sigmoid_s_RNN+sigmoid_b_RNN,0.05)
    xhi = np.quantile(sigmoid_s_RNN+sigmoid_b_RNN,0.95)

    # Plot the ROC curve for RNN and AvsE
    fpr_rnn, tpr_rnn, thr_rnn, auc_rnn = get_roc(sigmoid_s_RNN, sigmoid_b_RNN)
    fpr_avse, tpr_avse, thr_avse, auc_avse = get_roc(avse_s, avse_b)
    rej_tpr = tpr_avse[np.argmin(np.abs(thr_avse+1.0))]
    plt.plot(fpr_rnn,tpr_rnn,label="RNN AUC: %.3f SEP Remain: %.1f%%"%(auc_rnn,fpr_rnn[np.argmin(np.abs(tpr_rnn-rej_tpr))]*100.0))
    plt.plot(fpr_avse,tpr_avse,label="AvsE AUC: %.3f SEP Remain: %.1f%%"%(auc_avse,fpr_avse[np.argmin(np.abs(thr_avse+1.0))]*100.0))
    plt.legend()
    plt.savefig("ROC.png",dpi=200)
    plt.show()
    plt.cla()
    plt.clf()
    plt.close()
    
    #Save CNN and RNN models.
    torch.save(RNNclassifier.state_dict(), 'RNN.pt')
    

## Plotting
- Confusion plot comparing AvsE and RNN

In [None]:
# confusion plot of A vs. E and RNN classifier
xlow = np.quantile(sigmoid_s_RNN+sigmoid_b_RNN,0)
xhi = np.quantile(sigmoid_s_RNN+sigmoid_b_RNN,1.00)
ylow = -10
yhi =2

threshold_rnn = thr_rnn[np.argmin(np.abs(tpr_rnn-tpr_avse[np.argmin(np.abs(thr_avse+1.0))]))]

#Plot sigmoid output for DEP events
plt.hist2d(sigmoid_s_RNN, avse_s,bins = (np.linspace(xlow,xhi,100),np.linspace(ylow,yhi,100)),cmap="PuRd",norm=matplotlib.colors.LogNorm())
plt.axhline(y=-1,color="blue")
plt.axvline(x=threshold_rnn,color="blue")
plt.title("Network Output of DEP")
plt.xlabel("RNN Output")
plt.ylabel("AvsE Corrected")
# plt.legend()
plt.savefig("AR_signal.png",dpi=200)
plt.show()
plt.cla()
plt.clf()
plt.close()

#Plot sigmoid output for SEP events
plt.hist2d(sigmoid_b_RNN, avse_b,bins = (np.linspace(xlow,xhi,100),np.linspace(ylow,yhi,100)),cmap="PuRd",norm=matplotlib.colors.LogNorm())
plt.axhline(y=-1,color="blue")
plt.axvline(x=threshold_rnn,color="blue")
plt.title("Network Output of SEP")
plt.xlabel("RNN Output")
plt.ylabel("AvsE Corrected")
# plt.legend()
plt.savefig("AR_bkg.png",dpi=200)
plt.show()
plt.cla()
plt.clf()
plt.close()

- network output as a histogram

In [None]:
# Network output for DEP events
plt.hist(sigmoid_s_RNN, bins = np.linspace(xlow,xhi,100),color="red",histtype="step",label="DEP")
plt.hist(sigmoid_b_RNN, bins = np.linspace(xlow,xhi,100),color="blue",histtype="step",label="SEP")
plt.title("RNN output")
plt.xlabel("RNN Sigmoid Output")
plt.ylabel("RNN output")
plt.legend()
plt.savefig("RNN1d.png",dpi=200)
plt.show()
plt.cla()
plt.clf()
plt.close()

- Plotting the attention score of give events

In [None]:
def plot_attention(waveform, attscore,bkg=False):
    '''
    This function plots the attention score distribution on given waveform
    waveform: the vector of original waveform
    attscore: the attention score obtained from the RNN
    '''
    from matplotlib import cm
    from matplotlib import gridspec
    colormap_normal = cm.get_cmap("cool")
    
    waveform=np.array(waveform)
    attscore = np.array(attscore)
    fig = plt.figure(figsize=(20, 12))
    gs = gridspec.GridSpec(1, 2, width_ratios=[8,1]) 

    plt.subplot(gs[0])
    rescale = lambda y: (y - np.min(y)) / (np.max(y) - np.min(y))
    len_wf = len(waveform)
    print(np.linspace(0,len_wf,len_wf).shape,waveform.shape, rescale(attscore).shape)
    plt.bar(np.linspace(0,len_wf,len_wf),waveform,width=1.5, color=colormap_normal(rescale(attscore)))
    plt.xlabel("Time Sample")
    plt.ylabel("ADC Counts")

    loss_ax_scale = fig.add_subplot(gs[1])
    loss_ax_scale.set_xticks([])
    loss_ax_scale.tick_params(length=0)
    plt.yticks([1,72], ["High Attention", "Low Attention"],rotation=90)  # Set text labels and properties.

    # loss_ax_scale.set_yticks([1.0,0.0])
    # loss_ax_scale.set_yticklabels(["High Attention", "Low Attention"],rotation=90)
    loss_scale = np.linspace(1.0, 0.0, 100)

    for i in range(0,1):
        loss_scale = np.vstack((loss_scale,loss_scale))
    loss_scale = loss_ax_scale.imshow(np.transpose(loss_scale),cmap=colormap_normal, interpolation='nearest')

    plt.tight_layout()
    plt.show()
    if bkg:
        plt.savefig("att_nhit_bkg.png",dpi=200)
    else:
        plt.savefig("att_nhit_sig.png",dpi=200)
    plt.cla()
    plt.clf()
    plt.close()

In [None]:
#Set the RNN to attention score mode
attentionRNN = RNN(True)
attentionRNN.to(DEVICE).double()
model_dict = attentionRNN.state_dict()
pretrained_dict = torch.load('RNN.pt',map_location='cpu')
model_dict.update(pretrained_dict) 
attentionRNN.load_state_dict(pretrained_dict)

attentionRNN.eval()

wf = next(iter(test_loader))[0]

with torch.no_grad():
    
    waveform = torch.tensor(wf).to(DEVICE)
    attention  = attentionRNN(waveform)
    print(attention.size())
    
    ibatch=0
    wf = waveform[ibatch]#.view(600,3)[:,0]
    attention = attention[ibatch]
    plot_attention(wf.cpu().data.numpy().flatten(), attention.cpu().data.numpy().flatten())
    assert 0
                
                
