##### Imports

In [None]:
%%capture
!pip install datasets
!pip install numpy
!pip install pandas
!pip3 install http://download.pytorch.org/whl/cu92/torch-0.4.1-cp36-cp36m-linux_x86_64.whl
!pip3 install torchvision
!pip install simcse
!pip install gensim==4.1.2
!pip install cython
!pip install nltk

import nltk
nltk.download('punkt')

import datasets
from datasets import load_dataset, list_datasets
import pandas as pd 
import re 
import numpy as np 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import random
from simcse import SimCSE
random.seed(10)
torch.manual_seed(0)
np.random.seed(0)
import re
import time
import math
from IPython.utils import io
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
from gensim.models import FastText
from nltk.tokenize import word_tokenize

In [None]:
from google.colab import drive
drive.mount('/content/drive')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
cd /content/drive/My\ Drive/484-finalProject

##### Load Wikitext dataset (please run!)

Headings without any text below it (i.g. only table) are excluded.

In [None]:
class Node(object):
    '''
    each node contains 
    - parent 
    - children 
    - text
    '''
    def __init__(self,txt: str, level:int):
        self.text = txt 
        self.level = level
        self.parent = None 
        self.children = []
    def insertChild(self,child):
        self.children.append(child)
    def linkParent(self, parent):
        if(self.parent != None):
            print("ERROR: node ", self.text, "already has a parent")
        else:
            self.parent = parent 
            
        
class Tree(object):
    def __init__(self,document):
        self.root = Node(document['title'],level=0)
        self.depth = np.amax([v['type'] for v in document['document']], initial=0)
        curNode = self.root 
        # para of format {"text", "type"}
        for para in document['document']:
            newNode = Node(para['text'],para['type'])
            
            # growing in depth
            if(newNode.level == -1 or newNode.level > curNode.level):
                curNode.insertChild(newNode)
                newNode.linkParent(curNode)
                if(newNode.level > 0):
                    curNode = newNode 
                
            # new heading belong to the same or lower level of subheading 
            else: 
                # trace back to the heading level that new heading is immediately under 
                while(curNode.level>=newNode.level):
                    curNode = curNode.parent
                curNode.insertChild(newNode)
                newNode.linkParent(curNode)
                curNode = newNode 
        return 
    
    def printTree(self):
        print("======== PRINTING TREE =========")
        print("TITLE: ", self.root.text)
        print("MAX DEPTH: ", self.depth)
        print("===============================")
        def printNode(curNode):
            print(curNode.text)
            if(curNode.level == -1):
                return 
            
            for child in curNode.children:
                printNode(child)
            return 
        printNode(self.root)
  

In [None]:
## helper functions 

# get type of text 
def checkHeading(txt):
    if(txt == ''):
        return -2
    if(re.search(r'^\s=.+\s=\s\n',txt)):
        return int(len(re.findall(r'\s=',txt))/2 - 1)
    return -1 

# load documents to feed to tree 
def createDocuments(data):
    documents_with = []
    documents_without = []
    document_with = []
    document_without = []
    curTitle = ''
    for i in data:
        c = checkHeading(i)
        if(c==-2):
            continue
        if(c>-1):
            # strip heading 
            i = re.findall(r'=\s([^=]+)\s=', i)[0]
        if(c==0):
            
            # clear out empty headings 
            while(len(document_with)>1 and document_with[-1]['type']!=-1):
                document_with.pop(-1)
            documents_with.append({'title': curTitle, 'document':document_with})
            documents_without.append(document_without)
            curTitle = i
            document_with = []
            document_without = []
            
        else:
            # clear out empty headings GOOFY HELP HOW TO CLEAN THIS UP 
            if(len(document_with)>1 and document_with[-1]['type']!=-1 and c <= document_with[-1]['type'] and c!=-1):
                document_with.pop(-1)
            document_with.append({'text':i,'type':c})
            if(c==-1):
                document_without.append(i)
            
    documents_with.pop(0)
    documents_without.pop(0)
    return documents_with, documents_without

loadData() creates a list of data points containing the title of article, raw text (paragraphs), and the tree representation of heading structures.

In [None]:
## load wiki dataset 
def loadData(split='test', min_size=-1):
    """
    prepare dataset for training, which is a list of dictionaries containing:  
    - document title (string)
    - paragraphs (list of string)
    - tree representation of headings
    """
    
    data_raw = load_dataset("wikitext",'wikitext-103-v1',split=split)
    data_raw = data_raw['text']
    documents_with, documents_without = createDocuments(data_raw)
    
    data = []
    i = 0
    for document in documents_with:
        tree = Tree(document)
        if len(documents_without[i]) < min_size:
          continue
