# KGC Control Experiments

We run two control experiments to check correctness of metric calculation,
and get a upper performance boundary for chat based llms which propose mentions.

In [1]:
import irt2

p_data = irt2.ENV.DIR.DATA

In [2]:
from irt2.types import Split, Task, Sample, MID, RID, VID
from irt2.dataset import IRT2
from irt2.evaluation import Predictions

import random
from typing import Iterable, Literal


Tasks = dict[tuple[MID, RID], set[VID]]


def true_vids(tasks: Tasks, ds: IRT2, **_) -> Predictions:
    """This model cheats and always answers always correctly."""
    for (mid, rid), vids in tasks.items():
        yield (mid, rid), ((vid, 1) for vid in vids)

def true_mentions(
    tasks: Tasks,
    ds: IRT2,
    split: Literal['validation', 'test'],
    **_,
) -> Predictions:
    """This model cheats and knows the correct mentions."""
    splits = (Split.train, Split.valid)
    if split == 'test':
        Splits += (Split.test, )

    ids = ds.idmap
    for (mid, rid), gt_vids in tasks.items():
        mentions = {
            ids.mid2str[mid]
            for mids in map(ids.vid2mids.get, gt_vids)
            for mid in mids
        }

        pr_vids = ds.find_by_mention(
            *mentions,
            splits=splits,
        )

        yield (mid, rid), ((vid, 1) for vid in pr_vids)


def random_guessing(
    tasks: Tasks,
    ds: IRT2,
    split: Literal['validation', 'test'],
    seed: int,
    **_,
) -> Predictions:
    """This model is just guessing randomly."""
    rng = random.Random()
    rng.seed(seed)

    ids = ds.idmap
    candidates = ids.split2vids[Split.train] | ids.split2vids[Split.valid]
    if split == 'test':
        candidates |= ids.split2vids[Split.test]

    perm = list(candidates)
    for (mid, rid), vids in tasks.items():
        yield (mid, rid), ((vid, rng.random()) for vid in rng.sample(perm, k=100))


def most_likely_relation(tasks: Task, ds: IRT2) -> Predictions:
    """This model takes the most likely known tail based on the relation."""


In [3]:
from itertools import islice

ds = IRT2.from_dir(p_data / 'irt2' / 'irt2-cde-tiny')
model = true_vids(ds.open_kgc_val_tails, ds)

for (mid, rid), vids in islice(model, 1):
    print(ds.idmap.mid2str[mid], ds.idmap.rid2str[rid])
    for vid, _ in vids:
        print('  - ', ds.idmap.vid2str[vid])


coleridge P106:occupation
  -  Q4964182:philosopher
  -  Q49757:poet
  -  Q36180:writer
  -  Q1234713:theologian


In [7]:
from irt2 import evaluation
from ktz.collections import dflat

import yaml
from functools import partial
from typing import Callable


def flatten(report: dict):
    before = dict(
        dataset=report['dataset'],
        model=report['model'],
        date=report['date'],
        split=report['split'],
    )

    metrics = dflat(report['metrics'], sep=' ')

    return before | metrics


def evaluate(
    ds: IRT2,
    name: str,
    split: str,
    head_predictions: Predictions,
    tail_predictions: Predictions,
):
    metrics = evaluation.evaluate(
        ds=ds,
        task='kgc',
        split=split,
        head_predictions=head_predictions,
        tail_predictions=tail_predictions,
    )

    return evaluation.create_report(
        metrics,
        ds,
        'kgc',
        split=split,
        model=name,
        filenames=dict(notebook='ipynb/control-experiments.ipynb'),
    )



def run(ds: IRT2, name: str, model: Callable, split: str):
    curried = partial(
        model,
        ds=ds,
        split=split,
        seed=31198,
    )

    report = evaluate(
        ds=ds,
        name=name,
        split=split,
        head_predictions=curried(ds.open_kgc_val_heads),
        tail_predictions=curried(ds.open_kgc_val_tails),
    )

    return report


# run(ds, 'known-vertices', true_vids, split='validation')
run(ds, 'known-mentions', true_mentions, split='validation')
# run(ds, 'random-guessing', random_guessing, split='validation')

