# KGC Control Experiments

We run two control experiments to check correctness of metric calculation,
and get a upper performance boundary for chat based llms which propose mentions.

In [1]:
import irt2

p_data = irt2.ENV.DIR.DATA

In [2]:
from irt2.types import Split, Task, Sample, MID, RID, VID
from irt2.dataset import IRT2
from irt2.evaluation import Predictions

import random
from typing import Iterable, Literal


Tasks = dict[tuple[MID, RID], set[VID]]


def true_vids(tasks: Tasks, ds: IRT2, **_) -> Predictions:
    """This model cheats and always answers always correctly."""
    for (mid, rid), vids in tasks.items():
        yield (mid, rid), ((vid, 1) for vid in vids)

def true_mentions(
    tasks: Tasks,
    ds: IRT2,
    split: Literal['validation', 'test'],
    **_,
) -> Predictions:
    """This model cheats and knows the correct mentions."""
    splits = (Split.train, Split.valid)
    if split == 'test':
        splits += (Split.test, )

    ids = ds.idmap
    for (mid, rid), gt_vids in tasks.items():
        mentions = {
            ids.mid2str[mid]
            for mids in map(ids.vid2mids.get, gt_vids)
            for mid in mids
        }

        pr_vids = ds.find_by_mention(
            *mentions,
            splits=splits,
        )

        yield (mid, rid), ((vid, 1) for vid in pr_vids)


def random_guessing(
    tasks: Tasks,
    ds: IRT2,
    split: Literal['validation', 'test'],
    seed: int,
    **_,
) -> Predictions:
    """This model is just guessing randomly."""
    rng = random.Random()
    rng.seed(seed)

    ids = ds.idmap
    candidates = ids.split2vids[Split.train] | ids.split2vids[Split.valid]
    if split == 'test':
        candidates |= ids.split2vids[Split.test]

    perm = list(candidates)
    for (mid, rid), vids in tasks.items():
        yield (mid, rid), ((vid, rng.random()) for vid in rng.sample(perm, k=100))



MODELS = {
    'true-vertices': true_vids,
    'true-mentions': true_mentions,
    'random-guessing': random_guessing,
}

In [3]:
from irt2 import evaluation
from ktz.collections import dflat

import yaml
from functools import partial
from typing import Callable


def flatten(report: dict):
    before = dict(
        dataset=report['dataset'],
        model=report['model'],
        date=report['date'],
        split=report['split'],
    )

    metrics = dflat(report['metrics'], sep=' ')
    metrics = dict(sorted(metrics.items()))

    return before | metrics


def evaluate(
    ds: IRT2,
    name: str,
    split: str,
    head_predictions: Predictions,
    tail_predictions: Predictions,
):
    metrics = evaluation.evaluate(
        ds=ds,
        task='kgc',
        split=split,
        head_predictions=head_predictions,
        tail_predictions=tail_predictions,
    )

    return evaluation.create_report(
        metrics,
        ds,
        task='kgc',
        split=split,
        model=name,
        filenames=dict(notebook='ipynb/control-experiments.ipynb'),
    )



def run(
    ds: IRT2,
    name: str,
    model: Callable,
    split: str,
    seed: int,
):
    predictor = partial(
        model,
        ds=ds,
        split=split,
        seed=seed,
    )

    assert split == 'validation' or split == 'test'

    if split == 'validation':
        head_predictions = predictor(ds.open_kgc_val_heads)
        tail_predictions = predictor(ds.open_kgc_val_tails)

    if split == 'test':
        head_predictions = predictor(ds.open_kgc_test_heads)
        tail_predictions = predictor(ds.open_kgc_test_tails)


    report = evaluate(
        ds=ds,
        name=name,
        split=split,
        head_predictions=head_predictions,
        tail_predictions=tail_predictions,
    )

    return report


In [5]:
import csv
from pathlib import Path
from ktz.collections import dconv
from irt2.loader import LOADER


