In [2]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as dsets
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import torch.optim as optim
import torch.nn.functional as F

In [55]:
import pickle as pkl
from collections import defaultdict
import pandas as pd
import os
import numpy as np
from tqdm import tqdm, tqdm_notebook
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

In [4]:
# %run ../twitter15/twitter15-datapreprocess.ipynb

## Loading Labels

In [5]:
twitter15_label_file = '../twitter15/label.txt'
twitter15_text_file = '../twitter15/source_tweets.txt'

In [6]:
def load_labels(file):
    f = open(file,'r')
    labels = {}
    
    raw_data = f.readlines()
    
    for line in raw_data:
        line = line.strip()
        line = line.split(':')
        labels[int(line[1])] = line[0]
    
    return labels

In [7]:
twitter15_labels = load_labels(twitter15_label_file)
twitter15_labels

{731166399389962242: 'unverified',
 714598641827246081: 'unverified',
 691809004356501505: 'non-rumor',
 693204708933160960: 'non-rumor',
 551099691702956032: 'true',
 767223774072676354: 'non-rumor',
 715515982584881152: 'unverified',
 514106273852174337: 'true',
 500319801344929795: 'unverified',
 495366618818830336: 'false',
 532206910796468224: 'false',
 560187970389819392: 'false',
 531568534066057217: 'false',
 489829414704648192: 'false',
 524925730053181440: 'unverified',
 766989078294306816: 'non-rumor',
 499530130487017472: 'unverified',
 520284654755381249: 'false',
 767515401831997440: 'non-rumor',
 565999191982616577: 'true',
 554343513887105024: 'false',
 767715489485205505: 'non-rumor',
 553467311261503488: 'unverified',
 553960736964476928: 'false',
 500303431928922113: 'unverified',
 538900739880665088: 'unverified',
 516420964834611201: 'unverified',
 80080680482123777: 'false',
 687945926774800385: 'non-rumor',
 436089462326849536: 'false',
 568589712644026368: 'fals

## Data Structures so that Pickle can work

In [8]:
class Node:
    def __init__(self,uid,tid,time_stamp,label):
        self.children = {}
        self.childrenList = []
        self.num_children = 0
        self.tid = tid
        self.uid = uid
        self.label = label
        self.time_stamp = time_stamp
    
    def add_child(self,node):
        if node.uid not in self.children:
            self.children[node.uid] = node
            self.num_children += 1
        else:
            self.children[node.uid] = node
        self.childrenList = list(self.children.values())

In [9]:
class Tree:
    def __init__(self,root):
        self.root = root
        self.tweet_id = root.tid
        self.uid = root.uid
        self.height = 0
        self.nodes = 0
    
    def show(self):
        queue = [self.root,0]
        
        while len(queue) != 0:
            toprint = queue.pop(0)
            if toprint == 0:
                print('\n')
            else:
                print(toprint.uid,end=' ')
                queue += toprint.children.values()
                queue.append(0)
                
    def insertnode(self,curnode,parent,child):
        if curnode.uid == parent.uid:
            curnode.add_child(child)
            return 1

        elif parent.uid in curnode.children:
            s = self.insertnode(curnode.children[parent.uid],parent,child)
            return 2
        else:
            for node in curnode.children:
                s = self.insertnode(curnode.children[node],parent,child)
                if s == 2:
                    break

In [10]:
def loadPklFileNum(datapath,incSize,fileNum):
    
    with open(datapath+str(incSize)+'inc_'+str(fileNum)+'.pickle', 'rb') as handle:
        twitTrees = pkl.load(handle)
    return twitTrees

In [11]:
def loadTreeFilesOfIncrement(datapath,incSize):
    twittertrees = {}
    
    files = [x for x in os.listdir(t15Datapath) if str(incSize)+'inc' in x]
    
    for file in tqdm(files):
        with open(datapath+file,'rb') as handle:
            partialTrees = pkl.load(handle)
        twittertrees.update(partialTrees)
        
    return twittertrees

In [13]:
t15Datapath = '/home/nikhil.pinnaparaju/Research/Temporal Tree Encoding/twitter15/pickledTrees/'
# twitter15_trees = loadPklFileNum(t15Datapath,20,1)

In [14]:
twitter15_trees = loadTreeFilesOfIncrement(t15Datapath,20)

100%|██████████| 16/16 [02:10<00:00,  8.17s/it]


In [15]:
X = []
y = []
for tid in twitter15_trees:
    if tid in twitter15_trees and tid in twitter15_labels:
        X.append(twitter15_trees[tid])
        y.append(twitter15_labels[tid])

In [16]:
if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'

## Loading UserData

In [17]:
%run ../twitter15/userdata_parser.ipynb

100%|██████████| 34/34 [04:44<00:00,  8.36s/it]


In [18]:
for key in tqdm(userVects):
    userVects[key] = userVects[key].float().to(device)

userVects = defaultdict(lambda:torch.tensor([1.1100e+02, 1.5000e+01, 0.0000e+00, 7.9700e+02, 4.7300e+02, 0.0000e+00,
        8.3326e+04, 1.0000e+00]).to(device),userVects)

100%|██████████| 430343/430343 [00:16<00:00, 26437.39it/s]


## Loading All Architectures

In [19]:
%run ./temporal_tree_model.ipynb 

In [20]:
labelMap = {}
labelCount = 0
for label in list(twitter15_labels.values()):
    if label not in labelMap:
        labelMap[label] = labelCount
        labelCount += 1
labelMap

{'unverified': 0, 'non-rumor': 1, 'true': 2, 'false': 3}

## Optim and Loss Fxn & Creating Model Inst of Regular Temporal Tree Encoder

In [22]:
criterion = torch.nn.CrossEntropyLoss()

In [23]:
# model = treeEncoder(torch.cuda.is_available(),8,30,userVects,twitter15_labels,labelMap,criterion,device)
# model = model.to(device)

In [24]:
checkpoint = torch.load('./tempTreeEnc.pth')

In [25]:
model = lstmTreeEncoder(torch.cuda.is_available(),8,30,userVects,twitter15_labels,labelMap,criterion,device)
model = model.to(device)
model.load_state_dict(checkpoint['state_dict'])

<All keys matched successfully>

In [26]:
optimizer = torch.optim.Adam(model.parameters(),lr = 0.01)

In [27]:
sample_pred = model(twitter15_trees[537913349338435584])
sample_pred

100%|██████████| 12/12 [00:02<00:00,  4.48it/s]


tensor([[ 0.2583, -0.0960,  0.0710, -0.0551]], device='cuda:0',
       grad_fn=<AddmmBackward>)

In [28]:
model

lstmTreeEncoder(
  (treeEnc): treeEncoder(
    (criterion): CrossEntropyLoss()
    (ix): Linear(in_features=8, out_features=30, bias=True)
    (ih): Linear(in_features=30, out_features=30, bias=True)
    (fx): Linear(in_features=8, out_features=30, bias=True)
    (fh): Linear(in_features=30, out_features=30, bias=True)
    (ux): Linear(in_features=8, out_features=30, bias=True)
    (uh): Linear(in_features=30, out_features=30, bias=True)
    (ox): Linear(in_features=8, out_features=30, bias=True)
    (oh): Linear(in_features=30, out_features=30, bias=True)
    (outputModule): OutputModule(
      (l1): Linear(in_features=30, out_features=4, bias=True)
      (logsoftmax): LogSoftmax()
    )
  )
  (topLevelLSTM): LSTM(30, 15)
  (fc): Linear(in_features=15, out_features=4, bias=True)
)

## Training

In [29]:
epochs = 5

In [30]:
X = []
y = []
for tid in twitter15_trees:
        if tid in twitter15_trees and tid in twitter15_labels:
            X.append(twitter15_trees[tid])
            y.append(twitter15_labels[tid])

In [31]:
x_train, x_test, y_train, y_test = train_test_split(X,y)

In [32]:
x_train[0]

[<__main__.Tree at 0x7fba9b1fc490>,
 <__main__.Tree at 0x7fba9b1fca10>,
 <__main__.Tree at 0x7fba9b207490>,
 <__main__.Tree at 0x7fba9b211410>,
 <__main__.Tree at 0x7fba9b21d890>,
 <__main__.Tree at 0x7fba9b1b2250>,
 <__main__.Tree at 0x7fba9b1c9110>,
 <__main__.Tree at 0x7fba9b1de4d0>,
 <__main__.Tree at 0x7fba9b176d90>]

#### Trainer for Tree Encoder

count = 0

train_iterwise = []
val_iterwise = []

for i in range(epochs):
    train_losses = []
    val_losses = []
    
    for treeSet in tqdm(x_train):
        tnum = 0
        for tree in treeSet:
            print(tnum)
            count += 1
            tnum += 1
            optimizer.zero_grad()
            
            (h,c),loss = model(tree.root)
        
            label = Variable(torch.tensor(labelMap[tree.root.label]))
            
            if torch.cuda.is_available():
                label.to(device)
            
            loss.backward()
            train_losses.append(loss.item())
            optimizer.step()
            
            if count % 10000 == 0:
                preds = []
                labels = []
                
                for valSet in x_test:
                    finalTree = valSet[-1]
                    preds.append(model.predict(finalTree.root))
                    labels.append(labelMap[finalTree.root.label])
                
                predTensor = torch.stack(preds)
                labelTensor = torch.tensor(labels).to(device)
#                 print(predTensor)
#                 print(labelTensor)
                loss = criterion(predTensor.reshape(-1,4), labelTensor.reshape(-1))
                print('Loss Value: ', loss.item())
                val_losses.append(loss.item())
    train_iterwise.append(np.array(train_losses).mean())
    val_iterwise.append(np.array(val_losses).mean())

#### Trainer for Temporal Tree Encoder

In [48]:
lossfile = './lossesTempEnc.txt'
f = open(lossfile, "a")

In [None]:
count = 0
for i in range(epochs):    
    for treeSet in tqdm_notebook(x_train):     
        count += 1
        optimizer.zero_grad()
            
        pred = model(treeSet)
        
        label = Variable(torch.tensor(labelMap[treeSet[0].root.label]).reshape(-1).to(device))
        loss = criterion(pred,label)
        
        loss.backward()
        optimizer.step()
            
        if count % 500 == 0:
            preds = []
            
            for i in range(len(x_test)):
                valTreeSet = x_test[i]
                preds.append(model(valTreeSet))
                
                predTensor = torch.stack(preds)
                labelTensor = torch.tensor([labelMap[i] for i in y_test]).to(device)
                loss = criterion(predTensor.reshape(-1,4), labelTensor.reshape(-1))
                
                f.write(str(loss.item()))
                
f.close()

HBox(children=(IntProgress(value=0, max=1117), HTML(value='')))





 22%|██▏       | 2/9 [00:00<00:00, 16.10it/s][A[A

 33%|███▎      | 3/9 [00:00<00:00, 12.37it/s][A[A

 44%|████▍     | 4/9 [00:00<00:00,  9.44it/s][A[A

 56%|█████▌    | 5/9 [00:00<00:00,  7.56it/s][A[A

 67%|██████▋   | 6/9 [00:00<00:00,  6.26it/s][A[A

 78%|███████▊  | 7/9 [00:01<00:00,  5.25it/s][A[A

 89%|████████▉ | 8/9 [00:01<00:00,  4.48it/s][A[A

100%|██████████| 9/9 [00:01<00:00,  5.21it/s][A[A


  0%|          | 0/10 [00:00<?, ?it/s][A[A

 20%|██        | 2/10 [00:00<00:00, 15.93it/s][A[A

 30%|███       | 3/10 [00:00<00:00, 12.44it/s][A[A

 40%|████      | 4/10 [00:00<00:00,  9.55it/s][A[A

 50%|█████     | 5/10 [00:00<00:00,  7.47it/s][A[A

 60%|██████    | 6/10 [00:00<00:00,  6.20it/s][A[A

 70%|███████   | 7/10 [00:01<00:00,  5.31it/s][A[A

 80%|████████  | 8/10 [00:01<00:00,  4.58it/s][A[A

 90%|█████████ | 9/10 [00:01<00:00,  4.01it/s][A[A

100%|██████████| 10/10 [00:02<00:00,  4.94it/s][A[A


  0%|          | 0/7 [00:00<?, ?it/s]

 83%|████████▎ | 15/18 [00:04<00:01,  2.30it/s][A[A

 89%|████████▉ | 16/18 [00:04<00:00,  2.13it/s][A[A

 94%|█████████▍| 17/18 [00:05<00:00,  1.98it/s][A[A

100%|██████████| 18/18 [00:05<00:00,  3.03it/s][A[A


  0%|          | 0/12 [00:00<?, ?it/s][A[A

 17%|█▋        | 2/12 [00:00<00:00, 18.47it/s][A[A

 25%|██▌       | 3/12 [00:00<00:00, 14.03it/s][A[A

 33%|███▎      | 4/12 [00:00<00:00, 10.50it/s][A[A

 42%|████▏     | 5/12 [00:00<00:00,  8.20it/s][A[A

 50%|█████     | 6/12 [00:00<00:00,  6.80it/s][A[A

 58%|█████▊    | 7/12 [00:01<00:00,  5.69it/s][A[A

 67%|██████▋   | 8/12 [00:01<00:00,  4.85it/s][A[A

 75%|███████▌  | 9/12 [00:01<00:00,  4.21it/s][A[A

 83%|████████▎ | 10/12 [00:01<00:00,  3.71it/s][A[A

 92%|█████████▏| 11/12 [00:02<00:00,  3.32it/s][A[A

100%|██████████| 12/12 [00:02<00:00,  4.40it/s][A[A


  0%|          | 0/43 [00:00<?, ?it/s][A[A

  5%|▍         | 2/43 [00:00<00:02, 16.09it/s][A[A

  7%|▋         | 3/43 [00:00<00:03,

 24%|██▍       | 6/25 [00:00<00:02,  7.29it/s][A[A

 28%|██▊       | 7/25 [00:00<00:02,  6.02it/s][A[A

 32%|███▏      | 8/25 [00:01<00:03,  5.12it/s][A[A

 36%|███▌      | 9/25 [00:01<00:03,  4.41it/s][A[A

 40%|████      | 10/25 [00:01<00:03,  3.90it/s][A[A

 44%|████▍     | 11/25 [00:02<00:04,  3.50it/s][A[A

 48%|████▊     | 12/25 [00:02<00:04,  3.17it/s][A[A

 52%|█████▏    | 13/25 [00:02<00:04,  2.89it/s][A[A

 56%|█████▌    | 14/25 [00:03<00:04,  2.65it/s][A[A

 60%|██████    | 15/25 [00:03<00:04,  2.44it/s][A[A

 64%|██████▍   | 16/25 [00:04<00:03,  2.26it/s][A[A

 68%|██████▊   | 17/25 [00:04<00:03,  2.10it/s][A[A

 72%|███████▏  | 18/25 [00:05<00:03,  1.95it/s][A[A

 76%|███████▌  | 19/25 [00:06<00:03,  1.82it/s][A[A

 80%|████████  | 20/25 [00:06<00:02,  1.71it/s][A[A

 84%|████████▍ | 21/25 [00:07<00:02,  1.61it/s][A[A

 88%|████████▊ | 22/25 [00:08<00:01,  1.52it/s][A[A

 92%|█████████▏| 23/25 [00:09<00:01,  1.45it/s][A[A

 96%|█████████

 71%|███████▏  | 30/42 [00:16<00:11,  1.04it/s][A[A

 74%|███████▍  | 31/42 [00:17<00:10,  1.00it/s][A[A

 76%|███████▌  | 32/42 [00:18<00:10,  1.03s/it][A[A

 79%|███████▊  | 33/42 [00:19<00:09,  1.06s/it][A[A

 81%|████████  | 34/42 [00:20<00:08,  1.10s/it][A[A

 83%|████████▎ | 35/42 [00:21<00:07,  1.13s/it][A[A

 86%|████████▌ | 36/42 [00:23<00:07,  1.17s/it][A[A

 88%|████████▊ | 37/42 [00:24<00:06,  1.20s/it][A[A

 90%|█████████ | 38/42 [00:25<00:04,  1.24s/it][A[A

 93%|█████████▎| 39/42 [00:26<00:03,  1.27s/it][A[A

 95%|█████████▌| 40/42 [00:28<00:02,  1.31s/it][A[A

 98%|█████████▊| 41/42 [00:29<00:01,  1.34s/it][A[A

100%|██████████| 42/42 [00:31<00:00,  1.35it/s][A[A


  0%|          | 0/6 [00:00<?, ?it/s][A[A

 33%|███▎      | 2/6 [00:00<00:00, 17.67it/s][A[A

 50%|█████     | 3/6 [00:00<00:00, 13.93it/s][A[A

 67%|██████▋   | 4/6 [00:00<00:00, 10.64it/s][A[A

 83%|████████▎ | 5/6 [00:00<00:00,  8.25it/s][A[A

100%|██████████| 6/6 [00:00

100%|██████████| 17/17 [00:05<00:00,  3.09it/s][A[A


  0%|          | 0/7 [00:00<?, ?it/s][A[A

 29%|██▊       | 2/7 [00:00<00:00, 19.02it/s][A[A

 43%|████▎     | 3/7 [00:00<00:00, 14.65it/s][A[A

 57%|█████▋    | 4/7 [00:00<00:00, 11.14it/s][A[A

 71%|███████▏  | 5/7 [00:00<00:00,  8.67it/s][A[A

 86%|████████▌ | 6/7 [00:00<00:00,  6.98it/s][A[A

100%|██████████| 7/7 [00:00<00:00,  7.33it/s][A[A


  0%|          | 0/95 [00:00<?, ?it/s][A[A

  2%|▏         | 2/95 [00:00<00:06, 14.88it/s][A[A

  3%|▎         | 3/95 [00:00<00:07, 12.18it/s][A[A

  4%|▍         | 4/95 [00:00<00:09,  9.60it/s][A[A

  5%|▌         | 5/95 [00:00<00:11,  7.61it/s][A[A

  6%|▋         | 6/95 [00:00<00:14,  6.33it/s][A[A

  7%|▋         | 7/95 [00:01<00:16,  5.38it/s][A[A

  8%|▊         | 8/95 [00:01<00:18,  4.64it/s][A[A

  9%|▉         | 9/95 [00:01<00:21,  4.07it/s][A[A

 11%|█         | 10/95 [00:02<00:23,  3.63it/s][A[A

 12%|█▏        | 11/95 [00:02<00:25,  3.30it/s]

 80%|████████  | 12/15 [00:02<00:01,  2.87it/s][A[A

 87%|████████▋ | 13/15 [00:03<00:00,  2.60it/s][A[A

 93%|█████████▎| 14/15 [00:03<00:00,  2.38it/s][A[A

100%|██████████| 15/15 [00:04<00:00,  3.44it/s][A[A


  0%|          | 0/6 [00:00<?, ?it/s][A[A

 33%|███▎      | 2/6 [00:00<00:00, 16.06it/s][A[A

 50%|█████     | 3/6 [00:00<00:00, 13.03it/s][A[A

 67%|██████▋   | 4/6 [00:00<00:00, 10.09it/s][A[A

 83%|████████▎ | 5/6 [00:00<00:00,  7.91it/s][A[A

100%|██████████| 6/6 [00:00<00:00,  7.71it/s][A[A


  0%|          | 0/20 [00:00<?, ?it/s][A[A

 10%|█         | 2/20 [00:00<00:01, 14.50it/s][A[A

 15%|█▌        | 3/20 [00:00<00:01, 11.74it/s][A[A

 20%|██        | 4/20 [00:00<00:01,  9.23it/s][A[A

 25%|██▌       | 5/20 [00:00<00:02,  7.35it/s][A[A

 30%|███       | 6/20 [00:00<00:02,  6.18it/s][A[A

 35%|███▌      | 7/20 [00:01<00:02,  5.41it/s][A[A

 40%|████      | 8/20 [00:01<00:02,  4.75it/s][A[A

 45%|████▌     | 9/20 [00:01<00:02,  4.20it/

 57%|█████▋    | 13/23 [00:03<00:03,  2.59it/s][A[A

 61%|██████    | 14/23 [00:03<00:03,  2.37it/s][A[A

 65%|██████▌   | 15/23 [00:04<00:03,  2.19it/s][A[A

 70%|██████▉   | 16/23 [00:04<00:03,  2.04it/s][A[A

 74%|███████▍  | 17/23 [00:05<00:03,  1.90it/s][A[A

 78%|███████▊  | 18/23 [00:06<00:02,  1.78it/s][A[A

 83%|████████▎ | 19/23 [00:06<00:02,  1.68it/s][A[A

 87%|████████▋ | 20/23 [00:07<00:01,  1.58it/s][A[A

 91%|█████████▏| 21/23 [00:08<00:01,  1.50it/s][A[A

 96%|█████████▌| 22/23 [00:09<00:00,  1.43it/s][A[A

100%|██████████| 23/23 [00:09<00:00,  2.32it/s][A[A


  0%|          | 0/35 [00:00<?, ?it/s][A[A

  6%|▌         | 2/35 [00:00<00:01, 16.57it/s][A[A

  9%|▊         | 3/35 [00:00<00:02, 12.99it/s][A[A

 11%|█▏        | 4/35 [00:00<00:03, 10.01it/s][A[A

 14%|█▍        | 5/35 [00:00<00:03,  7.86it/s][A[A

 17%|█▋        | 6/35 [00:00<00:04,  6.53it/s][A[A

 20%|██        | 7/35 [00:01<00:04,  5.60it/s][A[A

 23%|██▎       | 8/35 [0

  7%|▋         | 2/30 [00:00<00:01, 15.44it/s][A[A

 10%|█         | 3/30 [00:00<00:02, 12.56it/s][A[A

 13%|█▎        | 4/30 [00:00<00:02,  9.86it/s][A[A

 17%|█▋        | 5/30 [00:00<00:03,  7.80it/s][A[A

 20%|██        | 6/30 [00:00<00:03,  6.49it/s][A[A

 23%|██▎       | 7/30 [00:01<00:04,  5.54it/s][A[A

 27%|██▋       | 8/30 [00:01<00:04,  4.77it/s][A[A

 30%|███       | 9/30 [00:01<00:05,  4.16it/s][A[A

 33%|███▎      | 10/30 [00:01<00:05,  3.67it/s][A[A

 37%|███▋      | 11/30 [00:02<00:05,  3.29it/s][A[A

 40%|████      | 12/30 [00:02<00:06,  2.97it/s][A[A

 43%|████▎     | 13/30 [00:03<00:06,  2.70it/s][A[A

 47%|████▋     | 14/30 [00:03<00:06,  2.48it/s][A[A

 50%|█████     | 15/30 [00:04<00:06,  2.29it/s][A[A

 53%|█████▎    | 16/30 [00:04<00:06,  2.13it/s][A[A

 57%|█████▋    | 17/30 [00:05<00:06,  1.99it/s][A[A

 60%|██████    | 18/30 [00:05<00:06,  1.87it/s][A[A

 63%|██████▎   | 19/30 [00:06<00:06,  1.76it/s][A[A

 67%|██████▋   | 2

 37%|███▋      | 27/73 [00:13<00:39,  1.16it/s][A[A

 38%|███▊      | 28/73 [00:14<00:40,  1.12it/s][A[A

 40%|███▉      | 29/73 [00:15<00:41,  1.07it/s][A[A

 41%|████      | 30/73 [00:16<00:41,  1.03it/s][A[A

 42%|████▏     | 31/73 [00:17<00:42,  1.02s/it][A[A

 44%|████▍     | 32/73 [00:18<00:43,  1.05s/it][A[A

 45%|████▌     | 33/73 [00:19<00:43,  1.09s/it][A[A

 47%|████▋     | 34/73 [00:20<00:44,  1.13s/it][A[A

 48%|████▊     | 35/73 [00:22<00:44,  1.16s/it][A[A

 49%|████▉     | 36/73 [00:23<00:44,  1.20s/it][A[A

 51%|█████     | 37/73 [00:24<00:44,  1.23s/it][A[A

 52%|█████▏    | 38/73 [00:26<00:44,  1.27s/it][A[A

 53%|█████▎    | 39/73 [00:27<00:44,  1.30s/it][A[A

 55%|█████▍    | 40/73 [00:28<00:44,  1.34s/it][A[A

 56%|█████▌    | 41/73 [00:30<00:43,  1.37s/it][A[A

 58%|█████▊    | 42/73 [00:31<00:43,  1.41s/it][A[A

 59%|█████▉    | 43/73 [00:33<00:43,  1.44s/it][A[A

 60%|██████    | 44/73 [00:35<00:42,  1.48s/it][A[A

 62%|█████

 54%|█████▍    | 13/24 [00:03<00:04,  2.68it/s][A[A

 58%|█████▊    | 14/24 [00:03<00:04,  2.46it/s][A[A

 62%|██████▎   | 15/24 [00:04<00:03,  2.27it/s][A[A

 67%|██████▋   | 16/24 [00:04<00:03,  2.11it/s][A[A

 71%|███████   | 17/24 [00:05<00:03,  1.97it/s][A[A

 75%|███████▌  | 18/24 [00:05<00:03,  1.85it/s][A[A

 79%|███████▉  | 19/24 [00:06<00:02,  1.74it/s][A[A

 83%|████████▎ | 20/24 [00:07<00:02,  1.65it/s][A[A

 88%|████████▊ | 21/24 [00:08<00:01,  1.56it/s][A[A

 92%|█████████▏| 22/24 [00:08<00:01,  1.48it/s][A[A

 96%|█████████▌| 23/24 [00:09<00:00,  1.41it/s][A[A

100%|██████████| 24/24 [00:10<00:00,  2.31it/s][A[A


  0%|          | 0/42 [00:00<?, ?it/s][A[A

  5%|▍         | 2/42 [00:00<00:02, 17.21it/s][A[A

  7%|▋         | 3/42 [00:00<00:02, 13.38it/s][A[A

 10%|▉         | 4/42 [00:00<00:03, 10.27it/s][A[A

 12%|█▏        | 5/42 [00:00<00:04,  8.03it/s][A[A

 14%|█▍        | 6/42 [00:00<00:05,  6.63it/s][A[A

 17%|█▋        | 7/42 [

In [None]:
# torch.save(model,'./tempTreeEnc.pth')
torch.save({'state_dict': model.state_dict()}, './tempTreeEnc.pth')

## Model Validation

In [None]:
preds = []
labels = []

for valSet in x_test:
    finalTree = valSet[-1]
    preds.append(model.treeEnc.predict(finalTree.root))
    labels.append(labelMap[finalTree.root.label])
                
    predTensor = torch.stack(preds)
    labelTensor = torch.tensor(labels).to(device)
#print(predTensor)
#print(labelTensor)
    loss = criterion(predTensor.reshape(-1,4), labelTensor.reshape(-1))
    print('Loss Value: ', loss.item())
    val_losses.append(loss.item())

## Plotting Losses

In [None]:
from matplotlib import pyplot as plt
import seaborn as sns

In [None]:
train_iterwise

In [None]:
iterNums = [i for i in range(len(train_iterwise))]
sns.lineplot(iterNums,train_iterwise)

In [None]:
iterNums = [i for i in range(len(val_losses))]
sns.lineplot(iterNums,val_losses)

In [None]:
len(train_losses)

In [None]:
lenAggreg = 0
for subset in x_train:
    lenAggreg += len(subset)
print(lenAggreg)
print(lenAggreg/len(x_train))

# Training Temporal Decay Model

In [None]:
model = temporalDecayTreeEncoder(cuda,8,30,userVects,labels,labelMap,criterion,device)

In [None]:
testModel(x_train[0][0].root)

In [None]:
checkpoint = torch.load('./tempDecayTreeEnc.pth')

In [None]:
model = temporalDecayTreeEncoder(torch.cuda.is_available(),8,30,userVects,twitter15_labels,labelMap,criterion,device)
model = model.to(device)
model.load_state_dict(checkpoint['state_dict'])

In [None]:
optimizer = torch.optim.Adam(model.parameters(),lr = 0.01)

In [None]:
lossfile = './lossesDecayTempEnc.txt'
f = open(lossfile, "a")

In [None]:
count = 0
for i in range(epochs):    
    for treeSet in tqdm_notebook(x_train):     
        count += 1
        optimizer.zero_grad()
            
        pred = model(treeSet)
        
        label = Variable(torch.tensor(labelMap[treeSet[0].root.label]).reshape(-1).to(device))
        loss = criterion(pred,label)
        
        loss.backward()
        optimizer.step()
            
        if count % 500 == 0:
            preds = []
            
            for i in range(len(x_test)):
                valTreeSet = x_test[i]
                preds.append(model(valTreeSet))
                
                predTensor = torch.stack(preds)
                labelTensor = torch.tensor([labelMap[i] for i in y_test]).to(device)
                loss = criterion(predTensor.reshape(-1,4), labelTensor.reshape(-1))
                
                f.write(str(loss.item()))
                
f.close()

In [None]:
# torch.save(model,'./tempTreeEnc.pth')
torch.save({'state_dict': model.state_dict()}, './tempDecayTreeEnc.pth')

In [None]:
preds = []
labels = []

for valSet in x_test:
    finalTree = valSet[-1]
    preds.append(model.treeEnc.predict(finalTree.root))
    labels.append(labelMap[finalTree.root.label])
                
    predTensor = torch.stack(preds)
    labelTensor = torch.tensor(labels).to(device)
#print(predTensor)
#print(labelTensor)
    loss = criterion(predTensor.reshape(-1,4), labelTensor.reshape(-1))
    print('Loss Value: ', loss.item())
    val_losses.append(loss.item())

In [None]:
labelMap[y_test[1]]

In [None]:
sampleout = model(x_test[0])

In [None]:
sampleout[0]

In [None]:
sampleout[0].max(0)