In [None]:
import sys, os, pickle
sys.path.append("..")
import torch
from trainer.recorder import Recorder
import numpy as np
import matplotlib.pyplot as plt
from datetime import MAXYEAR, MINYEAR
from datetime import datetime, time, timedelta

def today():
    today = datetime.now().date()
    return datetime.combine(today, time.min)
def days_ago(days):
    day = today()
    return day - timedelta(days=days)
maxtime = datetime(MAXYEAR, 1, 1)
mintime = datetime(MINYEAR, 1, 1)

def stdmean(logger: Recorder, *labels, summarize=None):
    '''
    compute the stdmean of logged data with the given labels across all runs
    customized by the `summarize` fn
    '''
    if summarize is None:
        summarize = lambda x: x
    series_dict = {}
    for run in logger:
        run_dict = logger.get_data(run, *labels)
        summary_dict = summarize(run_dict)
        for new_label in summary_dict:
            if new_label not in series_dict:
                series_dict[new_label] = []
            series_dict[new_label].append(summary_dict[new_label])
    stdmean_dict = {}
    for label in series_dict:
        t = torch.tensor(series_dict[label])
        stdmean_dict[label] = [t.mean().item(), t.std().item()]
    return stdmean_dict

def stdmean_acc(logger: Recorder):
    def get_acc(val_test):
        val_acc = 100 * torch.tensor(val_test['val/acc'])
        valid = val_acc.max().item()
        test = 100 * val_test['test/acc'][val_acc.argmax()]
        return {'val/acc' : valid, 'test/acc': test}
    return stdmean(logger, 'val/acc', 'test/acc', summarize=get_acc)

import re

def select(loggers, filters: dict, time=None):
    def match(to_match: dict, filters: dict):
        assert isinstance(filters, dict)
        for k in filters:
            if k not in to_match:
                return False
            if isinstance(filters[k], dict):
                if not match(to_match[k], filters[k]):
                    return False
            elif isinstance(filters[k], list) or isinstance(filters[k], tuple):
                for i, f in enumerate(filters[k]):
                    if not match(to_match[k][i], f):
                        return False
            elif isinstance(filters[k], str) and isinstance(to_match[k], str):
                if re.match(filters[k], to_match[k]) is None:
                    return False
            else:
                if to_match[k] != filters[k]:
                    return False
        return True
    filtered = [
        logger for logger in loggers if match(logger.info, filters)
    ]
    if time is not None:
        start, end = time
        filtered = [
            logger for logger in filtered if logger.time < end and logger.time >= start
        ]
    return filtered

import glob
def load_logs(pattern):
    def load_pkl(fname):
        with open(fname, 'rb') as fp:
            try:
                logs = pickle.load(fp)
                logs.fname = fname.split('/')[-1]
                logs.time = datetime.fromtimestamp(os.path.getmtime(fname)).replace(microsecond=0)
                if not isinstance(logs.info, dict):
                    logs.info = vars(logs.info)
                return logs
            except:
                print("ignoring log:", fname)
    files = glob.glob(pattern)
    files.sort(key=os.path.getmtime)
    return [trace for trace in [load_pkl(f) for f in files] if trace is not None]

def extract_acc_curve(log: Recorder):
    acc_keys = ('train/acc', 'val/acc', 'test/acc')
    proc_data = {
        k: [] for k in acc_keys
    }
    for run in log:
        run_acc = log.get_data(run, *acc_keys)
        for k in acc_keys:
            proc_data[k].append(run_acc[k])
    for k in acc_keys:
        proc_data[k] = np.array(proc_data[k])
    proc_data['runs'] = log.num_runs
    return proc_data

In [None]:
from scipy.interpolate import pchip_interpolate as interpolate