#         tree.printTree()
        data.append({
            "title":document['title'],
            "paragraphs":documents_without[i],
            "tree": tree
        })
        i+=1
    return data    

##### Preparing LCA loss evaluation function + tools for trees (please run!)

In [None]:
# tree-related helper functions

# iterate over a tree rooted at node in preorder traversal
def preorder(node):
    if len(node.children) == 0:
        yield node
    for ch in node.children:
        yield from preorder(ch)

# only prints leaves, i.e. text representations of paragraph
# note: depends on accurate text, level population
# text should be indices
def print_tree(curNode):
    if curNode.level == -1:
        print(curNode.text, end='')
        return
    print('[', end='')
    for idx, child in enumerate(curNode.children):
        print_tree(child)
        if idx < len(curNode.children) - 1:
            print(', ', end='')
    print(']', end='')
    if curNode.level == 0:
        print() # final print after entire tree is printed

# text should be snippets
def print_snippet_tree(curNode, indent='  '):
    if curNode.level == -1:
        print(indent + curNode.text, end='')
        return
    print(indent+'[heading]')
    for idx, child in enumerate(curNode.children):
        print_snippet_tree(child, indent+'  ')
        if idx < len(curNode.children) - 1:
            print()
    if curNode.level == 0:
        print() # final print after entire tree is printed

def clone_tree(root):
    root_copy = Node(root.text, root.level)
    for ch in root.children:
        root_copy.insertChild(clone_tree(ch))
        root_copy.children[-1].linkParent(root_copy)
    return root_copy

# return indexified tree with text as paragraphs to text as indices of paragraphs for more concise printing
def indexified_tree(root):
    root_copy = clone_tree(root)
    for idx, node in enumerate(preorder(root_copy)):
        node.text = idx
    return root_copy

# return indexified tree with text as paragraphs to text as indices of paragraphs for more concise printing
def textified_tree(root, paras):
    root_copy = clone_tree(root)
    for idx, node in enumerate(preorder(root_copy)):
        node.text = paras[idx][:40] + '...'
    return root_copy

# print_tree(roots[0])
# print_tree(train_y[0].root)

In [None]:
# lca-related helper functions and lca loss
# note: assumed indexified trees
def trace_helper(node, i, trace):
    if node.text == i:
        return True
    for idx, ch in enumerate(node.children):
        if trace_helper(ch, i, trace):
            trace.append(idx)
            return True
    return False

def get_trace(root, i):
    trace = []
    trace_helper(root, i, trace)
    trace.reverse()
    return trace

def compute_lca_dist(root, i, j):
    trace_i = get_trace(root, i)
    trace_j = get_trace(root, j)
    # print(trace_i)
    # print(trace_j)
    for idx in range(min(len(trace_i), len(trace_j))):
        if trace_i[idx] != trace_j[idx]:
            return len(trace_i) + len(trace_j) - 2 * idx
    return len(trace_i) + len(trace_j)

def lca_loss(root1, root2, num_paras):
    loss = 0
    for i in range(2, num_paras+1): # 1-indexed from indexify
        j = i-1
        dist1 = compute_lca_dist(root1, i, j)
        dist2 = compute_lca_dist(root2, i, j)
        # print(i, j, dist1, dist2)
        loss += (dist1 - dist2) * (dist1 - dist2)
    if num_paras == 1:
        return loss
    return loss / (num_paras - 1)

def batch_lca_loss(roots1, roots2, num_paras):
    tt = 0
    for root1, root2, num in zip(roots1, roots2, num_paras):
        tt += lca_loss(root1, root2, num)
    return tt / len(roots1)

##### Model 2: Greedy decoding + LCA loss

