# This code uses the multi head self attention (Transformer) model to perform Network PSA

In [1]:
import numpy as np
import os
import argparse
import time
import math
import random
import torch.nn as nn
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn import init
import torch.nn.functional as F
import torch
import torch.utils.data as data_utils
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
import gzip
import pickle5 as pickle
import numpy as np
from torch.autograd import Variable
from scipy import sparse
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.preprocessing import StandardScaler
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import cm
import torchsnooper

from torch import nn, einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

from tqdm import tqdm

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
'''
These 2 parameters define the range of waveform we'd like to analyze
'''
LSPAN = 300 # number of time samples prior to t0
HSPAN = 300 # number of time samples after t0
DATA_EMB_DIM = 500

gets the false positive rate, true positive rate, cutting threshold and area under curve using the given signal and background array

In [None]:
def get_roc(sig,bkg):
    testY = np.array([1]*len(sig) + [0]*len(bkg))
    predY = np.array(sig+bkg)
    auc = roc_auc_score(testY, predY)
    fpr, tpr, thr = roc_curve(testY, predY)
    return fpr,tpr,thr,auc

## Load Data

In [4]:
class DetectorDataset(Dataset):

    def __init__(self, dep="DEP_P42575A_10percent.pickle", sep = "SEP_P42575A_10percent.pickle",dsize=-1):
        
        DEP_dict = self.event_loader(dep)
        SEP_dict = self.event_loader(sep)

        if dsize == -1:
            dsize = min(len(DEP_dict), len(SEP_dict))
        
        #Shuffle dataset and select #dsize event from DEP and SEP
        np.random.shuffle(DEP_dict)
        np.random.shuffle(SEP_dict)
        DEP_dict = DEP_dict[:dsize]
        SEP_dict = SEP_dict[:dsize]
        self.event_dict = DEP_dict + SEP_dict
        self.label = ([1]*len(DEP_dict)) + ([0] * len(SEP_dict))
        
        self.size = len(self.event_dict)
        print(self.size)
        
        #Get offset values:
        self.max_offset = np.max(self.get_field_from_dict(DEP_dict,"tstart") + self.get_field_from_dict(SEP_dict,"tstart"))
        
        #Get all unique detector name:
        self.detector_name = np.unique(self.get_field_from_dict(DEP_dict,"detector") + self.get_field_from_dict(SEP_dict,"detector"))
        
        self.data_emb_dim = DATA_EMB_DIM

    def __len__(self):
        return self.size
    
    def build_scaler(self):
        '''
        '''
        wf_array = []
        for i in range(self.size):
            wf_array.append(self.get_wf(i).reshape(1,-1))
        wf_array = np.concatenate(wf_array,axis=0)
        scaler = StandardScaler()
        scaler.fit(wf_array)
        return scaler
    
    def get_scaler(self):
        return self.scaler
    
    def set_scaler(self,scaler):
        self.scaler = scaler
    
    def get_wf(self,idx):
        event = self.event_dict[idx]
        wf = np.array(event["wf"]).flatten()
        midindex = event["t0"]

        #baseline subtraction
        wf -= np.average(wf[:(midindex-50)])
        
        #Extract waveform from its t0
        wfbegin = midindex - LSPAN
        wfend = midindex + HSPAN
        
        wf = wf[wfbegin:wfend]
        wf = (wf - np.min(wf)) / (np.max(wf) - np.min(wf))#rescale wf to between 0 and 1

        
        return wf

    def __getitem__(self, idx):
        event = self.event_dict[idx]
        wf = np.array(event["wf"]).flatten()
        midindex = event["t0"]

        wf = self.get_wf(idx)
        
        avse = event["avse"]
        tdrift = event["tDrift"]
        
        return self.tokenizer(wf), self.label[idx], avse
        
    def return_label(self):
        return self.trainY

    def return_detector_array(self):
        return self.detector_name
    
    #Load event from .pickle file
    def event_loader(self, address):
        wf_list = []
        with (open(address, "rb")) as openfile:
            while True:
#                 if len(wf_list) > 2000:
#                     break
                try:
                   wf_list.append(pickle.load(openfile, encoding='latin1'))
                except EOFError:
                    break
        return wf_list
    
    def get_field_from_dict(self, input_dict, fieldname):
        field_list = []
        for event in input_dict:
            field_list.append(event[fieldname])
        return field_list
    
    def tokenizer(self,x):
        '''
        Tokenize the waveform
        Originally the waveform is a 1D array of floatpoint numbers, say its shape is wf.shape = (1000,)
        After tokenizing, its shape becomes wf.shape = (1000,self.data_emb_dim)
        Where now in each dimension it is an one-hot vector with length of self.data_emb_dim
        Example:
            suppose self.data_emb_dim = 10
            Then originally we have wf[5] = 0.25
            Then after tokenizing, wf[5] becomes [0 0 1 0 0 0 0 0 0 0]
        '''
        nbins = self.data_emb_dim
        x += np.random.rand(len(x)) * 0.02 - 0.01
        x = np.clip(x,0.0,1.0)
        token_range = np.linspace(0.0,1.0,nbins+1)
        return np.array([np.argwhere(xval>=token_range)[-1,0] for xval in x])
    
    def get_data_emb(self):
        return self.data_emb_dim
        
# print(next(iter(DetectorDataset())))

## Modules of the multi-head self attention network

In [5]:
class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        '''
        The multi-head self attention moduel
        q: query
        k: key
        v: value
        q,k,v are produced by multiplying kernel matrices to the input at each time
        '''
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

In [6]:
#Load dataset
def load_data(batch_size):

    dataset = DetectorDataset()
    test_dataset = DetectorDataset(dep="DEP_P42575A_Co56.pickle")
    validation_split = .3 #Split data set into training & testing with 7:3 ratio
    shuffle_dataset = True
    random_seed= 42222

    #make sure we have the same amount of signal/bkg in the training/test dataset
    division = 2
    dataset_size = int(len(dataset)/division)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    if shuffle_dataset :
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    train_indices += list(division*dataset_size - 1-np.array(train_indices))
    val_indices += list(division*dataset_size- 1-np.array(val_indices))

    np.random.shuffle(train_indices)
    np.random.shuffle(val_indices)
    
    test_dataset_size = int(len(test_dataset))
    test_indices = list(range(test_dataset_size))
    if shuffle_dataset :
        np.random.seed(random_seed)
        np.random.shuffle(test_indices)

    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(test_indices)

    train_loader = data_utils.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler, drop_last=True)
    test_loader = data_utils.DataLoader(test_dataset, batch_size=batch_size,sampler=valid_sampler,  drop_last=True)

    return train_loader,test_loader, dataset.get_data_emb()

## The Transformer Classifier

In [7]:
class WaveTransformer(nn.Module):
    def __init__(self,data_emb, get_attention = False):
        super(WaveTransformer, self).__init__()
        self.seg = 1      #Segment waveform to reduce its length. If the original waveform is (2000,1), then segment it with self.seg=5 can reduce its length to (400,5)
        self.embedding_dim = 64  # Embedding dimension
        self.seq_len = (HSPAN + LSPAN)//self.seg
        self.embedding = nn.Embedding(data_emb, self.embedding_dim)
        # This is the classification token, it attends different parts of the wavefom to make classification decision
        self.cls_token = nn.Parameter(torch.randn(1, 1, self.embedding_dim))
        # dim, depth, heads, dim_head, mlp_dim, dropout = 0.
        self.transformer = Transformer(self.embedding_dim,1,10,32, 1024,0.3)
        # Positional encoder, learning during training. It's length is seq_len + 1 because we need to hold the classification token at the beginning
        self.pos_encoder = nn.Parameter(torch.randn(1, self.seq_len+1, self.embedding_dim))
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(self.embedding_dim),
            nn.Linear(self.embedding_dim, 1)
        )