def _run_all(datasets, models, splits, seed, at_most = None):
    for dataset_config in datasets:
        ds = LOADER[dataset_config['loader']](dataset_config['path'])
        ds = ds.tasks_subsample(to=at_most, seed=seed)

        print(str(ds), f'{at_most=}', f'{seed=}')
        # print(', '.join(map(str, ds.table_row)))

        for model in models:
            for split in splits:
                print('  - ', model, split)

                report = run(ds, model, MODELS[model], split, seed)
                yield flatten(report)


def run_all(out, datasets, models, splits, seed, at_most = None):
    out.parent.mkdir(exist_ok=True, parents=True)

    print(f'write results to {out}')
    with out.open(mode='w') as fd:
        writer = None

        for flat in _run_all(datasets, models, splits, seed, at_most):
            if writer is None:
                header = ['at most', 'seed'] + list(flat.keys())
                writer = csv.DictWriter(fd, fieldnames=header)
                writer.writeheader()

            writer.writerow(flat | {'at most': at_most, 'seed': seed})



all_config = {
    'datasets': [
        {
            'path': p_data / 'irt2' / 'irt2-cde-tiny',
            'loader': 'irt2',
        },
        {
            'path': p_data / 'irt2' / 'irt2-cde-small',
            'loader': 'irt2',
        },
        {
            'path': p_data / 'irt2' / 'irt2-cde-medium',
            'loader': 'irt2',
        },
        {
            'path': p_data / 'irt2' / 'irt2-cde-large',
            'loader': 'irt2',
        },
        # {   'path': p_data/ 'blp' / 'umls',
        #     'loader': 'blp/umls',
        # },
        {
            'path': p_data/ 'blp' / 'WN18RR',
            'loader': 'blp/wn18rr',
        },
        {
            'path': p_data/ 'blp' / 'FB15k-237',
            'loader': 'blp/fb15k237',
        },
        # {
        #     'path': p_data/ 'blp' / 'Wikidata5M',
        #     'loader': 'blp/wikidata5m',
        # },
    ],
    'models': [
        # 'true-vertices',
        'true-mentions',
        # 'random-guessing',
    ],
    'splits': [
        'validation',
        # 'test',
    ],
    'at_most': 1000,
    'seed': 31189,
}

def main(config):
    root = p_data / "evaluation"
    ffmt = "control-experiments-{at_most}-{seed}.{suffix}"

    fcsv = ffmt.format(
        at_most=config['at_most'],
        seed=config['seed'],
        suffix='csv',
    )

    run_all(
        out=root / fcsv,
        **config,
    )


main(all_config)

write results to /home/felix/Complex/dkg/irt2/data/evaluation/control-experiments-1000-31189.csv


IRT2/CDE-T: 12389 vertices | 5 relations | 23894 mentions at_most=1000 seed=31189
  -  true-vertices validation
  -  true-vertices test


  -  true-mentions validation


  -  true-mentions test


  -  random-guessing validation
  -  random-guessing test


IRT2/CDE-S: 14207 vertices | 12 relations | 28582 mentions at_most=1000 seed=31189
  -  true-vertices validation
  -  true-vertices test
  -  true-mentions validation


  -  true-mentions test


  -  random-guessing validation
  -  random-guessing test


IRT2/CDE-M: 15020 vertices | 45 relations | 32666 mentions at_most=1000 seed=31189
  -  true-vertices validation
  -  true-vertices test
  -  true-mentions validation


  -  true-mentions test


  -  random-guessing validation
  -  random-guessing test


IRT2/CDE-L: 15020 vertices | 45 relations | 32666 mentions at_most=1000 seed=31189
  -  true-vertices validation
  -  true-vertices test
  -  true-mentions validation


  -  true-mentions test


  -  random-guessing validation
  -  random-guessing test


BLP/WN18RR: 40943 vertices | 11 relations | 40943 mentions at_most=1000 seed=31189
  -  true-vertices validation
  -  true-vertices test
  -  true-mentions validation


  -  true-mentions test


  -  random-guessing validation
  -  random-guessing test


