In [None]:
import sys, os, pickle
sys.path.append("..")
import torch
from trainer.recorder import Recorder
import matplotlib.pyplot as plt
from datetime import MAXYEAR, MINYEAR
from datetime import datetime, time, timedelta

def today():
    today = datetime.now().date()
    return datetime.combine(today, time.min)
def days_ago(days):
    day = today()
    return day - timedelta(days=days)
maxtime = datetime(MAXYEAR, 1, 1)
mintime = datetime(MINYEAR, 1, 1)

def stdmean(logger: Recorder, *labels, summarize=None):
    '''
    compute the stdmean of logged data with the given labels across all runs
    customized by the `summarize` fn
    '''
    if summarize is None:
        summarize = lambda x: x
    series_dict = {}
    for run in logger:
        run_dict = logger.get_data(run, *labels)
        summary_dict = summarize(run_dict)
        for new_label in summary_dict:
            if new_label not in series_dict:
                series_dict[new_label] = []
            series_dict[new_label].append(summary_dict[new_label])
    stdmean_dict = {}
    for label in series_dict:
        t = torch.tensor(series_dict[label])
        stdmean_dict[label] = [t.mean().item(), t.std().item()]
    return stdmean_dict

def stdmean_acc(logger: Recorder):
    def get_acc(val_test):
        val_acc = 100 * torch.tensor(val_test['val/acc'])
        valid = val_acc.max().item()
        test = 100 * val_test['test/acc'][val_acc.argmax()]
        return {'val/acc' : valid, 'test/acc': test}
    return stdmean(logger, 'val/acc', 'test/acc', summarize=get_acc)

def select(loggers, filters: dict, time=None):
    def match(to_match: dict, filters: dict):
        assert isinstance(filters, dict)
        for k in filters:
            if k not in to_match:
                return False
            if isinstance(filters[k], dict):
                if not match(to_match[k], filters[k]):
                    return False
            elif isinstance(filters[k], list) or isinstance(filters[k], tuple):
                for i, f in enumerate(filters[k]):
                    if not match(to_match[k][i], f):
                        return False
            else:
                if to_match[k] != filters[k]:
                    return False
        return True
    filtered = [
        logger for logger in loggers if match(logger.info, filters)
    ]
    if time is not None:
        start, end = time
        filtered = [
            logger for logger in filtered if logger.time < end and logger.time >= start
        ]
    return filtered

import glob
def load_logs(dir):
    def load_pkl(fname):
        with open(fname, 'rb') as fp:
            logs = pickle.load(fp)
            logs.fname = fname.split('/')[-1]
            logs.time = datetime.fromtimestamp(os.path.getmtime(fname)).replace(microsecond=0)
            if not isinstance(logs.info, dict):
                logs.info = vars(logs.info)
            return logs
    files = glob.glob(dir)
    return [load_pkl(f) for f in files]

def extract_acc_curve(logs: list[Recorder], key_fn=None):
    logs = sorted(logs, key=lambda log: key_fn(log.info))
    proc_data = {}
    for log in logs:
        block_info =  key_fn(log.info)
        acc_keys = ('train/acc', 'val/acc', 'test/acc')
        proc_data[block_info] = {
            k: [] for k in acc_keys
        }
        for run in log:
            run_acc = log.get_data(run, *acc_keys)
            for k in acc_keys:
                proc_data[block_info][k].append(run_acc[k])
    return proc_data

In [None]:
logs = load_logs('../logdir/acc_study/*')
mts = select(logs, {
    'dataset': 'ogbn-arxiv_r',
    'hb': 'metis',
    })
for log in mts:
    print(log.time,  log.info, log.stdmean())
fnl = select(logs, {
    'dataset': 'ogbn-arxiv_r',
    'hb': 'fennel',
    })
for log in fnl:
    print(log.time, log.info, log.stdmean())
# fnllb = select(logs, {
#     'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv'},
#     'model': {'arch': 'sage', 'epochs': 100},
#     'sample': {'train': [{'partition': 'fennel-lb', 'num_repeats': 2}]},
#     })
# for log in fnllb:
#     print(log.time, log.info['sample'], log.stdmean())

def key_fn(info):
    return info['dataset'], info['num_blocks'], info['num_blocks'] // info['block_ratio']

acc_series = {
    'HB-metis': extract_acc_curve(mts, key_fn),
    'HB-lb': extract_acc_curve(fnl, key_fn),
}

In [None]:
logs = load_logs('../logdir/acc/*')
ns = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'sampler': 'ns'}]},
    })
for log in ns:
    print(log.time,  log.info['sample'], log.stdmean())
mts = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'metis', 'num_repeats': 2}]},
    })
for log in mts:
    print(log.time,  log.info['sample'], log.stdmean())
fnl = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'fennel', 'num_repeats': 2}]},
    })
for log in fnl:
    print(log.time, log.info['sample'], log.stdmean())
fnllb = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'fennel-lb', 'num_repeats': 2}]},
    })
for log in fnllb:
    print(log.time, log.info['sample'], log.stdmean())
rnd = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'rand', 'num_repeats': 2}]},
    })
for log in rnd:
    print(log.time, log.info['sample'], log.stdmean())

def key_fn(info):
    cluster_info = info['sample']['train'][0]
    return cluster_info['P'], cluster_info['batch_size'], cluster_info['num_repeats']
acc_series = {
    'NS': extract_acc_curve(ns, lambda x: 'NS'),
    'HB-metis': extract_acc_curve(mts, key_fn),
    'HB-lb': extract_acc_curve(fnllb, key_fn),
    'HB-no-lb': extract_acc_curve(fnl, key_fn),
    'HB-marius': extract_acc_curve(rnd, key_fn),
}

In [None]:

from scipy.interpolate import make_interp_spline
def make_conv_figure(axs, shfl_method, ylim=[0.5, 0.8], stderr=True):
    plt.ylim(ylim)
    # for i, acc_type in enumerate(('train/acc', 'val/acc')):
    for i, acc_type in enumerate(('val/acc', 'test/acc')):
        ax = axs[i]
        ax.set_title(acc_type)
        ax.margins(x=0)
        acc_blocks = acc_series[shfl_method]
        for block_info in acc_blocks:
            acc_curves = torch.tensor(acc_blocks[block_info][acc_type])
            xs = range(acc_curves.size(1))
            mean = acc_curves.mean(dim=0)
            std = acc_curves.std(dim=0)
            interp_xs = torch.arange(0, mean.size(0)-1, mean.size(0)/1000)
            lower = make_interp_spline(xs, mean-std)(interp_xs)
            upper = make_interp_spline(xs, mean+std)(interp_xs)
            mean = make_interp_spline(xs, mean)(interp_xs)
            label = shfl_method
            if 'HB' in shfl_method:
                label += f' {block_info}'.replace(', ', '/')
            if stderr:
                if 'GS' in shfl_method:
                    ax.plot(interp_xs, mean, marker=',', label=label, color='red', ls='-', lw=2)
                else:
                    ax.plot(interp_xs, mean, marker=',', label=label)
                ax.fill_between(interp_xs, lower, upper, alpha=0.1, interpolate=True)
            else:
                ax.plot(acc_curves[2][:25], marker=',', label=label, ls='-', lw=1)

    ax.legend(fontsize=12)

# fig, axs = plt.subplots(1, 3, figsize=(18, 18), sharey='row', dpi=200)
# fig.tight_layout()
fig, axs = plt.subplots(1, 2, figsize=(10, 4), sharey='row', squeeze=False)
# fig.suptitle(f"Model Convergence under Different Shuffling Schemes", fontsize=20)
# fig.subplots_adjust(top=0.92)
# for ax, title in zip(axs[0], ('val/acc', 'test/acc')):
#     ax.set_title(title, fontsize=20)
for ax in axs[:,0]:
    ax.set_ylabel('accuracy', fontsize=20)
for ax in axs[-1]:
    ax.set_xlabel('epoch', fontsize=20)

for k in acc_series:
    make_conv_figure(axs[0], shfl_method=k, stderr=False)
plt.show()
# fig.savefig(
#     'label_accuracy.pdf', bbox_inches = "tight"
# )

In [None]:
emd_dict = {
    'random hier-batching': {
        (64,4): 2.2e-3,
        (64,8): 2.2e-3,
        (64,16): 2.2e-3,
    },
    'metis hier-batching': {
        (64,4): 9.2e-3,
        (64,8): 6.9e-3,
        (64,16): 4.4e-3,
        # (1024,64): 4.2e-3,
        # (1024,128): 3.3e-3,
        # (1024,256): 2.8e-3,
    },
    'fennel-LB hier-batching': {
        (64,4): 2.2e-3,
        (64,8): 2.1e-3,
        (64,16): 2.1e-3,
        (1024,16): 2.1e-3,
    },
    'global shuffling': {
        (64,8): 2.1e-3,
    },
    'shuffling once': {
        (64,8): 2.1e-3,
    }
}

def extract_acc(logs: list[Recorder]):
    acc_dict = {}
    for log in logs:
        info = log.info
        block_info = info.num_blocks, info.num_blocks // info.block_ratio
        acc_dict[block_info] = stdmean_acc(log)
    return acc_dict

acc_method_dict = {
    'random hier-batching': extract_acc(rd_logs),
    'metis hier-batching': extract_acc(hb_logs),
    'fennel-LB hier-batching': extract_acc(fl_logs),
    'global shuffling': extract_acc(gs_logs),
    # 'shuffling once': extract_acc(ss_logs),
}

style_dict = {
    'random hier-batching': {'marker': 'o', 'color': 'green'},
    'metis hier-batching': {'marker': '^', 'color': 'blue'},
    'fennel-LB hier-batching': {'marker': 'v', 'color': 'brown'},
    'global shuffling': {'marker': 'D', 'color': 'red'},
    'shuffling once': {'marker': 'h', 'color': 'orange'},
}

def make_acc_figure(ax, label, ylim=[68, 74], set_label=False):
    plt.ylim(ylim)
    for shfl_method in acc_method_dict:
        acc_dict = acc_method_dict[shfl_method]
        _set_label = set_label
        for block_info in acc_dict:
            if block_info[0] == block_info[1]:
                continue
            if block_info not in emd_dict[shfl_method]:
                continue
            mean, std = acc_dict[block_info][label]
            emd = emd_dict[shfl_method][block_info]
            if _set_label:
                ax.errorbar([emd], [mean], yerr=[std], capsize=3, **style_dict[shfl_method], label=shfl_method)
                _set_label = False
            else:
                ax.errorbar([emd], [mean], yerr=[std], capsize=3, **style_dict[shfl_method])
    if set_label:
        ax.legend(fontsize=12)
    ax.set_title(label, fontsize=16)

fig, axs = plt.subplots(1, 2, figsize=(8, 4), sharex=True, dpi=160)
fig.tight_layout()
fig.suptitle(f"Model Accuracy / Mini-batch Label Discrepancy", fontsize=16)
fig.subplots_adjust(top=0.85)
make_acc_figure(axs[0], 'val/acc', ylim=[70, 73], set_label=True)
make_acc_figure(axs[1], 'test/acc', ylim=[69, 72])
axs[0].set_ylabel('accuracy', fontsize=16)
fig.text(0.5, -0.04, 'mean discrepancy of label distribution', ha='center', fontsize=16)
fig.savefig("label_discrepancy.pdf", bbox_inches='tight')
plt.show()
