<a href="https://colab.research.google.com/github/kevinwatkins/deep-learning-sandbox/blob/master/eleutherai/EleutherAI_training_overview.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import io, re, time, ast, requests, json, dateutil, IPython.display, numpy as np, matplotlib.pyplot as plt, matplotlib.image as mpimg
openai_images = {}
server_name = 'vm.eleuther.ai'
omniboard_port = 8081
repo = 'https://github.com/EleutherAI/GPTNeo/'

## Tensorboard list

In [None]:
# (name, url_stem, run_stem)
tensorboards = [
    ('gpt3-small', None, None),
    ('gpt3-175b', None, None),
]

## Configurations

Some configurations are given in a cell here; others are loaded from omniboard. Some OpenAI GPT-3 models are included for comparison.

In [None]:
configs = json.loads('''{

"gpt3-small": {
    "train_batch_size": 250,
    "train_steps": 585938,
    "n_head": 12,
    "n_vocab": 50257,
    "n_layer": 12,
    "n_embd": 768,
    "n_ctx": 2048,
    "mesh_shape": "dummy:2048"
},

"gpt3-175b": {
    "train_batch_size": 1600,
    "train_steps": 91553,
    "n_head": 96,
    "n_vocab": 50257,
    "n_layer": 96,
    "n_embd": 12288,
    "n_ctx": 2048,
    "mesh_shape": "dummy:2048"
}

}''')
run_datas = {}

## Omniboard connection

In [None]:
omniboard_uri = f'http://{server_name}:{omniboard_port}/sacred'
# with open('omniboard_auth', 'r') as f:
#     global omniboard_auth
#     m = re.match(r'^(.+):(.+)$', f.read())
#     omniboard_auth = (m[1], m[2])

def parse_json_datetime(s):
    return dateutil.parser.isoparse(s).timestamp()

def get_omniboard_runs():
    try:
        resp = requests.get(f'{omniboard_uri}/api/v1/Runs', params={
            'select': '_id,start_time,heartbeat,status,omniboard.tags'
        })
        resp.raise_for_status()

        r_json = resp.json()
        return {r['_id']: r for r in r_json if r['_id'] is not None}
    except:
        return None

def get_omniboard_details(run_id):
    try:
        resp = requests.get(f'{omniboard_uri}/api/v1/Runs/{run_id}', params={
            'select': 'captured_out,experiment,config,start_time'
        })
        resp.raise_for_status()

        r_json = resp.json()
        #print(r_json)
        r = {}

        if 'config' in r_json and 'n_head' in r_json['config']:
            r['config'] = r_json['config']

        if 'captured_out' in r_json:
            r_out = r_json['captured_out']
            #print(r_out[:1000])
            match = re.search(r'^Tensorboard at port: (\d+)$', r_out, re.MULTILINE)
            if match:
                r['tb_port'] = int(match[1])
            if 'config' not in r:
                match = re.search(r'^params = (.+)$', r_out, re.MULTILINE)
                if match:
                    r['config'] = ast.literal_eval(match[1])
            match = re.search(r'^(\d\d\d\d-\d\d-\d\d) (\d\d:\d\d:\d\d.\d\d\d\d\d\d): ', r_out, re.MULTILINE)
            if match:
                iso_dt = f'{match[1]}T{match[2]}Z'
                r['start_time'] = parse_json_datetime(iso_dt)
        
        if 'experiment' in r_json and 'repositories' in r_json['experiment']:
            git_commits = set([
                l['commit']
                for l in r_json['experiment']['repositories']
                if l['url'].casefold() == repo.casefold()
            ])
            if len(git_commits) > 0:
                r['git_commit'] = git_commits.pop()

        # Not reliable because existing runs get restarted under new sacred ids
        # if 'start_time' in r_json:
        #     r['start_time'] = parse_json_datetime(r_json['start_time'])

        return r
    except:
        return {}

