## [TODO]

- Dropout everywhere
- Should glove embeddings be updated?
- Dealing with unknown tokens
- You and stocknet-code have included y_T in calculation in ATA, while paper equations have not. Think about it
- Correct y_size to 2 from 1 because two classes *aye*, *no*

# Creating the Model

In [2]:
import torch; torch.manual_seed(0)
import torch.nn as nn
import torch.distributions as ds

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import pandas as pd
import numpy as np

In [4]:
from collections import defaultdict

In [5]:
word_embed_size = 100
embedder_hidden_size = 100
gru_num_layers = 1
window_size = 4

dropout_vmd_in = 0.3
vmd_hidden_size = 150
g_size = 50
y_size = 2

num_epochs = 1

In [6]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.Embedder = Embedder()
        self.VMD = VMD(self)
        self.ATA = ATA(self)
    
    
    def forward(self, X, Y):
        """
            
        """
        
        X = self.Embedder(X)
        self.VMD(X, Y)
        self.ATA(Y)
#         print("X", X.shape)
#         print("kl", self.kl.shape)
#         print("g", self.g.shape)
#         print("y", self.y.shape)
        # result = something
#         return X  
    
    
    def get_z_and_z_distr(self, arg, is_prior):
        mean = nn.Linear(arg.size(-1), vmd_hidden_size)(arg)
        stddev = nn.Linear(arg.size(-1), vmd_hidden_size)(arg)      
        stddev = torch.sqrt(torch.exp(stddev))
        epsilon = torch.randn(vmd_hidden_size)
        
        z = mean if is_prior else mean + torch.mul(epsilon, stddev)
        z_pdf = ds.normal.Normal(loc=mean, scale=stddev)
        return z, z_pdf

In [7]:
class Embedder(nn.Module):
    def __init__(self):
        super(Embedder, self).__init__()
        self.bi_gru = nn.GRU(word_embed_size, embedder_hidden_size, num_layers=gru_num_layers, batch_first=True, bidirectional=True)
        

    def forward(self, X):
        """
            X: window_size (for speaker's history) * max_words_in_a_speech * input_size
            
        """
#         X = get_glove_padded_seq(X)
        h_0 = torch.randn(2*gru_num_layers, window_size, embedder_hidden_size) # multiplied by 2 because bidirectional
        _, h_n = self.bi_gru(X, h_0)
        h_f = h_n[0, :, :] 
        h_b = h_n[1, :, :]
        msg_embed = (h_f + h_b) / 2
        return msg_embed
    

In [8]:
class VMD(nn.Module):
    def __init__(self, model):
        super(VMD, self).__init__()
#         self.input_dropout = nn.Dropout(p=dropout_vmd_in)
        self.model = model    
        self.gru = nn.GRU(embedder_hidden_size, vmd_hidden_size, num_layers=gru_num_layers, bidirectional=False)
        
        
    def forward(self, X, Y):
#         X = self.input_dropout(X)
        h_0 = torch.randn(gru_num_layers, vmd_hidden_size) 
        h_s, _ = self.gru(X, h_0) # h_s: window_size * vmd_hidden_size
        
        z_prior = []
        z_post = []
        kl = []
        z_post_t_minus_1 = torch.randn(vmd_hidden_size)
        for t in range(window_size):
            h_z_prior_t = nn.Linear(embedder_hidden_size+2*vmd_hidden_size, vmd_hidden_size)(
                torch.cat([X[t], h_s[t], z_post_t_minus_1])
            )
            h_z_prior_t = nn.Tanh()(h_z_prior_t)

            h_z_post_t = nn.Linear(embedder_hidden_size+2*vmd_hidden_size+y_size, vmd_hidden_size)(
                torch.cat([X[t], h_s[t], Y[t], z_post_t_minus_1])
            )
            h_z_post_t = nn.Tanh()(h_z_post_t)
            
            z_prior_t, z_prior_t_pdf = self.model.get_z_and_z_distr(h_z_prior_t, is_prior=True)
            z_post_t, z_post_t_pdf = self.model.get_z_and_z_distr(h_z_post_t, is_prior=False)
            z_post_t_minus_1 = z_post_t
            
            kl_t = ds.kl.kl_divergence(z_prior_t_pdf, z_post_t_pdf)
            
            z_prior.append(z_prior_t)
            z_post.append(z_post_t)
            kl.append(kl_t)
        
        z_prior = torch.stack(z_prior) # window_size * vmd_hidden_size
        z_post = torch.stack(z_post) # window_size * vmd_hidden_size
        self.model.kl = torch.stack(kl).sum(dim=1) # window_size
        
        self.model.g = nn.Linear(2*vmd_hidden_size, g_size)(
            torch.cat([h_s, z_post],dim=1) # TODO: check if X is also to be concatenated as in Eqn 21
        )        
        self.model.g = nn.Tanh()(self.model.g) #
        
        y = nn.Linear(g_size, y_size)(self.model.g)
        self.model.y = nn.Softmax(dim=1)(y)