In [None]:
MAX_DEPTH = 8
class GreedyDecoder:
    def __init__(self, thresholds, similarity, encode):
        self.thresholds = thresholds
        self.similarity = similarity
        self.encode = encode
    
    # add level and parent info to a tree rooted at node
    def update_levels_parents(self, node, depth):
        # print(node.text)
        if len(node.children) == 0: # leaf paragraph node
            node.level = -1
            return
        node.level = depth
        for ch in node.children:
            ch.linkParent(node)
            self.update_levels_parents(ch, depth+1)

    # decodes and returns tree rooted at node
    def encode_decode(self, paragraphs):
        embs = []
        for para in paragraphs:
            embs.append(self.encode(para))
        return self.decode(embs)
    
    # decodes and returns tree rooted at node, with text fields populated, e.g. for printing the tree
    def encode_decode_with_text(self, X, indexify=True):
        root = self.encode_decode(X)
        for idx, leaf in enumerate(preorder(root)):
            if indexify:
                leaf.text = idx
            else:
                leaf.text = X[idx]
        return root
    
    # here X and y are lists
    def batch_encode_decode_with_text(self, X):
        roots = []
        for X_i in X:
            roots.append(self.encode_decode_with_text(X_i))
        return roots

    # bottom up decoding: for each depth, join paragraphs whose pairwise similarity reaches the threshold
    # and represent them collectively by the mean of their embeddings
    def decode(self, embs):
        if len(embs) == 0:
            return Node('', 0) # should never happen
        roots = []
        for i in range(len(embs)):
            roots.append(Node(i, -1))
        dim = embs[0].shape[0]
        for depth in range(MAX_DEPTH-1): # in last layer, everything must be joined together
            next_idxs = [[0]]
            for i in range(1, len(embs)):
                if self.similarity(embs[i-1], embs[i]) >= self.thresholds[depth]:
                    next_idxs[-1].append(i)
                else:
                    next_idxs.append([i])

            # print(next_idxs)
            # for root in roots:
            #     print_tree(root)

            # update roots and embs
            next_roots = []
            next_embs = []
            for comp in next_idxs:
                # don't add trivial (1 -> 1) edges
                if len(comp) == 1:
                    next_roots.append(roots[comp[0]])
                    next_embs.append(embs[comp[0]])
                    continue

                next_roots.append(Node('', -2)) # meaningless params since we only need tree structure
                next_embs.append(torch.zeros(dim))
                for idx in comp:
                    next_roots[-1].insertChild(roots[idx])
                    next_embs[-1] += embs[idx]
                next_embs[-1] /= len(comp)
            roots = next_roots
            embs = next_embs
        
        # join everything together in the last layer
        root = Node('', 0)
        for node in roots:
            root.insertChild(node)

        # update parents and levels, then return
        self.update_levels_parents(root, 0)
        return root

In [None]:
doc2vec_model = Doc2Vec.load("models/{}_d2v.model".format(256))
simcse_model = SimCSE("princeton-nlp/sup-simcse-bert-base-uncased")

def sim(v1, v2):
      # return torch.dot(v1, v2) / torch.norm(torch.sub(v1, v2))
      return torch.dot(v1, v2) / (torch.norm(v1) * torch.norm(v2))

def doc2vec_enc(v):
    return torch.tensor(doc2vec_model.infer_vector(word_tokenize(v)))

def simcse_enc(v):
    para = re.sub('\n', '', v)
    sents = re.split('[.]|[!]|[?]', para.strip())
    with io.capture_output() as captured:
        vecs = simcse_model.encode(sents, device=device, batch_size=len(sents))
    return torch.mean(vecs, dim=0)

def get_doc2vec_greedy():
  # thresholds = [1.769, 1.261, 0.007, 2.289, 1.793, 0.590, 3.620, 3.969] # 128 thresholds
  thresholds = [1.769, 1.261, 0.007, 2.289, 1.793, 0.590, 3.620, 3.969] # 256 thresholds
  greedy = GreedyDecoder(thresholds, similarity=sim, encode=doc2vec_enc)
  return greedy

def get_simcse_greedy(use_proj=False, mlp=None):
  thresholds = [1.668, 0.429, 0.810, 2.581, 1.015, 1.205, 1.755, 0.440] # simcse thresholds
  greedy = GreedyDecoder(thresholds, similarity=sim, encode=mlp.emb.embed_one_para if use_proj else simcse_enc)
  return greedy

def get_fasttext_greedy(use_proj=False, mlp=None):
  thresholds = [1.41, 0.55, 1.72, 0.47, 2.8, 2.35, 1.3, 0.21]
  greedy = GreedyDecoder(thresholds, similarity=sim, encode=mlp.emb.embed_one_para if use_proj else None)
  return greedy

In [None]:
def evaluate_thresholds(X, y, thresholds, similarity=sim, encode=doc2vec_enc):
    greedy = GreedyDecoder(thresholds, similarity=similarity, encode=encode)
    y_hat = greedy.batch_encode_decode_with_text(X)
    y = [indexified_tree(y_i.root) for y_i in y]
    lens = [len(x) for x in X]
    return batch_lca_loss(y, y_hat, lens)
    # print('Actual:')
    # print_tree(y_hat[0])
    # print('Predicted:')
    # print_tree(y[0])

