# KGC Control Experiments

We run control experiments to check correctness of metric calculation,
and get a approximate performance boundary for chat based llms which propose mentions.

In [1]:
%load_ext autoreload
%autoreload 2

import irt2

p_data = irt2.ENV.DIR.DATA

In [2]:
from irt2.types import Split, Task, Sample, MID, RID, VID
from irt2.dataset import IRT2
from irt2.evaluation import Predictions

import random
from typing import Iterable, Literal


Tasks = dict[tuple[MID, RID], set[VID]]


def true_vids(tasks: Tasks, ds: IRT2, **_) -> Predictions:
    """This model cheats and always answers always correctly."""
    for (mid, rid), vids in tasks.items():
        yield (mid, rid), ((vid, 1) for vid in vids)

def true_mentions(
    tasks: Tasks,
    ds: IRT2,
    split: Literal['validation', 'test'],
    **_,
) -> Predictions:
    """This model cheats and knows the correct mentions."""
    splits = (Split.train, Split.valid)
    if split == 'test':
        splits += (Split.test, )

    ids = ds.idmap
    for (mid, rid), gt_vids in tasks.items():
        mentions = {
            ids.mid2str[mid]
            for mids in map(ids.vid2mids.get, gt_vids)
            for mid in mids
        }

        pr_vids = ds.find_by_mention(
            *mentions,
            splits=splits,
        )

        yield (mid, rid), ((vid, 1) for vid in pr_vids)


def random_guessing(
    tasks: Tasks,
    ds: IRT2,
    split: Literal['validation', 'test'],
    seed: int,
    **_,
) -> Predictions:
    """This model is just guessing randomly."""
    rng = random.Random()
    rng.seed(seed)

    ids = ds.idmap
    candidates = ids.split2vids[Split.train] | ids.split2vids[Split.valid]
    if split == 'test':
        candidates |= ids.split2vids[Split.test]

    perm = list(candidates)
    for (mid, rid), vids in tasks.items():
        yield (mid, rid), ((vid, rng.random()) for vid in rng.sample(perm, k=100))



MODELS = {
    'true-vertices': true_vids,
    'true-mentions': true_mentions,
    'random-guessing': random_guessing,
}

In [3]:
from irt2 import evaluation
from ktz.collections import dflat

import yaml
from functools import partial
from typing import Callable


def flatten(report: dict):
    before = dict(
        dataset=report['dataset'],
        model=report['model'],
        date=report['date'],
        split=report['split'],
    )

    metrics = dflat(report['metrics'], sep=' ')
    metrics = dict(sorted(metrics.items()))

    return before | metrics


def evaluate(
    ds: IRT2,
    name: str,
    split: str,
    head_predictions: Predictions,
    tail_predictions: Predictions,
):
    metrics = evaluation.evaluate(
        ds=ds,
        task='kgc',
        split=split,
        head_predictions=head_predictions,
        tail_predictions=tail_predictions,
    )

    return evaluation.create_report(
        metrics,
        ds,
        task='kgc',
        split=split,
        model=name,
        filenames=dict(notebook='ipynb/control-experiments.ipynb'),
    )



def run(
    ds: IRT2,
    name: str,
    model: Callable,
    split: str,
    seed: int,
):
    predictor = partial(
        model,
        ds=ds,
        split=split,
        seed=seed,
    )

    assert split == 'validation' or split == 'test'

    if split == 'validation':
        head_predictions = predictor(ds.open_kgc_val_heads)
        tail_predictions = predictor(ds.open_kgc_val_tails)

    if split == 'test':
        head_predictions = predictor(ds.open_kgc_test_heads)
        tail_predictions = predictor(ds.open_kgc_test_tails)


    report = evaluate(
        ds=ds,
        name=name,
        split=split,
        head_predictions=head_predictions,
        tail_predictions=tail_predictions,
    )

    return report


In [4]:
!pwd

/home/felix/Complex/dkg/irt2/ipynb


In [5]:
import csv
from pathlib import Path
from ktz.collections import dconv, dflat
from irt2.loader import from_config_file


