In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
import string
import math
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score, roc_curve
from sklearn.model_selection import GridSearchCV
%matplotlib inline
import torch
import torch.nn as nn # Contains all the functions we need to to train our network
import torch.nn.functional as F # Contains some additional functions such as activations
from torch.autograd import Variable

In [None]:
def get_positional_embeddings(sequence_length, d): #"i" in paper corresponds to j - i.e. along dimension of size d
    result = torch.ones(sequence_length, d) #pos in paper refers to which token - i.e. varying from 1 to 50
    for i in range(sequence_length):
      for j in range(d):
        if j%2==0:
          result[i,j] = math.sin(i/10000**(j/d))
        else:
          result[i,j] = math.cos(i/10000**((j-1)/d))
    return result

In [None]:
class Selective_Transformer(nn.Module):
    def __init__(self, input_dim, hidden_d, out_d, n_heads, n_blocks, pos_dim, attends):
        super(Selective_Transformer, self).__init__()
        self.pos_dim = pos_dim
        #self.class_token = nn.Parameter(torch.rand((input_dim)))
        self.attends = attends
        self.hidden_d = hidden_d+pos_dim
        
        self.linear = nn.Linear(input_dim+pos_dim, self.hidden_d)
        
        self.blocks = nn.ModuleList([MyViTBlock(self.hidden_d, n_heads, attends, pos_dim) for _ in range(n_blocks)])
        
        self.mlp = nn.Linear(self.hidden_d,out_d)#for CELoss
        #nn.Sequential(nn.Linear(hidden_d,out_d), nn.Softmax(dim=-1))
        
    def forward(self, sentence):
        #print(sentence.shape)
        pos = torch.unsqueeze(get_positional_embeddings(sentence.shape[-2], pos_dim).to(device),0)
        #out = torch.cat((torch.unsqueeze(self.class_token+token_pos,0),sentence[0]),1)
        #print(sentence.shape)
        out = sentence[0]
        out = torch.cat(out,pos)
        out = self.linear(out) #so input is now hidden dim shape, i.e.1,s+1, h_d
        for block in self.blocks:
            out = block(out)    
        out = torch.mean(out,dim=1)
        
        return torch.unsqueeze(self.mlp(out),0)


In [None]:
class MyMSA(nn.Module):
    """MSA block"""
    def __init__(self, d, n_heads=2, attends, pos_dim):#d is hidden dim
        super(MyMSA, self).__init__()
        self.val_dim = d
        if attends == "val":
            self.d = d-pos_dim
        else:
            self.d = pos_dim
        self.n_heads = n_heads
        self.attends = attends
        self.pos_dim = pos_dim

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads) #dim of each head
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        
        if self.attends == "pos":
            self.v_mappings = nn.ModuleList([nn.Linear(val_dim/n_heads, val_dim/n_heads) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):#TODO: values only based on val section
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        if self.attends == "val" # easy case
            result = []
            for sequence in sequences:
                seq_result = []
                for head in range(self.n_heads):
                    q_mapping = self.q_mappings[head]
                    k_mapping = self.k_mappings[head]
                    v_mapping = self.v_mappings[head]

                    #print(f"{head=}")

                    seq = sequence[:, head * self.d_head: (head + 1) * self.d_head] #interesting? so each attention head only looks at a subset of features
                    #print(seq.shape)
                    q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                    attention = self.softmax(q @ k.T / math.sqrt(self.d)) # here we take dot product between q and k vectors
                    seq_result.append(attention @ v) #and here we do a weighted sum over v vectors based on attentions
                result.append(torch.hstack(seq_result))    
        elif self.attends == "pos":

            result = []
            for sequence in sequences:
                seq_result = []
                for head in range(self.n_heads):
                    q_mapping = self.q_mappings[head]
                    k_mapping = self.k_mappings[head]
                    v_mapping = self.v_mappings[head]

                    #print(f"{head=}")
                    #so q,k attend over pos, val still from embs
                    seq = sequence[:, val_dim+ head * self.d_head: val_dim+(head + 1) * self.d_head] #interesting? so each attention head only looks at a subset of features
                    seq2 = sequence[:, head * self.val_dim/self.n_heads: (head + 1) * self.val_dim/self.n_heads]
                    #print(seq.shape)
                    q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq2)

                    attention = self.softmax(q @ k.T / math.sqrt(self.d)) # here we take dot product between q and k vectors
                    seq_result.append(attention @ v) #and here we do a weighted sum over v vectors based on attentions
                result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])


In [None]:
class MyViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4, attends, pos_dim):
        super(MyViTBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MyMSA(hidden_d, n_heads, attends)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d-pos_dim, mlp_ratio * (hidden_d-pos_dim)),#i.e. mlp ratio tells us how much bigger mlp hidden is than previous hidden
            nn.GELU(),
            nn.Linear(mlp_ratio * (hidden_d-pos_dim), (hidden_d-pos_dim))
            )

    def forward(self, x):
        out = x + torch.cat((self.mhsa(self.norm1(x))[:,:,0:(hidden_d-pos_dim)],x[:,:,hidden_d-pos_dim:]),2) #so we do residual on multi self attention
        out = out + torch.cat((self.mlp(self.norm2(out)[:,:,0:(hidden_d-pos_dim)]),self.norm2(out)[:,:,(hidden_d-pos_dim):])) #then residual on mlp
        #importantly, this MLP can't update pos info
        return out