def get_omniboard_run_data(run_id):
    try:
        resp = requests.get(f'{omniboard_uri}/api/v1/Metrics', params={
            'query': json.dumps({'run_id': str(run_id)})
        })
        resp.raise_for_status()

        r_json = resp.json()
        metrics = {
            r['name']: r
            for r in r_json
            if r['run_id'] == run_id
        }
        if 'loss' not in metrics:
            return None
        loss_metric = metrics['loss']

        wall_time = np.array([parse_json_datetime(t) for t in loss_metric['timestamps']], dtype=float)
        step = np.array(loss_metric['steps'], dtype=int)
        loss = np.array(loss_metric['values'], dtype=float)

        _, unique_indices = np.unique(np.array(step), return_index=True)

        return list(zip(
            wall_time[unique_indices],
            [int(s) for s in step[unique_indices]], # urgh
            loss[unique_indices],
        ))

    except:
        return None

def add_omniboard_run(run_id, status):
    try:
        details = get_omniboard_details(run_id)
        if 'config' not in details:
            return
        run_config = details['config']
        run_config['omniboard_id'] = run_id
        run_config['omniboard_status'] = status
        if 'git_commit' in details:
            run_config['git_commit'] = details['git_commit']
        if 'start_time' in details:
            run_config['start_time'] = details['start_time']
        model_path = run_config['model_path']
        name = re.search(r'/([^/]+)$', model_path)[1]
        name = f'{run_id}-{name}'
        if name in configs:
            return
        configs[name] = run_config

        if 'tb_port' in details and status == 'RUNNING':
            tb_port = int(details['tb_port'])
            tensorboards.append((name, f'http://{server_name}:{tb_port}', ''))
        else:
            tensorboards.append((name, None, None))

        run_data = get_omniboard_run_data(run_id)
        if run_data:
            run_datas[name] = run_data
    except:
        return

### Scrape the omniboard runs

In [None]:
omniboard_runs = get_omniboard_runs()
if omniboard_runs is None:
    raise Exception('omniboard is down')
runs_to_include = [
    k for (k, v) in omniboard_runs.items()
    if v['status'] == 'RUNNING'
    or ('omniboard' in v and 'tags' in v['omniboard'] and 'foomboard' in v['omniboard']['tags'])
]
runs_to_include.sort()
for run_id in runs_to_include:
    add_omniboard_run(run_id, omniboard_runs[run_id]['status'])

## Data loading

The last cell in this section will attempt to load data from all the tensorboards and create (or overwrite) local copies. The rest of the notebook operates on the local copies.

FYI this section has been more or less obsoleted by the omniboard scraping, above.

In [None]:
def run_config_path(run):
    return f'{run}-config.json'
def run_data_path(run):
    return f'{run}.json'

In [None]:
request_timeout = 15
def get_config(run):
    if run in configs:
        return configs[run]
    url_stem, run_stem = tbruns[run]
    url = f'{url_stem}/data/plugin/text/text'
    try:
        resp = requests.get(url, params={
            'tag': 'run_config',
            'run': f'{run_stem}config',
        }, timeout=request_timeout)
        resp.raise_for_status()

        # wheeeeeeeeeee
        json1 = resp.json()
        text1 = json1[-1]['text']
        text2 = re.sub('<p>(.*)</p>', '\\1', text1)
        return ast.literal_eval(text2)
    except:
        return None

def get_run_data(run):
    if run in run_datas:
        return run_datas[run]
    url_stem, run_stem = tbruns[run]
    url = f'{url_stem}/data/plugin/scalars/scalars'
    try:
        resp = requests.get(url, params={
            'tag': 'loss',
            'run': f'{run_stem}.',
            'experiment': '',
        }, timeout=request_timeout)
        resp.raise_for_status()
        return resp.json()
    except:
        return None

def update_config(run):
    config = get_config(run)
    if config:
        with open(run_config_path(run), 'w') as f:
            f.write(json.dumps(config))
        print(f'refreshed {run} config')
    else:
        print(f'no luck refreshing {run} config')

def update_run_data(run):
    data = get_run_data(run)
    if data:
        with open(run_data_path(run), 'w') as f:
            f.write(json.dumps(data))
        print(f'refreshed {run}')
    else:
        print(f'no luck refreshing {run}')

