## IMPORTS AND ALL THAT

In [1]:
%%capture
!pip install datasets
!pip install numpy
!pip install pandas
!pip3 install http://download.pytorch.org/whl/cu92/torch-0.4.1-cp36-cp36m-linux_x86_64.whl
!pip3 install torchvision

!pip install simcse

!pip install gensim==4.1.2
!pip install cython
!pip install nltk

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
import random
random.seed(10)
torch.manual_seed(0)
np.random.seed(0)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import re
import time
import math
from IPython.utils import io

from tqdm import tqdm

import nltk
nltk.download('punkt')
from gensim.models.doc2vec import Doc2Vec, TaggedDocument
from simcse import SimCSE
from nltk.tokenize import word_tokenize

from gensim.models import FastText

from google.colab import drive
drive.mount('/content/drive')

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
Mounted at /content/drive


Make sure to add 484-finalProject as shortcut in drive.

In [3]:
cd /content/drive/My\ Drive/484-finalProject

/content/drive/.shortcut-targets-by-id/154EdVBqpeIxqbeun-wGvxSgsYCVV1wdV/484-finalProject


##PREPARE DATASET

In [4]:
import datasets
from datasets import load_dataset, list_datasets
import pandas as pd 
import re 
import numpy as np 

In [5]:
class Node(object):
    '''
    each node contains 
    - parent 
    - children 
    - text
    '''
    def __init__(self,txt: str, level:int):
        self.text = txt 
        self.level = level
        self.parent = None 
        self.children = []
    def insertChild(self,child):
        self.children.append(child)
    def linkParent(self, parent):
        if(self.parent != None):
            print("ERROR: node ", self.text, "already has a parent")
        else:
            self.parent = parent 
            
        
class Tree(object):
    def __init__(self,document):
        self.root = Node(document['title'],level=0)
        self.depth = np.amax([v['type'] for v in document['document']], initial=0)
        curNode = self.root 
        # para of format {"text", "type"}
        for para in document['document']:
            newNode = Node(para['text'],para['type'])
            
            # growing in depth
            if(newNode.level == -1 or newNode.level > curNode.level):
                curNode.insertChild(newNode)
                newNode.linkParent(curNode)
                if(newNode.level > 0):
                    curNode = newNode 
                
            # new heading belong to the same or lower level of subheading 
            else: 
                # trace back to the heading level that new heading is immediately under 
                while(curNode.level>=newNode.level):
                    curNode = curNode.parent
                curNode.insertChild(newNode)
                newNode.linkParent(curNode)
                curNode = newNode 
        return 
    
    def printTree(self):
        print("======== PRINTING TREE =========")
        print("TITLE: ", self.root.text)
        print("MAX DEPTH: ", self.depth)
        print("===============================")
        def printNode(curNode):
            print(curNode.text)
            if(curNode.level == -1):
                return 
            
            for child in curNode.children:
                printNode(child)
            return 
        printNode(self.root)       

In [6]:
## helper functions 

# get type of text 
def checkHeading(txt):
    if(txt == ''):
        return -2
    if(re.search(r'^\s=.+\s=\s\n',txt)):
        return int(len(re.findall(r'\s=',txt))/2 - 1)
    return -1 

# load documents to feed to tree 
def createDocuments(data):
    documents_with = []
    documents_without = []
    document_with = []
    document_without = []
    curTitle = ''
    for i in data:
        c = checkHeading(i)
        if(c==-2):
            continue
        if(c>-1):
            # strip heading 
            i = re.findall(r'=\s([^=]+)\s=', i)[0]
        if(c==0):
            
            # clear out empty headings 
            while(len(document_with)>1 and document_with[-1]['type']!=-1):
                document_with.pop(-1)
            documents_with.append({'title': curTitle, 'document':document_with})
            documents_without.append(document_without)
            curTitle = i
            document_with = []
            document_without = []
            
        else:
            # clear out empty headings GOOFY HELP HOW TO CLEAN THIS UP 
            if(len(document_with)>1 and document_with[-1]['type']!=-1 and c <= document_with[-1]['type'] and c!=-1):
                document_with.pop(-1)
            document_with.append({'text':i,'type':c})
            if(c==-1):
                document_without.append(i)
            
    documents_with.pop(0)
    documents_without.pop(0)
    return documents_with, documents_without

