# IRT2 - Inductive Reasoning with Text

This notebook describes how to load the IRT2 dataset. 
Some of the properties are looked at in detail to offer insights into the datamodel.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import irt2
from irt2.dataset import IRT2
from irt2.dataset import MID

import textwrap
from itertools import islice
from tabulate import tabulate

from collections import Counter
from collections import defaultdict

from typing import Generator

#  folder convention:
#     data/irt2/<graph>/<size>
#  where
#     graph = cde|fb
#     size  = tiny|small|medium|large (abbreviated: T|S|M|L)
#  for example:
#     path = 'data/irt2/cde/small'

data = IRT2.from_dir(path=irt2.ENV.DIR.DATA / 'irt2' / 'cde' / 'small')
print(str(data))

In [None]:
# it iterates all text contexts and this might take while...
# repeated calls are cheap: return value is cached
print(data.description)

In [None]:
# Further information is given in the configuration file
# which was used for dataset creation. For an explanation of
# the different options, see the original files in /conf.
import yaml

print(yaml.dump(data.config))

In [None]:
# show example vertices and relations

print('\nvertices:')
print(f'    vid name')
for vid, name in islice(data.vertices.items(), 10):
    print(f'{vid:7d} {name}')

print('\nrelations:')
print(f'    rid name')
for rid, name in islice(data.relations.items(), 10):
    print(f'{rid:7d} {name}')

In [None]:
# show example closed-world triples

print(tabulate(
    [
        (h, data.vertices[h], r, data.relations[r], t, data.vertices[t])
        for h, t, r in islice(data.closed_triples, 20)
    ],
    headers=('VID', 'head', 'RID', 'relation', 'VID', 'tail')
))

In [None]:
# this showcases how to access text contexts

def count_contexts(contexts: Generator, n: int = None):

    counts = dict(total=0, mids=Counter(), origins=Counter())

    for context in islice(contexts, n):

        assert context.mid in data.mentions
        assert context.mention in context.data

        counts['total'] += 1
        counts['mids'][context.mid] += 1
        counts['origins'][context.origin] += 1

    print(f'  read {counts["total"]} relevant contexts')
    print(f'  for {len(counts["mids"])} mentions from {len(counts["origins"])} origins')

    return counts

# Contexts are retrieved using a context manager which handles
# opening/closing files appropriately. The managed object is
# a generator yielding irt2.dataset.Context objects.

n = 10_000

with data.closed_contexts() as contexts:
    print('\ncounting closed-world (training) contexts')
    ctx_counts_closed = count_contexts(contexts, n=n)

with data.open_contexts_validation() as contexts:
    print('\ncounting open-world (validation) contexts')
    ctx_counts_open_val = count_contexts(contexts, n=n)

with data.open_contexts_test() as contexts:
    print('\ncounting open-world (test) contexts')
    ctx_counts_open_test = count_contexts(contexts, n=n)

In [None]:
# show some mentions

print('\nclosed-world (training) ' + '-' * 20)
for vid, mids in islice(data.closed_mentions.items(), 30, 35):
    print(f'\n  {len(mids)} mentions of {data.vertices[vid]} ({vid=})')
    for mid in mids:
        mention = data.mentions[mid]
        print(f'    {mid=} {mention} ({ctx_counts_closed["mids"][mid]} matches)')

print('\nopen-world (validation) ' + '-' * 20)
# open-world mentions
for vid, mids in islice(data.open_mentions_val.items(), 30, 35):
    print(f'\n  {len(mids)} mentions of {data.vertices[vid]} ({vid=})')
    for mid in mids:
        mention = data.mentions[mid]
        print(f'    {mid=} {mention} ({ctx_counts_open_val["mids"][mid]} matches)')

print('\nopen-world (test) ' + '-' * 20)
# open-world mentions
for vid, mids in islice(data.open_mentions_test.items(), 30, 35):
    print(f'\n  {len(mids)} mentions of {data.vertices[vid]} ({vid=})')
    for mid in mids:
        mention = data.mentions[mid]
        print(f'    {mid=} {mention} ({ctx_counts_open_test["mids"][mid]} matches)')

In [None]:
# some examples for the head and tail tasks
# also doing a reverse-lookup for head vertices

from itertools import chain

mid2vid = {
    mid: vid
    for vid, mids in chain(
            data.closed_mentions.items(),
            data.open_mentions_val.items(),
            data.open_mentions_test.items(),
    )
    for mid in mids
}

N = 5

print('\nHEAD TASK ' + '-' * 20)
for (mid, rid), vids in islice(data.open_task_val_heads.items(), 10):
    print(f'\n"{data.mentions[mid]}" ({data.vertices[mid2vid[mid]]}) {data.relations[rid]} ?')
    for vid in list(vids)[:N]:
        print(f'  answer: {data.vertices[vid]}')

    if len(vids) > N:
        print(f'  (+{len(vids) - N} more)')

print('\nTAIL TASK ' + '-' * 20)
for (mid, rid), vids in islice(data.open_task_val_tails.items(), 10):
    print(f'\n? {data.relations[rid]} "{data.mentions[mid]}" ({data.vertices[mid2vid[mid]]})')
    for vid in list(vids)[:5]:
        print(f'  answer: {data.vertices[vid]}')

    if len(vids) > N:
        print(f'  (+{len(vids) - N} more)')


In [None]:
# print some example texts

texts = defaultdict(set)
with data.closed_contexts() as contexts:
    for ctx in islice(contexts, 1000):
        texts[ctx.mid].add(ctx)

    texts = dict(texts)


for mid, contexts in islice(texts.items(), 3):
    mention_norm = data.mentions[mid]
    vertex = data.vertices[mid2vid[mid]]

    print(f'\ntext for {mention_norm} ({mid=}) ({vertex=})')
    for context in contexts:
        wrapped = '\n'.join(textwrap.wrap(str(context.data), 80),)
        indented = textwrap.indent(wrapped, ' ' * 2)
        print('\n' + indented)

In [None]:
# we can create a Graph instance from the dataset to look at the
# training data a bit more closely

from irt2.graph import Relation


print(data.graph.description)

relations = Relation.from_graph(data.graph)
print(f'got {len(relations)} relations')

from tabulate import tabulate


def relation_table(relations):
    headers = '#', 'name', 'rid', 'ratio', '#heads', '#tails', '#triples'
    rows = []
    for no, rel in enumerate(relations, 1):
        rows.append((no, rel.name, rel.rid, rel.ratio, len(rel.heads), len(rel.tails), len(rel.triples)))

    return tabulate(rows, headers=headers)


print(relation_table(relations))