def _run_all(datasets_config, models, splits, seed: int):
    datasets = from_config_file(
        root_path=irt2.ENV.DIR.ROOT,
        **datasets_config,
    )

    for _, dataset in datasets:
        print('\n', str(dataset))

        for split in splits:
            if split == 'validation':
                n_heads = len(dataset.open_kgc_val_heads)
                n_tails = len(dataset.open_kgc_val_tails)

            if split == 'test':
                n_heads = len(dataset.open_kgc_test_heads)
                n_tails = len(dataset.open_kgc_test_tails)

            options = dataset.meta['loader']
            percentage = None
            if "subsample" in options:
                percentage = options["subsample"].get(split, None)

            print(
                '  ' + split,
                f'percentage={percentage}',
                f'{n_heads} head and {n_tails} tail tasks'
                f' = {n_heads + n_tails}',
                sep='\n    - ',
            )

            meta = {
                'percentage': percentage,
                'total tasks': n_heads + n_tails,
                'head tasks': n_heads,
                'tail tasks': n_tails,
            }

            # print(', '.join(map(str, dataset.table_row)))
            for model in models:
                print('    - model: ', model)
                report = run(dataset, model, MODELS[model], split, seed)

                h10 = report['metrics']['all']['micro']['hits_at_10']
                print(f'    - result: {h10:2.3f}')

                yield meta | flatten(report)


def run_all(out, datasets_config, models, splits, seed: int):
    out.parent.mkdir(exist_ok=True, parents=True)

    print(f'write results to {out}')
    with out.open(mode='w') as fd:
        writer = None

        for flat in _run_all(datasets_config, models, splits, seed):
            if writer is None:
                header = ['seed'] + list(flat.keys())

                writer = csv.DictWriter(fd, fieldnames=header)
                writer.writeheader()

            writer.writerow(flat | {'seed': seed})



all_config = {
    'datasets_config': {
        'config_file': irt2.ENV.DIR.CONF / 'datasets' / 'original-subsampled.yaml',
        # 'config_file': irt2.ENV.DIR.CONF / 'datasets' / 'full-subsampled.yaml',
        # 'without': ('blp/wikidata5m', )
    },
    'models': ['true-mentions'],
    'splits': [
        'validation',
        'test',
    ],
    'seed': 31189,
}


def main(config):
    name = config['datasets_config']['config_file'].name
    fcsv = f"control-experiments-{name}.csv"
    run_all(out=p_data / "evaluation" / fcsv, **config)


main(all_config)
print('done')

write results to /home/felix/Complex/dkg/irt2/data/evaluation/control-experiments-full-subsampled.yaml.csv



 BLP/FB15K237: 14541 vertices | 237 relations | 14541 mentions
  validation
    - percentage=0.03
    - 452 head and 574 tail tasks = 1026
    - model:  true-mentions


    - result: 1.000
  test
    - percentage=0.03
    - 541 head and 744 tail tasks = 1285
    - model:  true-mentions


    - result: 1.000



 BLP/WIKIDATA5M: 4818582 vertices | 822 relations | 11804166 mentions
  validation
    - percentage=0.09
    - 520 head and 579 tail tasks = 1099
    - model:  true-mentions


    - result: 0.998
  test
    - percentage=0.08
    - 471 head and 529 tail tasks = 1000
    - model:  true-mentions


    - result: 0.998



 BLP/WN18RR: 40943 vertices | 11 relations | 40943 mentions
  validation
    - percentage=0.06
    - 482 head and 573 tail tasks = 1055
    - model:  true-mentions


    - result: 0.987
  test
    - percentage=0.06
    - 516 head and 601 tail tasks = 1117
    - model:  true-mentions


    - result: 0.987



 IRT2/CDE-L: 15020 vertices | 45 relations | 32666 mentions
  validation
    - percentage=0.05
    - 75 head and 994 tail tasks = 1069
    - model:  true-mentions


    - result: 0.669
  test
    - percentage=0.02
    - 66 head and 931 tail tasks = 997
    - model:  true-mentions


    - result: 0.625



 IRT2/CDE-M: 15020 vertices | 45 relations | 32666 mentions
  validation
    - percentage=0.04
    - 70 head and 1007 tail tasks = 1077
    - model:  true-mentions


    - result: 0.661
  test
    - percentage=0.01
    - 72 head and 1007 tail tasks = 1079
    - model:  true-mentions


    - result: 0.678



 IRT2/CDE-S: 14207 vertices | 12 relations | 28582 mentions
  validation
    - percentage=0.08
    - 30 head and 1104 tail tasks = 1134
    - model:  true-mentions
    - result: 0.628
  test
    - percentage=0.02
    - 27 head and 1102 tail tasks = 1129
    - model:  true-mentions


    - result: 0.694



 IRT2/CDE-T: 12389 vertices | 5 relations | 23894 mentions
  validation
    - percentage=0.17
    - 25 head and 975 tail tasks = 1000
    - model:  true-mentions
    - result: 0.747
  test
    - percentage=0.02
    - 25 head and 1031 tail tasks = 1056
    - model:  true-mentions


    - result: 0.793