def plot_on(ax, xss, yss, epochs=None):
    def make_conv_figure(ax, xs, acc_curves, **kwargs):
        '''
        xs: x-axis values, epochs or training time
        acc_curves: y-axis values, accuracy (of multiple runs)
        '''
        # for i, acc_type in enumerate(('train/acc', 'val/acc')):
        # titles = ['validation', 'test']
        # for i, acc_type in enumerate(('val/acc', 'test/acc')):
        # ax.set_title(titles[i])
        # ax.margins(x=0)
        steps = xs.shape[0]
        if epochs is not None:
            acc_curves = acc_curves[:, :(epochs+1)]
        if xs is None:
            xs = np.arange(1, acc_curves.size(1)+1, 1, dtype=int)
        mean = acc_curves.mean(axis=0)
        std = acc_curves.std(axis=0)
        if np.isnan(std).any():
            std[:] = 0
        # lower, upper = mean-std, mean+std
        # interp_xs = xs
        interp_xs = np.arange(xs.min(), xs.max(), (xs.max()-xs.min()) / (1000*steps))
        l, r = [(2, 0.0)], [(2, 0.0)]
        lower = interpolate(xs, mean-std, interp_xs)
        upper = interpolate(xs, mean+std, interp_xs)
        mean = interpolate(xs, mean, interp_xs)
        ax.plot(interp_xs, mean, marker=',', ls='-', lw=2, **kwargs)
        kwargs.pop('label', None)
        ax.fill_between(interp_xs, lower, upper, alpha=0.1, **kwargs)

    # fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharey='row', squeeze=False)
    # fig.tight_layout()
    # fig.suptitle(f"Model Convergence Rate in Epochs", fontsize=16)
    # fig.subplots_adjust(top=0.88)
    # # for ax, title in zip(axs[0], ('val/acc', 'test/acc')):
    # #     ax.set_title(title, fontsize=20)
    # for ax in axs[:,0]:
    #     ax.set_ylabel('accuracy', fontsize=16)
    # for ax in axs[-1]:
    #     ax.set_xlabel('epoch', fontsize=16)

    for k in yss:
        xs = None if xss is None else xss[k]
        make_conv_figure(ax, xs=xs, acc_curves=yss[k], label=k)
    # plt.show()
    # if save_to is not None:
    #     fig.savefig(
    #         save_to, bbox_inches = "tight"
    #     )

In [None]:
logs = load_logs('../logdir/acc/*')

Retrive the convergence curve (in wallclock time) of training GraphSAGE on ogbn-papers100M

In [None]:
ns = select(logs,
    {'dataset': {'root': '.*/ogbn_papers100M'},
     'model': {'arch': '^sage$', 'num_layers': 3, 'epochs': 30},
     'sample': {'train': [{'sampler': 'ns', 'batch_size': 1000}]},
    },
    # time=(days_ago(30), maxtime)
)
for log in ns:
    print(log.md5)
    print(log.info['sample']['train'])
    print(log.stdmean(), "\n")
ns = ns[-1]

rnd = select(logs,
    {'dataset': {'root': '.*/ogbn_papers100M'},
     'model': {'arch': '^sage$'},
     'sample': {'train': [{
         'P': 1024, 'batch_size': 128, 'partition': 'rand'
        }]},
    },
    # time=(days_ago(30), maxtime)
)
for log in rnd:
    print(log.info['sample']['train'])
    print(log.stdmean(), "\n")
rnd = rnd[-1]

hbx1 = select(logs,
    {'dataset': {'root': '.*/ogbn_papers100M'},
     'model': {'arch': '^sage$'},
     'sample': {'train': [{
         'P': 1024, 'batch_size': 128, 'partition': 'fennel-wlb',
          'num_repeats': 1, 'pivots': True,
        }]},
    },
    # time=(days_ago(30), maxtime)
)
for log in hbx1:
    print(log.info['sample']['train'])
    print(log.stdmean(), "\n")
hbx1 = hbx1[-1]

hbx2 = select(logs,
    {'dataset': {'root': '.*/ogbn_papers100M'},
     'model': {'arch': '^sage$'},
     'sample': {'train': [{
         'P': 1024, 'batch_size': 128, 'partition': 'fennel-wlb',
         'num_repeats': 2, 'pivots': True
        }]},
    },
    # time=(days_ago(30), maxtime)
)
for log in hbx2:
    print(log.info['sample']['train'])
    print(log.stdmean(), "\n")
hbx2 = hbx2[-1]

hbp0 = select(logs,
    {'dataset': {'root': '.*/ogbn_papers100M'},
     'model': {'arch': '^sage$'},
     'sample': {'train': [{
         'P': 1024, 'batch_size': 128, 'partition': 'fennel-wlb',
         'num_repeats': 2, 'pivots': False
        }]},
    },
    # time=(days_ago(30), maxtime)
)
for log in hbp0:
    print(log.info['sample']['train'])
    print(log.stdmean(), "\n")
