In [None]:
# different statistics and descriptive plots regarding irt2 metrics

In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import irt2
from irt2.dataset import IRT2

def load_datasets():
    root = irt2.ENV.DIR.DATA / 'irt2' / 'cde'
    return {
        path.name: IRT2.from_dir(path=path) for path in (
            root / 'irt2-cde-tiny',
            root / 'irt2-cde-small',
            root / 'irt2-cde-medium',
            root / 'irt2-cde-large',
        )
    }

    # return {
    #     path.name: IRT2.from_dir(path=path) for path in (
    #         root / 'tiny-weighted',
    #         root / 'small-weighted',
    #         root / 'medium-weighted',
    #         root / 'large-weighted',
    #     )
    # }

datasets = load_datasets()
for name, dataset in datasets.items():
    print(f"\nloaded {name}:")
    print(f"{dataset.graph}")
    print(f"got ratios for {len(dataset.relations)} relations")



loaded irt2-cde-tiny:
IRT graph: [IRT2/CDE-T] (12389 entities)
got ratios for 5 relations

loaded irt2-cde-small:


IRT graph: [IRT2/CDE-S] (14207 entities)
got ratios for 12 relations

loaded irt2-cde-medium:


IRT graph: [IRT2/CDE-M] (15020 entities)
got ratios for 45 relations

loaded irt2-cde-large:


IRT graph: [IRT2/CDE-L] (15020 entities)
got ratios for 45 relations


In [7]:
# load source graph

import irt2
from irt2.graph import Graph
from irt2.graph import load_graph


def load_source_graph(config):
    return load_graph(
        config['graph loader'],
        config['graph name'],
        *[irt2.ENV.DIR.ROOT / path for path in config['graph loader args']],
        **{k: irt2.ENV.DIR.ROOT / path for k, path in config['graph loader kwargs'].items()},
    )


source_graph = load_source_graph(config=list(datasets.values())[0].config['create'])
print(str(source_graph))

IRT graph: [CodEx-M] (17050 entities)


In [8]:
# We want to check whether the greedy approach to subsampling actually yields
# comparable ratios between relations or if some larger bias can be spotted.
#
# (Think about having a relation "fifty" where 50% of all vertices take part and
# another "hunny" where all vertices are included. In the most extreme case, if
# by chance the greedy approach saturates "hunny" without ever encountering any
# "fifty" vertices, no vertex for "fifty" will be included in the closed world)

import numpy as np
import numpy.ma as ma

from pprint import pprint

from irt2.graph import Relation
from ktz.collections import Incrementer


def get_triple_counts(source_graph, datasets):
    reldic = {'original': Relation.from_graph(g=source_graph)}
    reldic |= {name: dataset.ratios for name, dataset in datasets.items()}

    counts = np.zeros((len(reldic), len(reldic['original'])), dtype=np.uint)
    counts = ma.array(counts, mask=True)

    idxs = Incrementer()

    # start with the smallest dataset to have the most common relations
    # at the top of the count matrix

    lis = sorted(reldic.items(), key=lambda kv: len(kv[1]))
    for i, (name, rels) in enumerate(lis):
        for rel in sorted(rels, key=lambda rel: len(rel.triples), reverse=True):

            j = idxs[rel.name]
            counts[i][j] = len(rel.triples)

    ds_names = [name for (name, _) in lis]
    rel_names = [name for name in idxs.keys()]

    return ds_names, rel_names, counts


def _create_counts_table(ds_names, rel_names, counts, normalise=False):
    cols = []

    rel_ids, rel_names = zip(*map(lambda s: s.split(':'), rel_names))
    cols.append(['id'] + list(rel_ids))
    cols.append(['relation'] + [name.replace(',', '') for name in rel_names])

    if normalise:
        counts = counts / counts.sum(axis=1).reshape((1, -1)).T
        counts = counts - counts[-1]

    for name, countv in zip(ds_names, counts):
        cols.append([name] + countv.tolist())

    return list(zip(*cols))


def print_counts_csv(*args, **kwargs):
    args = get_triple_counts(*args, **kwargs)

    def csv(rows):
        rows = [
            [('' if cell is None else str(cell)) for cell in row]
            for row in rows
        ]

        rows = [','.join(row) for row in rows]
        return '\n'.join(rows)

    print('\ncounts:')
    print(csv(_create_counts_table(*args)))

    print('\nnorms:')
    print(csv(_create_counts_table(*args, normalise=True)))


print_counts_csv(source_graph=source_graph, datasets=datasets)


counts:
id,relation,irt2-cde-tiny,irt2-cde-small,irt2-cde-medium,irt2-cde-large,original
P106,occupation,1504,3140,3094,27529,71596
P27,country of citizenship,499,984,1002,8805,16828
P1412,languages spoken written or signed,431,815,841,6774,12584
P30,continent,355,387,389,389,391
P19,place of birth,139,413,403,3693,7214
P495,country of origin,,1022,831,1110,2049
P140,religion,,285,277,1690,2651
P108,employer,,203,302,2309,4795
P159,headquarters location,,150,162,163,169
P119,place of burial,,107,110,1034,1972
P452,industry,,16,16,16,17
P40,child,,5,4,176,391
P530,diplomatic relation,,,6199,6199,6225
P463,member of,,,4478,7912,11490
P136,genre,,,2301,6523,11761
P17,country,,,1280,1280,1323
P69,educated at,,,610,5213,9752
P840,narrative location,,,600,792,1506
P37,official language,,,401,401,403
P161,cast member,,,387,2770,9249
P264,record label,,,331,2090,3456
P172,ethnic group,,,301,1390,2293
P20,place of death,,,298,2783,5442
P509,cause of death,,,215,1865,3210
P101,field of work,,,2