def random_search(X, y, n_tries=50, min_num=[0]*8, max_num=[3]*8, similarity=sim, encode=doc2vec_enc):
    results = []
    for i in range(n_tries):
        print('try #' + str(i))
        thresholds = []
        for j in range(8):
            thresholds.append(random.uniform(min_num[j],max_num[j]))
        # print('thresholds:', thresholds)
        loss = evaluate_thresholds(X, y, thresholds, similarity=similarity, encode=encode)
        results.append({'loss':loss, 'thresholds':thresholds})
    
    # print sorted version by each index
    print('printing sorted by component...')
    for i in range(8):
        print('component', i)
        results_sorted = sorted(results, key=lambda x: x['thresholds'][i])
        for result in results_sorted:
            thresh =  ["{0:0.5f}".format(i) for i in result['thresholds']]
            print(result['loss'], '\t', thresh)
        print()

##### Model 3: Recursive split MLP

In [None]:
def recur_search(node, to_be_marked, curr_data):
  if node.level == -1:
    curr_data.append(([n.level for n in to_be_marked], node.parent.level + 1))  # add data of which nodes you are first of, and also what is your depth
    to_be_marked = []
    return to_be_marked
  to_be_marked.append(node)
  for child in node.children:
    to_be_marked = recur_search(child, to_be_marked, curr_data)
  return to_be_marked

def convert_dataset(data, window_size):
  paras, trees = [d['paragraphs'] for d in data], [d['tree'] for d in data]
  dataX, dataD, datay = [], [], []
  for i in range(len(paras)):  # for each article
    article = paras[i]
    curr_data = []
    recur_search(trees[i].root, [], curr_data)
    for p in range(len(article)):  # for each paragraph
      context = []
      for j in range(p - window_size, p + window_size + 1):  # for para in context
        if j < 0 or j >= len(article):
          context.append(None)
        else:
          context.append(article[j])
      breaks, depth = curr_data[p]
      print(breaks, '  d=', depth, '    ', article[p][:20])
      for d in range(0 if p == 0 else 1, depth):
        dataX.append(context)
        dataD.append([d])
        datay.append([1 if d in breaks else 0])
  dataD = torch.tensor(np.array(dataD))
  datay = torch.tensor(np.array(datay))
  return dataX, dataD, datay

In [None]:
# MODULES FOR EMBEDDING 3 DIFFERENT WAYS: DOC2VEC, PROJECTED SIMCSE, PROJECTED FASTTEXT
# each one takes a batch of lists of paragraphs, outputs a batch of concatenated paragraph embeddings
# forward pass input: batch (len B) of list of paragraphs (len 2 * window_size + 1), each para variable length
# forward pass output: tensor of size B x ((2 * window_size + 1) * emb_dim)
# each one can also embed one paragraph at a time
class Doc2VecEmbedding(nn.Module):
  def __init__(self, window_size, emb_dim): 
    super().__init__()
    self.emb_dim = emb_dim
    self.doc2vec = Doc2Vec.load("models/{}_d2v.model".format(emb_dim))
  
  def embed_one_para(self, p):
    return torch.tensor(self.doc2vec.infer_vector(word_tokenize(p.lower())) if p is not None else np.zeros(shape=(self.emb_dim))).to(device)

  def forward(self, x):
    with torch.no_grad():
      batch = []
      for b in x:
        paras = []
        for p in b:
          paras.append(self.doc2vec.infer_vector(word_tokenize(p.lower())) if p is not None else np.zeros(shape=(self.emb_dim)))
        batch.append(np.concatenate(paras, axis=0))
      return torch.tensor(np.stack(batch, axis=0)).to(device)