loadData() creates a list of data points containing the title of article, raw text (paragraphs), and the tree representation of heading structures.

In [7]:
## load wiki dataset 
def loadData(train = False, min_size=-1):
    """
    prepare dataset for training, which is a list of dictionaries containing:  
    - document title (string)
    - paragraphs (list of string)
    - tree representation of headings
    """
    
    data_raw = load_dataset("wikitext",'wikitext-103-v1',split='train' if train else 'test')
    data_raw = data_raw['text']
    documents_with, documents_without = createDocuments(data_raw)
    
    data = []
    i = 0
    for document in documents_with:
        tree = Tree(document)
        if len(documents_without[i]) < min_size:
          continue
#         tree.printTree()
        data.append({
            "title":document['title'],
            "paragraphs":documents_without[i],
            "tree": tree
        })
        i+=1
    return data    

In [8]:
data = loadData(train=True)

Downloading builder script:   0%|          | 0.00/2.03k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.25k [00:00<?, ?B/s]

Downloading and preparing dataset wikitext/wikitext-103-v1 (download: 181.42 MiB, generated: 522.23 MiB, post-processed: Unknown size, total: 703.64 MiB) to /root/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126...


Downloading data:   0%|          | 0.00/190M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/1801350 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

Dataset wikitext downloaded and prepared to /root/.cache/huggingface/datasets/wikitext/wikitext-103-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126. Subsequent calls will reuse this data.


## DA LCA

##### Preparing LCA loss evaluation function + tools for trees (please run!)

In [9]:
# tree-related helper functions

# iterate over a tree rooted at node in preorder traversal
def preorder(node):
    if len(node.children) == 0:
        yield node
    for ch in node.children:
        yield from preorder(ch)

# only prints leaves, i.e. text representations of paragraph
# note: depends on accurate text, level population
# text should be indices
def print_tree(curNode):
    if curNode.level == -1:
        print(curNode.text, end='')
        return
    print('[', end='')
    for idx, child in enumerate(curNode.children):
        print_tree(child)
        if idx < len(curNode.children) - 1:
            print(', ', end='')
    print(']', end='')
    if curNode.level == 0:
        print() # final print after entire tree is printed

# text should be snippets
def print_snippet_tree(curNode, indent='  '):
    if curNode.level == -1:
        print(indent + curNode.text, end='')
        return
    print(indent+'[heading]')
    for idx, child in enumerate(curNode.children):
        print_snippet_tree(child, indent+'  ')
        if idx < len(curNode.children) - 1:
            print()
    if curNode.level == 0:
        print() # final print after entire tree is printed

def clone_tree(root):
    root_copy = Node(root.text, root.level)
    for ch in root.children:
        root_copy.insertChild(clone_tree(ch))
        root_copy.children[-1].linkParent(root_copy)
    return root_copy

# return indexified tree with text as paragraphs to text as indices of paragraphs for more concise printing
def indexified_tree(root):
    root_copy = clone_tree(root)
    for idx, node in enumerate(preorder(root_copy)):
        node.text = idx
    return root_copy

# return indexified tree with text as paragraphs to text as indices of paragraphs for more concise printing
def textified_tree(root, paras):
    root_copy = clone_tree(root)
    for idx, node in enumerate(preorder(root_copy)):
        node.text = paras[idx][:40] + '...'
    return root_copy

# print_tree(roots[0])
# print_tree(train_y[0].root)

In [10]:
# lca-related helper functions and lca loss
# note: assumed indexified trees
def trace_helper(node, i, trace):
    if node.text == i:
        return True
    for idx, ch in enumerate(node.children):
        if trace_helper(ch, i, trace):
            trace.append(idx)
            return True
    return False

def get_trace(root, i):
    trace = []
    trace_helper(root, i, trace)
    trace.reverse()
    return trace

def compute_lca_dist(root, i, j):
    trace_i = get_trace(root, i)
    trace_j = get_trace(root, j)
    # print(trace_i)
    # print(trace_j)
    for idx in range(min(len(trace_i), len(trace_j))):
        if trace_i[idx] != trace_j[idx]:
            return len(trace_i) + len(trace_j) - 2 * idx
    return len(trace_i) + len(trace_j)

