In [1]:
# handle the system stuff, colab stuff, etc
import os
DIR = os.path.abspath("../")

# make sure we have the necessary folders
for subdir in ['data', 'figs', 'datasets']: 
    temp = os.path.join(DIR, subdir)
    if not os.path.isdir(temp): os.mkdir(temp)

from meta_opt.train_loops import train_standard_opt, train_hgd, train_meta_opt
from meta_opt.utils.experiment_utils import make, save_checkpoint, process_results, bcolors, plot, get_final_cparams
import meta_opt.configs as configs

import re
import matplotlib.pyplot as plt
import numpy as np
import dill as pkl
import optax

# ==================================================
# configuration and seeds for each trial
SEEDS = [0,]

NAME = 'wmt_yeet'
CFG = {
    # training options
    'workload': 'WMT',
    'num_iters': 5000,
    'eval_every': 200,
    'num_eval_iters': -1,
    'batch_size': 32,
    'full_batch': False,
    'reset_every': int(1e9),

    # experiment options
    'experiment_name': NAME,
    'load_checkpoint': False,
    'overwrite': True,  # whether to allow us to overwrite existing checkpoints or throw errors
    'directory': DIR,
}

2024-03-13 00:57:56.472378: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [4]:
from time import perf_counter
from collections import defaultdict
import tqdm
import jax

from meta_opt.nn import reset_model, train_step, eval
from meta_opt.workloads import get_workload
from meta_opt.utils.pytree_utils import pytree_sq_norm
from meta_opt.utils.experiment_utils import get_opt_hyperparams

# -------------------------------------------------------------------------------------------------
# ------------------------------ Standard Optax Optimizers ----------------------------------------
# -------------------------------------------------------------------------------------------------

def train_standard_opt(cfg, optimizer):
    tstate, train_ds, test_ds, rng, args = get_workload(cfg, optimizer)

    stats = defaultdict(dict)
    args['optimizer_args'] = get_opt_hyperparams(tstate.opt_state)
    args['optimizer_name'] = 'standard'
    stats['args'] = args

    t0 = perf_counter()
    last_eval_step = None
    pbar = tqdm.tqdm(train_ds.as_numpy_iterator(), total=args['num_iters'])
    for t, batch in enumerate(pbar):

        if t % args['reset_every'] == 0:
            reset_rng, rng = jax.random.split(rng)
            tstate = reset_model(reset_rng, tstate)
            del reset_rng

        tstate, (loss, grads) = train_step(tstate, batch)

        # update all the stats
        s = {}
        s['timestamp'] = perf_counter() - t0
        s['loss'] = loss
        if t % args['eval_every'] == 0 and t != 0:
            for k, v in eval(tstate, test_ds.as_numpy_iterator()).items(): s[f'eval_{k}'] = v
            s['param_sq_norm'] = pytree_sq_norm(tstate.params)
            s['grad_sq_norm'] = pytree_sq_norm(grads)
            last_eval_step = t

        stats[t] = s
        pbar.set_postfix({'loss': round(s['loss'].item(), 3), 
                          'eval_loss': round(stats[last_eval_step]['eval_loss'].item(), 3) if last_eval_step is not None else 'N/A',
                          })

    return dict(stats)

CFG['seed'] = SEEDS[0]
train_standard_opt(CFG, optax.inject_hyperparams(optax.sgd)(0.1))

  0%|          | 0/5000 [00:21<?, ?it/s]E0313 01:01:32.054113 4162103 pjrt_stream_executor_client.cc:2804] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 6590442048 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  797.73MiB
              constant allocation:    1.00MiB
        maybe_live_out allocation:    1.56GiB
     preallocated temp allocation:    6.14GiB
  preallocated temp fragmentation:   65.00MiB (1.03%)
                 total allocation:    8.47GiB
              total fragmentation:  915.75MiB (10.55%)
Peak buffers:
	Buffer 1:
		Size: 1000.00MiB
		Operator: op_name="jit(train_step)/jit(main)/transpose(jvp(Transformer))/Transformer.decode/decoder/div" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33
		XLA Label: fusion
		Shape: f32[32,256,32000]

	Buffer 2:
		Size: 500.00MiB
		Operator: op_name="jit(train_step)/jit(main)/jvp(Transformer)/Transformer.decode/decoder/share

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 6590442048 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  797.73MiB
              constant allocation:    1.00MiB
        maybe_live_out allocation:    1.56GiB
     preallocated temp allocation:    6.14GiB
  preallocated temp fragmentation:   65.00MiB (1.03%)
                 total allocation:    8.47GiB
              total fragmentation:  915.75MiB (10.55%)