In [None]:
runs = np.array([bd[0] for bd in tensorboards])
tbruns = {bd[0]: (bd[1], bd[2]) for bd in tensorboards if bd[1]}
for run in runs:
    if run in tbruns and run not in configs:
        update_config(run)
    if run in tbruns:
        update_run_data(run)

## Reformatting data as numpy arrays

In [None]:
def load_config(run):
    if run in configs:
        return configs[run]
    try:
        with open(run_config_path(run), 'r') as f:
            return json.load(f)
    except:
        return None
def load_run_data(run):
    if run in run_datas:
        return run_datas[run]
    try:
        with open(run_data_path(run), 'r') as f:
            return json.load(f)
    except:
        return []
run_config_dict = {run: load_config(run) for run in runs}
runs = [run for run in runs if run_config_dict[run] is not None]
run_data_list = [load_run_data(run) for run in runs]

In [None]:
def cfarray(key, default_value=np.nan):
    return np.array([run_config_dict[run].get(key, default_value) for run in runs])

In [None]:
train_batch_size = cfarray('train_batch_size')
n_head = cfarray('n_head')
n_vocab = cfarray('n_vocab')
n_layer = cfarray('n_layer')
n_embd = cfarray('n_embd')
n_ctx = cfarray('n_ctx')
mesh_shape = cfarray('mesh_shape')
train_steps = cfarray('train_steps')
omniboard_id = cfarray('omniboard_id')
start_time = cfarray('start_time')
git_commit = cfarray('git_commit', '')
tpu_name = cfarray('tpu_name', '')
omniboard_status = cfarray('omniboard_status', '')

In [None]:
n_rows = np.array([len(d) for d in run_data_list])
run_data = np.full([len(runs), max(1, max(n_rows)), 3], np.nan)
for i in range(len(runs)):
    data = run_data_list[i]
    if len(data) > 0:
        run_data[i, :len(data), :] = np.array(data)

In [None]:
def get_n_cores(mesh_shape):
    dims = mesh_shape.split(',')
    dim_lens = [int(dim.split(':')[1]) for dim in dims]
    return np.prod(dim_lens)

In [None]:
wall_time = run_data[:, :, 0]
step = run_data[:, :, 1]
loss = run_data[:, :, 2]

## Further calculated quantities

In [None]:
nz_rows = np.nonzero(n_rows)
nz_last_row = n_rows[nz_rows]-1
def last(a):
    r = np.full(len(runs), np.nan)
    r[nz_rows] = a[nz_rows, nz_last_row]
    return r

def get_stem(r):
    if r in tbruns:
        return tbruns[r][0]
    return ''
tb_stem = np.array([get_stem(r) for r in runs])
time_base = wall_time[:, 0]
step_base = step[:, 0]
for i in range(len(runs)):
    if not np.isnan(start_time[i]):
        time_base[i] = start_time[i]
        step_base[i] = 0
wall_elapsed = last(wall_time) - time_base
batches_elapsed = last(step) - step_base
wall_per_batch = wall_elapsed/batches_elapsed
tokens_per_sec = train_batch_size*n_ctx/wall_per_batch
last_update_time = last(wall_time)
end_step = last(step)
end_loss = last(loss)
fraction_done = end_step/train_steps
flops_per_core = 52.5e12
n_cores = np.array([get_n_cores(shape) for shape in mesh_shape], dtype=float)
total_flops = flops_per_core*n_cores
approx_model_params = n_layer*(n_embd.astype(float)**2)*12 + n_vocab*n_embd.astype(float)
total_train_tokens = train_steps*n_ctx*train_batch_size.astype(float)
total_approx_ops = approx_model_params*total_train_tokens*6
total_pflops_days = total_approx_ops/1e15/86400
theo_train_days = total_approx_ops/total_flops/86400
train_tokens_elapsed = end_step*n_ctx*train_batch_size
approx_ops_elapsed = approx_model_params*train_tokens_elapsed*6
pflops_days_elapsed = approx_ops_elapsed/1e15/86400

