# Demo of Self Attention in Pytorch 

In this Notebook we want to test, that Self-Attention and Piece wise Feed Forward Neural Networks can solve a simple logical, deterministic problem. Furthermore, we want to attest to the degree, that the attention weights can "explain" the predictions. We see that this is the case for Piece Wise Feed Forward Neural Netowrks but it is not! the case for Self Attention Neural Networks. We conjegture, that this is due to the fact that the representation of Self Attention is dependend on all elements of the input sequence

In [1]:
import random
import numpy as np
from model_pytorch import *
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset,sampler,DataLoader

import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
from sklearn.metrics import roc_auc_score


In [2]:
#We generate list of range 
all_data=[]
all_label=[]
for j in range(100000):
    rand_list=[]
    for j in range(random.randint(1,100)):
        rand_list.append((random.randint(1,20),random.randint(1,20)))
    if ((17, 1) in rand_list and (8, 19) in rand_list):
        all_label.append(1)            
    else:
        all_label.append(0)
        
    all_data.append(np.array(rand_list))

In [3]:
ones=np.array(all_data)[np.where(all_label)]

In [4]:
train_X,val_X,train_y,val_y = train_test_split(all_data, all_label, test_size=0.1, random_state=123)


In [5]:
class dummy_data(Dataset):
    def __init__(self, input,output):     
        
        self.data=input
        self.label=output
        
    def __len__(self):
        return int(len(self.data))
    
    def __getitem__(self, index):
        
        X=self.data[index]
        y=self.label[index]
        
        return X, y

In [6]:
train_ds=dummy_data(train_X,train_y)
val_ds=dummy_data(val_X,val_y)

In [7]:
batch_size=32

In [8]:
#THe collate difnes how data gets batched 
def my_collate(batch):
    texts=([x[0] for x in batch])
    labels=np.array(([x[1] for x in batch]))
    batch_size=len(batch)
    maxlen=np.max([len(x) for x in texts])
    text_stack=np.zeros(shape=(batch_size,maxlen,2))
    for enu,txt in enumerate(texts):
        text_stack[enu,0:len(txt),:]=txt

    return torch.tensor(text_stack[:,:,0]).cuda(),torch.tensor(text_stack[:,:,1]).cuda(),torch.tensor(labels).cuda()

In [9]:
val_dl= DataLoader(dataset=val_ds,
                      batch_size=batch_size,
                      shuffle=False,
                      collate_fn=my_collate
                      )

In [10]:
train_dl= DataLoader(dataset=train_ds,
                      batch_size=batch_size,
                      shuffle=True,
                      collate_fn=my_collate
                      )

In [11]:
class simple_fraud_model(nn.Module):
    def __init__(self, d_model=32,heads=1,nlay=1,dropout=0,SelfA=True,return_w=False):
        
        super(simple_fraud_model,self).__init__()
        self.return_w=return_w
        emb_d_1=int(d_model/2)
        emb_d_2=int(d_model/2)
        
        self.embedding_1=nn.Embedding(num_embeddings=21,embedding_dim=emb_d_1)
        self.embedding_2=nn.Embedding(num_embeddings=21,embedding_dim=emb_d_2)

        
        if SelfA==True:
            self.encoder_layers=EncoderLayer(d_model=d_model,heads=heads,dropout=dropout,share_params=True)

        if SelfA==False:
            self.encoder_layers=FeedForward(d_model)


        self.mula=multi_attention(input_dim=d_model,key_dim=d_model,nheads=1,return_weights=True,value_dim=32)

        self.fully_con=nn.Linear(d_model,d_model*4)
        self.relu=nn.ReLU()
        self.final_fully_con=nn.Linear(d_model*4,1)
        self.sig=nn.Sigmoid()
        self.selfa=SelfA
        
        
    def forward(self, x1,x2,return_w):
        e1=self.embedding_1(x1)
        e2=self.embedding_2(x2)

        cat=torch.cat([e1,e2],dim=2)
        if self.selfa==True:
            
            feat,w2=self.encoder_layers(cat)
        else:
            feat=self.encoder_layers(cat)

        ag,weights=self.mula(feat)
        fc=self.relu(self.fully_con(ag.squeeze()))
        preds=self.sig(self.final_fully_con(fc))
        if return_w==False:
            return preds.squeeze()
        if return_w==True:
            if self.selfa==True:
                return preds.squeeze(),weights#,w2
            else:
                return preds.squeeze(),weights

In [12]:
def train_eval(atnm,train,opti,crit,eval_metrics,iterator,n_iter,writer):
    
    if train==False:
        loss_val=[]
        name1="val_loss"
        name2="val_roc"  
        atnm.eval()
    else:
        name1="train_loss"
        name2="train_roc"
        atnm.train()

    if eval_metrics:
        store_label=[]
        store_preds=[]
    #The epcoh 
    for batch in iterator:
        #print(batch[0].shape[1])
        opti.zero_grad()
        #batch=ba
        predictions=atnm(batch[0].long(),batch[1].long(),False)        #print(predictions)
        loss = crit(predictions, batch[2].float().cuda())

        if train==True:
            loss.backward()
            opti.step()
            n_iter=n_iter+1
            writer.add_scalar(name1,loss.cpu().detach().numpy(),n_iter)
        #when we dont train we dont write during epoch but only at the end
        #also we dont up the iter
        if train==False:
            loss_val.append(loss.cpu().detach().numpy())
        if eval_metrics== True: 
            store_preds.append(predictions.cpu().detach().numpy())
            store_label.append(batch[2].float().cpu().detach().numpy())
            
        del predictions
        del loss