def lca_loss(root1, root2, num_paras):
    loss = 0
    for i in range(1, num_paras+1): # 1-indexed from indexify
        for j in range(i+1, num_paras+1):
            dist1 = compute_lca_dist(root1, i, j)
            dist2 = compute_lca_dist(root2, i, j)
            # print(i, j, dist1, dist2)
            loss += (dist1 - dist2) * (dist1 - dist2)
    if num_paras == 1:
        return loss
    return loss / num_paras / (num_paras - 1) * 2

def batch_lca_loss(roots1, roots2, num_paras):
    tt = 0
    for root1, root2, num in zip(roots1, roots2, num_paras):
        tt += lca_loss(root1, root2, num)
    return tt / len(roots1)

## DA MODEL / TRAINING

##### Model 3: Recursive split MLP -- note this code may have a decoding bug

In [11]:
def add_tree_breaks(node, breaks, cur_breaks, depths, depth):
    if node.level == -1: # this is paragraph node, not a heading
        depths.append(depth)
        breaks.append(cur_breaks)
        return
    for child in node.children:
        if child == node.children[0]: # first entry
            cur_breaks += 1
            depth = node.level
        if child == node.children[-1]:
            cur_breaks = 0
            depth = 0
        add_tree_breaks(child, breaks, cur_breaks, depths, depth)
        if child == node.children[0]:
            cur_breaks = 0
            depth = 0
        if child == node.children[-1]:
            cur_breaks = 0
            depth = 0
    cur_breaks = 0

def get_breaks(paras, trees):
    max_paras = max([len(x) for x in paras])
    breaks = []
    depths = []
    for tree in trees:
        breaks_i = []
        depths_i = []
        add_tree_breaks(tree.root, breaks_i, 0, depths_i, 0)
        breaks_i = torch.tensor(breaks_i)
        depths_i[0] += 1
        depths_i = torch.tensor(depths_i)
        breaks.append(breaks_i)
        depths.append(depths_i)
    return breaks, depths

def convert_dataset(dataset, window_size):
  breaks, depths = get_breaks([d['paragraphs'] for d in dataset], [d['tree'] for d in dataset])
  X = []
  D = []
  y = []
  for i in tqdm(range(len(dataset))):  # for article
      # print('{}'.format(100 * i / len(dataset)))
      article = dataset[i]['paragraphs']
      for p in range(1, len(article)):  # for para in article (excluding first, since its clearly always a break)
          depth = depths[i][p]
          isBreak = breaks[i][p]
          for b in range(1, depth + 1):  # for depths <= depth of para
              context = []
              for j in range(p - window_size, p + window_size + 1):  # for para in context
                  if j < 0 or j >= len(article):
                      context.append(None)
                  else:
                      context.append(article[j])
              X.append(context)
              D.append(np.array([b]))
              val = 1 if b > depth - isBreak else 0  # 1 if target para is first after break at depth b
              y.append(np.array([val]))
          if depth == 0:
              for b in range(1, 8):  # MAX_DEPTH is 8
                  context = []
                  for j in range(p - window_size, p + window_size + 1):  # for para in context
                      if j < 0 or j >= len(article):
                        context.append(None)
                      else:
                        context.append(article[j])
                  X.append(context)
                  D.append(np.array([0]))
                  y.append(np.array([0]))
  return X, torch.tensor(np.stack(D, axis=0)), torch.tensor(np.stack(y, axis=0))

In [12]:
def get_split_data(data, window_size, train_ratio=0.8, val_ratio=0.0001, test_ratio=0.0001):
    random.shuffle(data)
    n_data = len(data)
    train_idx = int(train_ratio * n_data)
    val_idx = int((train_ratio + val_ratio) * n_data)
    test_idx = int((train_ratio + val_ratio + test_ratio) * n_data)
    print('getting training data')
    train_data = convert_dataset(data[:train_idx], window_size=window_size)
    print('getting val data')
    val_data = convert_dataset(data[train_idx:val_idx], window_size=window_size)
    print('getting test data')
    test_data = convert_dataset(data[val_idx:test_idx], window_size=window_size)
    return train_data, val_data, test_data