theo_wall_per_batch = n_ctx*train_batch_size*approx_model_params*6/total_flops
theo_eff = theo_wall_per_batch/wall_per_batch
wall_remaining = (total_approx_ops-approx_ops_elapsed)/total_flops/theo_eff
est_finish_time = last_update_time + wall_remaining

row_train_tokens = step*n_ctx[:, np.newaxis]*train_batch_size[:, np.newaxis].astype(float)
row_approx_n_ops = approx_model_params[:, np.newaxis]*row_train_tokens*6
row_pflops_days = row_approx_n_ops/1e15/86400

## Model size comparison

In [None]:
md_text = []
si_prefixes = ' kMGTPEZY'

def format_val(val, format_spec):
    if isinstance(val, float) and np.isnan(val):
        return ''
    if format_spec == 'utc':
        return time.strftime('%Y-%m-%d %H:%M:%S UTC', time.gmtime(val))
    si_index = 0
    if len(format_spec) > 0 and format_spec[-1] == 'S':
        format_spec = format_spec[:-1] + 'f'
        while val/(1e3**si_index) >= 1e3 and si_index < len(si_prefixes) - 1:
            si_index += 1
        s = format(val/(1e3**si_index), format_spec)
        if si_index > 0:
            s = s + ' ' + si_prefixes[si_index]
        return s
    return format(val, format_spec)

def add_row(name, values=[], format_spec=''):
    global md_text
    md_text.append(f'| {name} ')
    md_text += [f'| {format_val(v, format_spec)} ' for v in values]
    md_text.append('|\n')

blanks = ['' for r in runs]
add_row(' ', runs)
add_row('----', ['----:' for r in runs])
add_row('**Model shape**', blanks)
add_row('git_commit', [f'[{c[:7]}]({repo}/tree/{c})' if len(c) else '' for c in git_commit])
add_row('n_head', n_head)
add_row('n_vocab', n_vocab)
add_row('n_layer', n_layer)
add_row('n_embd', n_embd)
add_row('n_ctx', n_ctx)
add_row('approx_model_params', approx_model_params, '.2S')
add_row('**Training size**', blanks)
add_row('train_batch_size', train_batch_size)
add_row('train_steps', train_steps)
add_row('total_train_tokens', total_train_tokens, '.2S')
add_row('total_approx_ops', total_approx_ops, '.2e')
add_row('total_pflops_days', total_pflops_days, '.2f')
add_row('**TPU**', blanks)
add_row('tpu_name', tpu_name)
add_row('n_cores', n_cores, '.0f')
add_row('total_flops', total_flops, '.2S')
add_row('theo_train_days', theo_train_days, '.2f')
add_row('**Training progress**', blanks)
add_row('tb_url', [f'[{s[7:]}]({s})' if len(s) else '' for s in tb_stem])
add_row('sacred_id', [f'[{int(id)}]({omniboard_uri}?runId={int(id)})'
    if not np.isnan(id) else '' for id in omniboard_id])
add_row('status', omniboard_status)
add_row('start_time', start_time, 'utc')
add_row('n_updates', n_rows)
add_row('last_update_time', last_update_time, 'utc')
add_row('wall_time_secs', wall_elapsed, '.1f')
add_row('latest_batch', end_step, '.0f')
add_row('latest_loss', end_loss, '.2f')
add_row('fraction_done', fraction_done, '.4f')
add_row('train_tokens_elapsed', train_tokens_elapsed, '.2S')
add_row('approx_ops_elapsed', approx_ops_elapsed, '.2e')
add_row('pflops_days_elapsed', pflops_days_elapsed, '.2f')
add_row('secs_per_batch', wall_per_batch, '.2f')
add_row('tokens_per_sec', tokens_per_sec, '.0f')
add_row('theo_eff', theo_eff, '.3f')
add_row('wall_remaining_secs', wall_remaining, '.0f')
add_row('est_finish_time', est_finish_time, 'utc')

display(
    IPython.display.HTML('''<style>
      table {
        table-layout: fixed;
        border-collapse: collapse;
        font-size: 9pt;
      }
      tbody tr:nth-child(odd) {
        background-color: #f0f0f0;
      }
      th {
        width: 115px;
      }
      td, th {
        padding: 2px 5px;
      }
    </style>'''),
    IPython.display.Markdown(''.join(md_text))
)