hbp0 = hbp0[-1]

acc_series = {
    'NS-Ext': extract_acc_curve(ns),
    'HB-rand': extract_acc_curve(rnd),
    'HB-ours(r=1)': extract_acc_curve(hbx2),
    'HB-ours(r=0)': extract_acc_curve(hbx1),
    # 'HB-ours(p=0)': extract_acc_curve(hbp0),
}

In [None]:
# from google sheets, local machine with 32GB
epochs = 30
epoch_time = {
    'NS-Ext': 181,
    # 85 epochs
    'HB-rand': 48,
    # 50 epochs
    'HB-ours(r=1)': 87,
    # 75 epochs
    'HB-ours(r=0)': 59,
    # 75 epochs
    # 'HB-ours(p=0)': 55,
}
# start from zeros
train_time = {
    k : np.arange(0, epochs+1, step=1) * epoch_time[k] for k in epoch_time
}
val_acc = {
    k : np.concatenate([
        np.ones((acc_series[k]['runs'], 1))*.2,
        acc_series[k]['val/acc']
    ], axis=1) for k in epoch_time
}

fig, axs = plt.subplots(figsize=(6, 5), squeeze=False)
fig.tight_layout()
ax = axs[0][0]
# fig.suptitle(f"Model Convergence Rate in Epochs", fontsize=16)
# fig.subplots_adjust(top=0.88)
# for ax, title in zip(axs[0], ('val/acc', 'test/acc')):
#     ax.set_title(title, fontsize=20)
ax.set_title('Model Convergence in Wallclock Time', fontsize=16)
ax.set_ylabel('Validation Acc', fontsize=16)
ax.set_xlabel('Training Time', fontsize=16)
ax.set_xlim([-200,4000])
ax.set_ylim([0.6,0.7])

plot_on(ax, train_time, val_acc, epochs=epochs)
plt.legend()
plt.show()

Retrive the convergence curve (in wallclock time) of training GraphSAGE on mag240m-c

In [None]:

ns = select(logs,
    {'dataset': {'root': '.*/mag240m_c'},
     'model': {'arch': '^sage$', 'num_layers': 3, 'epochs': 30},
     'sample': {'train': [{'sampler': 'ns', 'batch_size': 1000}]},
    },
    # time=(days_ago(30), maxtime)
)
for log in ns:
    print(log.info['sample']['train'])
    print(log.stdmean(), "\n")
ns = ns[-1]

rnd = select(logs,
    {'dataset': {'root': '.*/mag240m_c'},
     'model': {'arch': '^sage$'},
     'sample': {'train': [{
         'P': 1024, 'batch_size': 64, 'partition': 'rand'
        }]},
    },
    # time=(days_ago(30), maxtime)
)
for log in rnd:
    print(log.info['sample']['train'])
    print(log.stdmean(), "\n")
rnd = rnd[-1]

hbx1 = select(logs,
    {'dataset': {'root': '.*/mag240m_c'},
     'model': {'arch': '^sage$'},
     'sample': {'train': [{
         'P': 1024, 'batch_size': 64, 'partition': 'fennel-wlb',
          'num_repeats': 1, 'pivots': True,
        }]},
    },
    # time=(days_ago(30), maxtime)
)
for log in hbx1:
    print(log.info['sample']['train'])
    print(log.stdmean(), "\n")
hbx1 = hbx1[-1]

hbx2 = select(logs,
    {'dataset': {'root': '.*/mag240m_c'},
     'model': {'arch': '^sage$'},
     'sample': {'train': [{
         'P': 1024, 'batch_size': 64, 'partition': 'fennel-wlb',
          'num_repeats': 2, 'pivots': True,
        }]},
    },
    # time=(days_ago(30), maxtime)
)
for log in hbx2:
    print(log.info['sample']['train'])
    print(log.stdmean(), "\n")
hbx2 = hbx2[-1]

hbp0 = select(logs,
    {'dataset': {'root': '.*/mag240m_c'},
     'model': {'arch': '^sage$'},
     'sample': {'train': [{
         'P': 1024, 'batch_size': 64, 'partition': 'fennel-wlb',
         'num_repeats': 2, 'pivots': False,
        }]},
    },
)
for log in hbp0:
    print(log.info['sample']['train'])
    print(log.stdmean(), "\n")
hbp0 = hbp0[-1]