BLP/FB15K237: 14951 vertices | 237 relations | 14951 mentions at_most=1000 seed=31189
  -  true-vertices validation
  -  true-vertices test
  -  true-mentions validation


  -  true-mentions test


  -  random-guessing validation
  -  random-guessing test


In [9]:
from typing import Iterable


subsample_config = {
    'datasets': [
        {
            'path': p_data / 'irt2' / 'irt2-cde-tiny',
            'loader': 'irt2',
        },
        {
            'path': p_data / 'irt2' / 'irt2-cde-small',
            'loader': 'irt2',
        },
        {
            'path': p_data / 'irt2' / 'irt2-cde-medium',
            'loader': 'irt2',
        },
        {
            'path': p_data / 'irt2' / 'irt2-cde-large',
            'loader': 'irt2',
        },
        {
            'path': p_data/ 'blp' / 'WN18RR',
            'loader': 'blp/wn18rr',
        },
        {
            'path': p_data/ 'blp' / 'FB15k-237',
            'loader': 'blp/fb15k237',
        },
        # {
        #     'path': p_data/ 'blp' / 'Wikidata5M',
        #     'loader': 'blp/wikidata5m',
        # },
    ],
    'seed': 31189,
}


def run_subsampling(out, datasets, seed, percentages: Iterable[float]):
    out.parent.mkdir(exist_ok=True, parents=True)

    print(f'write results to {out}')
    with out.open(mode='w') as fd:
        writer = None

        for dataset_config in datasets:
            ds = LOADER[dataset_config['loader']](dataset_config['path'])
            print(str(ds))

            for percentage in percentages:
                at_most = len(ds.open_kgc_val_heads) + len(ds.open_kgc_val_tails)
                at_most = int(percentage * at_most)

                print(f'  - {int(percentage * 100):3d}% = {at_most}', f'{seed=}')
                sub_ds = ds.tasks_subsample(to=at_most, seed=seed)

                report = run(
                    sub_ds,
                    name='true-mentions',
                    model=MODELS['true-mentions'],
                    split='validation',
                    seed=seed,
                )

                flat = flatten(report)

                if writer is None:
                    header = ['percentage', 'at most', 'seed'] + list(flat.keys())
                    writer = csv.DictWriter(fd, fieldnames=header)
                    writer.writeheader()

                writer.writerow(flat | {'percentage': percentage, 'at most': at_most, 'seed': seed})


def subsample_experiments(config, ks: Iterable[int]):
    root = p_data / "evaluation"
    ffmt = "subsample-experiments-{seed}.{suffix}"

    fcsv = ffmt.format(
        seed=config['seed'],
        suffix='csv',
    )

    run_subsampling(
        out=root / fcsv,
        percentages=[0.01, 0.025],
        # percentages=[x/100 for x in range(5, 101, 5)],
        **config,
    )


subsample_experiments(subsample_config, ks=[50, 100])

write results to /home/felix/Complex/dkg/irt2/data/evaluation/subsample-experiments-31189.csv


IRT2/CDE-T: 12389 vertices | 5 relations | 23894 mentions
  -   1% = 58 seed=31189
  -   2% = 147 seed=31189


IRT2/CDE-S: 14207 vertices | 12 relations | 28582 mentions
  -   1% = 141 seed=31189
  -   2% = 354 seed=31189


IRT2/CDE-M: 15020 vertices | 45 relations | 32666 mentions
  -   1% = 269 seed=31189


  -   2% = 673 seed=31189


IRT2/CDE-L: 15020 vertices | 45 relations | 32666 mentions
  -   1% = 213 seed=31189


  -   2% = 534 seed=31189


BLP/WN18RR: 40943 vertices | 11 relations | 40943 mentions
  -   1% = 175 seed=31189


  -   2% = 439 seed=31189


BLP/FB15K237: 14951 vertices | 237 relations | 14951 mentions
  -   1% = 342 seed=31189


  -   2% = 855 seed=31189