Peak buffers:
	Buffer 1:
		Size: 1000.00MiB
		Operator: op_name="jit(train_step)/jit(main)/transpose(jvp(Transformer))/Transformer.decode/decoder/div" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33
		XLA Label: fusion
		Shape: f32[32,256,32000]
		==========================

	Buffer 2:
		Size: 500.00MiB
		Operator: op_name="jit(train_step)/jit(main)/jvp(Transformer)/Transformer.decode/decoder/shared_embedding.attend/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=float32]" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33
		XLA Label: fusion
		Shape: f16[8192,32000]
		==========================

	Buffer 3:
		Size: 125.00MiB
		Operator: op_name="jit(train_step)/jit(main)/jvp(Transformer)/Transformer.decode/decoder/shared_embedding.attend/transpose[permutation=(1, 0)]" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33
		Entry Parameter Subshape: f32[32000,1024]
		==========================

	Buffer 4:
		Size: 125.00MiB
		Operator: op_name="jit(train_step)/jit(main)/add" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33
		XLA Label: fusion
		Shape: f32[32000,1024]
		==========================

	Buffer 5:
		Size: 125.00MiB
		Operator: op_name="jit(train_step)/jit(main)/transpose(jvp(Transformer))/Transformer.decode/decoder/shared_embedding.attend/dot_general[dimension_numbers=(((0, 1), (0, 1)), ((), ())) precision=None preferred_element_type=float32]" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33
		XLA Label: custom-call
		Shape: f32[32000,1024]
		==========================

	Buffer 6:
		Size: 64.00MiB
		Operator: op_name="jit(train_step)/jit(main)/jvp(Transformer)/Transformer.encode/encoder/encoderblock_5/MlpBlock_0/Dense_0/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33
		XLA Label: custom-call
		Shape: f16[8192,4096]
		==========================

	Buffer 7:
		Size: 64.00MiB
		Operator: op_name="jit(train_step)/jit(main)/jvp(Transformer)/Transformer.encode/encoder/encoderblock_5/MultiHeadDotProductAttention_0/convert_element_type[new_dtype=float16 weak_type=False]" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33 deduplicated_name="input_exponential_reduce_fusion"
		XLA Label: fusion
		Shape: f16[32,16,256,256]
		==========================

	Buffer 8:
		Size: 64.00MiB
		Operator: op_name="jit(train_step)/jit(main)/jvp(Transformer)/Transformer.encode/encoder/encoderblock_4/MlpBlock_0/Dense_0/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33
		XLA Label: custom-call
		Shape: f16[8192,4096]
		==========================

	Buffer 9:
		Size: 64.00MiB
		Operator: op_name="jit(train_step)/jit(main)/jvp(Transformer)/Transformer.encode/encoder/encoderblock_4/MultiHeadDotProductAttention_0/convert_element_type[new_dtype=float16 weak_type=False]" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33 deduplicated_name="input_exponential_reduce_fusion"
		XLA Label: fusion
		Shape: f16[32,16,256,256]
		==========================

	Buffer 10:
		Size: 64.00MiB
		Operator: op_name="jit(train_step)/jit(main)/jvp(Transformer)/Transformer.encode/encoder/encoderblock_3/MlpBlock_0/Dense_0/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33
		XLA Label: custom-call
		Shape: f16[8192,4096]
		==========================

	Buffer 11:
		Size: 64.00MiB
		Operator: op_name="jit(train_step)/jit(main)/jvp(Transformer)/Transformer.encode/encoder/encoderblock_3/MultiHeadDotProductAttention_0/convert_element_type[new_dtype=float16 weak_type=False]" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33 deduplicated_name="input_exponential_reduce_fusion"
		XLA Label: fusion
		Shape: f16[32,16,256,256]
		==========================

	Buffer 12:
		Size: 64.00MiB
		Operator: op_name="jit(train_step)/jit(main)/jvp(Transformer)/Transformer.encode/encoder/encoderblock_2/MlpBlock_0/Dense_0/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33
		XLA Label: custom-call
		Shape: f16[8192,4096]
		==========================

	Buffer 13:
		Size: 64.00MiB
		Operator: op_name="jit(train_step)/jit(main)/jvp(Transformer)/Transformer.encode/encoder/encoderblock_2/MultiHeadDotProductAttention_0/convert_element_type[new_dtype=float16 weak_type=False]" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33 deduplicated_name="input_exponential_reduce_fusion"
		XLA Label: fusion
		Shape: f16[32,16,256,256]
		==========================

	Buffer 14:
		Size: 64.00MiB
		Operator: op_name="jit(train_step)/jit(main)/jvp(Transformer)/Transformer.encode/encoder/encoderblock_1/MlpBlock_0/Dense_0/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33
		XLA Label: custom-call
		Shape: f16[8192,4096]
		==========================

	Buffer 15:
		Size: 64.00MiB
		Operator: op_name="jit(train_step)/jit(main)/jvp(Transformer)/Transformer.encode/encoder/encoderblock_1/MultiHeadDotProductAttention_0/convert_element_type[new_dtype=float16 weak_type=False]" source_file="/tmp/ipykernel_4162103/4186806824.py" source_line=33 deduplicated_name="input_exponential_reduce_fusion"
		XLA Label: fusion
		Shape: f16[32,16,256,256]
		==========================