### End of Batch
    if train == False:
        writer.add_scalar(name1,np.mean(loss_val),n_iter)

    if eval_metrics== True: 
        store_preds=np.concatenate(store_preds)
        store_label=np.concatenate(store_label)
        roc=roc_auc_score(store_label,store_preds)
        writer.add_scalar(name2,roc,n_iter)

        return roc

In [19]:
sf=simple_fraud_model(SelfA=False).cuda()
criterion=nn.BCELoss()
optimizer = torch.optim.Adam(sf.parameters(), lr=5e-5)

In [22]:
writer = SummaryWriter(log_dir="logs/pff_residual")

In [23]:
max_epochs=20

In [24]:
for j in range(max_epochs):
    #atnm,train,opti,crit,eval_metrics,iterator
    n_iter=(train_ds.__len__()/batch_size)*j
    roc_t=train_eval(sf,True,optimizer,criterion,True
                   ,iter(train_dl),n_iter=n_iter,writer=writer)
    roc_v=train_eval(atnm=sf
                     ,train=False
                     ,opti=optimizer
                     ,crit=criterion
                     ,eval_metrics=True
                     ,iterator=iter(val_dl)
                    ,writer=writer
                    ,n_iter=n_iter
                   )
    print(roc_v)

0.7554393598925176
0.8205884881820161
0.8583355415425753
0.9602899449967341
0.9916014302156436
0.9992824786516028
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0


In [25]:
it=iter(val_dl)


# Explaining predictions

In this loop we show that the model correctly recovers the "reason" for a fraud prediction as being the existence of the two problematic input elements. Here Self Attention generally performs much better. 



In [26]:
for batch in it:
    #We get both inputs
    i1=batch[0].long()
    i2=batch[1].long()
    
    #We get both prediction and attention weights 
    predictions,weights=sf(i1,i2,True)
    #we put the predictions and weights to numpy, 
    preds_np=np.round(predictions.detach().cpu().numpy())
    weights_np=weights.detach().cpu().numpy()
    #We only select the elments we predicted as a 
    fraud_elem=np.where(preds_np>0.5)[0]
    #A prediction check, that we are correct
    if fraud_elem.size>0:
        
        if not(all(fraud_elem == np.where(batch[2].cpu().numpy())[0])) : 
            print("predicted wrong")
            break
            
    #if len(np.where(fraud_elem))>0:
        i1=i1.detach().cpu().numpy()
        i2=i2.detach().cpu().numpy()   
        
        for fraud_ele in fraud_elem:
            my_weight=np.squeeze(np.round(weights_np[fraud_ele],2))
            in_elem=np.stack([np.squeeze(i2[fraud_ele]),np.squeeze(i1[fraud_ele])],axis=1)

            atn_elem=np.where(my_weight>0.5)[0]
            in_elem_j=in_elem[atn_elem]
            
            print(in_elem_j)


[[19  8]
 [ 1 17]]
[[ 1 17]
 [19  8]]
[[19  8]
 [ 1 17]]
[[19  8]
 [ 1 17]]
[[19  8]
 [ 1 17]]
[[ 1 17]
 [19  8]]
[[ 1 17]
 [19  8]]
[[ 1 17]
 [19  8]]
[[ 1 17]
 [19  8]]
[[19  8]
 [19  8]
 [ 1 17]]
[[19  8]
 [ 1 17]
 [19  8]]
[[ 1 17]
 [19  8]]
[[19  8]
 [ 1 17]]
[[19  8]
 [ 1 17]
 [ 1 17]]
[[ 1 17]
 [19  8]]
[[19  8]
 [ 1 17]]
[[19  8]
 [ 1 17]]
[[ 1 17]
 [ 1 17]
 [19  8]]
[[ 1 17]
 [19  8]]
[[19  8]
 [ 1 17]]
[[ 1 17]
 [19  8]]
[[19  8]
 [ 1 17]]
[[ 1 17]
 [19  8]]
[[ 1 17]
 [19  8]]
[[19  8]
 [ 1 17]]
[[19  8]
 [ 1 17]]
[[ 1 17]
 [19  8]]
[[19  8]
 [ 1 17]]
[[ 1 17]
 [19  8]]
[[ 1 17]
 [19  8]]
[[ 1 17]
 [19  8]]
[[ 1 17]
 [19  8]]
[[19  8]
 [ 1 17]]
[[ 1 17]
 [19  8]]
[[ 1 17]
 [19  8]]
[[ 1 17]
 [19  8]
 [19  8]]
[[19  8]
 [ 1 17]]
[[19  8]
 [ 1 17]]
[[ 1 17]
 [19  8]]
[[19  8]
 [ 1 17]
 [19  8]]
[[19  8]
 [ 1 17]]
[[19  8]
 [ 1 17]]
[[ 1 17]
 [19  8]]
[[19  8]
 [ 1 17]]
[[ 1 17]
 [19  8]]
[[ 1 17]
 [19  8]]
[[19  8]
 [ 1 17]]
[[ 1 17]
 [19  8]]
[[ 1 17]
 [19  8]]
[[ 1 17]
 [19  

In [None]:
(17, 1) in rand_list and (8, 19)

In [249]:
my_weight[16]

0.01