class SimCSEEmbedding(nn.Module):
  def __init__(self, window_size, emb_dim, dropout): 
    super().__init__()
    self.window_size = window_size
    self.emb_dim = emb_dim
    self.simcse = SimCSE("princeton-nlp/sup-simcse-bert-base-uncased")
    self.SIMCSE_DIM = 768 # dim of simcse sentence embeddings
    self.lstm = nn.LSTM(input_size=self.SIMCSE_DIM,
                        hidden_size=int(emb_dim / 4),
                        num_layers=1,
                        bidirectional=True,
                        batch_first=True, 
                        dropout=0.0).to(device)

  def embed_one_para(self, p):
    p = re.sub('\n', '', p)
    sents = re.split('[.]|[!]|[?]', p.strip())
    with io.capture_output() as captured:
        sents_emb = self.simcse.encode(sents, device=device, batch_size=len(sents), max_length=64)  # a tensor of len(sents) x SIMCSE_DIM
    sents_emb = torch.nn.utils.rnn.pad_sequence([sents_emb], batch_first=True).to(device)
    packed_in = torch.nn.utils.rnn.pack_padded_sequence(sents_emb, torch.tensor([a.shape[0] for a in sents_emb]), batch_first=True, enforce_sorted=False)
    _, (hidden, cell) = self.lstm(packed_in.to(device))
    para_emb = torch.cat((hidden[0], cell[0], hidden[1], cell[1]), dim=1).squeeze().cpu()
    return para_emb

  def forward(self, x):
    B = len(x)  # batch size
    batch = []
    with torch.no_grad():
      for b in x:
        for p in b:
          if p is not None:
            p = re.sub('\n', '', p)
            sents = re.split('[.]|[!]|[?]', p.strip())
            with io.capture_output() as captured:
                sents_emb = self.simcse.encode(sents, device=device, batch_size=len(sents), max_length=64)  # a tensor of len(sents) x SIMCSE_DIM
            batch.append(sents_emb)
          else:
            batch.append(torch.zeros((1, self.SIMCSE_DIM), device=device))
      batch_emb = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True).to(device)
      packed_in = torch.nn.utils.rnn.pack_padded_sequence(batch_emb, torch.tensor([a.shape[0] for a in batch_emb]), batch_first=True, enforce_sorted=False)

    _, (hidden, cell) = self.lstm(packed_in.to(device))
    para_embs = torch.cat((hidden[0], cell[0], hidden[1], cell[1]), dim=1)        
      
    return para_embs.reshape(B, (2 * self.window_size + 1) * self.emb_dim)


class FastTextEmbedding(nn.Module):
  def __init__(self, window_size, emb_dim, dropout): 
    super().__init__()
    self.window_size = window_size
    self.emb_dim = emb_dim
    self.fasttext = FastText.load_fasttext_format('models/fast-text-300.bin').wv
    self.FASTTEXT_DIM = 300 # dim of fasttext word embeddings
    self.lstm = nn.LSTM(input_size=self.FASTTEXT_DIM,
                        hidden_size=int(emb_dim / 4),
                        num_layers=1,
                        bidirectional=True,
                        batch_first=True, 
                        dropout=0.0).to(device)

  def embed_one_para(self, p):
    p = re.sub('\n', '', p)
    words = re.sub("[^\s\w]", "", p.strip()).split(' ')
    words = list(filter(None, words))
    if len(words) == 0:
        return torch.zeros((self.FASTTEXT_DIM))
    words_emb = torch.stack([torch.tensor(self.fasttext[word]) for word in words]).to(device)  # a tensor of len(words) x FASTTEXT_DIM
    words_emb = torch.nn.utils.rnn.pad_sequence([words_emb], batch_first=True).to(device)
    packed_in = torch.nn.utils.rnn.pack_padded_sequence(words_emb, torch.tensor([a.shape[0] for a in words_emb]), batch_first=True, enforce_sorted=False)
    _, (hidden, cell) = self.lstm(packed_in.to(device))
    para_emb = torch.cat((hidden[0], cell[0], hidden[1], cell[1]), dim=1).squeeze().cpu()
    return para_emb

  def forward(self, x):
    B = len(x)  # batch size
    batch = []
    with torch.no_grad():
      for b in x:
        for p in b:
          if p is not None:
            p = re.sub('\n', '', p)
            words = re.sub("[^\s\w]", "", p.strip()).split(' ')
            words = list(filter(None, words))
            if len(words) == 0:
                batch.append(torch.zeros((1, self.FASTTEXT_DIM)))
                continue
            words_emb = torch.stack([torch.tensor(self.fasttext[word]) for word in words]).to(device)  # a tensor of len(words) x FASTTEXT_DIM
            batch.append(words_emb)
          else:
            batch.append(torch.zeros((1, self.FASTTEXT_DIM), device=device))
      batch_emb = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True).to(device)
      packed_in = torch.nn.utils.rnn.pack_padded_sequence(batch_emb, torch.tensor([a.shape[0] for a in batch_emb]), batch_first=True, enforce_sorted=False)

    _, (hidden, cell) = self.lstm(packed_in.to(device))
    para_embs = torch.cat((hidden[0], cell[0], hidden[1], cell[1]), dim=1)        
      
    return para_embs.reshape(B, (2 * self.window_size + 1) * self.emb_dim)

