In [None]:
# notebook setup

%load_ext autoreload
%autoreload 2

from irt2 import ipynb
ipynb.setup_logging()

In [None]:
# ---  configure here

sampling = 'uniform'
size = 'medium'


In [None]:
import irt2
from irt2.create import EID, Mention

from pathlib import Path
from pprint import pprint

from ktz.collections import ryaml

import numpy as np
import matplotlib.pyplot as plt


FOLDER = f'{size}-{sampling}'
config = ryaml(
    irt2.ENV.DIR.CONF / 'create' / f'cde.yaml',
    irt2.ENV.DIR.CONF / 'create' / f'cde-{sampling}.yaml',
    irt2.ENV.DIR.CONF / 'create' / f'cde-size-{size}.yaml'
)

pprint(config)

In [None]:
from ktz.string import args_hash
from ktz.functools import Cascade

# set up cascade for development

def setup_cascade(config, keys):
    # add config parameters here that affect all @run.cache loader
    hash = args_hash({k: config[k] for k in {
        'source matches',
        'prune mentions'
    }})
    print(hash)

    prefix = 'create.ipynb'

    cascade = {name: f'{prefix}-{hash}-{name}' for name in keys}
    run = Cascade(path=irt2.ENV.DIR.CACHE, **cascade)

    return run


run = setup_cascade(config, keys=['matches', 'mentions', 'split'])

In [None]:
from irt2 import create


@run.cache('matches')
def load_matches():

    matches = create.index_matches(
        path=irt2.ENV.DIR.ROOT / config['source matches'],
    )

    print('\nmatch index:')
    print(f'  total: {len(matches.flat)}')
    print(f'  pages: {len(matches.keys("page"))}')
    print(f'   eids: {len(matches.keys("eid"))}')

    return matches


matches = load_matches()

In [None]:
# plot how matches distribute over entities

def plot_dist(
    y,
    title: str,
    subtitle: str,
    filenames: list[Path] = None,
):
    fig, ax = plt.subplots()
    fig.suptitle(title)

    ax.set_title(subtitle, color='#999')
    ax.set_yscale('log')

    ax.plot(np.arange(len(y)), y, color='#666')

    if filenames:
        for filename in filenames:
            fig.savefig(filename)


@run.when('matches')
def plot_distribution_idx(index, **kwargs):

    lens = {}
    for eid in index.keys('eid'):
        n = len(index.dis(eid=eid))
        lens[eid] = n

    lens = sorted(((n, eid) for eid, n in lens.items()), reverse=True)
    y = np.array([n for n, _ in lens])

    plot_dist(y=y, **kwargs)


plot_distribution_idx(
    index=matches,
    title="Matches Count",
    subtitle="{config['source graph']} ({config['source name']})",
)

In [None]:
from ktz.dataclasses import Index
from dataclasses import dataclass

print(f"\nmapping EIDS to mentions")
print(f"pruning at threshold: {config['prune mentions']}")


@dataclass
class Mentions:

    eid2mentions: dict[EID, dict[Mention, int]]
    norm2mentions: dict[str, str]


@run.cache('mentions')
def get_mentions(index: Index, prune: int):
    mentions = create.get_mentions(index=index, prune=prune)
    print(f'retained {len(mentions.eid2mentions)} mentions')
    return mentions


mentions = get_mentions(
    index=matches,
    prune=config['prune mentions'],
)

In [None]:
# some example mentions

@run.when('mentions')
def print_mention_counts(mentions, eid):
    counts = mentions.eid2mentions[eid]

    for mention, count in sorted(counts.items(), key=lambda t: t[1], reverse=True):
        print(f"{count:5d} {mention}" )
    print()

print_mention_counts(mentions=mentions, eid='Q11708')
print_mention_counts(mentions=mentions, eid='Q49297')
print_mention_counts(mentions=mentions, eid='Q21077')

In [None]:
# entities with most mentions

@run.when('mentions', 'matches')
def match_examples(matches, mentions):
    for eid, mdic in sorted(mentions.eid2mentions.items(), key=lambda t: len(t[1]), reverse=True)[:20]:
        print(len(mdic), eid, list(matches.get(eid=eid))[0].entity)


match_examples(matches, mentions)

In [None]:
from collections import Counter


@run.when('mentions')
def plot_mention_counts(mentions, title, subtitle, **kwargs):
    counts = Counter()
    for countdic in mentions.eid2mentions.values():
        counts[len(countdic)] += 1

    y, x = zip(*sorted(counts.items()))

    fig, ax = plt.subplots()
    fig.suptitle(title)
    ax.set_title(subtitle, color='#999')

    ax.scatter(x, y, **kwargs)


plot_mention_counts(
    mentions=mentions,
    title="Mentions per Entity",
    subtitle=config['source name'],
    color='#333',
    marker='.',
)

In [None]:
from irt2.graph import Graph
from irt2.graph import load_graph


graph = load_graph(
    config['graph loader'],
    config['graph name'],
    *[irt2.ENV.DIR.ROOT / path for path in config['graph loader args']],
    **{k: irt2.ENV.DIR.ROOT / path for k, path in config['graph loader kwargs'].items()},
)


print(graph.description)

In [None]:
# split triples/vertices/mentions:
#  - select all mentions of concept entities
#  - shuffle and split remaining mentions randomly (cw/ow-validation/ow-test)
#  - assign vertices based on mention split to be either cw or ow
#  - assign triples based on vertex split

from irt2.create import Split
from IPython.core.debugger import set_trace

def create_split(config, graph, mentions):

    ratio_train = config['target mention split']
    print(f"targeting {int(ratio_train * 100)}% closed-world mentions")

    ratio_val = config['target validation split']
    print(f"using {int(ratio_val * 100)}% open-world mentions for validation")

    split = Split.create(
        graph,
        mentions=mentions,
        seed=config['seed'],
        ratio_train=ratio_train,
        ratio_val=ratio_val,
        concept_rels=config['concept relations'],
        include_rels=config['include relations'],
        exclude_rels=config['exclude relations'],
        prune=config.get('target mention count', None),
        sampling=config['sampling'],
    )

    print('running self-check...')
    split.check()
    print('self-check passed')

    print(split.description)
    print(f'\nretained {len(split.relations)}/{len(graph.source.rels)} relations')

    return split

split = create_split(config, graph, mentions)

In [None]:
from tabulate import tabulate
from irt2.create import create_dataset

dataset, counts = create_dataset(
    out=irt2.ENV.DIR.DATA / 'irt2' / 'cde' / FOLDER,
    config=config,
    split=split,
    overwrite=True
)

print(f"\n{dataset}\n")
print(tabulate(sorted(counts.items())))