In [None]:
import os
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import glob
import ast
from datetime import datetime
from tabulate import tabulate
from math import ceil
import random
import sys


from IPython.core.display import HTML

display(HTML("<style>.container {width:90% !important;}</style>"))
display(HTML("<style>pre { white-space: pre !important; }</style>"))
from operator import itemgetter

# Global Variables

In [None]:
WORKDIR = '/used/with/slurm/runs'
PATHS = ['/path/to/train/logs']
df_save_path = None 

PREFIX_INDICATOR = "valid"
XPS_PATHS = []

VAR_ARGS = set()
ALL_ARGS = {}
DATA_LOAD_ARGS = ['reload_size','batch_load','shuffle','reuse','num_reuse_samples','times_reused','output_int_base','correctQ','balanced_base','add_unred_perc']
MODEL_ARGS = ['max_output_len','max_len','xav_init','gelu_activation','norm_attention','dropout','attention_dropout','use_circreg','reg_value']
EVAL_ARGS = ['eval_only','eval_from_exp','eval_data','eval_verbose','eval_verbose_print','stopping_criterion','validation_metrics']
RUN_ARGS = ['fp16','amp','debug_slurm','debug','cpu','local_rank','master_port','windows','nvidia_apex','is_slurm_job','node_id','global_rank','world_size']
UNWANTED_ARGS = DATA_LOAD_ARGS + MODEL_ARGS + EVAL_ARGS + RUN_ARGS + ['dim_red','data_cols','dense_cols']


In [None]:
class Patterns:
    RuntimeError = "RuntimeError:"
    CudaOOM = "CUDA out of memory"
    Terminated = "Exited with exit code 1"
    Forced = "Force Terminated"
    Signal10 = "Signal handler called with signal 10"
    Signal15 = "Signal handler called with signal 15"
    EpochStart = "============ Starting epoch"
    EpochEnd = "============ End of epoch"
    EpochLog = '__log__:'
    IterationLog = "- LR:"
    Cancelled = "CANCELLED AT"
    NodeFailure = "DUE TO NODE FAILURE"

# Parsing Experiment Parameters

In [None]:
for PATH_ENV in PATHS:
    XPS_PATHS += [os.path.join(PATH_ENV, name) for name in os.listdir(PATH_ENV)]
print(len(XPS_PATHS),"experiments found")

pickled_xp = 0
for path in XPS_PATHS:
    pa = os.path.join(path, 'params.pkl')
    if not os.path.exists(pa):
        print("Unpickled experiment: ", path)
        continue
    pk = pickle.load(open(pa,'rb'))
    ALL_ARGS.update(pk.__dict__)
    pickled_xp += 1
print(pickled_xp, "pickled experiments found")
print()

for path in XPS_PATHS:
    pa = os.path.join(path, 'params.pkl')
    if not os.path.exists(pa):
        continue
    pk = pickle.load(open(pa,'rb'))
    for key,value in ALL_ARGS.items():
        if key in pk.__dict__ and np.all(value == pk.__dict__[key]):
            continue
        if key not in UNWANTED_ARGS:
            VAR_ARGS.add(key)
            
            
print("common args")
for key in ALL_ARGS:
    if key not in UNWANTED_ARGS and key not in VAR_ARGS:
        print(key,"=", ALL_ARGS[key])
print()
            
print(len(VAR_ARGS)," variables params out of", len(ALL_ARGS))
print(VAR_ARGS)


# Useful Functions to Parse Experiment Logs

In [None]:
def read_stderr(xp_path):
    dirs = xp_path.split('/')
    EXP_ENV, xp = dirs[-2], dirs[-1]
    res = {"env": EXP_ENV, "xp": xp, "stderr": False, "log": False, "error": False}
    stderr_file = os.path.join(WORKDIR or os.path.expanduser("~"), 'workdir/'+EXP_ENV+'/*/'+xp+'.stderr')
    nb_stderr =len(glob.glob(stderr_file))
    if nb_stderr > 1:
        print("duplicate stderr", EXP_ENV, xp)
        return res
    
    for name in glob.glob(stderr_file):
        with open(name, 'rt') as f:
            res.update({"stderr": True, "runtime_errors": [], "oom": False, "terminated": False, "forced": False, "cancelled": False})
            
            for line in f:
                if line.find(Patterns.RuntimeError) >= 0:
                    res["error"] = True
                    res["runtime_errors"].append(line)
                if line.find(Patterns.CudaOOM) >= 0:
                    res["oom"] = cuda 
                if line.find(Patterns.Terminated) >=0:
                    res["terminated"] = True
                if line.find(Patterns.Forced) >=0:
                    res["forced"] = True
                if (line.find(Patterns.Cancelled) >=0) and (line.find(Patterns.Requeue)<0) and (line.find(Patterns.NodeFailure)< 0):
                    res["cancelled"] = True
                if line.find('NaN detected')>=0:
                    break

            if len(res["runtime_errors"]) > 0 and not cuda:    
                print(stderr_file,"runtime error no oom")
    return res