#     @torchsnooper.snoop()
    def forward(self, x):
        b,n = x.size()
        x = self.embedding(x)  # [batch, seq_len, data_embedding_dimension] -> [batch, seq_len, self.embedding_dim]
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b) # [1, 1, self.embedding_dim] -> [batch, 1, self.embedding_dim]
        x = torch.cat((cls_tokens, x), dim=1) # {[batch, seq_len, self.embedding_dim], [batch, 1, self.embedding_dim]} -> [batch, seq_len+1, self.embedding_dim]
        x += self.pos_encoder[:, :(n + 1)]
        
        x = self.transformer(x)[:,0] # Feed through the transformer model, then select only the 0th element in the sequence (that is, the classification token)
        return self.mlp_head(x)      # Feed cls_token to a fully connected NN for classification decision

    def get_emb_dim(self):
        return self.embedding_dim

In [8]:
#Load data
BATCH_SIZE = 4
train_loader, test_loader, data_emb_dim = load_data(BATCH_SIZE)

11576
4868


In [9]:
#This feeds the waveform into classifier and get sigmoid output for signal and background events
def get_sigmoid(waveform_in, labels_in ,classifier_in):
    waveform_in = waveform_in.to(DEVICE).long()
    labels_in = labels_in.to(DEVICE).float()
    outputs_in  = classifier_in(waveform_in)

    lb_data_in = labels_in.cpu().data.numpy().flatten()
    outpt_data_in = outputs_in.cpu().data.numpy().flatten()

    signal_in = np.argwhere(lb_data_in == 1.0)
    bkg_in = np.argwhere(lb_data_in == 0.0)

    return list(outpt_data_in[signal_in].flatten()), list(outpt_data_in[bkg_in].flatten())

## Train the transformer classifier

In [10]:
NUM_EPOCHS = 50

#Define RNN network
RNNclassifier = WaveTransformer(data_emb_dim)
LEARNING_RATE =RNNclassifier.get_emb_dim()**-0.5
RNNclassifier.to(DEVICE)



print("#params", sum(x.numel() for x in RNNclassifier.parameters()))

RNNcriterion = torch.nn.BCEWithLogitsLoss() #BCEWithLogitsLoss does not require the last layer to be sigmoid
RNNcriterion = RNNcriterion.to(DEVICE)

# Warmup training scheme
# This allows the attention mechanism to learn general features of data in first few epochs
warmup_size = 4000 # Warm up step used in transformer paper
print("Warmup Size: %d"%(warmup_size))
lmbda = lambda epoch: min((epoch+1)**-0.5, (epoch+1)*warmup_size**-1.5)
RNNoptimizer = torch.optim.AdamW(RNNclassifier.parameters(),lr=LEARNING_RATE, betas=(0.9, 0.98),eps=1e-9)
scheduler = torch.optim.lr_scheduler.LambdaLR(RNNoptimizer, lr_lambda=lmbda)