#         print("y:", self.model.y)

        self.model.g_T = self.model.g[-1, :] 
        # TODO: 1. will g_T be different during training and evaluation?
        # TODO: 2. is g_T definitely what you think it is?


In [9]:
class ATA(nn.Module):
    def __init__(self, model):
        super(ATA, self).__init__()
        self.model = model
        self.alpha = 0.5
        
        
    def forward(self, Y):
        linear_i = nn.Linear(g_size, g_size, bias=False)(self.model.g)
        linear_i = nn.Tanh()(linear_i) # (window_size, g_size)
        w_i = nn.init.xavier_normal_(torch.zeros((g_size, 1))) # (g_size, 1)
        v_i = linear_i @ w_i # (window_size, 1)
        
        linear_d = nn.Linear(g_size, g_size, bias=False)(self.model.g)
        linear_d = nn.Tanh()(linear_d)
        g_T = self.model.g_T[:, None]
        v_d = linear_d @ g_T
        
        aux_score = torch.mul(v_i, v_d)
        aux_score[-1, :] = np.NINF
        
        v_starred = nn.Softmax(dim=0)(aux_score)
        self.model.v_starred = torch.where(torch.isnan(v_starred), torch.tensor(0, dtype=torch.float32), v_starred)
        
        att_c = self.model.v_starred.T @ self.model.y
#         print(att_c.shape, self.model.g_T.shape)
        self.model.y_T = nn.Linear(y_size+g_size, y_size)(torch.cat([torch.squeeze(att_c), self.model.g_T]))
        self.model.y_T = nn.Softmax(dim=0)(self.model.y_T)  
#         print("y_T:", self.model.y_T.shape)
        
        self.calculate_loss(Y)
        
        
    def calculate_loss(self, Y):
#         print(self.model.v_starred)
        v_aux = self.alpha * self.model.v_starred
        likelihood_aux = torch.sum(torch.mul(Y, torch.log(self.model.y)), dim=1)
        
        kl_lambda = self.get_kl_lambda()
        obj_aux = likelihood_aux - kl_lambda * self.model.kl
        
        y_T_orig = Y[-1, :]
        likelihood_T = torch.sum(torch.mul(y_T_orig, torch.log(self.model.y_T)))
        
        kl_T = self.model.kl[-1]
        obj_T = likelihood_T - kl_lambda * kl_T
        
        self.model.loss = obj_T + torch.sum(torch.mul(obj_aux, v_aux))
        
        
    def get_kl_lambda(self):
        # TODO: implement KL annealing
        return 0.5

# Prepare train data

In [10]:
data = pd.read_csv("dataset/ParlVote/ParlVote_concat.csv")

In [11]:
# speakers[speaker_id: int] = list of tuples (debate_id, speech, vote)
speakers = defaultdict(lambda: [])
for idx, row in data.iterrows():
    speakers[row["speaker_id"]].append([row["debate_id"], row["speech"], row["vote"]])
        
for k, v in speakers.items():
    speakers[k] = sorted(v)

## Analysis of speaker sequence length

In [12]:
mx = 0
for usid in data.speaker_id.unique():
    mx = max(mx, len(speakers[usid]))
print(mx)

284


In [13]:
speaker_history_lens = []
for sid, each in speakers.items():
    speaker_history_lens.append(len(each)) 
speaker_history_lens.sort()

In [14]:
shl = pd.Series(speaker_history_lens)
shl.describe()

count    1346.000000
mean       24.859584
std        26.420400
min         1.000000
25%         8.000000
50%        16.000000
75%        32.000000
max       284.000000
dtype: float64

In [15]:
(len(list(filter(lambda x: x >= 4, speaker_history_lens))) / len(speaker_history_lens)) * 100

90.04457652303121

- The above cell shows 90% speakers have at least *window_size*=4 speeches
- If we use all data, we need to pad speeches that have (< *window_size*=4) speeches
- Decision: Drop these 10% speakers

In [16]:
speakers = dict(filter(lambda val: len(val[1]) >= 4, speakers.items()))

In [17]:
100 * (len(speakers) / len(speaker_history_lens))

90.04457652303121

## Download Glove-100d

In [18]:
# ! wget -r -nc https://nlp.stanford.edu/data/glove.6B.zip