acc_series = {
    'NS-Ext': extract_acc_curve(ns),
    'HB-rand': extract_acc_curve(rnd),
    'HB-ours(r=1)': extract_acc_curve(hbx2),
    'HB-ours(r=0)': extract_acc_curve(hbx1),
    'HB-ours(p=0)': extract_acc_curve(hbp0),
}

In [None]:
# from google sheets, local machine with 32GB
epochs = 30
epoch_time = {
    'NS-Ext': 610,
    'HB-rand': 82,
    'HB-ours(r=1)': 148,
    'HB-ours(r=0)': 100,
}
# start from zeros
train_time = {
    k : np.arange(0, epochs+1, step=1) * epoch_time[k] for k in epoch_time
}
val_acc = {
    k : np.concatenate([
        np.ones((acc_series[k]['runs'], 1))*.2,
        acc_series[k]['val/acc']
    ], axis=1) for k in epoch_time
}

fig, axs = plt.subplots(figsize=(6, 5), squeeze=False)
fig.tight_layout()
ax = axs[0][0]
# fig.suptitle(f"Model Convergence Rate in Epochs", fontsize=16)
# fig.subplots_adjust(top=0.88)
# for ax, title in zip(axs[0], ('val/acc', 'test/acc')):
#     ax.set_title(title, fontsize=20)
ax.set_title('Model Convergence in Wallclock Time', fontsize=16)
ax.set_ylabel('Validation Acc', fontsize=16)
ax.set_xlabel('Training Time', fontsize=16)
ax.set_xlim([-200,8000])
ax.set_ylim([0.56,0.66])

plot_on(ax, train_time, val_acc, epochs=epochs)
plt.legend()
plt.show()

Retrive the convergence curve (in wallclock time) of training GraphSAGE on mag240m (full dataset).

In [None]:
ns = select(logs,
    {'dataset': {'root': '.*/mag240m$'},
     'model': {'arch': '^sage$', 'num_layers': 3, 'epochs': 30},
     'sample': {'train': [{'sampler': 'ns', 'batch_size': 1000}]},
    },
)
for log in ns:
    print(log.info['sample']['train'])
    print(log.stdmean(), "\n")
ns = ns[-1]
print(ns.log[2]['train/loss'])
print(ns.log[2]['val/acc'])

Dataset: ogbn-arxiv with 20% nodes selected randomly as the training set, the remaining split in half as validation and test sets.

In [None]:
# ogbn_arxiv_r with 20% random split for training nodes
logs = load_logs('../logdir/acc/*')
has_pivots = True
repeats = 1

ns = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv_r'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'sampler': 'ns'}]},
    })
for log in ns:
    print(log.time,  log.info['sample'], log.stdmean())

mts_tb = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv_r'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'metis-tb', 'pivots': has_pivots, 'num_repeats': repeats}]},
    })
print()
for log in mts_tb:
    print(log.time,  log.info['sample'], log.stdmean())

mts_wtb = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv_r'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'metis-wtb', 'pivots': has_pivots, 'num_repeats': repeats}]},
    })
print()
for log in mts_wtb:
    print(log.time,  log.info['sample'], log.stdmean())

mts_c = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv_r'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'metis', 'pivots': has_pivots, 'num_repeats': repeats}]},
    })
print()
for log in mts_c:
    print(log.time,  log.info['sample'], log.stdmean())

print()
mts_w = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv_r'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'metis-w', 'pivots': has_pivots, 'num_repeats': repeats}]},
    })
for log in mts_w:
    print(log.time,  log.info['sample'], log.stdmean())

# print()
# fnlvnl = select(logs, {
#     'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv_r'},
#     'model': {'arch': 'sage', 'epochs': 100},
#     'sample': {'train': [{'partition': 'fennel-vnl', 'pivots': has_pivots, 'num_repeats': repeats}]},
#     })
# for log in fnlvnl:
#     print(log.time, log.info['sample'], log.stdmean())

print()
fnl = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv_r'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'fennel', 'pivots': has_pivots, 'num_repeats': repeats}]},
    })
for log in fnl:
    print(log.time, log.info['sample'], log.stdmean())

fnllb = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv_r'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'fennel-lb', 'pivots': has_pivots, 'num_repeats': repeats}]},
    })
print()
for log in fnllb:
    print(log.time, log.info['sample'], log.stdmean())

