In [None]:
import json
import numpy as np
import torch.nn.functional as F
import os
import sys
import scipy.stats
import matplotlib.pyplot as plt
from pathlib import Path
from multiprocessing import Pool
get_ipython().system = os.system

In [None]:
# Parameters
GPU_ID=0
NUM_THREAD = 8
N=100
A2C_NAME='1201_512'
A2C_T=15000000

NUM_SEED = 1 #100



In [None]:
# Paths
META_DIR_PATH = './data/meta_models'
ENCODER_DIR_PATH = '{}/encoder'.format(META_DIR_PATH)
DATASET_PATH = './data/records/{}.pickle.gzip'.format(A2C_T)

os.environ['GPU_ID'] = str(GPU_ID)
os.environ['N'] = str(N)
os.environ['GPU_ID'] = str(GPU_ID)
os.environ['A2C_NAME'] = str(A2C_NAME)
os.environ['A2C_T'] = str(A2C_T)
os.environ['META_DIR_PATH'] = str(META_DIR_PATH)
os.environ['ENCODER_DIR_PATH'] = str(ENCODER_DIR_PATH)
os.environ['DATASET_PATH'] = str(DATASET_PATH)

# Training and evaluating Encoder and Manifestor

In [None]:
for seed in range(NUM_SEED):
    !bash ./train_encoder.sh
    for mapping in '012 021 102 120 201 210'.split():
        os.environ['ENCODER_SEED'] = str(seed)
        os.environ['MODEL_SEED'] = str(seed + 1000)
        os.environ['MAPPING'] = mapping
        !bash ./train_meta.sh Manifestor

# Draw Fig. 3

In [None]:
def softmax(x):
    return np.exp(x) / np.exp(x).sum(axis=1, keepdims=True)

def calc_loss2(g2, r, match_rate):
    cross_entropy_all = list()
    match_rate_all = list()
    for g2_epi, r_epi, match_epi in zip(g2, r, match_rate):
        # calc loss (normalized cross entropy)
        match_epi = np.array(match_epi)
        r_epi = np.array([r_epi[:, i:i+100].sum(axis=1)/20 for i in range(match_epi.shape[0])])
        g0_epi = softmax(r_epi)
        nll = -np.log(g2_epi)
        cross_entropy = nll * g0_epi

        cross_entropy_all.append(cross_entropy)
        match_rate_all.append(match_epi)
    cross_entropy_all = np.vstack(cross_entropy_all)
    match_rate_all = np.hstack(match_rate_all)

    loss = cross_entropy_all.sum(axis=1) / cross_entropy_all.sum()
    loss *= match_rate_all
    loss /= match_rate_all.sum()
    return loss.sum()

In [None]:
DATA_DIR = Path('./data/meta_models/goal_loss'.format(A2C_T))
def process(seed):
    translaters = '012 021 102 120 201 210'.split()
    results_in_seed = dict()
    accuracies = dict()
    for translater in translaters:
        with (DATA_DIR/'{}_{}.json'.format(seed, translater)).open() as f:
            data_original = json.load(f)
        data =  {key: np.array([line[key] for line in data_original]) for key in ['g0','g1', 'g2', 'r_']}
        data['match_rate'] = [line['match_rate'] for line in data_original]
        results_in_seed[translater] = data
        accuracies[translater] = (data['g1'].reshape(-1) == data['g2'].reshape((-1, 3)).argmax(axis=1)).mean()
        
    best_mapping, best_accuracy = max(accuracies.items(), key = lambda x:x[1])
    # for debug only
    # return best_mapping
    trans = str.maketrans({num: char for char, num in zip('ABC', best_mapping)})

    losses = dict()
    for translater, data in results_in_seed.items():
        order = translater.translate(trans)
        loss = calc_loss2(data['g2'], data['r_'], data['match_rate'])
        losses[order] = loss
    ratio = np.array([losses[key] / losses['ABC'] for key in 'ACB BAC BCA CAB CBA'.split()])
    n_win = (ratio > 1).sum()
    return n_win, ratio, best_accuracy, best_mapping

with Pool(NUM_THREAD) as p:
    rtn = p.map(process, range(NUM_SEED))

In [None]:
threshold = 0
n_win = [l[0] for l in rtn if l is not None]
ratios = [l[1] for l in rtn if l is not None]
accuracies = [l[2] for l in rtn if l is not None]
best_mappings = [l[3] for l in rtn if l is not None]
assert len(n_win) == NUM_SEED

with (Path(META_DIR_PATH)/'best_mappings.txt').open('w') as f:
    f.write(' '.join(best_mappings))