In [None]:
class MLP(nn.Module):
  def __init__(self, layer_dims, window_size, emb_dim, emb_method, dropout=0.1):
        super().__init__()
        self.window_size = window_size
        self.emb_dim = emb_dim # dimension of each paragraph embedding
        if emb_method == 'doc2vec':
          self.emb = Doc2VecEmbedding(window_size, emb_dim)
        elif emb_method == 'simcse':
          self.emb = SimCSEEmbedding(window_size, emb_dim, dropout=dropout)
        elif emb_method == 'fasttext':
          self.emb = FastTextEmbedding(window_size, emb_dim, dropout=dropout)
        else:
          raise NotImplementedError()

        in_feats = (2 * window_size + 1) * self.emb_dim
        self.layers = []
        for dim in layer_dims:
          self.layers.append(nn.Linear(in_features=in_feats + 1, out_features=dim))
          in_feats = dim
        self.layers.append(nn.Linear(in_features=in_feats + 1, out_features=1))
        self.layers = nn.ModuleList(self.layers)
        self.dropout = nn.Dropout(dropout)

  def forward(self, x, d):
      '''
      x is a batch (list) of windows (list) of paragraphs, which are strings
      d is the depths for the whole batch (tensor)
      '''
      B = len(x)
      x = self.emb(x).float()  # x is a (B x (2W + 1) x E) tensor
      x = x.reshape(B, (2 * self.window_size + 1) * self.emb_dim)
      for layer in self.layers[:-1]:
        x = F.relu(layer(torch.cat((x, d), dim=1)))
      x = self.dropout(x)
      x = torch.sigmoid(self.layers[-1](torch.cat((x, d), dim=1)))
      return x

  def recursive_outline(self, indices, node, threshold, growth, contexts, wordy):
    window_size = int((len(contexts[0]) - 1) / 2)
    if len(indices) == 0:
        return
    if len(indices) == 1:
        new = Node(contexts[indices[0]][window_size], -1)
        new.linkParent(node)
        node.insertChild(new)
        return
    outs = []
    for i in indices:
        X = [contexts[i]]
        D = torch.tensor([node.level + 1]).to(device).unsqueeze(dim=0)
        out = self.forward(X, D).squeeze()
        if wordy:
          print(out.cpu().item(), D.cpu().item(), '       ', contexts[i][window_size][:40])
        t = threshold if node.level < 1 else threshold * (growth ** (node.level + 1))
        outs.append(out.cpu().item() > t)
    prev = 0
    flag = True
    for o in range(1, len(outs)):
        if outs[o]:
            if o - prev > 1:
              new = Node('', node.level + 1)
              new.linkParent(node)
              node.insertChild(new)
              self.recursive_outline(indices[prev:o], new, threshold, growth, contexts, wordy)
            elif o - prev == 1:
              new = Node(contexts[indices[prev]][window_size], -1)
              new.linkParent(node)
              node.insertChild(new)
              prev = o
              continue
            else:
              continue
            prev = o
            flag = False
    if flag:
      for i in indices:
        new = Node(contexts[i][window_size], -1)
        new.linkParent(node)
        node.insertChild(new)
    else:
       new = Node('', node.level + 1)
       new.linkParent(node)
       node.insertChild(new)
       self.recursive_outline(indices[prev:], new, threshold, growth, contexts, wordy)
    return

  def outline(self, article, threshold=0.17, growth=1.6, wordy=False):
      self.eval()
      contexts = []
      for p in range(len(article)):
        context = []
        for j in range(p - self.window_size, p + self.window_size + 1):
            if j < 0 or j >= len(article):
                context.append(None)
            else:
                context.append(article[j])
        contexts.append(context)

      indices = list(range(len(article)))  # indices to recur with
      root = Node('root', 0)
      self.recursive_outline(indices, root, threshold, growth, contexts, wordy)

      def printNode(curNode):
          print(curNode.level, '       ', curNode.text[:50])
          if curNode.level == -1:
              return

          for child in curNode.children:
              printNode(child)
          return

      if wordy:
        printNode(root)
      return root

  def save(self, path: str):
      """ Save the model to a file.
      @param path (str): path to the model
      """

      params = {
          # 'args': dict(hid_dim=self.hid_dim, n_layers=self.n_layers, num_heads=self.num_heads,
          #       num_enc_layers=self.num_enc_layers, num_dec_layers=self.num_dec_layers, ff_dim=self.ff_dim, dropout=self.dropout),
          'state_dict': self.state_dict()
      }

      torch.save(params, path)