In [13]:
# iterate over batches of data and labels
def batch_iter(data, batch_size, shuffle=False):
    X, D, y = data
    batch_num = math.ceil(len(X) / batch_size)
    index_array = list(range(len(X)))

    if shuffle:
        np.random.shuffle(index_array)

    for i in range(batch_num):
        indices = index_array[i * batch_size: (i + 1) * batch_size]
        batch_data_X = [X[idx] for idx in indices]
        batch_data_D = D[indices]
        batch_data_y = y[indices]

        yield batch_data_X, batch_data_D, batch_data_y

In [14]:
# MODULES FOR EMBEDDING 3 DIFFERENT WAYS: DOC2VEC, PROJECTED SIMCSE, PROJECTED FASTTEXT
# each one takes a batch of lists of paragraphs, outputs a batch of concatenated paragraph embeddings
# forward pass input: batch (len B) of list of paragraphs (len 2 * window_size + 1), each para variable length
# forward pass output: tensor of size B x ((2 * window_size + 1) * emb_dim)
class Doc2VecEmbedding(nn.Module):
  def __init__(self, window_size, emb_dim): 
    super().__init__()
    self.emb_dim = emb_dim
    self.doc2vec = Doc2Vec.load("models/{}_d2v.model".format(emb_dim))

  def forward(self, x):
    with torch.no_grad():
      batch = []
      for b in x:
        paras = []
        for p in b:
          paras.append(self.doc2vec.infer_vector(word_tokenize(p.lower())) if p is not None else np.zeros(shape=(self.emb_dim)))
        batch.append(np.concatenate(paras, axis=0))
      return torch.tensor(np.stack(batch, axis=0)).to(device)

class SimCSEEmbedding(nn.Module):
  def __init__(self, window_size, emb_dim, dropout): 
    super().__init__()
    self.window_size = window_size
    self.emb_dim = emb_dim
    self.simcse = SimCSE("princeton-nlp/sup-simcse-bert-base-uncased")
    self.SIMCSE_DIM = 768 # dim of simcse sentence embeddings
    self.lstm = nn.LSTM(input_size=self.SIMCSE_DIM,
                        hidden_size=int(emb_dim / 4),
                        num_layers=1,
                        bidirectional=True,
                        batch_first=True, 
                        dropout=0.0).to(device)

  def forward(self, x):
    B = len(x)  # batch size
    batch = []
    with torch.no_grad():
      for b in x:
        for p in b:
          if p is not None:
            p = re.sub('\n', '', p)
            sents = re.split('[.]|[!]|[?]', p.strip())
            with io.capture_output() as captured:
                sents_emb = self.simcse.encode(sents, device=device, batch_size=len(sents), max_length=128)  # a tensor of len(sents) x SIMCSE_DIM
            batch.append(sents_emb)
          else:
            batch.append(torch.zeros((1, self.SIMCSE_DIM), device=device))
      batch_emb = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True).to(device)
      packed_in = torch.nn.utils.rnn.pack_padded_sequence(batch_emb, torch.tensor([a.shape[0] for a in batch_emb]), batch_first=True, enforce_sorted=False)

    _, (hidden, cell) = self.lstm(packed_in.to(device))
    para_embs = torch.cat((hidden[0], cell[0], hidden[1], cell[1]), dim=1)        
      
    return para_embs.reshape(B, (2 * self.window_size + 1) * self.emb_dim)