done


In [6]:
from typing import Iterable
from irt2.loader import from_config_file

def run_subsampling(out, config_file, percentages: Iterable[float], seed: int):
    out.parent.mkdir(exist_ok=True, parents=True)

    print(f'write results to {out}')
    with out.open(mode='w') as fd:
        writer = None

        datasets = from_config_file(
            config_file,
            root_path=irt2.ENV.DIR.ROOT
        )

        for _, dataset in datasets:
            print(str(dataset))
            print(dataset.meta['loader'])
            assert "subsample" not in dataset.meta['loader']

            for percentage in percentages:
                print(f'  - {int(percentage * 100):3d}%', f'{seed=}')
                sub_ds = dataset.tasks_subsample_kgc(
                    seed=seed,
                    percentage_val=percentage,
                )

                report = run(
                    sub_ds,
                    name='true-mentions',
                    model=MODELS['true-mentions'],
                    split='validation',
                    seed=seed,
                )

                flat = flatten(report)

                if writer is None:
                    header = ['percentage', 'head tasks', 'tail tasks', 'seed'] + list(flat.keys())
                    writer = csv.DictWriter(fd, fieldnames=header)
                    writer.writeheader()

                writer.writerow(flat | {
                    'percentage': percentage,
                    'head tasks': len(sub_ds.open_kgc_val_heads),
                    'tail tasks': len(sub_ds.open_kgc_val_tails),
                    'seed': seed
                })


def subsample_experiments(config_file, percentages, seed):
    fname = f"subsample-experiments-{config_file.name}.csv"

    run_subsampling(
        out=p_data / "evaluation" / fname,
        config_file=config_file,
        percentages=percentages,
        seed=seed,
    )


subsample_experiments(
    irt2.ENV.DIR.CONF / 'datasets' / 'original.yaml',
    # irt2.ENV.DIR.CONF / 'datasets' / 'full.yaml',
    (
        [x/100 for x in range(1, 10)] +
        [x/100 for x in range(10, 40, 5)] +
        [x/100 for x in range(40, 101, 20)]
    ),
    seed=31189,
)

print('done')

write results to /home/felix/Complex/dkg/irt2/data/evaluation/subsample-experiments-full.yaml.csv


BLP/FB15K237: 14541 vertices | 237 relations | 14541 mentions
{'loader': 'blp/fb15k237', 'path': 'data/blp/FB15k-237'}
  -   1% seed=31189
  -   2% seed=31189


  -   3% seed=31189


  -   4% seed=31189


  -   5% seed=31189


  -   6% seed=31189


  -   7% seed=31189


  -   8% seed=31189


  -   9% seed=31189


  -  10% seed=31189


  -  15% seed=31189


  -  20% seed=31189


  -  25% seed=31189


  -  30% seed=31189


  -  35% seed=31189


  -  40% seed=31189


  -  60% seed=31189


  -  80% seed=31189


  - 100% seed=31189


