In [None]:
import sys; sys.path.append('..')

In [None]:
import math
import pickle
from collections import Counter
from fractions import Fraction
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from util import ds_names, ds_names_short, colors_group, step_plot
from itertools import combinations
from scipy.stats import pearsonr
import seaborn as sns

In [None]:
p_incident = Path('../data/incident')
p_core = Path('../data/core')
p_figures = Path('figures')
p_figures.mkdir(exist_ok=True)
p_dumps = Path('dumps')
p_dumps.mkdir(exist_ok=True)
plt.rcParams.update({'font.size': 30, 'text.usetex': False, 'lines.linewidth': 7})
for ds_index, ds_name in enumerate(ds_names):
    p_incident_ds = p_incident / ds_name
    p_dumps_ds = p_dumps / ds_name
    p_dumps_ds.mkdir(exist_ok=True)
    with (p_incident_ds / 'i2edges.pkl').open('rb') as f:
        i2edges = pickle.load(f)
    with (p_incident_ds / 'v2edges.pkl').open('rb') as f:
        v2edges = pickle.load(f)
    n = len(v2edges)
    nodes_sorted = sorted(v2edges)
    degree_sorted = [len(v2edges[v]) for v in nodes_sorted]
    cnt_degree = Counter(degree_sorted)
    entropy_degree = -sum(n_d / n * math.log2(n_d / n) for n_d in cnt_degree.values())
    t2info_gain = dict()
    p_core_ds = p_core / ds_name
    for c_t in tqdm(list(p_core_ds.glob('c*.pkl'))):
        t = Fraction(c_t.name[2:-4].replace('-', '/'))
        with c_t.open('rb') as f:
            v2c_t = pickle.load(f)
        c_t_sorted = [v2c_t[v] for v in nodes_sorted]
        cnt_combined = Counter(zip(degree_sorted, c_t_sorted))
        entropy_combined = -sum(n_c / n * math.log2(n_c / n) for n_c in cnt_combined.values())
        t2info_gain[t] = entropy_combined - entropy_degree
    with (p_dumps_ds / 't2info_gain.pkl').open('wb') as f:
        pickle.dump(t2info_gain, f)

    plt.clf()
    step_plot(t2info_gain)
    plt.xlabel('t')
    plt.ylabel('info. gain')
    plt.savefig(p_figures / '{}.png'.format(ds_name), bbox_inches='tight')
    plt.savefig(p_figures / '{}.pdf'.format(ds_name), bbox_inches='tight')
    plt.show()

In [None]:
ds2info_gains = dict()
ds_num = len(ds_names)
for ds_index, ds_name in enumerate(ds_names):
    p_dumps_ds = p_dumps / ds_name
    with (p_dumps_ds / 't2info_gain.pkl').open('rb') as f:
        t2info_gain = pickle.load(f)
    t_samples = np.linspace(0, 1, 100)
    i_t = 0
    info_gains = []
    t_sorted = sorted(t2info_gain)
    for t in t_sorted:
        info_gain = t2info_gain[t]
        if t < t_samples[i_t]:
            continue
        while t_samples[i_t] <= t:
            info_gains.append(info_gain)
            i_t += 1
            if i_t >= len(t_samples):
                break
    assert len(info_gains) == len(t_samples)
    ds2info_gains[ds_index] = info_gains[:]
pearson_IG_matrix = np.ones((ds_num, ds_num))
for i_1, i_2 in tqdm(list(combinations(range(ds_num), 2))):
    info_gains_1 = ds2info_gains[i_1]
    info_gains_2 = ds2info_gains[i_2]
    r, _ = pearsonr(info_gains_1, info_gains_2)
    pearson_IG_matrix[i_1][i_2] = pearson_IG_matrix[i_2][i_1] = r

with Path('pearson_IG_matrix.pkl').open('wb') as f:
    pickle.dump(pearson_IG_matrix, f)

plt.rcParams.update({'font.size': 15})
plt.clf()
ax = sns.heatmap(pearson_IG_matrix, xticklabels=False, yticklabels=ds_names_short, cmap='RdBu')
for ticklabel in ax.get_yticklabels():
    ticklabel_text = ticklabel.get_text()
    ds_index = ds_names_short.index(ticklabel_text)
    tickcolor = colors_group[ds_index]
    ticklabel.set_color(tickcolor)
ax.figure.savefig('pearson_IG_matrix.png', bbox_inches='tight')
ax.figure.savefig('pearson_IG_matrix.pdf', bbox_inches='tight')
plt.show()

r_total = []
r_within = []
r_cross = []
for i_1, i_2 in tqdm(list(combinations(range(ds_num), 2))):
    r = pearson_IG_matrix[i_1][i_2]
    if ds_names[i_1].split('-')[0] == ds_names[i_2].split('-')[0]:
        r_within.append(r)
    else:
        r_cross.append(r)
    r_total.append(r)
print('Global average: {:.3f}'.format(np.mean(r_total)))
print('Within-domain average: {:.3f}'.format(np.mean(r_within)))
print('Cross-domain average: {:.3f}'.format(np.mean(r_cross)))