fnlwlb = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv_r'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'fennel-wlb', 'pivots': has_pivots, 'num_repeats': repeats}]},
    })
print()
for log in fnlwlb:
    print(log.time, log.info['sample'], log.stdmean())

rnd = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv_r'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'rand', 'num_repeats': 1}]},
    })
print()
for log in rnd:
    print(log.time, log.info['sample'], log.stdmean())

def key_fn(info):
    cluster_info = info['sample']['train'][0]
    return cluster_info['P'], cluster_info['batch_size'], cluster_info['num_repeats']
acc_series = {
    'NS (in-mem)': extract_acc_curve(ns, lambda x: 'NS'),
    # 'HB-min-w': extract_acc_curve(mts_w, key_fn),
    'HB-min-c': extract_acc_curve(mts_c, key_fn),
    # 'HB-min-tb': extract_acc_curve(mts_tb, key_fn),
    # 'HB-nlb': extract_acc_curve(fnl, key_fn),
    # 'HB-lb': extract_acc_curve(fnllb, key_fn),
    'HB-ours': extract_acc_curve(fnlwlb, key_fn),
    'Marius': extract_acc_curve(rnd, key_fn),
}
plot(acc_series, ylim=[0.51, 0.75])


Dataset: ogbn-arxiv with standard splits from OGB

In [None]:
logs = load_logs('../logdir/acc/*')
has_pivots = True
repeats = 1

ns = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'sampler': 'ns'}]},
    })
for log in ns:
    print(log.time,  log.info['sample'], log.stdmean())

mts_w = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'metis-w', 'pivots': has_pivots, 'num_repeats': repeats}]},
    })
for log in mts_w:
    print(log.time,  log.info['sample'], log.stdmean())

mts = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'metis', 'pivots': has_pivots, 'num_repeats': repeats}]},
    })
for log in mts:
    print(log.time,  log.info['sample'], log.stdmean())

fnl = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'fennel', 'pivots': has_pivots, 'num_repeats': repeats}]},
    })
print()
for log in fnl:
    print(log.time, log.info['sample'], log.stdmean())

fnllb = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'fennel-lb', 'pivots': has_pivots, 'num_repeats': repeats}]},
    })
print()
for log in fnllb:
    print(log.time, log.info['sample'], log.stdmean())

fnlwlb = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'fennel-wlb', 'pivots': has_pivots, 'num_repeats': repeats}]},
    })
print()
for log in fnllb:
    print(log.time, log.info['sample'], log.stdmean())

rnd = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/ogbn_arxiv'},
    'model': {'arch': 'sage', 'epochs': 100},
    'sample': {'train': [{'partition': 'rand', 'num_repeats': 1}]},
    })
print()
for log in rnd:
    print(log.time, log.info['sample'], log.stdmean())

def key_fn(info):
    cluster_info = info['sample']['train'][0]
    return cluster_info['P'], cluster_info['batch_size'], cluster_info['num_repeats']
acc_series = {
    'NS (in-mem)': extract_acc_curve(ns, lambda x: 'NS'),
    # 'HB-min-w': extract_acc_curve(mts_w, key_fn),
    'HB-min-c': extract_acc_curve(mts, key_fn),
    'HB-lb': extract_acc_curve(fnllb, key_fn),
    'HB-ours': extract_acc_curve(fnlwlb, key_fn),
    'Marius': extract_acc_curve(rnd, key_fn),
}

plot(acc_series, ylim=[0.4, 0.7], epochs=15)


Dataset: ogbn-papers100M, SAGE

