In [1]:
import sys
sys.path.append('/fsx/matzeni/duck')

In [2]:
import torch
import h5py
import json
from pathlib import Path
import pickle
from tqdm import tqdm
import logging
from transformers import AutoTokenizer
from typing import Any, Dict, List, Optional, Tuple
import copy
from einops import rearrange, repeat
import numpy as np
from duck.box_tensors import BoxTensor
from duck.task.duck_entity_disambiguation import Duck
from hydra import compose, initialize
import hydra
from duck.common.utils import make_reproducible
import collections
import logging
import math
logger = logging.getLogger()

In [3]:
initialize(config_path="conf", version_base=None)

hydra.initialize()

In [5]:
ckpt_path = "/checkpoints/matzeni/duck/checkpoints/jaccard_kldiv_epoch=0_duck_loss=-0.794.ckpt"
duck = Duck.load_from_checkpoint(ckpt_path).eval().cuda()

In [6]:
config = compose(
    config_name="duck_conf",
    overrides=[]
)

In [7]:
datamodule = hydra.utils.instantiate(config.data)

In [8]:
def batch_to_cuda(batch):
    if isinstance(batch, torch.Tensor):
        return batch.cuda()
    if isinstance(batch, collections.abc.Mapping):
        return {k: batch_to_cuda(v) for k, v in batch.items()}
    if isinstance(batch, str):
        return batch
    if isinstance(batch, collections.abc.Sequence):
        return [batch_to_cuda(v) for v in batch]
    return batch

In [9]:
make_reproducible(ngpus=config.trainer.devices)
train_data = datamodule.train_dataloader()
train_data_iterator = iter(train_data)

In [10]:
batch = next(train_data_iterator)
batch = next(train_data_iterator)
batch = batch_to_cuda(batch)
representations = duck(batch)
mentions = representations["mentions"]
entities = representations["entities"]
entity_boxes = representations["entity_boxes"]
neighbors = representations["neighbors"]
neighbor_boxes = representations["neighbor_boxes"]
entity_relation_ids = batch["relation_ids"]
neighbor_relation_ids = batch["neighbor_relation_ids"]
neighbors, neighbor_boxes, neighbor_relation_ids = duck._extend_with_in_batch_neighbors(
    neighbors,
    neighbor_boxes,
    neighbor_relation_ids
)        

In [11]:
print(batch["entity_labels"])

['Skeid Fotball', 'Atlético Madrid', 'FC Lokomotíva Košice', 'Japan', 'England', 'Switzerland', 'Greenwich Mean Time', 'Hull City A.F.C.', 'Mike Watkinson', 'Albania', "Finland men's national ice hockey team", 'United States', 'United States', 'Portugal', 'United Nations', 'London Stansted Airport', 'Germany', 'Croatia', 'Northamptonshire County Cricket Club', 'Paris', 'Super League', 'Husqvarna Motorcycles', 'Mikael Tillström', 'St. Louis Cardinals', 'Carl Fogarty', "Standard & Poor's", 'Pol Pot', 'Dakar', 'Olympic sports', 'Greenwich Mean Time', 'Sweden', 'Brisbane Lions']


In [12]:
entity_boxes

BoxTensor(
	left=tensor([[-0.0151,  0.0026, -0.0225,  ..., -0.0041, -0.0206,  0.0002],
        [-0.0151,  0.0026, -0.0225,  ..., -0.0041, -0.0206,  0.0002],
        [-0.0151,  0.0026, -0.0225,  ..., -0.0041, -0.0206,  0.0002],
        ...,
        [-0.0151,  0.0026, -0.0225,  ..., -0.0041, -0.0206,  0.0002],
        [-0.0151,  0.0026, -0.0225,  ..., -0.0041, -0.0206,  0.0002],
        [-0.0151,  0.0026, -0.0225,  ..., -0.0041, -0.0206,  0.0002]],
       device='cuda:0', grad_fn=<AddmmBackward>),
	right=tensor([[0.9895, 1.0000, 0.9884,  ..., 0.9928, 0.9911, 0.9933],
        [0.9895, 1.0000, 0.9884,  ..., 0.9928, 0.9911, 0.9933],
        [0.9895, 1.0000, 0.9884,  ..., 0.9928, 0.9911, 0.9933],
        ...,
        [0.9895, 1.0000, 0.9884,  ..., 0.9928, 0.9911, 0.9933],
        [0.9895, 1.0000, 0.9884,  ..., 0.9928, 0.9911, 0.9933],
        [0.9895, 1.0000, 0.9884,  ..., 0.9928, 0.9911, 0.9933]],
       device='cuda:0', grad_fn=<AddBackward0>)
)