In [None]:
# RUN THIS CELL FOR DOC2VEC MLP MODEL, with either SMALL or LARGE to indicate model size
def get_doc2vec_mlp(use_large):
  LARGE = use_large  # make this true to use larger model
  MODEL_PATH = 'checkpoints/{}'.format('test_large_mlp_doc2vec.bin' if LARGE else 'test_mlp_doc2vec.bin')
  # MODEL_PATH = 'checkpoints/{}'.format('mlp_big_windowdoc2vec.bin')


  WINDOW_SIZE = 4 if LARGE else 3  # number of neighbors to consider in each direction
  EMB_DIM = 256  # dim each paragraph becomes, via magic :)
  EMB_METHOD = 'doc2vec'  # one of 'doc2vec', 'simcse', 'fasttext'
  MLP_ARCHITECTURE = [5096, 1024, 256, 64] if LARGE else [1024, 256, 64]  # sizes of hidden layers in MLP

  model = MLP(layer_dims=MLP_ARCHITECTURE, window_size=WINDOW_SIZE, emb_dim=EMB_DIM, emb_method=EMB_METHOD)
  params = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)
  model.load_state_dict(params['state_dict'])
  model = model.to(device)
  model.eval()
  return model

In [None]:
# RUN THIS CELL FOR FASTTEXT MLP MODEL, with either SMALL or LARGE to indicate model size
def get_fasttext_mlp(use_large):
  LARGE = use_large  # make this true to use larger model
  MODEL_PATH = 'checkpoints/{}'.format('test_large_mlp_fasttext.bin' if LARGE else 'mlp_fasttext.bin')


  WINDOW_SIZE = 4 if LARGE else 2  # number of neighbors to consider in each direction
  EMB_DIM = 512 if LARGE else 256  # dim each paragraph becomes, via magic :)
  EMB_METHOD = 'fasttext'  # one of 'doc2vec', 'simcse', 'fasttext'
  MLP_ARCHITECTURE = [5096, 1024, 256, 64] if LARGE else [1024, 256, 64]  # sizes of hidden layers in MLP

  model = MLP(layer_dims=MLP_ARCHITECTURE, window_size=WINDOW_SIZE, emb_dim=EMB_DIM, emb_method=EMB_METHOD)
  params = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)
  model.load_state_dict(params['state_dict'])
  model = model.to(device)
  model.eval()
  return model

In [None]:
# RUN THIS CELL FOR SIMCSE MLP MODEL
def get_simcse_mlp():
  MODEL_PATH = 'checkpoints/{}'.format('test_mlp_simcse.bin')
  # MODEL_PATH = 'checkpoints/{}'.format('evan_mlp_simcse.bin')
  # MODEL_PATH = 'checkpoints/{}'.format('mlp_simcse.bin')


  WINDOW_SIZE = 3  # number of neighbors to consider in each direction
  EMB_DIM = 256  # dim each paragraph becomes, via magic :)
  EMB_METHOD = 'simcse'  # one of 'doc2vec', 'simcse', 'fasttext'
  MLP_ARCHITECTURE = [1024, 256, 64]  # sizes of hidden layers in MLP

  model = MLP(layer_dims=MLP_ARCHITECTURE, window_size=WINDOW_SIZE, emb_dim=EMB_DIM, emb_method=EMB_METHOD)
  params = torch.load(MODEL_PATH, map_location=lambda storage, loc: storage)
  model.load_state_dict(params['state_dict'])
  model = model.to(device)
  model.eval()
  return model

##### Eval

In [None]:
# LOAD DATASET
# train = loadData('train')
val = loadData('validation')
test = loadData('test')
test.extend(val)

In [None]:
# MAKE MODELS
# mlp = get_doc2vec_mlp(use_large=True)
mlp = get_simcse_mlp()
# mlp = get_fasttext_mlp(use_large=True)

# greedy = get_doc2vec_greedy()
greedy = get_simcse_greedy(use_proj=False, mlp=mlp)
# greedy = get_fasttext_greedy(use_proj=True, mlp=mlp)

In [None]:
# RUN MLP OVER WHOLE SAMPLE SET
from tqdm import tqdm

sample = test
roots = []
for i in tqdm(range(len(sample))):
  out = mlp.outline(sample[i]['paragraphs'], threshold=0.15, growth=2.0, wordy=False)
  roots.append(indexified_tree(out))

golds = [indexified_tree(a['tree'].root) for a in sample]
lens = [len(a['paragraphs']) for a in sample]
print(batch_lca_loss(golds, roots, lens))

In [None]:
# RUN GREEDY OVER WHOLE SAMPLE SET
sample = test
X, y = [d['paragraphs'] for d in sample], [d['tree'] for d in sample]
y_hat = greedy.batch_encode_decode_with_text(X)
y = [indexified_tree(y_i.root) for y_i in y]
lens = [len(x) for x in X]
# print(batch_lca_loss(y, y_hat, lens))
print(lca_loss(y[0], y_hat[0], lens[0]))

