In [1]:
%load_ext autoreload
%autoreload 2
from pathlib import Path
import torch
from typing import Dict, List, Iterator, Union, Any
import logging
import tempfile
import json
import shutil
from pathlib import Path
from pprint import pprint
from models.utils.wandb import RunData
from models.utils.allennlp import (
    load_config, load_dataset_reader, load_iterator, load_model,load_best_metrics, load_modules, 
    load_outputs, create_onepass_generator)
from allennlp.models import Model
from allennlp.data import DatasetReader, DataIterator
import pandas as pd
from copy import deepcopy
import plotly.graph_objs as go
import numpy as np
from ipywidgets import interact, fixed
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", level=logging.INFO)

logger = logging.getLogger(__name__)


In [13]:
# config
RUN = 'rmhlaa7z'
EXTRA_MODULES = ['models', 'datasets']
DATA_DIR = '/Users/dhruv/UnsyncedDocuments/IESL/kb_completion/models/.data'
DOWNLOADS_DIR = '.models/HYPO_50'
#DOWNLOADS_DIR = '/var/folders/r2/8mbb22091rb4vq4xtd19bbw80000gn/T/tmp5ypyp_4a'

In [14]:
# get the run files
logger.info("Downloading run files if needed...")
rd = RunData(RUN, group='iesl-boxes', project='kb-completion',
             download_dir=Path(DOWNLOADS_DIR))

rd.download_files(skip_patterns=['allennlp_serialization_dir/metrics_epoch_[\d]+.json$'])

2020-02-13 22:33:14,899 - INFO - __main__ - Downloading run files if needed...
2020-02-13 22:33:15,886 - INFO - models.utils.wandb - Setting up download dir as .models/HYPO_50/rmhlaa7z
2020-02-13 22:33:16,119 - INFO - models.utils.wandb - Downloading allennlp_serialization_dir/best.th
2020-02-13 22:33:17,264 - INFO - models.utils.wandb - Downloading allennlp_serialization_dir/config.json
2020-02-13 22:33:17,858 - INFO - models.utils.wandb - Downloading allennlp_serialization_dir/log/train/events.out.tfevents.1581556112.node163
2020-02-13 22:33:20,397 - INFO - models.utils.wandb - Downloading allennlp_serialization_dir/metrics.json
2020-02-13 22:33:26,192 - INFO - models.utils.wandb - Downloading allennlp_serialization_dir/model.tar.gz
2020-02-13 22:33:27,034 - INFO - models.utils.wandb - Downloading allennlp_serialization_dir/model_state_epoch_896.th
2020-02-13 22:33:27,844 - INFO - models.utils.wandb - Downloading allennlp_serialization_dir/stderr.log
2020-02-13 22:33:28,819 - INFO - 

In [15]:

# load modules
logger.info("Loading {} for AllenNLP".format(EXTRA_MODULES))
load_modules(EXTRA_MODULES)

# override the val dataset setup to become the test dataset setup
overrides = {'validation_dataset_reader':
             {
                 'all_datadir': DATA_DIR,
                 "validation_file": "classification_samples_test2id.txt"
             }
            }
logger.info("Setting up config overrides ...")

# load config
config = load_config(
    rd.download_dir/"allennlp_serialization_dir", overrides_dict=overrides)
#pprint(config.as_dict())

# load best metrics
#best_metrics = load_best_metrics(rd.download_dir/"allennlp_serialization_dir")
# load model
model = load_model(rd.download_dir/"allennlp_serialization_dir", config=config)
#model.test_threshold = best_metrics['best_validation_threshold']
model.test_threshold =-5.59



2020-02-13 22:33:34,455 - INFO - __main__ - Loading ['models', 'datasets'] for AllenNLP
2020-02-13 22:33:34,696 - INFO - __main__ - Setting up config overrides ...
2020-02-13 22:33:34,766 - INFO - allennlp.common.params - Converting Params object to dict; logging of default values will not occur when dictionary parameters are used subsequently.
2020-02-13 22:33:34,767 - INFO - allennlp.common.params - CURRENTLY DEFINED PARAMETERS: 
2020-02-13 22:33:34,768 - INFO - allennlp.common.params - model.box_type = DeltaBoxTensor
2020-02-13 22:33:34,771 - INFO - allennlp.common.params - model.debug = False
2020-02-13 22:33:34,772 - INFO - allennlp.common.params - model.embedding_dim = 2
2020-02-13 22:33:34,773 - INFO - allennlp.common.params - model.init_interval_center = 0.2
2020-02-13 22:33:34,774 - INFO - allennlp.common.params - model.init_interval_delta = 1
2020-02-13 22:33:34,774 - INFO - allennlp.common.params - model.num_entities = 164228
2020-02-13 22:33:34,775 - INFO - allennlp.common.

