import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
import string
import math
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix, accuracy_score, roc_auc_score, roc_curve
from sklearn.model_selection import GridSearchCV
%matplotlib inline
import torch
import torch.nn as nn # Contains all the functions we need to to train our network
import torch.nn.functional as F # Contains some additional functions such as activations
from torch.autograd import Variable

In [3]:
def get_positional_embeddings(sequence_length, d): #"i" in paper corresponds to j - i.e. along dimension of size d
    result = torch.ones(sequence_length, d) #pos in paper refers to which token - i.e. varying from 1 to 50
    for i in range(sequence_length):
      for j in range(d):
        if j%2==0:
          result[i,j] = math.sin(i/10000**(j/d))
        else:
          result[i,j] = math.cos(i/10000**((j-1)/d))
    return result

In [4]:
class Transformer(nn.Module):
    def __init__(self, input_dim, hidden_d, out_d, n_heads, n_blocks):
        super(Transformer, self).__init__()
        
        #self.class_token = nn.Parameter(torch.rand((input_dim)))
        
        self.linear = nn.Linear(input_dim, hidden_d)
        
        self.blocks = nn.ModuleList([MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])
        
        self.mlp = nn.Linear(hidden_d,out_d)#for CELoss
        #nn.Sequential(nn.Linear(hidden_d,out_d), nn.Softmax(dim=-1))
        
    def forward(self, sentence):
        #print(sentence.shape)
        pos = torch.unsqueeze(get_positional_embeddings(sentence.shape[-2], sentence.shape[-1]).to(device),0)
        #out = torch.cat((torch.unsqueeze(self.class_token+token_pos,0),sentence[0]),1)
        #print(sentence.shape)
        out = sentence[0]
        out = out+pos
        out = self.linear(out) #so input is now hidden dim shape, i.e.1,s+1, h_d
        for block in self.blocks:
            out = block(out)    
        out = torch.mean(out,dim=1)
        
        return torch.unsqueeze(self.mlp(out),0)


In [5]:
class MyMSA(nn.Module):
    """MSA block"""
    def __init__(self, d, n_heads=2):#d is hidden dim
        super(MyMSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        d_head = int(d / n_heads) #dim of each head
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, sequences):
        # Sequences has shape (N, seq_length, token_dim)
        # We go into shape    (N, seq_length, n_heads, token_dim / n_heads)
        # And come back to    (N, seq_length, item_dim)  (through concatenation)
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]
                
                #print(f"{head=}")

                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head] #interesting? so each attention head only looks at a subset of features
                #print(seq.shape)
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                attention = self.softmax(q @ k.T / math.sqrt(self.d)) # here we take dot product between q and k vectors
                seq_result.append(attention @ v) #and here we do a weighted sum over v vectors based on attentions
            result.append(torch.hstack(seq_result))
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])


In [6]:
class MyViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=4):
        super(MyViTBlock, self).__init__()
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MyMSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),#i.e. mlp ratio tells us how much bigger mlp hidden is than previous hidden
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
            )

    def forward(self, x):
        out = x + self.mhsa(self.norm1(x)) #so we do residual on multi self attention
        out = out + self.mlp(self.norm2(out)) #then residual on mlp
        return out

In [7]:
class DownsamplingTransformer(nn.Module):
    def __init__(self, input_dim, hidden_d, out_d, n_heads, n_blocks):
        super(DownsamplingTransformer, self).__init__()
        input_dim, hidden_d, out_d, n_heads, n_blocks
    def __init__(self, input_dim, hidden_d, out_d, n_heads, n_blocks):
        super(tiered_transformer, self).__init__()
        self.layer1 = Transformer(input_dim, hidden_d, hidden_d, n_heads, n_blocks)
        self.layer2 = Transformer(hidden_d, hidden_d, out_d, n_heads, n_blocks)
        
    def forward(self, x):
        #input is 1xNxpaddedx768
        #transformers want 4 dim input, and do pos embedding themselves
        
        list_len = sentences.shape[1]
        dim = sentences.shape[-1]
        lengths = torch.sum((sentences[0]!=0),dim=1)[:,0]#get lengths of each sentence
        lst = []
        for i,elem in enumerate(sentences[0]):
            elem = elem[0:lengths[i]]
            fw = self.layer1(elem)
            lst.append(fw)    

        fws = torch.cat(lst)
        
        x = self.layer2(fws)
        return x