BLP/WIKIDATA5M: 4818582 vertices | 822 relations | 11804166 mentions
{'loader': 'blp/wikidata5m', 'path': 'data/blp/Wikidata5M'}
  -   1% seed=31189


  -   2% seed=31189


  -   3% seed=31189


  -   4% seed=31189


  -   5% seed=31189


  -   6% seed=31189


  -   7% seed=31189


  -   8% seed=31189


  -   9% seed=31189


  -  10% seed=31189


  -  15% seed=31189


  -  20% seed=31189


  -  25% seed=31189


  -  30% seed=31189


  -  35% seed=31189


  -  40% seed=31189


  -  60% seed=31189


  -  80% seed=31189


  - 100% seed=31189


BLP/WN18RR: 40943 vertices | 11 relations | 40943 mentions
{'loader': 'blp/wn18rr', 'path': 'data/blp/WN18RR'}
  -   1% seed=31189


  -   2% seed=31189
  -   3% seed=31189


  -   4% seed=31189


  -   5% seed=31189


  -   6% seed=31189


  -   7% seed=31189


  -   8% seed=31189


  -   9% seed=31189


  -  10% seed=31189


  -  15% seed=31189


  -  20% seed=31189


  -  25% seed=31189


  -  30% seed=31189


  -  35% seed=31189


  -  40% seed=31189


  -  60% seed=31189


  -  80% seed=31189


  - 100% seed=31189


IRT2/CDE-L: 15020 vertices | 45 relations | 32666 mentions
{'loader': 'irt2', 'path': 'data/irt2/irt2-cde-large', 'kwargs': {'mode': 'full'}}
  -   1% seed=31189
  -   2% seed=31189


  -   3% seed=31189


  -   4% seed=31189


  -   5% seed=31189


  -   6% seed=31189


  -   7% seed=31189


  -   8% seed=31189


  -   9% seed=31189


  -  10% seed=31189


  -  15% seed=31189


  -  20% seed=31189


  -  25% seed=31189


  -  30% seed=31189


  -  35% seed=31189


  -  40% seed=31189


  -  60% seed=31189


  -  80% seed=31189


  - 100% seed=31189


IRT2/CDE-M: 15020 vertices | 45 relations | 32666 mentions
{'loader': 'irt2', 'path': 'data/irt2/irt2-cde-medium', 'kwargs': {'mode': 'full'}}
  -   1% seed=31189


  -   2% seed=31189


  -   3% seed=31189


  -   4% seed=31189


  -   5% seed=31189


  -   6% seed=31189


  -   7% seed=31189


  -   8% seed=31189


  -   9% seed=31189


  -  10% seed=31189


  -  15% seed=31189


  -  20% seed=31189


  -  25% seed=31189


  -  30% seed=31189


  -  35% seed=31189


  -  40% seed=31189


  -  60% seed=31189


  -  80% seed=31189


  - 100% seed=31189


IRT2/CDE-S: 14207 vertices | 12 relations | 28582 mentions
{'loader': 'irt2', 'path': 'data/irt2/irt2-cde-small', 'kwargs': {'mode': 'full'}}
  -   1% seed=31189
  -   2% seed=31189


  -   3% seed=31189
  -   4% seed=31189


  -   5% seed=31189
  -   6% seed=31189


  -   7% seed=31189


  -   8% seed=31189


  -   9% seed=31189


  -  10% seed=31189


  -  15% seed=31189


  -  20% seed=31189


  -  25% seed=31189


  -  30% seed=31189


  -  35% seed=31189


  -  40% seed=31189


  -  60% seed=31189


  -  80% seed=31189


  - 100% seed=31189


IRT2/CDE-T: 12389 vertices | 5 relations | 23894 mentions
{'loader': 'irt2', 'path': 'data/irt2/irt2-cde-tiny', 'kwargs': {'mode': 'full'}}
  -   1% seed=31189


  -   2% seed=31189
  -   3% seed=31189


  -   4% seed=31189
  -   5% seed=31189


  -   6% seed=31189
  -   7% seed=31189


  -   8% seed=31189
  -   9% seed=31189


  -  10% seed=31189
  -  15% seed=31189


  -  20% seed=31189


  -  25% seed=31189


  -  30% seed=31189


  -  35% seed=31189


  -  40% seed=31189


  -  60% seed=31189


  -  80% seed=31189


  - 100% seed=31189


done
