In [5]:
import os
import pickle
import numpy as np

In [6]:
basedir = '/home/gamaga/work/cse517-lama/output/results'

In [7]:
model_names = sorted(os.listdir(basedir))
print(model_names)

['bert_base', 'bert_large', 'elmo', 'elmo5B', 'transformerxl']


In [8]:
dataset_names = ['ConceptNET', 'GoogleRE', 'Squad', 'TREx']

In [10]:
all_precisions = {}
for dataset_name in dataset_names:
    all_precisions[dataset_name] = {}
    for model_name in model_names:
        print(dataset_name, model_name)
        all_precisions[dataset_name][model_name] = {}
        precisions = []
        for template in os.listdir(os.path.join(basedir, model_name, dataset_name)):
            filename = os.path.join(basedir, model_name, dataset_name, template, 'result.pkl')
            with open(filename, 'rb') as f:
                results = pickle.load(f)
            num_queries = len(results['list_of_results'])
            sum_precision = sum([result['sample_Precision1'] for result in results['list_of_results']])
            avg_precision = sum_precision/num_queries
            print('  Template: %s. Num queries: %d. Avg P@1: %f' % (template, num_queries, avg_precision))
            precisions.append(avg_precision)
            all_precisions[dataset_name][model_name][template] = avg_precision

        print('  ' + '='*30)
        print('  Total P@1: %f' % np.mean(precisions))

ConceptNET bert_base
  Template: test. Num queries: 11460. Avg P@1: 0.156457
  Total P@1: 0.156457
ConceptNET bert_large
  Template: test. Num queries: 11460. Avg P@1: 0.192583
  Total P@1: 0.192583
ConceptNET elmo
  Template: test. Num queries: 11460. Avg P@1: 0.061082
  Total P@1: 0.061082
ConceptNET elmo5B
  Template: test. Num queries: 11460. Avg P@1: 0.062391
  Total P@1: 0.062391
ConceptNET transformerxl
  Template: test. Num queries: 11460. Avg P@1: 0.057068
  Total P@1: 0.057068
GoogleRE bert_base
  Template: place_of_death. Num queries: 765. Avg P@1: 0.130719
  Template: date_of_birth. Num queries: 1825. Avg P@1: 0.015342
  Template: place_of_birth. Num queries: 2937. Avg P@1: 0.149472
  Total P@1: 0.098511
GoogleRE bert_large
  Template: place_of_death. Num queries: 765. Avg P@1: 0.139869
  Template: date_of_birth. Num queries: 1825. Avg P@1: 0.013699
  Template: place_of_birth. Num queries: 2937. Avg P@1: 0.160708
  Total P@1: 0.104759
GoogleRE elmo
  Template: place_of_deat

  Template: P740. Num queries: 936. Avg P@1: 0.000000
  Template: P1412. Num queries: 969. Avg P@1: 0.000000
  Template: P131. Num queries: 881. Avg P@1: 0.000000
  Template: P937. Num queries: 954. Avg P@1: 0.000000
  Template: P463. Num queries: 225. Avg P@1: 0.000000
  Template: P47. Num queries: 920. Avg P@1: 0.000000
  Template: P361. Num queries: 932. Avg P@1: 0.000000
  Template: P136. Num queries: 931. Avg P@1: 0.000000
  Template: P27. Num queries: 966. Avg P@1: 0.000000
  Template: P279. Num queries: 963. Avg P@1: 0.000000
  Template: P413. Num queries: 952. Avg P@1: 0.000000
  Template: P276. Num queries: 958. Avg P@1: 0.000000
  Template: P37. Num queries: 966. Avg P@1: 0.000000
  Template: P407. Num queries: 877. Avg P@1: 0.000000
  Template: P190. Num queries: 992. Avg P@1: 0.000000
  Template: P264. Num queries: 429. Avg P@1: 0.000000
  Template: P103. Num queries: 977. Avg P@1: 0.000000
  Template: P101. Num queries: 696. Avg P@1: 0.000000
  Template: P176. Num queries:

In [11]:
relations_file = '/home/gamaga/work/cse517-lama/data/relations.jsonl'

with open(relations_file, 'r') as f:
    lines = f.readlines()

In [15]:
import json
import collections
trex_templates = collections.defaultdict(list)
for line in lines:
    data = json.loads(line.strip())
    trex_templates[data['type']].append(data['relation'])

In [16]:
trex_templates.keys()

dict_keys(['N-1', 'N-M', '1-1'])

In [21]:
for model_name in model_names:
    print(model_name)
    for key, templates in trex_templates.items():
        p1s = [
            all_precisions['TREx'][model_name][t] for t in templates
            if t in all_precisions['TREx'][model_name]]
        mean_p1 = np.mean(p1s)
        print('   Type: %s. Len: %d. P@1: %f' % (
            key, len(p1s), mean_p1))

bert_base
   Type: N-1. Len: 23. P@1: 0.315556
   Type: N-M. Len: 16. P@1: 0.246647
   Type: 1-1. Len: 2. P@1: 0.679640
bert_large
   Type: N-1. Len: 23. P@1: 0.338304
   Type: N-M. Len: 16. P@1: 0.242937
   Type: 1-1. Len: 2. P@1: 0.744629
elmo
   Type: N-1. Len: 23. P@1: 0.003500
   Type: N-M. Len: 16. P@1: 0.000000
   Type: 1-1. Len: 2. P@1: 0.000000
elmo5B
   Type: N-1. Len: 23. P@1: 0.003869
   Type: N-M. Len: 16. P@1: 0.000000
   Type: 1-1. Len: 2. P@1: 0.000000
transformerxl
   Type: N-1. Len: 23. P@1: 0.180284
   Type: N-M. Len: 16. P@1: 0.164786
   Type: 1-1. Len: 2. P@1: 0.364688