In [19]:
# !unzip nlp.stanford.edu/data/glove.6B.zip

In [20]:
# ! head -n1 glove.6B.100d.txt

In [21]:
embeds = dict()
for line in open("nlp.stanford.edu/data/glove.6B.100d.txt", "r"):
    line = line.split()
    embeds[line[0]] = torch.tensor(list(map(float, line[1:])), dtype=torch.float32)

In [22]:
# glove_vocab = set()
# for f in open("nlp.stanford.edu/data/glove.6B.100d.txt", "r"):
#     glove_vocab.add(f.split()[0])
# print(len(glove_vocab))

# del glove_vocab

## Preprocess

In [23]:
from nltk.tokenize import TweetTokenizer

In [24]:
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
# import nltk
# nltk.download('punkt')
# nltk.download('stopwords')

In [25]:
def preprocess(s: str):
    
    # Lowercasing
    s = s.lower()
    
    # Tokenization
    ttk = TweetTokenizer()
    tokens = ttk.tokenize(s)
#     tokens = word_tokenize(s)
    
    # Stopwords removal
    tokens = [w for w in tokens if not w in stopwords.words('english')]
    
#     # Stemming
#     ps = nltk.PorterStemmer()
#     tokens = [ps.stem(w) for w in tokens]
    return tokens

In [26]:
# %%timeit
# embed_keys = set(embeds.keys())
# ans = 0
# i = 0
# sids = speakers.keys()
# for sid in sids:
    
#     s = speakers[sid][0][1]
#     pre = preprocess(s)
#     ans += len(set(pre).difference((set(pre).intersection(embed_keys))))
    
#     print("i", i, " done:", ans)
#     i += 1

In [27]:
# embeds['abhorrent']

In [28]:
# input1 = embeds["drank"]
# input2 = embeds["drunk"]
# cos = nn.CosineSimilarity(dim=0)
# output = cos(input1, input2)

## Create PaddedTensor for input

In [29]:
from torch.nn.utils.rnn import pad_sequence

In [30]:
def get_X_Y_padded_tensors(debates):
        """
            X: window_size * max_seq_len * word_embed_size 
            Y: window_size
        """
        X = []
        Y = []
        max_tokens = 0
        for _, speech, vote in debates:
            tokens = preprocess(speech)
#             print("len preprocessed speech:", len(tokens))
            X.append(torch.stack([embeds.get(token, torch.randn(word_embed_size)) for token in tokens]))
            y = torch.zeros(2)
            y[vote] = 1. # y = [0, 1] if vote=1 else [1, 0]
            Y.append(y)
        
        X = pad_sequence(X, batch_first=True)
        return X, torch.stack(Y)

In [31]:
model = Model()

In [32]:
model.Embedder

Embedder(
  (bi_gru): GRU(100, 100, batch_first=True, bidirectional=True)
)

In [33]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [34]:
freq = [len(v) for k, v in speakers.items()]

In [35]:
freq_sum = torch.tensor(freq).sum()

In [36]:
(freq_sum.item())

33181

In [37]:
sum_ = 0
ans = 0
freq_sum = torch.tensor(freq).sum().item()
for i in range(len(freq)-1, -1, -1):
    if sum_ + (freq[i]) > 0.2 * freq_sum:
        break
    sum_ += (freq[i])
    ans = i
ans

837

In [39]:
model.state_dict()

RecursionError: maximum recursion depth exceeded while calling a Python object

In [38]:
%%timeit
for epoch in range(num_epochs):
    scnt = 0
    done_cnt = 0
    loss_batch = 0
    for speaker_num, (speaker_id, debates) in enumerate(speakers.items()):
        for i in range(window_size, len(debates)+1):
            X, Y = get_X_Y_padded_tensors(debates[i-window_size: i]) 
            model(X, Y)
            loss_batch += model.loss
            
            optimizer.zero_grad()
            model.loss.backward()
            optimizer.step()
            
            done_cnt += 1
            if done_cnt%5==0:
                print("Speaker num:", speaker_num, "\t", "done:", done_cnt, "\t", "perc:", 100*(done_cnt/(0.8*freq_sum)), "\tloss:", loss_batch, "\n"+"-"*50+"\n")
                loss_batch = 0
                
            if done_cnt%5 == 0:
                torch.save(model.state_dict(), "final_model.pt")

        if scnt > ans:
            break
        scnt += 1

Speaker num: 0 	 done: 5 	 perc: 0.01883608088966577 	loss: tensor(-79.5806, grad_fn=<AddBackward0>) 
--------------------------------------------------



RecursionError: maximum recursion depth exceeded while calling a Python object

In [1]:
import sys
sys.setrecursionlimit(10000)