In [1]:
import os

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from scipy import stats

from libs.utils.data import aggregate, mean_confidence_interval


fname = "results/separability/sep_{}.csv".format
MODELS = ["ComplEx", "DistMult", "RDF2Vec", "TransE", "TransH", "TransD"]

def load_df(model):
    df = pd.read_csv(fname(model), index_col=0)
    return df

In [6]:
linestyles = ["solid", "dashed", "dotted", "dashdot",  (0, (5, 10)), (0, (5, 1))]
markers = ["o", "v", "p", "P", "^", "<"]
style = [dict(linestyle=ls, marker=mk, linewidth=1) for ls, mk in zip(linestyles, markers)]

plt.rc("ps", fonttype=42)
plt.rc("pdf", fonttype=42)

plt.subplots_adjust(top=0.98, right=0.99, left=0.15, bottom=0.15, wspace=0.05)
fig, (ax1, ax2) = plt.subplots(1, 2, sharey=True, figsize=(6, 3))

# Left plot: separability vs taxonomic distance
distance = "taxo"
distances = {"mixed": "Mixed ", "taxo": "Taxonomic ", "geom": "Geometric ", "euc": "Lexical ", "cos": "Lexical ", "neuc": "Lexical "}

for distance, ax in zip(["taxo", "euc"], [ax1, ax2]):
    for s, model in zip(style, MODELS):
        df = load_df(model)
        
        x, ys = aggregate(df[distance], df.f1, 8, func=mean_confidence_interval)
        y, err = zip(*ys)
        ax.plot(x, y, label=model, **s)

    ax.set(
        xlabel=f"{distances[distance]} Distance",
        #ylabel="Separability"
    )

ax1.set(ylabel="Separability")
ax2.legend()

SAVE = False
RESNAME = "results/separability/sep_dist"
FORMATS = ["png", "pdf", "eps"]

if SAVE:
    for fmt in FORMATS:
        plt.savefig(RESNAME + "." + fmt, format=fmt)
plt.show()

  keepdims=keepdims)
  ret = ret.dtype.type(ret / rcount)


In [5]:
datatable = {}
col = "hsize"
for m in MODELS:
    df = load_df(m)
    df["hsize"] = df[["ca", "cb"]].apply(stats.hmean, axis=1)
    df["gsize"] = df[["ca", "cb"]].apply(stats.gmean, axis=1)
    steps = np.logspace(np.log10(df[col].min()-1), np.log10(df[col].max()+1), 10)
    x, y = aggregate(df[col], df.f1, steps)
    datatable[m] = (x, y)  

In [7]:
plt.figure(figsize=(6, 3))
plt.tight_layout()
plt.subplots_adjust(bottom=0.15, top=0.95)

m1 = datatable["TransE"][0][-1]

for s, (m, (x, y)) in zip(style, datatable.items()):
    if m == "ComplEx":
        x[-1] = m1
    plt.plot(x, y, label=m, **s)
    
plt.xscale("log")
plt.xlabel("Mean class size")
plt.ylabel("Separability")
plt.legend()


SAVE = False
RESNAME = "results/separability/sep_size"
FORMATS = ["png", "pdf", "eps"]

if SAVE:
    for fmt in FORMATS:
        plt.savefig(RESNAME + "." + fmt, format=fmt)
plt.show()