n_total = len(ratios) * 5

ratios = np.array(ratios)

In [None]:
class MyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(MyEncoder, self).default(obj)
print(ratios.argsort(axis=1).argsort(axis=1).mean(axis=0) + 1)
print(ratios.mean(axis=0))
print(ratios.min(axis=0))
print(ratios.max(axis=0))

statistics = dict(
    threshold=threshold,
    n_win=sum(n_win),
    n_total=n_total,
    percentage=sum(n_win)/n_total,
    min=ratios.min(),
    max=ratios.max(),
    mean=ratios.mean(),
    std=ratios.std(),
    median=np.median(ratios)
)
print(statistics)

plt.figure(figsize=(6,4))
plt.hist([ratios[ratios<1], ratios[ratios>=1]], histtype='barstacked', bins=np.arange(0.8, 1.55, 0.005),  ec='white', color=['tab:red','tab:blue'])
plt.xlabel('Ratio')
plt.ylabel('Density')
plt.savefig(DATA_DIR/'hist_{}.pdf'.format(A2C_T))
with (DATA_DIR/'hist_{}_stats.json'.format(A2C_T)).open('w') as f:
    json.dump(statistics, f, cls=MyEncoder)

scipy.stats.norm.interval(alpha=.95, loc=ratios.mean(), scale=ratios.std())

# TRAINING ablation and optimal

In [None]:
for seed, mapping in zip(range(NUM_SEED), best_mappings):
    os.environ['ENCODER_SEED'] = str(seed)
    os.environ['MODEL_SEED'] = str(seed + 1000)
    os.environ['MAPPING'] = mapping
    !bash ./train_meta.sh ablation
    !bash ./train_meta.sh optimal

# Draw Fig. 4

In [None]:
NUM_OPTIMAL = 5
optimals = sorted((Path(META_DIR_PATH)/'optimal').glob('*'), key=lambda p: int(p.name))[:NUM_OPTIMAL]
optimals = [str(list(p.glob('**/Transformer_state_seqlast19.pt'))[0]) for p in optimals]

os.environ['TEACHER_PATHS'] = ';'.join(optimals)
for seed, mapping in zip(range(NUM_SEED), best_mappings):
    os.environ['ENCODER_SEED'] = str(seed)
    os.environ['MODEL_SEED'] = str(seed + 1000)
    os.environ['MAPPING'] = mapping
        
    !bash ./eval_meta.sh Manifestor
    !bash ./eval_meta.sh ablation


In [None]:
def load_data(target):
    assert target in ['ablation', 'Manifestor']
    comparison_data_dir = Path(META_DIR_PATH)/target/'comparison'
    data = [np.load(comparison_data_dir/'{}_{}.npz'.format(i, mapping)) for i, mapping in enumerate(best_mappings)]
    return data

def calc_acc(ys, ts):
    t_labels = ts.mean(axis=1).argmax(axis=1)
    y_labels = ys.argmax(axis=1)
    n = ys.shape[0]
    return (t_labels == y_labels).sum() / n

acc_manifestor = np.array([calc_acc(**line) for line in load_data('Manifestor')])
acc_ablation = np.array([calc_acc(**line) for line in load_data('ablation')])

In [None]:
load_data('Manifestor')[0]['ys'], load_data('Manifestor')[0]['ts']

In [None]:
def print_stats(l):
    print(
        l.mean(),
        l.std(),
        l.min(),
        l.max(),
        np.median(l)
    )

print_stats(acc_ablation)
print_stats(acc_manifestor)
print(scipy.stats.mannwhitneyu(acc_ablation, acc_manifestor))

In [None]:
import pandas as pd
from matplotlib import pyplot as plt
import seaborn as sns

fig, ax = plt.subplots(figsize=(3,4))
df = pd.DataFrame(
    
    [['manifestor', line] for line in acc_manifestor] + [['ablation', line] for line in acc_ablation],
    columns=['type', 'accuracy'])
palette=(sns.color_palette()[0], sns.color_palette()[2])
sns.boxenplot(x='type', y='accuracy', data=df, ax=ax, palette=palette)
ax.set_ylabel('Accuracy')
ax.set_ylim([0.25,0.97])
ax.set_xlabel('')
ax.set_xticklabels(['Manifestor', 'Ablation'])
plt.savefig('{}/B-2.pdf'.format(META_DIR_PATH))

In [None]:
print(scipy.stats.norm.interval(alpha=.95, loc=acc_manifestor.mean(), scale=acc_manifestor.std()))
print(scipy.stats.norm.interval(alpha=.95, loc=acc_ablation.mean(), scale=acc_ablation.std()))