In [1]:
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader
import torch.nn as nn
from collections import defaultdict
import math
import scripts.state as state
import config as cfg
import data_loader as dl
from model import *
import copy
from scripts.evaluate import compute_metrics
import eval_utils as util


In [2]:
pose_set = dl.convert_to_labels(dl.get_file_contents(cfg.pose_set_file))
tagset = dl.convert_to_labels(dl.get_file_contents(cfg.tagset_file))
reverse_tagset = dict(zip(tagset.values(), tagset.keys()))
glove = torchtext.vocab.GloVe(name="840B", dim=300)

model = TDParser2(pose_set_dim=len(pose_set), pos_embedding_dim=cfg.pos_embedding_dim, glove_dim=300, tagset_dim=len(tagset), combine="concatenate")
model.to(cfg.DEVICE)

model.load_state_dict(torch.load("checkpoints/c_840b300_lr_0_0001/save/model_val.torch"))

<All keys matched successfully>

In [3]:
def run_parse_tree(model, parse_data, pose_set, glove, device, reverse_tagset):
    words_labels = dl.convert_sentence_to_labels(parse_data["words"])
    init_stack = []
    init_buffer = [state.Token(idx=words_labels[parse_data["words"][i]], word=parse_data["words"][i], pos=parse_data["pos"][i]) for i in range(len(parse_data["words"]))]
    init_dependencies = []
    ps = state.ParseState(stack=init_stack, parse_buffer=init_buffer, dependencies=init_dependencies)
    

    pred_actions = []

    max_num_actions = 2*len(parse_data["words"]) - 1

    while state.is_final_state(ps, cwindow=2) == False:# and max_num_actions > 0:

        sw, sp = dl.get_stack(ps.stack)
        bw, bp = dl.get_buffer(ps.parse_buffer)

        in_words = sw + bw
        in_pos = sp + bp
        pos_labels = torch.tensor(np.array([pose_set[pos] for pos in in_pos])).unsqueeze(0).to(device)
        word_embeds = glove.get_vecs_by_tokens(in_words).unsqueeze(0).to(device)

        pred_action_probs = model(word_embeds, pos_labels)

        ## get top k predictions
        _ , top_indices = torch.topk(pred_action_probs, 2)
        top_indices = top_indices.squeeze()
        current_action = reverse_tagset[top_indices[0].item()]

        ## solve for illegal actions
        if "REDUCE" in current_action and len(ps.stack) <= 1:
            current_action = "SHIFT"
        elif "SHIFT" in current_action and len(ps.parse_buffer) < 1:
            current_action = reverse_tagset[top_indices[1].item()]

        if "REDUCE_L" in current_action:
            state.left_arc(ps, current_action[9:])
        elif "REDUCE_R" in current_action:
            state.right_arc(ps, current_action[9:])
        else:
            state.shift(ps)
            
        pred_actions.append(current_action)

        max_num_actions -= 1

    return ps, pred_actions

In [4]:
### Train Example for checking the sanity
parse_data = {}
parse_data["words"] = ["Aesthetic", "Appreciation", "and", "Spanish", "Art", ":"]
parse_data["pos"] = ["ADJ", "NOUN", "CCONJ", "ADJ", "NOUN", "PUNCT"]
ps, _ = run_parse_tree(model, parse_data, pose_set, glove, cfg.DEVICE, reverse_tagset)
for d in ps.dependencies:
    print(f"s:{d.source.word}, t:{d.target.word}, l:{d.label}")

s:Appreciation, t:Aesthetic, l:amod
s:Art, t:Spanish, l:amod
s:Art, t:and, l:cc
s:Appreciation, t:Art, l:nmod
s:Appreciation, t::, l:punct


In [17]:
parse_data = {}
parse_data["words"] = ["Mary", "had", "a", "little", "lamb", "." ]
parse_data["pos"] = ["PROPN", "AUX", "DET", "ADJ", "NOUN", "PUNCT"]
ps, _ = run_parse_tree(model, parse_data, pose_set, glove, cfg.DEVICE, reverse_tagset)
for d in ps.dependencies:
    print(f"s:{d.source.word}, t:{d.target.word}, l:{d.label}")

s:lamb, t:little, l:amod
s:lamb, t:a, l:det
s:lamb, t:had, l:aux
s:lamb, t:Mary, l:nmod
s:lamb, t:., l:punct


In [18]:
parse_data = {}
parse_data["words"] = ["I", "ate", "the", "fish", "raw", "."]
parse_data["pos"] = ["PRON", "VERB", "DET", "NOUN", "ADJ", "PUNCT"]
ps, _ = run_parse_tree(model, parse_data, pose_set, glove, cfg.DEVICE, reverse_tagset)
for d in ps.dependencies:
    print(f"s:{d.source.word}, t:{d.target.word}, l:{d.label}")

s:ate, t:I, l:nsubj
s:fish, t:the, l:det
s:raw, t:fish, l:nsubj
s:ate, t:raw, l:conj
s:ate, t:., l:punct


In [19]:
parse_data = {}
parse_data["words"] = ["With", "neural", "networks", ",", "I", "love", "solving", "problems", "."]
parse_data["pos"] = ["ADP", "ADJ", "NOUN", "PUNCT", "PRON", "VERB", "VERB", "NOUN", "PUNCT"]
ps, _ = run_parse_tree(model, parse_data, pose_set, glove, cfg.DEVICE, reverse_tagset)
for d in ps.dependencies:
    print(f"s:{d.source.word}, t:{d.target.word}, l:{d.label}")

s:networks, t:neural, l:amod
s:networks, t:With, l:case
s:networks, t:,, l:punct
s:love, t:I, l:nsubj
s:solving, t:problems, l:obj
s:love, t:solving, l:advcl
s:networks, t:love, l:acl
s:networks, t:., l:punct


In [5]:
## Get output on hidden dataset
hidden_dataset = dl.ParsingDatasetEval(data_file=cfg.hidden_data_path, pose_set=pose_set, tagset=tagset, glove=glove, split='hidden')

hidden_loader = DataLoader(
    hidden_dataset,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    collate_fn=dl.custom_collate_fn,
    pin_memory=torch.cuda.is_available()
)

util.run_hidden_data(model, hidden_loader, reverse_tagset, "results.txt", cfg.DEVICE, complete=False)