#### HERE
class FastTextEmbedding(nn.Module):
  def __init__(self, window_size, emb_dim, dropout): 
    super().__init__()
    self.window_size = window_size
    self.emb_dim = emb_dim
    self.fasttext = FastText.load_fasttext_format('models/fast-text-300.bin')
    self.FASTTEXT_DIM = 300 # dim of fasttext word embeddings
    self.lstm = nn.LSTM(input_size=self.FASTTEXT_DIM,
                        hidden_size=int(emb_dim / 4),
                        num_layers=1,
                        bidirectional=True,
                        batch_first=True, 
                        dropout=0.0).to(device)

  def forward(self, x):
    B = len(x)  # batch size
    batch = []
    with torch.no_grad():
      for b in x:
        for p in b:
          if p is not None:
            p = re.sub('\n', '', p)
            words = re.sub("[^\s\w]", "", p.strip()).split(' ')
            words = list(filter(None, words))
            words_emb = torch.stack([torch.tensor(self.fasttext[word]) for word in words]).to(device)  # a tensor of len(words) x FASTTEXT_DIM
            batch.append(words_emb)
          else:
            batch.append(torch.zeros((1, self.FASTTEXT_DIM), device=device))
      batch_emb = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True).to(device)
      packed_in = torch.nn.utils.rnn.pack_padded_sequence(batch_emb, torch.tensor([a.shape[0] for a in batch_emb]), batch_first=True, enforce_sorted=False)

    _, (hidden, cell) = self.lstm(packed_in.to(device))
    para_embs = torch.cat((hidden[0], cell[0], hidden[1], cell[1]), dim=1)        
      
    return para_embs.reshape(B, (2 * self.window_size + 1) * self.emb_dim)

In [15]:
class MLP(nn.Module):
  def __init__(self, layer_dims, window_size, emb_dim, emb_method, dropout=0.1):
        super().__init__()
        self.window_size = window_size
        self.emb_dim = emb_dim # dimension of each paragraph embedding
        if emb_method == 'doc2vec':
          self.emb = Doc2VecEmbedding(window_size, emb_dim)
        elif emb_method == 'simcse':
          self.emb = SimCSEEmbedding(window_size, emb_dim, dropout=dropout)
        elif emb_method == 'fasttext':
          self.emb = FastTextEmbedding(window_size, emb_dim, dropout=dropout)
        else:
          raise NotImplementedError()

        in_feats = (2 * window_size + 1) * self.emb_dim
        self.layers = []
        for dim in layer_dims:
          self.layers.append(nn.Linear(in_features=in_feats + 1, out_features=dim))
          in_feats = dim
        self.layers.append(nn.Linear(in_features=in_feats + 1, out_features=1))
        self.layers = nn.ModuleList(self.layers)
        self.dropout = nn.Dropout(dropout)

  def forward(self, x, d):
      '''
      x is a batch (list) of windows (list) of paragraphs, which are strings
      d is the depths for the whole batch (tensor)
      '''
      B = len(x)
      x = self.emb(x).float()  # x is a (B x (2W + 1) x E) tensor
      x = x.reshape(B, (2 * self.window_size + 1) * self.emb_dim)
      for layer in self.layers[:-1]:
        x = F.relu(layer(torch.cat((x, d), dim=1)))
      x = self.dropout(x)
      x = torch.sigmoid(self.layers[-1](torch.cat((x, d), dim=1)))
      return x

  def recursive_outline(self, subarticle, node):
    if len(subarticle) == 1:
        new = Node(subarticle[0], -1)
        new.linkParent(node)
        node.insertChild(new)
        return
    outs = [False]
    for p in range(1, len(subarticle)):
        context = []
        for j in range(p - self.window_size, p + self.window_size + 1):
            if j < 0 or j >= len(subarticle):
                context.append(None)
            else:
                context.append(subarticle[j])
        X = [context]
        D = torch.tensor([node.level + 1]).to(device).unsqueeze(dim=0)
        out = self.forward(X, D).squeeze()
        outs.append(out.cpu().item() > 0.77)
    prev = 0
    flag = True
    for o in range(len(outs)):
        if outs[o]:
            new = Node('', node.level + 1)
            new.linkParent(node)
            node.insertChild(new)
            self.recursive_outline(subarticle[prev:o], new)
            prev = o
            flag = False
    if flag:
      for p in range(len(subarticle)):
        new = Node(subarticle[p], -1)
        new.linkParent(node)
        node.insertChild(new)

        # else:
        #     new = Node(subarticle[o], -1)
        #     new.linkParent(node)
        #     node.insertChild(new)
    else:
       new = Node('', node.level + 1)
       new.linkParent(node)
       node.insertChild(new)
       self.recursive_outline(subarticle[prev:], new)
    return

  def outline(self, article, wordy=False):
      self.eval()
      root = Node('root', 1)
      self.recursive_outline(article, root)
      new = Node(article[0], -1)
      curr = root
      while len(curr.children) > 0 and curr.level != -1:
        curr = curr.children[0]
      new.linkParent(curr.parent)
      curr.parent.insertChild(new)
      if len(article) > 1:
          new = Node(article[len(article)-1], -1)
          curr = root
          while len(curr.children) > 0 and curr.level != -1:
            curr = curr.children[-1]
          new.linkParent(curr.parent)
          curr.parent.insertChild(new)

      def printNode(curNode):
          print(curNode.level, '       ', curNode.text)
          if curNode.level == -1:
              return

          for child in curNode.children:
              printNode(child)
          return

      if wordy:
        printNode(root)
      return root

  def save(self, path: str):
      """ Save the model to a file.
      @param path (str): path to the model
      """

      params = {
          # 'args': dict(hid_dim=self.hid_dim, n_layers=self.n_layers, num_heads=self.num_heads,
          #       num_enc_layers=self.num_enc_layers, num_dec_layers=self.num_dec_layers, ff_dim=self.ff_dim, dropout=self.dropout),
          'state_dict': self.state_dict()
      }

      torch.save(params, path)