def read_params(res, xp_path):
    pa = os.path.join(xp_path, 'params.pkl')
    if not os.path.exists(pa):
        print("pickle", pa, "not found")
        return res
    pk = pickle.load(open(pa,'rb'))
    for key in VAR_ARGS:
        res[key] = pk.__dict__[key] if key in pk.__dict__ else None
    for key in ["batch_size", "N", "hamming", "Q", "sigma"]:
        if key not in VAR_ARGS:
            res[key] = ALL_ARGS[key]
    return res
            
def read_train_log(res, xp_path, max_epoch=None):
    pa = os.path.join(xp_path, 'train.log')
    if not os.path.exists(pa):
        return res
    res.update({"log": True, "nans": False, "curr_epoch": -1, "nonzeros_epoch": 9999, "nb_sig10": 0, "nb_sig15": 0, "train_loss": [], "val_loss": []})
    with open(pa, 'rt') as f:
        train_acc = []
        nonzeros_not_matched = 0
        for line in f:
            try:
                if line.find('NaN detected')>=0:
                    res["nans"] = True
                    break
                if line.find(Patterns.Signal10) >= 0:
                    nb_sig10 += 1
                if line.find(Patterns.Signal15) >= 0:
                    nb_sig15 += 1

                if line.find('Nonzero bits not identified. ')>=0:
                    nonzeros_not_matched += 1
                if line.find(Patterns.EpochStart) >=0:
                    curr_epoch = int(line.split('epoch ')[1].split()[0])
                    if curr_epoch == max_epoch: break
                    res["curr_epoch"] = curr_epoch
                    nonzeros_not_matched = 0
                if line.find(' - Saving checkpoint to ') >=0:
                    if nonzeros_not_matched != 4:
                        res["nonzeros_epoch"] = min(res["nonzeros_epoch"], curr_epoch)
                if line.find(Patterns.EpochEnd) >=0:
                    curr_epoch = int(line.split('epoch ')[1].split()[0])
                    if curr_epoch != res["curr_epoch"]:
                        print("epoch mismatch", curr_epoch, "in", xp_path)

                if line.find(Patterns.IterationLog) >=0:
                    loss = line.split("LOSS: ")[1].split(' - ')[0].split('||')
                    if line.find('ACC1: ') >=0:
                        acc = loss[1].split('ACC1: ')[1].strip()
                        train_acc.append(None if acc == 'nan' else float(acc))
                    loss = loss[0].strip()
                    res["train_loss"].append(None if loss == 'nan' else float(loss)) 
                if line.find(Patterns.EpochLog) > 0:
                    res["val_loss"].append(float(line.split('valid_xe_loss\": ')[1].split(',')[0].split('}')[0]))
                                
            except Exception as e:
                print(e, "exception in", xp_path)
                continue
                
        if len(train_acc) > 5:
            res["train_acc"] = train_acc
            res["Max acc"] = np.mean(train_acc[-5:])
        if res["nonzeros_epoch"] == 9999:
            res["nonzeros_epoch"] = -1
        res["best_xe_loss"] = min(res["val_loss"])
        res["last_xe_loss"] = res["val_loss"][-1]
        
    return res

def read_secret_rec(res, xp_path):
    res.update({"success_epoch": -1, "success": False})
    result_epoch = res["curr_epoch"]
    while result_epoch >= 0:
        pa = os.path.join(xp_path, f'secret_recovery_{result_epoch}.pkl')
        if os.path.exists(pa):
            try:
                pk = pickle.load(open(pa,'rb'))
                if type(pk) != dict:
                    pk = pk.__dict__
                res['success_methods'] = pk['success']
                if len(pk['success']) > 0:
                    res["success_epoch"] = result_epoch
                    res["success"] = True
                if 'partial_success' in pk:
                    res["nonzeros_epoch"] = result_epoch
            except:
                print('error reading secret recovery pickle')
        
        else:
            if result_epoch != res["curr_epoch"]:
                print("secret recovery pickle", xp_path, "not found")
        result_epoch -= 1
    return res
        

In [None]:
data = []
failed = {}
for xp_path in XPS_PATHS:
    res = read_stderr(xp_path) 
    res = read_params(res, xp_path)
    res = read_train_log(res, xp_path, None)
    data.append(read_secret_rec(res, xp_path))
    if res["error"]:
        key = str(res["N"]) +" ; "+ str(res["batch_size"])
        if key in failed:
            failed[key] +=1
        else:
            failed[key] = 1
print(failed)
print(len(data), "experiments read")
print(len([d for d in data if d["stderr"] is False]),"stderr not found")
print(len([d for d in data if d["error"] is True]),"runtime errors")
print(len([d for d in data if "oom" in d and d["oom"] is True]),"oom errors")
print(len([d for d in data if "terminated" in d and d["terminated"] is True]),"exit code 1")
print(len([d for d in data if "forced" in d and d["forced"] is True]),"Force Terminated")

In [None]:
def compose(f,g):
    return lambda x : f(g(x))

def print_table(data, args, sort=False):
    res = []
    for d in data:
        line = [d[v] if v in d else None for v in args]
        res.append(line)
    if sort:
        res = sorted(res, key=compose(float,itemgetter(0)), reverse=True)
    print(tabulate(res,headers=args,tablefmt="pretty"))


    