{'date': '2024-04-24T16:42:21.123379',
 'dataset': 'IRT2/CDE-T',
 'model': 'known-mentions',
 'task': 'kgc',
 'split': 'validation',
 'metrics': {'head': {'micro': {'mrr': 0.3998384997562972,
    'hits_at_1': 0.2706081081081081,
    'hits_at_10': 0.6567567567567567},
   'macro': {'mrr': 0.7769782342473822,
    'hits_at_1': 0.6834948917354642,
    'hits_at_10': 0.95504071696998}},
  'tail': {'micro': {'mrr': 0.35972883281774637,
    'hits_at_1': 0.22116982386174808,
    'hits_at_10': 0.7046360917248256},
   'macro': {'mrr': 0.32757701271011713,
    'hits_at_1': 0.20583638274068813,
    'hits_at_10': 0.6333079708624418}},
  'all': {'micro': {'mrr': 0.37295278181954084,
    'hits_at_1': 0.23746936957006015,
    'hits_at_10': 0.6888505235018935},
   'macro': {'mrr': 0.3388788616119889,
    'hits_at_1': 0.21784886520988556,
    'hits_at_10': 0.641399125734815}}},
 'notebook': 'ipynb/control-experiments.ipynb'}

In [8]:
import csv
from irt2.loader import LOADER


def _run_all(datasets, models, splits):
    for loader, path in datasets:
        ds = LOADER[loader](path)
        print(str(ds))

        for name, model in models:
            for split in splits:
                print('  - ', name, split)

                report = run(ds, name, model, split)
                yield flatten(report)


def run_all(out, datasets, models, splits):
    out.parent.mkdir(exist_ok=True, parents=True)

    with out.open(mode='w') as fd:
        writer = None

        for flat in _run_all(datasets, models, splits):
            if writer is None:
                writer = csv.DictWriter(fd, fieldnames=list(flat.keys()))
                writer.writeheader()

            writer.writerow(flat)


run_all(
    out=(p_data / 'evaluation' / 'control-experiments.csv'),
    datasets=(
        ('irt2', p_data / 'irt2' / 'irt2-cde-tiny'),
        ('irt2', p_data / 'irt2' / 'irt2-cde-small'),
        ('irt2', p_data / 'irt2' / 'irt2-cde-medium'),
        ('irt2', p_data / 'irt2' / 'irt2-cde-large'),
        ('blp-umls', p_data/ 'blp' / 'umls'),
        ('blp-wn18rr', p_data/ 'blp' / 'WN18RR'),
        ('blp-fb15k237', p_data/ 'blp' / 'FB15k-237'),
        # ('blp-wikidata5m', p_data/ 'blp' / 'Wikidata5M'),
    ),
    models=(
        ('true-vertices', true_vids),
        ('true-mentions', true_mentions),
        ('random-guessing', random_guessing),
    ),
    splits=(
        'validation',
    )
)

IRT2/CDE-T: 12389 vertices | 5 relations | 23894 mentions
  -  true-vertices validation


  -  true-mentions validation


  -  random-guessing validation


IRT2/CDE-S: 14207 vertices | 12 relations | 28582 mentions
  -  true-vertices validation


  -  true-mentions validation


  -  random-guessing validation


IRT2/CDE-M: 15020 vertices | 45 relations | 32666 mentions
  -  true-vertices validation


  -  true-mentions validation


  -  random-guessing validation


IRT2/CDE-L: 15020 vertices | 45 relations | 32666 mentions
  -  true-vertices validation


  -  true-mentions validation


  -  random-guessing validation


UMLS (BLP): 135 vertices | 46 relations | 135 mentions
  -  true-vertices validation
  -  true-mentions validation
  -  random-guessing validation


WN18RR (BLP): 40943 vertices | 11 relations | 40943 mentions
  -  true-vertices validation


  -  true-mentions validation


  -  random-guessing validation


FB15K237 (BLP): 14951 vertices | 237 relations | 14951 mentions
  -  true-vertices validation


  -  true-mentions validation


  -  random-guessing validation


WIKIDATA5M (BLP): 4809397 vertices | 822 relations | 11794981 mentions
  -  true-vertices validation


StatisticsError: mean requires at least one data point