In [1]:
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import glob
import json
from pathlib import Path

In [2]:
def load_json(dir):
    try:
        with open(dir) as f:
            data = json.load(f)
            
        return data
    
    except:
        return None

In [3]:
RESULT_DIR = "..\\result_summary"

In [9]:
def parse_file_nm(file_nm):
    mcts_info = None
    
    components = file_nm.split("\\")
    
    if "debug" in components:
        components.remove("debug")
    
    method = components[2]
        
    print(components)
    diff_cut = 1
    
    if 'am' in method:
        _, _, method, _, prob_type, num_probs, model_info, baseline_info, _ = components
        
    elif 'mcts' in method:
        _, _, method, diff_cut, _, prob_type, num_probs, model_info, baseline_info, mcts_info, _ = components
        diff_cut = float(diff_cut.split("-")[-1])

    num_probs = num_probs.split("-")[0]
    num_probs = int(num_probs.split("_")[1])
    baseline_info = baseline_info.split("-")[-1]
    
    return method, prob_type, num_probs, model_info, baseline_info, mcts_info, diff_cut

In [10]:
# parse all json file named with "all_result_avg.json" from the RESULT_DIR

all_result = {}

for file in glob.glob(f"{RESULT_DIR}/**/all_result_avg.json", recursive=True):
    data = load_json(file)
    # print(file)
    method, prob_type, num_probs, model_info, baseline_info, mcts_info, diff_cut = parse_file_nm(file)
    
    if data is not None:
        
        if method not in all_result:
            all_result[method] = {}
        
        if prob_type not in all_result[method]:
            all_result[method][prob_type] = {}
            
        if num_probs not in all_result[method][prob_type]:
            all_result[method][prob_type][num_probs] = {}
        
        if mcts_info is None:
            if model_info not in all_result[method][prob_type][num_probs]:
                all_result[method][prob_type][num_probs][model_info] = {}
            
            if baseline_info not in all_result[method][prob_type][num_probs][model_info]:
                all_result[method][prob_type][num_probs][model_info][baseline_info] = data
                
        else:
            if model_info not in all_result[method][prob_type][num_probs]:
                all_result[method][prob_type][num_probs][model_info] = {}
                
            if baseline_info not in all_result[method][prob_type][num_probs][model_info]:
                all_result[method][prob_type][num_probs][model_info][baseline_info] = {}
                
            if mcts_info not in all_result[method][prob_type][num_probs][model_info][baseline_info]:
                all_result[method][prob_type][num_probs][model_info][baseline_info][mcts_info] = {}
                
            if diff_cut not in all_result[method][prob_type][num_probs][model_info][baseline_info][mcts_info]:
                all_result[method][prob_type][num_probs][model_info][baseline_info][mcts_info][diff_cut] = data
                
    # break