def speed_table(data, args, indic, sort=False, percent=95):
    res = []
    for d in data:
        
        if indic in d:
            line = [d[v] if v in d else None for v in args]
            val= 1000
            for i,v in enumerate(d[indic]):
                if v >= percent and i < val:
                    val = i
                    
            line.insert(1,val)
            res.append(line)
    e= args.copy()
    e.insert(1,'first epoch')
    if sort:
        res = sorted(res, key=compose(float,itemgetter(1)), reverse=False)
    print(tabulate(res,headers=e,tablefmt="pretty"))

def training_curve(data, indic, beg=0, end=-1, maxval=None, minval=None, smooth=1):
    for d in data:
        if indic in d:
            if smooth != 1:
                num_points = len(d[indic])//smooth
                smoothed = np.empty(())
                plt.plot([i*smooth for i in range(num_points)], np.mean(np.array(d[indic][:num_points*smooth]).reshape(-1, smooth), axis=1))
            elif end == -1:
                plt.plot(d[indic][beg:])
            else:
                plt.plot(d[indic][beg:end])
    plt.ylim(minval,maxval)
    plt.rcParams['figure.figsize'] = [10, 10]
    plt.title(indic.replace("_", " ").title())
    plt.show()
    
def filter_xp(xp, filt):
    for f in filt:
        if not f in xp:
            return False
        if not xp[f] in filt[f]:
            return False
    return True

def xp_stats(data, splits, best_arg, best_value):
    res_dic = {}
    nb = 0
    for d in data:
        if d[best_arg] < best_value: continue
        nb += 1
        for s in splits:
            if not s in d: continue
            lib=s+':'+str(d[s])
            if lib in res_dic:
                res_dic[lib] += 1
            else:
                res_dic[lib]=1
                
    print()
    print(f"{nb} experiments with accuracy over {best_value}")
    for elem in sorted(res_dic):
        print(elem,' : ',res_dic[elem])
    print()


                   
                   

In [None]:
xp_filter ={} 

table_args = VAR_ARGS - set(['env_base_seed','secret_col','secret','master_addr','dump_path'])
fdata = [d for d in data if filter_xp(d, xp_filter) is True]

oomtab = [d for d in fdata if d["error"] is True]
print(f"CUDA out of memory ({len(oomtab)})")
print_table(oomtab, table_args)

forcetab = [d for d in fdata if 'forced' in d and d["forced"] is True]
print(f"Forced terminations ({len(forcetab)})")
print_table(forcetab, table_args)

unstartedtab = [d for d in fdata if "curr_epoch" in d and d["curr_epoch"] < 0] 
print(f"Not started ({len(unstartedtab)})")
print_table(unstartedtab, table_args)

crypto = False
runargs = ["curr_epoch", "best_xe_loss", "last_xe_loss", "nans", "error"]
for v in table_args:
    runargs.append(v)
    
runningtab = [d for d in fdata if "curr_epoch" in d and d["curr_epoch"] >= 0] 
print(f"Running experiments ({len(runningtab)})")
print_table(runningtab, runargs, sort=True)

In [None]:
plt.rcParams['figure.figsize'] = [10, 10]

training_curve(fdata, "val_loss")
training_curve(fdata, "train_loss")
training_curve(fdata, "train_acc", smooth=10)
# speed_table(runningtab, runargs, "beam_acc" , sort=True,percent=85)

# Get experiment results as a DataFrame

In [None]:
df = pd.DataFrame(runningtab)
print(df.shape)
for h, methods, n_epoch in zip(df['hamming'], df['success_methods'], df['curr_epoch']):
    print(h, n_epoch, methods)

In [None]:

def groupby_and_count(df, criterions, groupby_vars):
    sub = df[[*groupby_vars, *criterions]]
    gsub = sub.groupby(groupby_vars)
    df_count = gsub.sum()
    df_total = gsub.count()
    
    epochs, partial_epochs, partial = [], [], []
    for h in np.unique(df['hamming']):
        hdf = df[df['hamming'] == h]
        success_epochs = list(hdf[hdf['success_epoch']!= -1]['success_epoch'])
        nonzeros_epochs = list(hdf[hdf['nonzeros_epoch']!= -1]['nonzeros_epoch'])
        partial.append(len(nonzeros_epochs))
        epochs.append(','.join([str(ep) for ep in sorted(success_epochs)]))
        partial_epochs.append(','.join([str(ep) for ep in sorted(nonzeros_epochs)]))
    
    for criterion in criterions:
        df_count[criterion] = df_count[criterion].astype("int").astype("str") + "/" + df_total[criterion].astype("int").astype("str")
    df_count['epochs'] = epochs
    df_count['partial success'] = np.array(partial).astype("int").astype("str")
    df_count['partial success'] = df_count['partial success'] + "/" + df_total[criterion].astype("int").astype("str")
    df_count['nonzeros_epochs'] = partial_epochs
    return df_count

ended = groupby_and_count(df, ["success"], ['hamming', 'N'])
pd.DataFrame(ended.unstack().transpose())
