# Demo of Self Attention in Pytorch 

In this Notebook we want to test, that Self-Attention and Piece wise Feed Forward Neural Networks can solve a simple logical, deterministic problem. Furthermore, we want to attest to the degree, that the attention weights can "explain" the predictions. We see that this is the case for Piece Wise Feed Forward Neural Netowrks but it is not! the case for Self Attention Neural Networks. We conjegture, that this is due to the fact that the representation of Self Attention is dependend on all elements of the input sequence

In [1]:
import random
import numpy as np
from model_pytorch import *
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset,sampler,DataLoader

import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
from sklearn.metrics import roc_auc_score


In [2]:
random.seed(123)

# More complex generation

In [3]:
#first we get a set of fraud elements (a pair of numbers and a numeric value)
n_fraud=20
fraud_list=[]

for j in range(n_fraud):
    fraud_elem=[np.array([random.randint(1,20),random.randint(1,20)]),np.array([random.randint(1,20),random.randint(1,20)]),random.random()]
    fraud_list.append(fraud_elem)
    

In [4]:
#the condition: 
#Both combinations need to be present
#and when the first is present on that row the numeric variable needs to be bigger than the value shown here
#and on the seconde row it needs to be lower
fraud_list

[[array([2, 9]), array([ 3, 14]), 0.26655381253138355],
 [array([ 2, 13]), array([18, 18]), 0.33219769850967984],
 [array([2, 6]), array([ 5, 11]), 0.5609620089366762],
 [array([8, 6]), array([ 1, 14]), 0.7735879800290875],
 [array([20, 13]), array([3, 1]), 0.3154589737358633],
 [array([15,  4]), array([2, 3]), 0.6671381677849223],
 [array([5, 1]), array([10, 14]), 0.5734080856809272],
 [array([ 9, 16]), array([ 2, 10]), 0.3434621581559646],
 [array([16,  7]), array([20, 17]), 0.5648140371541568],
 [array([11,  1]), array([13, 17]), 0.43438667569188183],
 [array([18, 20]), array([16, 17]), 0.6684702219031471],
 [array([12, 17]), array([2, 6]), 0.7258278416111333],
 [array([ 3, 16]), array([9, 6]), 0.33490052325216046],
 [array([18, 13]), array([ 3, 15]), 0.3876556275275983],
 [array([11,  1]), array([ 7, 19]), 0.08815729492607605],
 [array([12,  1]), array([12,  8]), 0.40571885207293545],
 [array([17,  2]), array([10, 16]), 0.9640591723038818],
 [array([ 8, 18]), array([16,  6]), 0.153

In [5]:
#then we set up a logical function 
def add_1(a,b,v,rand_list):
    #If i do it this way the ordereing doenst matter !!!! ( have to modify it)
    places_a=np.sum(rand_list[:,0:2]==a,axis=1)==2
    places_b=np.sum(rand_list[:,0:2]==b,axis=1)==2
    #not using the value right now
    
    if (any(places_a) and any(places_b)): 
        #and any(rand_list[places_a][:,2]>v) and any(rand_list[places_b][:,2]>v)):
        return True
    else:
        return False

In [6]:
#Here we generate the data we will store:

#All the input data present
all_data=[]
#if at least one of the "reasons" was present
all_label=[]
#and the actual reason present ( we can use that to check if the explanation was correct)
all_reason=[]

for j in range(100000):
    rand_list=[]
    for j in range(random.randint(1,100)):
        rand_list.append((random.randint(1,20),random.randint(1,20),random.random()))
    
    rand_list=np.array(rand_list)
    
    add_list=[]
    reason_list=[]
    
    for i in fraud_list:
        if add_1(a=i[0],b=i[1],v=i[2],rand_list=rand_list):
            
            add_list.append(True)
            reason_list.append(i)
            
        else:
            add_list.append(False)
    if any(add_list):
        all_label.append(1) 
    else:
        all_label.append(0)
        
    all_data.append(rand_list)
    all_reason.append(reason_list)

In [7]:
test_share=0.1

In [8]:
train_split_elem=int((1-test_share)*len(all_data))

In [9]:
# we split into train and validation

train_X=all_data[:train_split_elem]
val_X=all_data[train_split_elem:]

train_y=all_label[:train_split_elem]
val_y=all_label[train_split_elem:]

train_r=all_reason[:train_split_elem]
val_r=all_reason[train_split_elem:]

In [10]:
np.mean(train_y)

0.2667888888888889

In [11]:
np.mean(val_y)

0.2664

In [12]:
class dummy_data(Dataset):
    def __init__(self, input,output,reason):     
        
        self.data=input
        self.label=output
        self.reason=reason
        
    def __len__(self):
        return int(len(self.data))
    
    def __getitem__(self, index):
        
        X=self.data[index]
        y=self.label[index]
        r=self.reason[index]
        return X, y, r

In [13]:
train_ds=dummy_data(train_X,train_y,train_r)
val_ds=dummy_data(val_X,val_y,val_r)

In [14]:
batch_size=128

In [15]:
#THe collate difnes how data gets batched 
#here we are using a simple dynamic badding for batching. 
def my_collate(batch):
    texts=([x[0] for x in batch])
    labels=np.array(([x[1] for x in batch]))
    reason=np.array(([x[2] for x in batch]))
    batch_size=len(batch)
    maxlen=np.max([len(x) for x in texts])
    text_stack=np.zeros(shape=(batch_size,maxlen,3))
    for enu,txt in enumerate(texts):
        text_stack[enu,0:len(txt),:]=txt

    return torch.tensor(text_stack[:,:,0]).cuda(),torch.tensor(text_stack[:,:,1]).cuda(),torch.tensor(text_stack[:,:,2]).cuda(),torch.tensor(labels).cuda(),reason

In [16]:
val_dl= DataLoader(dataset=val_ds,
                      batch_size=batch_size,
                      shuffle=False,
                      collate_fn=my_collate
                      )

In [17]:
train_dl= DataLoader(dataset=train_ds,
                      batch_size=batch_size,
                      shuffle=True,
                      collate_fn=my_collate
                      )

In [18]:
class simple_fraud_model(nn.Module):
    def __init__(self, d_model=32,heads=1,nlay=1,dropout=0,SelfA=True,return_w=False):
        
        super(simple_fraud_model,self).__init__()
        self.return_w=return_w
        emb_d_1=int(d_model/2)
        emb_d_2=int(d_model/2)
        
        self.embedding_1=nn.Embedding(num_embeddings=21,embedding_dim=emb_d_1)
        self.embedding_2=nn.Embedding(num_embeddings=21,embedding_dim=emb_d_2-1)

        
        if SelfA==True:
            self.encoder_layers=EncoderLayer(d_model=d_model,heads=heads,dropout=dropout,share_params=True)

        if SelfA==False:
            self.encoder_layers=FeedForward(d_model)


        self.mula=multi_attention(input_dim=d_model,key_dim=d_model,nheads=1,return_weights=True,value_dim=d_model)

        self.fully_con=nn.Linear(d_model,d_model*4)
        self.relu=nn.ReLU()
        
        self.fully_con_1=nn.Linear(d_model*4,d_model*4)
        self.relu_1=nn.ReLU()
        
        self.final_fully_con=nn.Linear(d_model*4,1)
        self.sig=nn.Sigmoid()
        self.selfa=SelfA
        
        
    def forward(self, x1,x2,x3,return_w):
        e1=self.embedding_1(x1)
        e2=self.embedding_2(x2)

        cat=torch.cat([e1,e2,x3.unsqueeze(2)],dim=2)
        if self.selfa==True:
            
            feat,w2=self.encoder_layers(cat)
        else:
            feat=self.encoder_layers(cat)

        ag,weights=self.mula(feat)
        #first layer with a relu activation
        fc=self.relu(self.fully_con(ag.squeeze()))
        #second layer
        fc=self.relu_1(self.fully_con_1(fc))
        #finally the last layer with a sigmoid
        preds=self.sig(self.final_fully_con(fc))
        
        if return_w==False:
            return preds.squeeze()
        if return_w==True:
            
            if self.selfa==True:
                #this might not be a 100% correct
                #We are multiplying the first layer attention weights together with the aggregation weights. 
                return preds.squeeze(),weights#torch.bmm(w2.squeeze(),weights)
            else:
                return preds.squeeze(),weights

In [19]:
#A train and evaluation funciton
def train_eval(atnm,train,opti,crit,eval_metrics,iterator,n_iter,writer):
    
    '''
    Args:
    atnm: A Model to be trained/evalued
    train: If we want to train/eval (If train we sub the gradient)
    opti: An optimizer to be used
    crit: A loss function to be used
    eval_matrics: If we want to keep track of predictions during batch gen and in the end
    calculate a metric on the whole data (aka the AUC)
    iterator: The data generator as an iterator
    n_iter: the current step to be updated
    writer: the tensorboard writer used to keep trakc of training results
    
    '''
    if train==False:
        loss_val=[]
        name1="val_loss"
        name2="val_roc"  
        atnm.eval()
    else:
        name1="train_loss"
        name2="train_roc"
        atnm.train()

    if eval_metrics:
        store_label=[]
        store_preds=[]
        
    #The epcoh 
    for batch in iterator:
        #print(batch[0].shape[1])
        opti.zero_grad()
        #batch=ba
        predictions,w=atnm(batch[0].long(),batch[1].long(),batch[2].float(),True)        #print(predictions)
        loss = crit(predictions, batch[3].float().cuda())

        if train==True:
            loss.backward()
            opti.step()
            n_iter=n_iter+1
            writer.add_scalar(name1,loss.cpu().detach().numpy(),n_iter)
        #when we dont train we dont write during epoch but only at the end
        #also we dont up the iter
        if train==False:
            loss_val.append(loss.cpu().detach().numpy())
        if eval_metrics== True: 
            store_preds.append(predictions.cpu().detach().numpy())
            store_label.append(batch[3].float().cpu().detach().numpy())
            
        del predictions
        del loss
        
    ### End of Batch
    if train == False:
        writer.add_scalar(name1,np.mean(loss_val),n_iter)

    if eval_metrics== True: 
        store_preds=np.concatenate(store_preds)
        store_label=np.concatenate(store_label)
        roc=roc_auc_score(store_label,store_preds)
        writer.add_scalar(name2,roc,n_iter)
        
        return roc
    

# Explaining predictions

In this loop we show that the model correctly recovers the "reason" for a fraud prediction as being the existence of the two problematic input elements. Here Self Attention generally performs much better. 



In [20]:
def explanation_score(sf,it):
    reason_list=[]
    explanation_list=[]
    weight_list=[]
    sf.eval()
    
    for batch in it:

        #We get both inputs
        i1=batch[0].long()
        i2=batch[1].long()
        i3=batch[2].float()

        #this is directly a list 
        r=batch[4]

        #We get both prediction and attention weights 
        predictions,weights=sf(i1,i2,batch[2].float(),True)

        #Get the predictions
        preds_np=np.round(predictions.detach().cpu().numpy())

        #get the attention weights 
        weights_np=weights.detach().cpu().numpy()
        fraud_elem=np.where(preds_np>0.5)[0]

        #Only if we actually predict something we care about an explanation (we can change that threshhold)
        if fraud_elem.size>0:
            
            i1=i1.detach().cpu().numpy()
            i2=i2.detach().cpu().numpy()
            i3=i3.detach().cpu().numpy()

            #This is the reason
            for fraud_ele in fraud_elem:

                #get the weigths(we round casue we want to look at them later)
                my_weight=np.squeeze(np.round(weights_np[fraud_ele],2))

                #get the element in the input the model put attention on
                atn_elem=np.where(my_weight>0.5)[0]

                #get the actual value so we can check if the actual confidence is a good measure
                atn_value=np.expand_dims(my_weight[np.where(my_weight>0.5)[0]],axis=1)

                #Next we stack up all inputs 
                in_elem=np.stack([np.squeeze(i1[fraud_ele]),np.squeeze(i2[fraud_ele]),np.squeeze(i3[fraud_ele])],axis=1)

                #we get the input elements our model has picked out 
                in_elem_j=in_elem[atn_elem]
            
                #then we concatenate together with the other variables
                in_elem_j=np.concatenate([in_elem_j,atn_value],axis=1)

                #in the end we save everything in a list for evaluation
                explanation_list.append(in_elem_j)
                reason_list.append(r[fraud_ele])
                
                
    if (len(reason_list))>0 and (len(explanation_list))>0:
        error_list=[]
        reason_wrong=[]

        for explanation,reason in zip(explanation_list,reason_list):
            e=explanation
            r=reason
            if len(e)>0:
                splits=np.split(e[:,0:4],axis=0,indices_or_sections=len(e))

                if len(r)>0:
                    e=np.array([j[0][0:2] for j in splits])
                    #new_array = [tuple(row) for row in e]
                    #e = np.unique(new_array)

                    e=np.unique(e,axis=0)

                    r=np.array(r[0][0:2])

                    #if not np.array_equal(np.sort(r.flatten()),np.sort(e.flatten())):
                    error_list.append(splits)
                    reason_wrong.append(r)
                else:
                    error_list.append(splits)
                    reason_wrong.append(r)
            else:
                error_list.append(e)
                reason_wrong.append(r)        
                    #print(reason_wrong)
        #print(reason_list)
        #so we define it as an error if we did not return the exact two conditions that lead to a "fraud" in a particular claim 
        error_score=len(error_list)/len(reason_list)
        
        return error_score,error_list,reason_wrong
    else:
        return 1,[],[]
    

In [21]:
sf=simple_fraud_model(SelfA=False,d_model=32).cuda()

In [22]:
criterion=nn.BCELoss()
optimizer = torch.optim.Adam(sf.parameters(), lr=5e-5)

In [23]:
writer = SummaryWriter(log_dir="logs/pff")

In [24]:
max_epochs=200


In [25]:
sensitivity_list=[]
totally_correct_list=[]
auc_list=[]

for j in range(max_epochs):
    #atnm,train,opti,crit,eval_metrics,iterator
    n_iter=(train_ds.__len__()/batch_size)*j
    
    roc_t=train_eval(sf,True,optimizer,criterion,True
                   ,iter(train_dl),n_iter=n_iter,writer=writer)
    
    roc_v=train_eval(atnm=sf
                     ,train=False
                     ,opti=optimizer
                     ,crit=criterion
                     ,eval_metrics=True
                     ,iterator=iter(val_dl)
                    ,writer=writer
                    ,n_iter=n_iter)
    
    eval_v,error_list,reason_wrong=explanation_score(sf,it=iter(val_dl))
    
    if len(error_list)>0 and len(reason_wrong)>0:
        #In this list comprehension we check how many of the conditions we recovered (even if we got too many)
        incorrect_predictions=[]
        len_list=[]
        summ=[]
        for err,rea in zip(error_list,reason_wrong):
            #if its zero, it means either: We have no reason (which adds to the )
            if len(err)>0 and len(rea)>0:
                le=np.unique(np.stack([x[0][0:2] for x in err]),axis=0)
                #here we do the comparrison, if an elemen in the (reason(which is a  tuple) is in the reason list)
                comp_lists=[m == [x[0][0:2] for x in  err] for m in rea]
                #if it appears it will show two ones in that row, that means when we sum over the rows we get a 2 
                #now we have to care about situation, where the releant input appears multiple times. also we need to check thaat
                #each reason appears, that happens here
                all_inputs_appear=np.sum([any(np.sum(x,axis=1)==2) for x in comp_lists])==len(comp_lists)
                summ.append(all_inputs_appear)
                len_list.append(le)
                
            #We didnt put weight on anything (happens at staart of training)
            if len(err)==0:
                summ.append(0)
            if len(rea)==0:
                incorrect_predictions.append((err,rea))

        sensitivity=np.sum(summ)/len(summ)
        
    else:
        sensitivity=0
        
    auc_list.append(roc_v)
    totally_correct_list.append(eval_v)
    sensitivity_list.append(sensitivity)
    
    writer.add_scalar("auc",roc_v,n_iter)
    writer.add_scalar("totally_correct",eval_v,n_iter)
    writer.add_scalar("sensitivtiy",sensitivity,n_iter)
    
    print(roc_v)
    print(eval_v)
    print(sensitivity)
    print(np.mean([len(x) for x in len_list]))
    

0.7946640410857968
1.0
0.5461465271170314
53.97621313035204
0.7967356157957304
1.0
0.6482558139534884
61.87063953488372
0.799437924497562
1.0
0.8029045643153527
67.03941908713693
0.8025201626108115
1.0
0.7929901423877328
64.08652792990142
0.8054224651314347
1.0
0.8195329087048833
58.25053078556263
0.8083457213347479
1.0
0.762970498474059
49.38250254323499
0.8127887719371498
1.0
0.7372188139059305
40.53169734151329
0.8175473558345696
1.0
0.6185682326621924
32.115212527964204
0.8224563764282276
1.0
0.584703947368421
23.06578947368421
0.8287392320073619
1.0
0.5867346938775511
17.008163265306123
0.8341078520587109
1.0
0.5909090909090909
12.984848484848484
0.8390099392604163
1.0
0.640807651434644
12.075451647183847
0.8430949351750878
1.0
0.6409978308026031
10.582429501084599
0.8459613426812853
1.0
0.6410496719775071
9.082474226804123
0.8494498622122668
1.0
0.6483300589390962
9.003929273084479
0.8515582273931511
1.0
0.6838021338506305
8.881668283220174
0.8530185378944922
1.0
0.70018621973929

KeyboardInterrupt: 

2.0


In [51]:
fraud_list

[[array([18, 13]), array([6, 1]), 0.6842163988189459]]

In [65]:
len_list[112]

array([[ 6.,  1.],
       [18., 13.]])

In [61]:
error_list

[[array([[18.        , 13.        ,  0.57301956,  0.95999998]]),
  array([[6.        , 1.        , 0.46191886, 0.97000003]])],
 [array([[6.        , 1.        , 0.09993429, 0.97000003]]),
  array([[18.        , 13.        ,  0.53477967,  0.95999998]])],
 [array([[6.        , 1.        , 0.38609537, 0.97000003]]),
  array([[18.        , 13.        ,  0.39514545,  0.95999998]]),
  array([[18.        , 13.        ,  0.09564076,  0.94999999]])],
 [array([[6.        , 1.        , 0.01215715, 0.97000003]]),
  array([[18.        , 13.        ,  0.68229145,  0.94999999]])],
 [array([[18.        , 13.        ,  0.71779233,  0.94999999]]),
  array([[6.        , 1.        , 0.47639892, 0.97000003]])],
 [array([[18.        , 13.        ,  0.85839802,  0.94999999]]),
  array([[6.        , 1.        , 0.12299316, 0.97000003]])],
 [array([[18.        , 13.        ,  0.16897102,  0.94999999]]),
  array([[6.        , 1.        , 0.02462346, 0.97000003]])],
 [array([[18.        , 13.        ,  0.6295735