In [16]:
# create vocab dict
from datasets.file_readers.openke import SamplesIdReader
vocab_file = Path('/Users/dhruv/UnsyncedDocuments/IESL/kb_completion/models/.data/noun_closure.tsv.vocab')
id2name = {}
name2idx = {}
with open(vocab_file) as f:
    for line in f:
        pair = [e.strip() for e in line.split()]
        id2name[int(pair[0])] = pair[1]
        name2idx[pair[1]] = int(pair[0])
    

In [222]:
# nodes to show
nodes = { #'cat.n.01': {},
    'animal.n.01': {'idx':None,'color':'yellow'},
    'dog.n.01':{'idx':None,'color':'red'},
         #'mammal.n.01':{'idx':None,'color':'blue'},
         'hound.n.01': {},
         'working_dog.n.01':{},
         #'watchdog.n.01':{'idx':None,'color':'green'},
    #'pet.n.01':{},
    'domestic_animal.n.01':{},
    'terrier.n.01':{},
    #'poodle.n.01':{},
    'caterpillar.n.01':{},
    #'sedan.n.01':{},
    'larva.n.01':{},
    #'beet_armyworm.n.01':{},
    #'invertebrate.n.01':{},
        }
#nodes = {'physical_entity.n.01':{},
#         'matter.n.03': {},
#         'sediment.n.01':{},
#         'substance.n.04': {},
#         'substance.n.07': {},
#        }


In [223]:
# assign idx
colors = [
    '#1f77b4',  # muted blue
    '#ff7f0e',  # safety orange
    '#2ca02c',  # cooked asparagus green
    '#d62728',  # brick red
    '#9467bd',  # muted purple
    '#8c564b',  # chestnut brown
    '#e377c2',  # raspberry yogurt pink
    '#7f7f7f',  # middle gray
    '#bcbd22',  # curry yellow-green
    '#17becf',   # blue-teal
    '#17b4cf',
    '#17b41f',
    '#11b4cf' ,
    # blue-teal
]
color_i = 0
for name, node in nodes.items():
    node["idx"] = name2idx[name]
    node['color'] = colors[color_i]
    color_i+=1
    with torch.no_grad():
        t = model.h(torch.tensor(node['idx'],dtype=torch.long))
    node["pos"] = {}
    n = node["pos"]
    (n["x0"],n["y0"]),(n["x1"],n["y1"]) = t.z.detach().tolist(), t.Z.detach().tolist()
    
    

In [224]:
pprint(nodes)

{'animal.n.01': {'color': '#1f77b4',
                 'idx': 7865,
                 'pos': {'x0': -1.3256515264511108,
                         'x1': -0.4576614499092102,
                         'y0': 0.8527862429618835,
                         'y1': 1.2844929695129395}},
 'caterpillar.n.01': {'color': '#e377c2',
                      'idx': 28394,
                      'pos': {'x0': -0.9928223490715027,
                              'x1': -0.8595743775367737,
                              'y0': 0.8939406275749207,
                              'y1': 0.9827897548675537}},
 'dog.n.01': {'color': '#ff7f0e',
              'idx': 40241,
              'pos': {'x0': -1.0852129459381104,
                      'x1': -0.6826506853103638,
                      'y0': 1.045533537864685,
                      'y1': 1.1646744012832642}},
 'domestic_animal.n.01': {'color': '#9467bd',
                          'idx': 5230,
                          'pos': {'x0': -1.1678781509399414,
                

In [225]:
import chart_studio.plotly as py
import plotly.graph_objs as go

In [226]:
def add_boxes(data, fig):
    for name, info in data.items():
        n = info["pos"]
        fig.add_trace(go.Scatter(
        x=[n["x0"]+0.05],
        y=[n["y0"]+0.006],
        text=[name],
            name=name,
        mode="text",
        textfont=dict(color=info['color'])
        ))
        fig.add_shape(type="rect",**(info['pos']), line=dict(color=info['color']))
        
    return fig
        

In [231]:


# Create figure with scatter trace
fig = go.Figure()
add_boxes(nodes, fig)

fig.update_layout(
    width = 1000,
    height = 800,
    #yaxis = dict(
    #  scaleanchor = "x",
    #  scaleratio = 1,
    #)
)

fig.update_layout(
    legend=dict(
        x=0.8,
        y=1))


fig.update_layout(showlegend=False)

In [232]:
fig.show()