In [20]:
def train(model, train_data, val_data, lr=0.002, batch_size=32, grad_clip=5.0, lr_decay=0.5,
          max_epoch=50, log_every=5, valid_niter=25, max_patience=4, max_num_trial=5, model_path='mlp.bin'):
    model.train()
    model.float()

    # # initialize model parameters
    # for p in model.parameters():
    #     p.data.uniform_(-0.1, 0.1)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('use device: %s' % device)

    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    print('{} parameters!'.format(sum([np.prod(p.size()) for p in model_parameters])))

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr)

    num_trial = 0
    train_iter = patience = cum_loss = report_loss = cum_tgt_words = report_tgt_words = 0
    cum_examples = report_examples = epoch = valid_num = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()
    print('begin Maximum Likelihood training')

    train_losses = []
    val_losses = []
    loss_fn = nn.BCELoss(reduction='sum')
    
    while True:
        epoch += 1
        batch_num = math.ceil(len(train_data[0]) / batch_size)
        current_iter = 0
        for batch in batch_iter(train_data, batch_size=batch_size, shuffle=True):
            X, D, y = batch
            D = D.to(device)
            y = y.to(dtype=torch.float32, device=device)

            model.train()
            current_iter += 1
            train_iter += 1

            optimizer.zero_grad()
            batch_size = len(X)
            out = model(X, D)
            train_loss = loss_fn(out, y)
            train_loss.backward()

            # clip gradient
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)

            optimizer.step()

            report_loss += train_loss.item()
            cum_loss += train_loss.item()
            report_examples += batch_size
            cum_examples += batch_size

            if train_iter % log_every == 0:
                print('epoch %d (%d / %d), iter %d, avg train loss %f, '
                      'cum examples %d, time elapsed %.2f sec' %
                      (epoch, current_iter, batch_num, train_iter,
                       report_loss / report_examples,
                       cum_examples,
                       time.time() - begin_time))

                train_time = time.time()
                report_loss = report_examples = 0.

            # perform validation
            if train_iter % valid_niter == 0:
                model.eval()
                with torch.no_grad():
                    print('epoch %d, iter %d, cum loss %f, cum examples %d' % (epoch, train_iter,
                            cum_loss / cum_examples,
                            cum_examples))
                    train_losses.append(cum_loss / cum_examples)
                    cum_loss = cum_examples = 0.

                    print('begin validation ...')

                    val_cum_loss = 0.
                    val_cum_examples = 0

                    count = 0
                    NUM_BATCHES = 8  # number of batches to validate over each time
                    for batch in batch_iter(val_data, batch_size, shuffle=True):
                        if count >= NUM_BATCHES:
                            break
                        X, D, y = batch
                        D = D.to(device)
                        y = y.to(dtype=torch.float32, device=device)

                        current_iter += 1
                        train_iter += 1

                        batch_size = len(X)
                        out = model(X, D)
                        val_loss = loss_fn(out, y)
                        val_cum_loss += val_loss.item()
                        val_cum_examples += batch_size
                        count += 1

                    val_losses.append(val_cum_loss / val_cum_examples)
                    valid_metric = -val_cum_loss / val_cum_examples # metric for evaluating whether model is improving on val data

                    print('validation: iter %d, val loss %f' % (train_iter, val_cum_loss / val_cum_examples))

                    is_better = len(hist_valid_scores) == 0 or valid_metric > max(hist_valid_scores)
                    hist_valid_scores.append(valid_metric)

                    if is_better:
                        patience = 0
                        print('epoch %d, iter %d: save currently the best model to [%s]' %
                                (epoch, train_iter, model_path))
                        model.save(model_path)
                        torch.save(optimizer.state_dict(), model_path + '.optim')
                        np.save('{}_train.npy'.format(model_path.split('.')[-2]), np.array(train_losses))
                        np.save('{}_val.npy'.format(model_path.split('.')[-2]), np.array(val_losses))
                    elif patience < max_patience:
                        patience += 1
                        print('hit patience %d' % patience)

                        if patience == max_patience:
                            num_trial += 1
                            print('hit #%d trial' % num_trial)
                            if num_trial == max_num_trial:
                                print('early stop!')
                                exit(0)

                            # decay lr, and restore from previously best checkpoint
                            lr = optimizer.param_groups[0]['lr'] * lr_decay
                            print('load previously best model and decay learning rate to %f' % lr)

                            # load model
                            params = torch.load(model_path, map_location=lambda storage, loc: storage)
                            model.load_state_dict(params['state_dict'])
                            model = model.to(device)
                            train_losses = list(np.load('{}_train.npy'.format(model_path.split('.')[-2]), allow_pickle=True))
                            val_losses = list(np.load('{}_val.npy'.format(model_path.split('.')[-2]), allow_pickle=True))

                            print('restore parameters of the optimizers')
                            optimizer.load_state_dict(torch.load(model_path + '.optim'))

                            # set new lr
                            for param_group in optimizer.param_groups:
                                param_group['lr'] = lr

                            # reset patience
                            patience = 0

        if epoch == max_epoch:
            print('reached maximum number of epochs!')
            break