In [None]:
# INFERENCE MODELS ON INDIVIDUAL DATA POINTS
# 23672 from train is is Dan Dugan
sample_i = 23672
sample = train

y_hat_mlp = mlp.outline(sample[sample_i]['paragraphs'], threshold=0.25, growth=2.0, wordy=False)
y_hat_greedy = greedy.encode_decode_with_text(sample[sample_i]['paragraphs'])

print('---------------MLP------------------')
print_snippet_tree(textified_tree(y_hat_mlp, sample[sample_i]['paragraphs']))
print()
print()
print('---------------GREEDY------------------')
print_snippet_tree(textified_tree(y_hat_greedy, sample[sample_i]['paragraphs']))
print()
print()
print('---------------GROUND TRUTH------------------')
print_snippet_tree(textified_tree(sample[sample_i]['tree'].root, sample[sample_i]['paragraphs']))


In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

%matplotlib inline
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import seaborn as sns

def Nmaxelements(list1, N):
    final_list = []
    for i in range(N): 
        max1 = 0
        best = -1
        for j in range(len(list1)):     
            if list1[j] > max1:
                max1 = list1[j]
                best = j      
        list1.remove(max1)
        final_list.append(best)      
    return final_list

# MAKE tSNE GRAPHS
# We will load several (N) largest articles, and see if we can cluster paragraphs from the articles by article id
# np.random.seed(0)
N = 12
# indices = np.random.permutation(len(test))[:N]
indices = Nmaxelements([len(d['paragraphs']) for d in test], N)
articles = [test[i] for i in indices]
dataX = []
datay = []
for a in range(len(articles)):
  for p in articles[a]['paragraphs']:
    dataX.append(mlp.emb.embed_one_para(p).detach().cpu().numpy())
    datay.append(a)
dataX = np.stack(dataX, axis=0)
datay = np.array(datay)
# print(dataX.shape, datay.shape)

pca = PCA(n_components=50)
pca_result = pca.fit_transform(dataX)

time_start = time.time()
tsne = TSNE(n_components=2, verbose=0, perplexity=40, n_iter=1000, learning_rate=200.0, init='random')
tsne_results = tsne.fit_transform(pca_result)
print('t-SNE done! Time elapsed: {} seconds'.format(time.time()-time_start))

plt.figure(figsize=(6,6))
ax1 = plt.subplot(1, 1, 1)
plt.title('tSNE Plot by Paragraph Source (Projected FastText)')
sns.scatterplot(x=tsne_results[:, 0], y=tsne_results[:, 1], hue=datay,
                palette=sns.color_palette("hls", N),
                legend="auto", alpha=0.3, ax=ax1)
plt.savefig('fasttext.pdf')

In [None]:
# # investigate distribution of Doc2Vec distributions
# def plot_similarity_distribution(similarity=sim, encode=simcse_enc):
#     sims = []
#     n = len(train_X[0])
#     for x in range(n):
#         for y in range(x+1,n):
#             # print(doc2vec_enc(train_X[0][x]), doc2vec_enc(train_X[0][y]))
#             sims.append(similarity(encode(train_X[0][x]), encode(train_X[0][y])))
#     plt.hist(sims)
#     plt.xlabel('Similarity')
#     plt.ylabel('Count')
#     plt.show()

In [None]:
# plot_similarity_distribution()

##### Training set statistics

Include stats of  
- number of paragraphs per article  
- average length of paragraphs per article 
- maximum depth of articles  

In [None]:
from matplotlib import pyplot as plt 
import numpy as np
import scipy.stats as stats
def getStat(data):
    # number of paragraphs per article 
    num_para = np.array([len(x['paragraphs']) for x in data])
    counts, edges, bars = plt.hist(num_para,40)
    print("========= number of paragraphs per article========")
    print(pd.Series(num_para).describe())
    plt.show()
    
    # number of sentences per paragraph 
    para_lens = []
    for d in data:
        paras = d['paragraphs']
        for para in paras:
            lgh = len(para.split('.'))-1
            para_lens.append(lgh)
    print("========= number of sentences per paragraph========")
    print(pd.Series(para_lens).describe())
    _,_,_ = plt.hist(para_lens,40)
    plt.show()

    
    # depth of articles 
    depths = [d['tree'].depth for d in data]
    print("========= maximum depth of articles========")
    print(pd.Series(depths).describe())
    _,_,_ = plt.hist(depths)
    plt.show()

    return 
    

In [None]:
# getStat(data)