In [7]:
from pathlib import Path
import pickle as pkl
example_data_path = Path("example_data")
cv_split = pkl.load(open(example_data_path / 'cross_validation_splits.pkl', 'rb'))
rel2data = pkl.load(open(example_data_path / 'data.pkl', 'rb'))
rel2templates = pkl.load(open(example_data_path / 'templates.pkl', 'rb'))
verbalizers = []
with open(example_data_path / 'class_verbalizers.txt', 'r') as f:
    for line in f.readlines():
        word = line.strip()
        assert len(word) != 0  # nonempty
        verbalizers.append(word)

In [8]:
# 29 relations (tasks) in total, 12K examples in total, 21K way classification for each task
len(rel2data.keys()), sum(len(v) for v in rel2data.values()), len(verbalizers)

(29, 11935, 21018)

In [6]:
# 8 fold cross validation
# create 8 partitions roughly the same size
# for each fold use 6 partitions for training, 1 for validation, 1 for testing
# the union of the 8 test sets gives the set of 29 relations
len(cv_split)

8

In [18]:
print(cv_split)


[{'train': ['P530', 'P108', 'P37', 'P127', 'P276', 'P47', 'P279', 'P131', 'P364', 'P138', 'P20', 'P527', 'P106', 'P937', 'P27', 'P361', 'P740', 'P1376', 'P101', 'P17', 'P1001'], 'val': ['P31', 'P159', 'P190', 'P407'], 'test': ['P19', 'P1412', 'P36', 'P495']}, {'train': ['P19', 'P1412', 'P36', 'P495', 'P276', 'P47', 'P279', 'P131', 'P31', 'P159', 'P190', 'P407', 'P364', 'P138', 'P20', 'P527', 'P361', 'P740', 'P1376', 'P101', 'P17', 'P1001'], 'val': ['P106', 'P937', 'P27'], 'test': ['P530', 'P108', 'P37', 'P127']}, {'train': ['P19', 'P1412', 'P36', 'P495', 'P530', 'P108', 'P37', 'P127', 'P31', 'P159', 'P190', 'P407', 'P106', 'P937', 'P27', 'P361', 'P740', 'P1376', 'P101', 'P17', 'P1001'], 'val': ['P364', 'P138', 'P20', 'P527'], 'test': ['P276', 'P47', 'P279', 'P131']}, {'train': ['P19', 'P1412', 'P36', 'P495', 'P276', 'P47', 'P279', 'P131', 'P364', 'P138', 'P20', 'P527', 'P106', 'P937', 'P27', 'P361', 'P740', 'P1376', 'P101', 'P17', 'P1001'], 'val': ['P530', 'P108', 'P37', 'P127'], 'test

In [15]:
len({task for split in cv_split for task in split["test"]})

29

In [20]:
rel2data['P19'][:5]

[{'<input>': 'Kandi Burruss', '<label>': 1213},
 {'<input>': 'Caroline Bynum', '<label>': 1213},
 {'<input>': 'Big Gipp', '<label>': 1213},
 {'<input>': 'Elise Broach', '<label>': 1213},
 {'<input>': 'Robbie Merrill', '<label>': 4683}]

In [21]:
rel2templates['P19'][:5]

['<input> was born in <label>.',
 '<input> is born in <label>.',
 '<input> was born <label>.',
 '<input> was born at <label>.',
 '<input> comes from <label>.']

In [22]:
fold_idx = 0
train_tasks, val_tasks = cv_split[fold_idx]['train'], cv_split[fold_idx]['val']
train_task2examples = {task: rel2data[task] for task in train_tasks}
train_task2templates = {task: rel2templates[task] for task in train_tasks}
val_task2examples = {task: rel2data[task] for task in val_tasks}
val_task2templates = {task: rel2templates[task] for task in val_tasks}

In [35]:
train_task2examples['P20']

[{'<input>': 'Louis Krages', '<label>': 1213},
 {'<input>': 'Donald L. Hollowell', '<label>': 1213},
 {'<input>': 'Elbert Tuttle', '<label>': 1213},
 {'<input>': 'Asa Griggs Candler', '<label>': 1213},
 {'<input>': 'Dieter Eppler', '<label>': 7464},
 {'<input>': 'Prince August Wilhelm of Prussia', '<label>': 7464},
 {'<input>': 'Erich Regener', '<label>': 7464},
 {'<input>': 'Johann Nepomuk David', '<label>': 7464},
 {'<input>': 'Georg Scheffers', '<label>': 1469},
 {'<input>': 'Muhammad Farid', '<label>': 1469},
 {'<input>': 'Michael Sachs', '<label>': 1469},
 {'<input>': 'Giulia Grisi', '<label>': 1469},
 {'<input>': 'Abraham Begeyn', '<label>': 1469},
 {'<input>': 'Mathilde Mallinger', '<label>': 1469},
 {'<input>': 'Peter Simon Pallas', '<label>': 1469},
 {'<input>': 'Jan Bouman', '<label>': 1469},
 {'<input>': 'Heinrich Strack', '<label>': 1469},
 {'<input>': 'Anita Berber', '<label>': 1469},
 {'<input>': 'Humphry Marshall', '<label>': 6063},
 {'<input>': 'Jean Calas', '<label>': 

In [36]:
train_task2templates['P20']

['<input> died in <label>.',
 '<input> was born in <label>.',
 '<input> is born in <label>.',
 '<input> was born <label>.',
 '<input> was born at <label>.',
 '<input> comes from <label>.',
 '<input> returned to <label>.',
 '<input> moved to <label>.',
 '<input> died at <label>.',
 '<input> died <label>.',
 '<input> died in the year <label>.',
 '<input> died in the <label>.']