['..', 'result_summary', 'am', 'pretrained_result', 'cvrp', 'N_100-B_64', 'shared_mha-128-6-32-4-swiglu-10-0.0001', '1562-1-mean', 'all_result_avg.json']
['..', 'result_summary', 'am', 'pretrained_result', 'cvrp', 'N_100-B_64', 'shared_mha-128-6-32-4-swiglu-10-0.0001', '1562-1-val', 'all_result_avg.json']
['..', 'result_summary', 'am', 'pretrained_result', 'cvrp', 'N_20-B_64', 'shared_mha-128-6-32-4-relu-10-0.0001', '1562-1-mean', 'all_result_avg.json']
['..', 'result_summary', 'am', 'pretrained_result', 'cvrp', 'N_20-B_64', 'shared_mha-128-6-32-4-relu-10-0.0001', '1562-1-val', 'all_result_avg.json']
['..', 'result_summary', 'am', 'pretrained_result', 'cvrp', 'N_20-B_64', 'shared_mha-128-6-32-4-swiglu-10-0.0001', '1562-1-mean', 'all_result_avg.json']
['..', 'result_summary', 'am', 'pretrained_result', 'cvrp', 'N_20-B_64', 'shared_mha-128-6-32-4-swiglu-10-0.0001', '1562-1-val', 'all_result_avg.json']
['..', 'result_summary', 'am', 'pretrained_result', 'cvrp', 'N_50-B_64', 'shared_mha-12

In [44]:
def get_parital_result(method, prob_type, num_prob):       
    partial_result = {'model_info': [], 'baseline': [], 'mcts_info': [], 'epoch': [], 'score': [], 'runtime': [], 'score_std': [], 'diff_cut': []}
    
    def _append_epoch_result(model_info, baseline_info, epoch, mcts, epoch_score, epoch_score_std, epoch_runtime, diff_cut):
        partial_result['model_info'].append(model_info)
        partial_result['baseline'].append(baseline_info)
        partial_result['mcts_info'].append(mcts)
        partial_result['epoch'].append(epoch)                
        partial_result['score'].append(epoch_score)
        partial_result['score_std'].append(epoch_score_std)
        partial_result['runtime'].append(epoch_runtime)
        partial_result['diff_cut'].append(diff_cut)

    target_result = all_result[method][prob_type][num_prob]
    
    if "am" in method:
        for model_info in target_result:
            for baseline_info in target_result[model_info]:
                all_epoch_results = target_result[model_info][baseline_info]
                
                for epoch in all_epoch_results:
                    current_epoch_result = all_epoch_results[epoch]['result_avg']
                    epoch_score = current_epoch_result['score']
                    epoch_score_std = all_epoch_results[epoch]['result_std']['score']
                    epoch_runtime = current_epoch_result['runtime']
                    
                    _append_epoch_result(model_info, baseline_info, epoch, "am", epoch_score,epoch_score_std, epoch_runtime, 1)
                    
    elif 'mcts' in method:
        for model_info in target_result:
            for baseline_info in target_result[model_info]:
                all_mcts_results = target_result[model_info][baseline_info]
                
                for mcts_info in all_mcts_results:
                    all_diff_results = all_mcts_results[mcts_info]
                    
                    for diff_cut in all_diff_results:
                        all_epoch_results = all_diff_results[diff_cut]
                    
                        for epoch in all_epoch_results:
                            current_epoch_result = all_epoch_results[epoch]['result_avg']
                            epoch_score = current_epoch_result['score']
                            epoch_score_std = all_epoch_results[epoch]['result_std']['score']
                            epoch_runtime = current_epoch_result['runtime']
                            
                            _append_epoch_result(model_info, baseline_info, epoch, mcts_info, epoch_score,epoch_score_std, epoch_runtime, diff_cut)                                         
            
    df = pd.DataFrame.from_dict(partial_result)
    df['activation'] = df['model_info'].apply(lambda x: x.split("-")[5])
    df['enc_layers'] = df['model_info'].apply(lambda x: x.split('-')[2])
    df = df.drop(columns=['model_info'])
    
    df['train_score'] = df.apply(lambda x: x['epoch'].split("-")[1].split("=")[1], axis=1).astype(float)
    df['epoch'] =  df.apply(lambda x: x['epoch'].split("-")[0].split("=")[1], axis=1).astype(int)
    
    if 'mcts' in method:
        df['cpuct'] = df['mcts_info'].apply(lambda x: x.split("-")[2].split('_')[1])
        df['ns'] = df['mcts_info'].apply(lambda x: x.split("-")[0].split('_')[1])
        df['cpuct'] = df['cpuct'].astype(float)
        df['ns'] = df['ns'].astype(int)
    
    elif 'am' in method:
        df['cpuct'] = 0
        df['ns'] = 0
        
    df = df.drop(columns=['mcts_info'])
    

        
    return df

In [45]:
def plot_bar_result(base_df, baseline, activation, prob_type, num_prob, plot_dev=False, hue='cpuct'):
    if prob_type == 'tsp':
        y_ranges = {20: (3.75, 3.95), 50: (5.7, 5.875), 100: (7.95, 8.25)}
        
    elif prob_type == 'cvrp':
        y_ranges = {20: (6.1, 7.5), 50: (9, 12), 100: (7.95, 8.25)}
        
    _df = base_df[(base_df['baseline'] == baseline) & (base_df['activation'] == activation)]
    
    # drop rows where the mcts_info is 0.8
    _df = _df[_df['cpuct'] != '0.8']
    
    # sort by the train score
    _df = _df.sort_values(by=['epoch', 'cpuct', 'ns']).reset_index(drop=True)
    _df['score_std'] = _df['score_std'].astype(float)
    
    _df['cpuct'] = _df['cpuct'].astype(object)
    _df['cpuct'] = _df['cpuct'].fillna('am')
    
    _df['ns'] = _df['ns'].astype(object)
    _df['ns'] = _df['ns'].fillna('am')
    
    plt.figure(figsize=(12, 7))
    
    ax = sns.barplot(data=_df, x=_df['epoch'], y='score', hue=hue)
    
    if plot_dev:
        bar_coords = [(rect.get_x(), rect.get_y(), rect.get_width(), rect.get_height()) for rect in ax.patches]
                
        for i, (x, y, w, h) in enumerate(bar_coords):
            x_pos = x + w / 2
            y_top = y + h + 0.01
            
            plt.errorbar(x=x_pos, y=y_top, yerr=_df['score_std'][i], fmt='none', color='black', capsize=4)
    
    plt.title(f"{prob_type}_{num_prob}_{baseline}_{activation}")
    plt.ylim(*y_ranges[num_prob])
    plt.legend()
    
    path = Path(f"../result_image/bars")
    
    if not path.exists():
        path.mkdir(parents=True)
        
    
    plt.savefig(f"{path}/{prob_type}_{num_prob}_{baseline}_{activation}.png")
    plt.show()

In [46]:
# tsp_20 = pd.concat([get_parital_result('am', 'tsp', 20), get_parital_result('mcts_v2', 'tsp', 20)])

# for _baseline in ['mean', 'val']:
#     for _activation in ['relu', 'swiglu']:
#         plot_bar_result(tsp_20, _baseline, _activation, 'tsp', 20, hue='cpuct')

In [48]:
tsp_100 = pd.concat([get_parital_result('am', 'tsp', 100), get_parital_result('mcts', 'tsp', 100)])
tsp_100.sort_values(by=['score']).head(15)

Unnamed: 0,baseline,epoch,score,runtime,score_std,diff_cut,activation,enc_layers,train_score,cpuct,ns
132,val,299,7.952167,10.248525,0.072568,0.75,swiglu,6,7.86777,1.1,1000
108,mean,299,7.952278,1.087797,0.073115,0.75,swiglu,6,7.90623,1.1,100
121,mean,299,7.952995,4.318947,0.073776,0.75,swiglu,6,7.90623,1.1,500
139,val,249,7.953231,5.765433,0.071119,0.75,swiglu,6,7.84949,1.1,500
128,val,299,7.953631,1.265813,0.074261,0.75,swiglu,6,7.86777,1.1,100
115,mean,299,7.95392,8.395795,0.072868,0.75,swiglu,6,7.90623,1.1,1000
126,val,249,7.956674,1.367164,0.071928,0.75,swiglu,6,7.84949,1.1,100
133,val,249,7.957373,11.524362,0.071122,0.75,swiglu,6,7.84949,1.1,1000
124,mean,199,7.958446,4.116848,0.075965,0.75,swiglu,6,7.99446,1.1,500
112,mean,199,7.958824,1.077001,0.075142,0.75,swiglu,6,7.99446,1.1,100


In [11]:
tsp_100.pivot_table(index=['epoch', 'ns', 'activation', 'baseline'], values=['score',], aggfunc=[np.mean, np.std])

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,mean
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,score
epoch,ns,activation,baseline,Unnamed: 4_level_2
121,0,relu,val,9.041163
121,100,relu,val,8.264503
121,250,relu,val,8.27292
121,500,relu,val,8.279451
121,1000,relu,val,8.349693
169,0,swiglu,mean,9.087556
169,100,swiglu,mean,8.211455
169,250,swiglu,mean,8.250589
169,500,swiglu,mean,8.303616
169,1000,swiglu,mean,8.352575


In [162]:
cond1 = tsp_100.activation == 'relu' 
cond2 = tsp_100.baseline == 'mean'
cond3 = tsp_100.activation == 'swiglu'
cond4 = tsp_100.baseline == 'val'

relu_mean_and_swiglu_val = tsp_100[(cond1 & cond2) | (cond3 & cond4)]

relu_mean_and_swiglu_val.pivot_table(index=['ns', 'activation', 'baseline'], values=['score', 'runtime', 'score_std'], aggfunc=[np.mean]).sort_values(by=[('mean', 'score')])

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,mean,mean,mean
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,runtime,score,score_std
ns,activation,baseline,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2
1000,relu,mean,154.726169,8.015666,0.109672
250,relu,mean,46.888426,8.018076,0.091403
100,relu,mean,20.850545,8.023215,0.086113
500,relu,mean,86.333362,8.023653,0.111706
500,swiglu,val,93.836482,8.102727,0.096249
100,swiglu,val,22.163065,8.103705,0.10686
250,swiglu,val,50.614391,8.107124,0.098039
1000,swiglu,val,172.673865,8.139727,0.222364
0,swiglu,val,0.228124,8.491095,0.159627
0,relu,mean,0.220766,8.529293,0.342256


In [164]:
cvrp_20 = pd.concat([get_parital_result('am', 'cvrp', 20), get_parital_result('mcts', 'cvrp', 20)])
cvrp_20.sort_values(by=['score'])

KeyError: 'cvrp'

In [151]:
# tsp_50 = pd.concat([get_parital_result('am', 'tsp', 50), get_parital_result('mcts', 'tsp', 50, leave_only_puct=True)])

# for _baseline in ['mean', 'val']:
#     for _activation in ['relu', 'swiglu']:
#         plot_bar_result(tsp_50, _baseline, _activation, 'tsp', 50)

In [152]:
# tsp_100 = pd.concat([get_parital_result('am', 'tsp', 100), get_parital_result('mcts', 'tsp', 100, leave_only_puct=True)])

# for _baseline in ['mean', 'val']:
#     for _activation in ['relu', 'swiglu']:
#         plot_bar_result(tsp_100, _baseline, _activation, 'tsp', 100)