## Loss vs compute plot

Note: the "compute" axis is the theoretical pflops-days that would be consumed if the tensor operations could be run at 100% efficiency. In practice, the actual pflops-days will be greater. Also, the number of floating point operations is only approximated here.

Note 2: As of this writing, EleutherAI is using its own BPE vocabulary, so in theory, losses from EleutherAI and OpenAI cannot be directly compared.

In [None]:
show_openai_runs = True

### Implementation details

In [None]:
plot_colors = 'red,light blue,light green,yellow,grey,light purple,tan,black'
plot_colors = [f'xkcd:{c}' for c in plot_colors.split(',')]

In [None]:
def download_openai_image(name):
    if name not in openai_images:
        resp = requests.get(f'https://storage.googleapis.com/via-whereas/foomboard/{name}')
        resp.raise_for_status()
        openai_images[name] = mpimg.imread(io.BytesIO(resp.content))
    return openai_images[name]

In [None]:
def show_loss_compute_plot():
    fig = plt.figure(figsize=(6.37, 6), dpi=100)

    if show_openai_runs:
        fig.add_subplot(label='image')
        img = download_openai_image('LanguageModelingComputePareto.png')
        plt.imshow(img[112:1520, 380:1912])
        plt.axis('off')

    ax = fig.add_subplot(label='runs')
    ax.patch.set_alpha(0)

    color_index = 0
    for i in range(len(runs)):
        (nzsteps,) = np.nonzero(step[i, :n_rows[i]])
        if len(nzsteps) > 0:
            plt.plot(row_pflops_days[i, nzsteps], loss[i, nzsteps],
                     plot_colors[color_index%len(plot_colors)], label=runs[i])
            color_index += 1

    if not show_openai_runs:
        frontier_xs = np.array([1e-6, 1e4])
        frontier_ys = 2.57*(frontier_xs**-0.048)
        plt.plot(frontier_xs, frontier_ys, 'k:')

    plt.xlabel('Compute (pflops-days)')
    plt.xscale('log')
    plt.xlim(1e-6, 1e4)
    plt.xticks(10.**np.arange(-6, 6, 2))

    plt.ylabel('Loss (per token, base e)')
    plt.yscale('log')
    plt.ylim(1.5, 6)
    y_ticks = [1.5] + list(range(2, 7))
    plt.yticks(y_ticks, labels=[str(t) for t in y_ticks])

    plt.legend(loc='lower left')
    
    plt.grid(not show_openai_runs)
    plt.show()

### The plot

In [None]:
show_loss_compute_plot()

## Loss vs training tokens plot

### Implementation details

In [None]:
def show_loss_tokens_plot():
    fig = plt.figure(figsize=(7.5, 5.69), dpi=100)

    if show_openai_runs:
        fig.add_subplot(label='image')
        img = download_openai_image('training_curves.png')
        plt.xlim(-3, 1538-214)
        plt.ylim(845-58, -198)
        plt.imshow(img[58:845, 214:1538])
        plt.axis('off')

    ax = fig.add_subplot(label='runs')
    ax.patch.set_alpha(0)

    color_index = 0
    for i in range(len(runs)):
        (nzsteps,) = np.nonzero(step[i, :n_rows[i]])
        if len(nzsteps) > 0:
            plt.plot(row_train_tokens[i, nzsteps]*1e-9, loss[i, nzsteps],
                     plot_colors[color_index%len(plot_colors)], label=runs[i])
            color_index += 1

    plt.xlabel('Training tokens (billions)')
    plt.xlim(-10, 300)
    plt.xticks(np.arange(0, 301, 50))

    plt.ylabel('Loss (per token, base e)')
    plt.ylim(1.5, 4.0)
    plt.yticks(np.arange(1.5, 4.001, 0.25))

    plt.legend(loc='upper right')
    
    plt.grid(not show_openai_runs)
    plt.show()

### The plot

In [None]:
show_loss_tokens_plot()