## DA USER INTERACTION AREA

In [17]:
# MODEL SPECS
WINDOW_SIZE = 3  # number of neighbors to consider in each direction
EMB_DIM = 256  # dim each paragraph becomes, via magic :)
EMB_METHOD = 'simcse'  # one of 'doc2vec', 'simcse', 'fasttext'
MLP_ARCHITECTURE = [1024, 256, 64]  # sizes of hidden layers in MLP

model = MLP(layer_dims=MLP_ARCHITECTURE, window_size=WINDOW_SIZE, emb_dim=EMB_DIM, emb_method=EMB_METHOD)

Downloading:   0%|          | 0.00/252 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/689 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/418M [00:00<?, ?B/s]

In [None]:
# TRAINING STUFF
train_data, val_data, test_data = get_split_data(data, window_size=WINDOW_SIZE, train_ratio=0.65, val_ratio=0.3, test_ratio=0.01)
train(model, train_data, val_data, lr=0.002, batch_size=128, grad_clip=5.0, lr_decay=0.5,
      max_epoch=100, log_every=5, valid_niter=25, max_patience=5, max_num_trial=5, model_path='checkpoints/mlp_{}.bin'.format(EMB_METHOD))

getting training data


100%|██████████| 19138/19138 [00:24<00:00, 797.22it/s] 


getting val data


100%|██████████| 8833/8833 [00:10<00:00, 856.39it/s] 


getting test data


100%|██████████| 295/295 [00:00<00:00, 878.90it/s]


use device: cuda:0
2543298 parameters!
begin Maximum Likelihood training
epoch 1 (5 / 24400), iter 5, avg train loss 0.068518, cum examples 640, time elapsed 63.39 sec


In [19]:
# EVAL STUFF
params = torch.load('checkpoints/mlp_{}.bin'.format(EMB_METHOD), map_location=lambda storage, loc: storage)
model.load_state_dict(params['state_dict'])
model = model.to(device)

# article = data[20]['paragraphs']
sample = data[200:300]
roots = [indexified_tree(model.outline(a['paragraphs'])) for a in sample]
golds = [indexified_tree(a['tree'].root) for a in sample]
lens = [len(a['paragraphs']) for a in sample]
batch_lca_loss(roots, golds, lens)

1.64601969725356