In [None]:
def plot(acc_series, ylim=None, stderr=True, save_to=None, epochs=None):
    def make_conv_figure(axs, shfl_method, ylim=[0.58, 0.74], stderr=True, intervals=1):
        plt.ylim(ylim)
        # for i, acc_type in enumerate(('train/acc', 'val/acc')):
        titles = ['train', 'validation']
        for i, acc_type in enumerate(('train/acc', 'val/acc')):
            ax = axs[i]
            ax.set_title(titles[i])
            ax.margins(x=0)
            ax.grid()
            acc_blocks = acc_series[shfl_method]
            for block_info in acc_blocks:
                acc_curves = torch.tensor(acc_blocks[block_info][acc_type])
                if epochs is not None:
                    acc_curves = acc_curves[:, :epochs]
                xs = range(0, acc_curves.size(1) * intervals + 1, intervals)
                mean = acc_curves.mean(dim=0)
                std = acc_curves.std(dim=0)
                mean = torch.tensor([0] + mean.tolist())
                std = torch.tensor([0] + std.tolist())
                if std.isnan().any().item():
                    std[:] = 0
                label = shfl_method
                # interp_xs = torch.arange(intervals, mean.size(0), mean.size(0)/1000)
                # lower = make_interp_spline(xs, mean-std)(interp_xs)
                # upper = make_interp_spline(xs, mean+std)(interp_xs)
                # mean = make_interp_spline(xs, mean)(interp_xs)
                # ax.plot(interp_xs, mean, marker=',', label=label)
                ax.plot(xs, mean, marker=',', label=label)
                ax.fill_between(xs, mean-std, mean+std, alpha=0.1, interpolate=True)

        ax.legend(fontsize=12)
    # fig, axs = plt.subplots(1, 3, figsize=(18, 18), sharey='row', dpi=200)
    fig, axs = plt.subplots(1, 2, figsize=(10, 5), sharey='row', squeeze=False)
    fig.tight_layout()
    fig.suptitle(f"Model Convergence Rate in Epochs", fontsize=16)
    fig.subplots_adjust(top=0.88)
    # for ax, title in zip(axs[0], ('val/acc', 'test/acc')):
    #     ax.set_title(title, fontsize=20)
    for ax in axs[:,0]:
        ax.set_ylabel('accuracy', fontsize=16)
    for ax in axs[-1]:
        ax.set_xlabel('epoch', fontsize=16)

    for k in acc_series:
        make_conv_figure(axs[0], shfl_method=k, stderr=stderr, ylim=ylim)
    plt.show()
    if save_to is not None:
        fig.savefig(
            save_to, bbox_inches = "tight"
        )

logs = load_logs('../logdir/acc/*')
has_pivots = True
repeats = 1

ns = select(logs, {
    'dataset': {'name': 'ogbn-papers100M'},
    'model': {'arch': 'sage'},
    # 'sample': {'train': [{'sampler': 'ns'}]},
    })
for log in ns:
    print(log.time,  log.info['sample'], log.stdmean(epochs=30))

print(ns[0].get_series(0, 'train/time'))
# ours = select(logs, {
#     'dataset': {'name': 'ogbn-papers100M'},
#     'model': {'arch': 'sage'},
#     'sample': {'train': [{'partition': 'fennel-w', 'num_repeats': 1}]}
#     })
# for log in ours:
#     print(log.time,  log.info['sample'], log.stdmean(epochs=30))

def key_fn(info):
    cluster_info = info['sample']['train'][0]
    return cluster_info['P'], cluster_info['batch_size'], cluster_info['num_repeats']
acc_series = {
    'NS (in-mem)': extract_acc_curve(ns, lambda x: 'NS'),
    'HB-ours': extract_acc_curve(ours, key_fn),
    # 'Marius': extract_acc_urve(rnd, key_fn),
}
plot(acc_series, ylim=[0.47, 0.74], epochs=30)


In [None]:
logs = load_logs('../logdir/acc/*')
has_pivots = True
repeats = 1

ns = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/mag240m_c'},
    'model': {'arch': 'sage'},
    'sample': {'train': [{'sampler': 'ns'}]},
    })
for log in ns:
    print(log.time,  log.info['sample'], log.stdmean())

fnlwlb = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/mag240m_c'},
    'model': {'arch': 'sage', 'epochs': 30},
    'sample': {'train': [{'partition': 'fennel-wlb', 'pivots': has_pivots, 'num_repeats': repeats}]},
    })
print()
for log in fnlwlb:
    print(log.time, log.info['sample'], log.stdmean())

fnlw = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/mag240m_c'},
    'model': {'arch': 'sage', 'epochs': 30},
    'sample': {'train': [{'partition': 'fennel-w', 'pivots': has_pivots, 'num_repeats': repeats}]},
    })
print()
for log in fnlw:
    print(log.time, log.info['sample'], log.stdmean())

rnd = select(logs, {
    'dataset': {'root': '/mnt/md0/hb_datasets/mag240m_c'},
    'model': {'arch': 'sage', 'epochs': 30},
    'sample': {'train': [{'partition': 'rand', 'num_repeats': 1}]},
    })