for epoch in range(NUM_EPOCHS):
    for i, (waveform, labels, avse) in enumerate(train_loader):
        RNNclassifier.train()
        waveform = waveform.to(DEVICE).long()
        labels = labels.to(DEVICE).float()
        labels = labels.view(-1,1)
        
        #Train RNN
        RNNoutputs  = RNNclassifier(waveform)
        RNNloss = RNNcriterion(RNNoutputs, labels)
        
        RNNloss.backward()
        RNNoptimizer.step()        # update parameters of net
        RNNoptimizer.zero_grad()   # reset gradient
        scheduler.step()

    print('\rEpoch [{0}/{1}], Iter [{2}/{3}] Loss: {4:.4f}'.format(
        epoch+1, NUM_EPOCHS, i+1, len(train_loader),
        RNNloss.item(), end=""),end="")
    sigmoid_s_RNN = []
    sigmoid_b_RNN = []
    avse_s = []
    avse_b = []

    for waveform,labels,avse in tqdm(test_loader):

        RNNclassifier.eval()

        with torch.no_grad():
            sig_RNN, bkg_RNN = get_sigmoid(waveform, labels, RNNclassifier)

            lb_data = labels.cpu().data.numpy().flatten()
            avse_data = avse.cpu().data.numpy().flatten()
            
            signal = np.argwhere(lb_data == 1.0)
            bkg = np.argwhere(lb_data == 0.0)
            
            sigmoid_s_RNN += sig_RNN
            sigmoid_b_RNN += bkg_RNN
            
            avse_s += list(avse_data[signal].flatten())
            avse_b += list(avse_data[bkg].flatten())

    #Set the range of scatter plot from 5% to 95% quantile of sigmoid output
    xlow = np.quantile(sigmoid_s_RNN+sigmoid_b_RNN,0.05)
    xhi = np.quantile(sigmoid_s_RNN+sigmoid_b_RNN,0.95)

    # Plot the ROC curve for RNN and AvsE
    fpr_rnn, tpr_rnn, thr_rnn, auc_rnn = get_roc(sigmoid_s_RNN, sigmoid_b_RNN)
    fpr_avse, tpr_avse, thr_avse, auc_avse = get_roc(avse_s, avse_b)
    rej_tpr = tpr_avse[np.argmin(np.abs(thr_avse+1.0))]
    plt.plot(fpr_rnn,tpr_rnn,label="Transformer AUC: %.3f SEP Remain: %.1f%%"%(auc_rnn,fpr_rnn[np.argmin(np.abs(tpr_rnn-rej_tpr))]*100.0))
    plt.plot(fpr_avse,tpr_avse,label="AvsE AUC: %.3f SEP Remain: %.1f%%"%(auc_avse,fpr_avse[np.argmin(np.abs(thr_avse+1.0))]*100.0))
    plt.legend()
    plt.savefig("ROC.png",dpi=200)
    plt.show()
    plt.cla()
    plt.clf()
    plt.close()
    
    #Save CNN and RNN models.
    torch.save(RNNclassifier.state_dict(), 'RNN.pt')
    

#params 259521
Warmup Size: 4000


RuntimeError: CUDA error: an illegal memory access was encountered

In [None]:
# confusion plot of A vs. E and RNN classifier
xlow = np.quantile(sigmoid_s_RNN+sigmoid_b_RNN,0)
xhi = np.quantile(sigmoid_s_RNN+sigmoid_b_RNN,1.00)
ylow = -10
yhi =2

threshold_rnn = thr_rnn[np.argmin(np.abs(tpr_rnn-tpr_avse[np.argmin(np.abs(thr_avse+1.0))]))]

#Plot sigmoid output for DEP events
plt.hist2d(sigmoid_s_RNN, avse_s,bins = (np.linspace(xlow,xhi,100),np.linspace(ylow,yhi,100)),cmap="PuRd",norm=matplotlib.colors.LogNorm())
plt.axhline(y=-1,color="blue")
plt.axvline(x=threshold_rnn,color="blue")
plt.title("Network Output of DEP")
plt.xlabel("RNN Output")
plt.ylabel("AvsE Corrected")
# plt.legend()
plt.savefig("AR_signal.png",dpi=200)
plt.show()
plt.cla()
plt.clf()
plt.close()

#Plot sigmoid output for SEP events
plt.hist2d(sigmoid_b_RNN, avse_b,bins = (np.linspace(xlow,xhi,100),np.linspace(ylow,yhi,100)),cmap="PuRd",norm=matplotlib.colors.LogNorm())
plt.axhline(y=-1,color="blue")
plt.axvline(x=threshold_rnn,color="blue")
plt.title("Network Output of SEP")
plt.xlabel("RNN Output")
plt.ylabel("AvsE Corrected")
# plt.legend()
plt.savefig("AR_bkg.png",dpi=200)
plt.show()
plt.cla()
plt.clf()
plt.close()

