# Ranking Task

In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import irt2
from irt2.dataset import IRT2

def load_datasets():
    root = irt2.ENV.DIR.DATA / 'irt2'
    return {
        path.name: IRT2.from_dir(path=path) for path in (
            root / 'irt2-cde-tiny',
            root / 'irt2-cde-small',
            root / 'irt2-cde-medium',
            root / 'irt2-cde-large',
        )
    }

datasets = load_datasets()
for name, dataset in datasets.items():
    print(f"\nloaded {name}:")
    print(f"{dataset.graph}")
    print(f"got ratios for {len(dataset.relations)} relations")



loaded irt2-cde-tiny:
IRT graph: [IRT2/CDE-T] (12389 entities)
got ratios for 5 relations

loaded irt2-cde-small:


IRT graph: [IRT2/CDE-S] (14207 entities)
got ratios for 12 relations

loaded irt2-cde-medium:


IRT graph: [IRT2/CDE-M] (15020 entities)
got ratios for 45 relations

loaded irt2-cde-large:


IRT graph: [IRT2/CDE-L] (15020 entities)
got ratios for 45 relations


In [33]:
dataset = datasets['irt2-cde-large']

'IRT2/CDE-L'

In [None]:
import irt2
from ktz.functools import Cascade


run = Cascade(
    path=irt2.ENV.DIR.CACHE,
    context_stats=f"{dataset.config['create']['name']}"
)

In [52]:
import random

from collections import Counter
from collections import defaultdict


@run.cache('context_stats')
def load_context_stats(ds):
    print('creating context stats; this might take a while')

    from itertools import islice
    
    stats = defaultdict(Counter)

    with ds.open_contexts_val() as ctxs:
        for ctx in islice(ctxs, 100):
            stats['mids'][ctx.mid] += 1

    return dict(stats)


@run.when('context_stats')
def show_ranking(ds, stats):
    task = dataset.open_ranking_val_heads

    # given (vid, rid) predict set of mids
    print(f"choosing from {len(task)} queries")

    # randomly pick one
    query, gt = random.choice(list(task.items()))
    vid, rid = query

    print(f'Who was/were {rid=}: {dataset.relations[rid]} {vid=}: {dataset.vertices[vid]}')
    for mid in gt:
        print(f"  {mid=}: {dataset.mentions[mid]}")
        print(f"     it has {stats['mids'][mid]} contexts")


stats = load_context_stats(ds=dataset)
# print(sorted(((v, k) for k, v in stats['mids'].items()), reverse=True))
show_ranking(ds=dataset, stats=stats)

creating context stats; this might take a while
choosing from 4831 queries
Who was/were rid=32: P1050:medical condition vid=223: Q202837:cardiac arrest
  mid=23523: the joan rivers show
     it has 0 contexts