print()
for log in rnd:
    print(log.time, log.info['sample'], log.stdmean())

def key_fn(info):
    cluster_info = info['sample']['train'][0]
    return cluster_info['P'], cluster_info['batch_size'], cluster_info['num_repeats']
acc_series = {
    'NS (in-mem)': extract_acc_curve(ns, lambda x: 'NS'),
    'HB-ours': extract_acc_curve(fnlwlb, key_fn),
    'HB-ours(w)': extract_acc_curve(fnlw, key_fn),
    'Rand+p': extract_acc_curve(rnd, key_fn),
}
plot(acc_series, ylim=[0.51, 0.75])


In [None]:
logs = load_logs('../logdir/acc/*.pkl')
papers = select(logs, {'dataset': {'root': '/mnt/md0/hb_datasets/mag240m_c'}})
for trace in papers:
    print(trace.info['sample']['train'], trace.info['model'], trace.stdmean(), trace.md5, end='\n\n')

In [None]:
emd_dict = {
    'random hier-batching': {
        (64,4): 2.2e-3,
        (64,8): 2.2e-3,
        (64,16): 2.2e-3,
    },
    'metis hier-batching': {
        (64,4): 9.2e-3,
        (64,8): 6.9e-3,
        (64,16): 4.4e-3,
        # (1024,64): 4.2e-3,
        # (1024,128): 3.3e-3,
        # (1024,256): 2.8e-3,
    },
    'fennel-LB hier-batching': {
        (64,4): 2.2e-3,
        (64,8): 2.1e-3,
        (64,16): 2.1e-3,
        (1024,16): 2.1e-3,
    },
    'global shuffling': {
        (64,8): 2.1e-3,
    },
    'shuffling once': {
        (64,8): 2.1e-3,
    }
}

def extract_acc(logs: list[Recorder]):
    acc_dict = {}
    for log in logs:
        info = log.info
        block_info = info.num_blocks, info.num_blocks // info.block_ratio
        acc_dict[block_info] = stdmean_acc(log)
    return acc_dict

acc_method_dict = {
    'random hier-batching': extract_acc(rd_logs),
    'metis hier-batching': extract_acc(hb_logs),
    'fennel-LB hier-batching': extract_acc(fl_logs),
    'global shuffling': extract_acc(gs_logs),
    # 'shuffling once': extract_acc(ss_logs),
}

style_dict = {
    'random hier-batching': {'marker': 'o', 'color': 'green'},
    'metis hier-batching': {'marker': '^', 'color': 'blue'},
    'fennel-LB hier-batching': {'marker': 'v', 'color': 'brown'},
    'global shuffling': {'marker': 'D', 'color': 'red'},
    'shuffling once': {'marker': 'h', 'color': 'orange'},
}

def make_acc_figure(ax, label, ylim=[68, 74], set_label=False):
    plt.ylim(ylim)
    for shfl_method in acc_method_dict:
        acc_dict = acc_method_dict[shfl_method]
        _set_label = set_label
        for block_info in acc_dict:
            if block_info[0] == block_info[1]:
                continue
            if block_info not in emd_dict[shfl_method]:
                continue
            mean, std = acc_dict[block_info][label]
            emd = emd_dict[shfl_method][block_info]
            if _set_label:
                ax.errorbar([emd], [mean], yerr=[std], capsize=3, **style_dict[shfl_method], label=shfl_method)
                _set_label = False
            else:
                ax.errorbar([emd], [mean], yerr=[std], capsize=3, **style_dict[shfl_method])
    if set_label:
        ax.legend(fontsize=12)
    ax.set_title(label, fontsize=16)

fig, axs = plt.subplots(1, 2, figsize=(8, 4), sharex=True, dpi=160)
fig.tight_layout()
fig.suptitle(f"Model Accuracy / Mini-batch Label Discrepancy", fontsize=16)
fig.subplots_adjust(top=0.85)
make_acc_figure(axs[0], 'val/acc', ylim=[70, 73], set_label=True)
make_acc_figure(axs[1], 'test/acc', ylim=[69, 72])
axs[0].set_ylabel('accuracy', fontsize=16)
fig.text(0.5, -0.04, 'mean discrepancy of label distribution', ha='center', fontsize=16)
fig.savefig("label_discrepancy.pdf", bbox_inches='tight')
plt.show()
