# <center>Validation/Test precision@100 - Supplementary Table 1</center>

This notebook contains the code to:
- Compute validation and test precision@100.
- Getting the table presented as Supplementary Table 1.

In [1]:
import os
import sys
import seaborn as sns
import pandas as pd

PREDICTIONS_PATH = '../data/predictions'

## Load precision data

In [2]:
prec_test = pd.read_csv(os.path.join(PREDICTIONS_PATH, 'precision_test-ensemble-all.csv'), sep='\t')
prec_val = pd.read_csv(os.path.join(PREDICTIONS_PATH, 'precision_val-ensemble-all.csv'), sep='\t')

prec_test['dataset'] = 'test'
prec_val['dataset'] = 'val'

prec = pd.concat([prec_val, prec_test], ignore_index=True)

### Filter K=100 

In [3]:
prec = prec[prec['K'] == 100]
prec = prec[['kg', 'dataset', 'model', 'precision']]

## Precision Table

### Auxiliary functions

In [4]:
kg2title = {
    'openbiolink': 'OpenbioLink',
    'biokg': 'BioKG',
}

model2title = {
    'rescal': 'RESCAL',
    'transe': 'TransE',
    'distmult': 'DistMult',
    'ermlp': 'ERMLP',
    'transh': 'TransH',
    'complex': 'ComplEx',
    'hole': 'HolE',
    'conve': 'ConvE',
    'rotate': 'RotatE',
    'mure': 'MuRE',   
}

model_names = [
    'RESCAL',
    'TransE',
    'DistMult',
    'ERMLP',
    'TransH',
    'ComplEx',
    'HolE',
    'ConvE',
    'RotatE',
    'MuRE',
]

# Supplementary Table 1

In [5]:
models_order = {m:i for i,m in enumerate(model_names)}
prec['Model'] = prec['model'].apply(lambda x: model2title[x])
prec = prec.iloc[prec['Model'].map(models_order).argsort()]
prec.drop('model', axis=1).set_index(['kg', 'Model', 'dataset']).unstack().sort_values('Model', key=lambda x: x.map(models_order))

Unnamed: 0_level_0,Unnamed: 1_level_0,precision,precision
Unnamed: 0_level_1,dataset,test,val
kg,Model,Unnamed: 2_level_2,Unnamed: 3_level_2
openbiolink,RESCAL,0.0,0.0
biokg,RESCAL,3.0,3.0
openbiolink,TransE,53.0,38.0
biokg,TransE,38.0,26.0
biokg,DistMult,1.0,1.0
openbiolink,DistMult,5.0,4.0
biokg,ERMLP,2.0,1.0
openbiolink,ERMLP,24.0,14.0
biokg,TransH,23.0,20.0
openbiolink,TransH,25.0,23.0