In [None]:
# Network output for DEP events
plt.hist(sigmoid_s_RNN, bins = np.linspace(xlow,xhi,100),color="red",histtype="step",label="DEP")
plt.hist(sigmoid_b_RNN, bins = np.linspace(xlow,xhi,100),color="blue",histtype="step",label="SEP")
plt.title("RNN output")
plt.xlabel("RNN Sigmoid Output")
plt.ylabel("RNN output")
plt.legend()
plt.savefig("RNN1d.png",dpi=200)
plt.show()
plt.cla()
plt.clf()
plt.close()

In [None]:
def plot_attention(waveform, attscore,bkg=False):
    '''
    This function plots the attention score distribution on given waveform
    waveform: the vector of original waveform
    attscore: the attention score obtained from the RNN
    '''
    from matplotlib import cm
    from matplotlib import gridspec
    colormap_normal = cm.get_cmap("cool")
    
    waveform=np.array(waveform)
    attscore = np.array(attscore)
    fig = plt.figure(figsize=(20, 12))
    gs = gridspec.GridSpec(1, 2, width_ratios=[8,1]) 

    plt.subplot(gs[0])
    rescale = lambda y: (y - np.min(y)) / (np.max(y) - np.min(y))
    len_wf = len(waveform)
    print(np.linspace(0,len_wf,len_wf).shape,waveform.shape, rescale(attscore).shape)
    plt.bar(np.linspace(0,len_wf,len_wf),waveform,width=1.5, color=colormap_normal(rescale(attscore)))
    plt.xlabel("Time Sample")
    plt.ylabel("ADC Counts")

    loss_ax_scale = fig.add_subplot(gs[1])
    loss_ax_scale.set_xticks([])
    loss_ax_scale.tick_params(length=0)
    plt.yticks([1,72], ["High Attention", "Low Attention"],rotation=90)  # Set text labels and properties.

    # loss_ax_scale.set_yticks([1.0,0.0])
    # loss_ax_scale.set_yticklabels(["High Attention", "Low Attention"],rotation=90)
    loss_scale = np.linspace(1.0, 0.0, 100)

    for i in range(0,1):
        loss_scale = np.vstack((loss_scale,loss_scale))
    loss_scale = loss_ax_scale.imshow(np.transpose(loss_scale),cmap=colormap_normal, interpolation='nearest')

    plt.tight_layout()
    if bkg:
        plt.savefig("att_nhit_bkg.png",dpi=200)
    else:
        plt.savefig("att_nhit_sig.png",dpi=200)
    plt.cla()
    plt.clf()
    plt.close()

In [None]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
wf = np.array([0]*160+np.linspace(0,1,20).tolist() + np.linspace(1,0.8,120).tolist())+np.random.randn(300)*0.02
wf = (wf - np.min(wf)) / (np.max(wf) - np.min(wf))#rescale wf to between 0 and 1

#Set the RNN to attention score mode
attentionRNN = RNN(True)
attentionRNN.to(DEVICE).double()
model_dict = attentionRNN.state_dict()
pretrained_dict = torch.load('RNN.pt',map_location='cpu')
model_dict.update(pretrained_dict) 
attentionRNN.load_state_dict(pretrained_dict)

attentionRNN.eval()

with torch.no_grad():

    waveform = torch.tensor(wf).to(DEVICE).view(1,-1,1).expand(32,300,1).double()
    attention  = attentionRNN(waveform)
    attention = torch.sum(attention,dim=1)
    
    ibatch=0
    wf = waveform[ibatch]#.view(600,3)[:,0]
    attention = attention[ibatch]
    plot_attention(wf.cpu().data.numpy().flatten(), attention.cpu().data.numpy().flatten())
    assert 0
                
