In [1]:
import sys
sys.path.append('/fsx/matzeni/duck')

In [2]:
import torch
import h5py
import json
from pathlib import Path
import pickle
from tqdm import tqdm
import logging
from transformers import AutoTokenizer
from typing import Any, Dict, List, Optional, Tuple
import copy
from einops import rearrange, repeat
import numpy as np
from duck.box_tensors import BoxTensor
from duck.task.duck_loss import BoxEDistance
from duck.task.duck_entity_disambiguation import Duck
from hydra import compose, initialize
import hydra
from duck.common.utils import seed_prg
import collections
import logging
import math
from duck.box_tensors.functional import cat_box
from matplotlib import pyplot as plt
from duck.common.utils import cartesian_to_spherical, load_json, load_jsonl
from einops import repeat, rearrange
import pandas as pd
from omegaconf import open_dict
import gc

In [3]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f886c2529a0>

In [4]:
initialize(config_path="../conf", version_base=None)

hydra.initialize()

In [5]:
# ckpt_path = "/checkpoints/matzeni/duck/checkpoints/duck/duck_hard_epoch=0_val_f1=0.887_130087_last.ckpt"
ckpt_path = "/checkpoints/matzeni/duck/checkpoints/duck/duck_aida_epoch=1_val_f1=0.884_130100.ckpt"
duck = Duck.load_from_checkpoint(ckpt_path).eval().cuda()

In [6]:
config = compose(
    config_name="duck",
    overrides=[]
)
with open_dict(config):
    # No need to read the training set
    config.data.relation_threshold = 0
    config.data.train_path = config.data.val_paths.aida_train
    config.data.val_paths = {}
    config.data.batch_size = 1
    config.data.transform.max_mention_len = 512
datamodule = hydra.utils.instantiate(config.data)

In [7]:
dataset_names = list(config.data.test_paths.keys())
test_dataloaders = {
    dataset_names[i]: dataloader for i, dataloader in enumerate(datamodule.test_dataloader())
}

In [8]:
def eval_loop(dataloader):
    results = []
    for batch in tqdm(dataloader):
        batch = duck.batch_to_device(batch)
        mentions = duck.encode_mention(batch["mentions"])
        if batch["candidate_tokens"]["data"].dim() == 2:
            continue
        candidates = batch["candidate_tokens"]["data"]
        candidate_mask = batch["candidate_tokens"]["attention_mask"].bool()
        n = candidates.size(1)
        candidates, _ = duck.entity_encoder(
            rearrange(candidates, "b n d -> (b n) d"),
            attention_mask=rearrange(candidate_mask, "b n d -> (b n) d")
        )
        candidates = rearrange(candidates, "(b n) d -> b n d", n=n)
        scores = duck.score(mentions, candidates)
        mask = batch["candidates"]["attention_mask"].bool()
        scores[~mask] = 0
        preds = scores.argmax(dim=-1)
        preds = batch["candidates"]["data"].gather(-1, preds.unsqueeze(0)).squeeze(0)
        ground_truth = batch["entity_ids"]
        preds = preds[mask.any(dim=-1)]
        ground_truth = ground_truth[mask.any(dim=-1)]
        matches = preds == ground_truth
        results.append(matches)
    results = torch.cat(results)
    acc = results.float().mean().item()
    return acc, results

In [9]:
microf1 = {}
results = {}
for dataset, dataloader in test_dataloaders.items():
    print(f"Evaluating on {dataset}")
    f1, matches = eval_loop(dataloader)
    print(f"MicroF1 on {dataset}: {f1:.4f}")
    microf1[dataset] = f1
    results[dataset] = matches

Evaluating on aida


100%|██████████| 4484/4484 [08:19<00:00,  8.97it/s]


MicroF1 on aida: 0.9373
Evaluating on ace2004


100%|██████████| 240/240 [00:31<00:00,  7.60it/s]


MicroF1 on ace2004: 0.9500
Evaluating on acquaint


100%|██████████| 702/702 [01:05<00:00, 10.74it/s]


MicroF1 on acquaint: 0.9131
Evaluating on clueweb


100%|██████████| 11078/11078 [17:57<00:00, 10.28it/s]


MicroF1 on clueweb: 0.7829
Evaluating on msnbc


100%|██████████| 651/651 [00:56<00:00, 11.59it/s]


MicroF1 on msnbc: 0.9462
Evaluating on wiki


100%|██████████| 6785/6785 [09:28<00:00, 11.93it/s]

MicroF1 on wiki: 0.8590





In [12]:
print("Average: ", sum(microf1.values()) / len(microf1))

Average:  0.8980885545412699
