In [1]:
from pathlib import Path
import pickle as pkl
data_lama_path = Path("data_lama")
cv_split = pkl.load(open(data_lama_path / 'cross_validation_splits.pkl', 'rb'))
rel2data = pkl.load(open(data_lama_path / 'data.pkl', 'rb'))
rel2templates = pkl.load(open(data_lama_path / 'templates.pkl', 'rb'))

In [2]:
verbalizers = []
with open(data_lama_path / 'class_verbalizers.txt', 'r') as f:
    for line in f.readlines():
        word = line.strip()
        assert len(word) != 0  # nonempty
        verbalizers.append(word)

In [3]:
# 29 relations (tasks) in total, 12K examples in total, 21K way classification for each task
len(rel2data.keys()), sum(len(v) for v in rel2data.values()), len(verbalizers)

(29, 11935, 21018)

In [4]:
# 8 fold cross validation
# create 8 partitions roughly the same size
# for each fold use 6 partitions for training, 1 for validation, 1 for testing
# the union of the 8 test sets gives the set of 29 relations
len(cv_split)

8

In [5]:
len({task for split in cv_split for task in split["test"]})

29

In [6]:
rel2data['P19'][:5]

[{'<input>': 'Kandi Burruss', '<label>': 1213},
 {'<input>': 'Caroline Bynum', '<label>': 1213},
 {'<input>': 'Big Gipp', '<label>': 1213},
 {'<input>': 'Elise Broach', '<label>': 1213},
 {'<input>': 'Robbie Merrill', '<label>': 4683}]

In [7]:
rel2templates['P19'][:5]

['<input> was born in <label>.',
 '<input> is born in <label>.',
 '<input> was born <label>.',
 '<input> was born at <label>.',
 '<input> comes from <label>.']

In [8]:
fold_idx = 0
train_tasks, val_tasks = cv_split[fold_idx]['train'], cv_split[fold_idx]['val']
train_task2examples = {task: rel2data[task] for task in train_tasks}
train_task2templates = {task: rel2templates[task] for task in train_tasks}
val_task2examples = {task: rel2data[task] for task in val_tasks}
val_task2templates = {task: rel2templates[task] for task in val_tasks}

In [9]:
with open("table_level_results_lama.pkl", "rb") as f:
    table_data = pkl.load(f)
with open("fold_level_results.pkl", "rb") as f:
    fold_data = pkl.load(f)
fold_data, table_data

({('bert-base-cased', 0): 0.10848652538953074,
  ('bert-base-cased', 1): 0.09114081552780393,
  ('bert-base-cased', 2): 0.11496727479675059,
  ('bert-base-cased', 5): 0.14595798089610912,
  ('bert-large-cased', 0): [0.19096878937760975,
   0.09419799523812804,
   0.1320909474616652,
   0.12454447451695569,
   0.07326126240686288,
   0.07488589315134987,
   0.2013666280813856,
   0.1669728045693516]},
 {('bert-base-cased', 0): 0.10848652538953074,
  ('bert-base-cased', 1): 0.09114081552780393,
  ('bert-base-cased', 2): 0.11496727479675059,
  ('bert-base-cased', 5): 0.14595798089610912,
  ('bert-large-cased', 0): 0.13228609935041358})