In [1]:
# # prepare the repo
# !git clone https://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@github.com/edogariu/meta-opt
# !mkdir meta-opt/data
# !ls -a meta-opt

# # get a link to the file
# from google.colab import drive
# drive.mount('/content/drive')
# !cp -r "meta-opt" "drive/My Drive/"
# DIR_PREFIX = "drive/My Drive/meta-opt"
# # !pip install kora -q  # library from https://stackoverflow.com/questions/62596466/how-can-i-run-notebooks-of-a-github-project-in-google-colab to help get ID
# # from kora.xattr import get_id
# # fid = get_id(f"{dir_prefix}meta_opt.ipynb")
# # print("https://colab.research.google.com/drive/"+fid)

# # install the package
# !pip install ./meta-opt

In [2]:
from time import perf_counter
from collections import defaultdict
from copy import deepcopy
import tqdm
import matplotlib.pyplot as plt
import pickle as pkl

import numpy as np
import tensorflow as tf
import jax
import optax

from meta_opt.training.trainer import create_train_state, gradient_descent, reset_model, forward
from meta_opt.training.utils import cross_entropy, mse, MLP, CNN, load_mnist, load_cifar10

from meta_opt.meta_opt import MetaOpt
from meta_opt.gaps import MetaOptGAPS

### Todo
- add `ADAM` setting for meta opt and keep track of it in args. run experiments again w adam and see
- add `accuracy` to statistics each eval round
- add MP, cosine, cyclical learning rates, hedging, AGD, DoWG, D-adaptation, adagrad & rmsprop
- try other settings
- check "training instability" literature

In [3]:
def set_seed(seed):
    if seed is None: 
        seed = np.random.randint()
        print('seed set to {}'.format(seed))
    np.random.seed(seed)
    tf.random.set_seed(seed)
    rng = jax.random.PRNGKey(seed)
    return rng, seed

def get_problem(seed, name, optimizer):
    rng, seed = set_seed(seed)
    init_rng, rng = jax.random.split(rng)

    # get dataset and model
    if name == 'MNIST':
        train_ds, test_ds, loss_fn, input_dims = load_mnist(NUM_ITERS, BATCH_SIZE)
        model = MLP([28 * 28, 100, 100, 10])
    elif name == 'CIFAR':
        train_ds, test_ds, loss_fn, input_dims = load_cifar10(NUM_ITERS, BATCH_SIZE)
        model = CNN(channels=[3, 32, 64, 32], layer_dims=[512, 128, 10], drop_last_activation=False)
    else:
        raise NotImplementedError(name)

    tstate = create_train_state(init_rng, model, input_dims, optimizer, loss_fn)
    del init_rng

    args = {'seed': seed,
            'model': str(model),
            'dataset': name,
            'num_iters': NUM_ITERS,
            'eval_every': EVAL_EVERY,
            'batch_size': BATCH_SIZE,
            'reset_every': RESET_EVERY,
            'print_every': PRINT_EVERY}

    return tstate, train_ds, test_ds, rng, args

# Standard Optimizers

In [4]:
def train_standard_opt(seed, problem_name, optimizer):
    tstate, train_ds, test_ds, rng, args = get_problem(seed, problem_name, optimizer)
    
    stats = defaultdict(dict)
    args['optimizer_args'] = deepcopy(tstate.opt_state.hyperparams)
    args['optimizer_args']['name'] = 'standard'
    stats['args'] = args

    t0 = perf_counter()
    for t, batch in enumerate(pbar := tqdm.tqdm(train_ds.as_numpy_iterator(), total=len(train_ds))):
        t += 1
    
        if t % RESET_EVERY == 0:
            reset_rng, rng = jax.random.split(rng)
            tstate = reset_model(reset_rng, tstate)
            del reset_rng

        tstate, (loss, grads) = gradient_descent(tstate, batch)
        
        # update all the stats
        s = {}
        s['timestamp'] = perf_counter() - t0
        s['loss'] = loss
        if t % EVAL_EVERY == 0: s['eval_loss'] = sum(forward(tstate, batch) for batch in test_ds.as_numpy_iterator()) / len(test_ds)
        stats[t] = s
    
        # print if we gotta
        if t % PRINT_EVERY == 0 and t > 0:
            idxs = [stats[i] for i in range(t - PRINT_EVERY, t) if i in stats]
            avg_train_loss = np.mean([s['loss'] for s in idxs if 'loss' in s])
            avg_eval_loss = np.mean([s['eval_loss'] for s in idxs if 'eval_loss' in s])
            print(f'iters {t - PRINT_EVERY} - {t}')
            print(f'\tavg train loss: {avg_train_loss}')
            print(f'\tavg eval loss: {avg_eval_loss}')
        pbar.set_postfix({'loss': round(s['loss'].item(), 3)})

    return dict(stats)

# Meta-Opt

In [5]:
def train_meta_opt(seed, problem_name: str, m_method: str, meta_lr: float, use_adam: bool, H: int, HH: int, initial_lr: int):
    optimizer = optax.sgd(learning_rate=initial_lr)
    tstate, train_ds, test_ds, rng, args = get_problem(seed, problem_name, optimizer)
    
    stats = defaultdict(dict)
    args['optimizer_args'] = {'name': 'meta',
                              'initial_lr': initial_lr,
                              'm_method': m_method,
                              'meta_lr': meta_lr,
                              'use_adam': use_adam,
                              'H': H,
                              'HH': HH
                              }
    stats['args'] = args
    meta_opt = MetaOpt(tstate, H=H, HH=HH, meta_lr=meta_lr, delta=1e-5, m_method=m_method, use_adam=use_adam)

    t0 = perf_counter()
    for t, batch in enumerate(pbar := tqdm.tqdm(train_ds.as_numpy_iterator(), total=len(train_ds))):
        t += 1
    
        if t % RESET_EVERY == 0:
            reset_rng, rng = jax.random.split(rng)
            tstate = reset_model(reset_rng, tstate)
            meta_opt = meta_opt.episode_reset()
            del reset_rng

        tstate, (loss, grads) = gradient_descent(tstate, batch)
        tstate = meta_opt.meta_step(tstate, grads, batch)
        
        # update all the stats
        s = {}
        s['timestamp'] = perf_counter() - t0
        s['loss'] = loss
        if t % EVAL_EVERY == 0: s['eval_loss'] = sum(forward(tstate, batch) for batch in test_ds.as_numpy_iterator()) / len(test_ds)
        if m_method == 'scalar': s['M'] = meta_opt.cstate.M.reshape(-1)
        stats[t] = s

        # print if we gotta
        if t % PRINT_EVERY == 0 and t > 0:
            idxs = [stats[i] for i in range(t - PRINT_EVERY, t) if i in stats]
            avg_train_loss = np.mean([s['loss'] for s in idxs if 'loss' in s])
            avg_eval_loss = np.mean([s['eval_loss'] for s in idxs if 'eval_loss' in s])
            print(f'iters {t - PRINT_EVERY} - {t}')
            print(f'\tavg train loss: {avg_train_loss}')
            print(f'\tavg eval loss: {avg_eval_loss}')
        pbar.set_postfix({'loss': round(s['loss'].item(), 3)})

    return dict(stats)

# Gradient-based Adaptive Policy Selection (GAPS) Meta-Opt

In [6]:
def train_gaps_meta_opt(seed, problem_name: str, m_method: str, meta_lr: float, use_adam: bool, H: int, B: int, initial_lr: int):
    optimizer = optax.sgd(learning_rate=initial_lr)
    tstate, train_ds, test_ds, rng, args = get_problem(seed, problem_name, optimizer)
    
    stats = defaultdict(dict)
    args['optimizer_args'] = {'name': 'gaps_meta',
                              'initial_lr': initial_lr,
                              'm_method': m_method,
                              'meta_lr': meta_lr,
                              'use_adam': use_adam,
                              'H': H,
                              'B': B
                              }
    stats['args'] = args

    meta_opt = MetaOptGAPS(tstate, H=H, B=B, meta_lr=meta_lr, use_adam=use_adam, delta=1e-5, m_method=m_method)
    
    t0 = perf_counter()
    for t, batch in enumerate(pbar := tqdm.tqdm(train_ds.as_numpy_iterator(), total=len(train_ds))):
        t += 1
    
        if t % RESET_EVERY == 0:
            reset_rng, rng = jax.random.split(rng)
            tstate = reset_model(reset_rng, tstate)
            meta_opt = meta_opt.episode_reset()
            del reset_rng

        # tstate, (loss, grads) = gradient_descent(tstate, batch)
        tstate, (loss, grads) = meta_opt.meta_step(tstate, batch)
        
        # update all the stats
        s = {}
        s['timestamp'] = perf_counter() - t0
        s['loss'] = loss
        if t % EVAL_EVERY == 0: s['eval_loss'] = sum(forward(tstate, batch) for batch in test_ds.as_numpy_iterator()) / len(test_ds)
        if m_method == 'scalar': s['M'] = meta_opt.cstate.M.reshape(-1)
        stats[t] = s

        # print if we gotta
        if t % PRINT_EVERY == 0 and t > 0:
            idxs = [stats[i] for i in range(t - PRINT_EVERY, t) if i in stats]
            avg_train_loss = np.mean([s['loss'] for s in idxs if 'loss' in s])
            avg_eval_loss = np.mean([s['eval_loss'] for s in idxs if 'eval_loss' in s])
            print(f'iters {t - PRINT_EVERY} - {t}')
            print(f'\tavg train loss: {avg_train_loss}')
            print(f'\tavg eval loss: {avg_eval_loss}')
        pbar.set_postfix({'loss': round(s['loss'].item(), 3)})

    return dict(stats)

# Hypergradient Descent

In [7]:
def train_hgd(seed, problem_name: str, initial_lr: float, hypergrad_lr: float):

    optimizer = optax.inject_hyperparams(optax.sgd)(learning_rate=initial_lr)
    tstate, train_ds, test_ds, rng, args = get_problem(seed, problem_name, optimizer)

    stats = defaultdict(dict)
    args['optimizer_args'] = {'name': 'hgd',
                              'initial_lr': initial_lr,
                              'hypergrad_lr': hypergrad_lr,
                              }
    stats['args'] = args
    
    prev_grads = None
    t0 = perf_counter()
    for t, batch in enumerate(pbar := tqdm.tqdm(train_ds.as_numpy_iterator(), total=len(train_ds))):
        t += 1
    
        if t % RESET_EVERY == 0:
            reset_rng, rng = jax.random.split(rng)
            tstate = reset_model(reset_rng, tstate)
            del reset_rng

        tstate, (loss, grads) = gradient_descent(tstate, batch)
        if prev_grads is not None: 
            hypergrad = -sum([(g1 * g2).sum() for g1, g2 in zip(jax.tree_util.tree_leaves(grads), jax.tree_util.tree_leaves(prev_grads))])
            tstate.opt_state.hyperparams['learning_rate'] -= hypergrad_lr * hypergrad
        prev_grads = grads
        
        # update all the stats
        s = {}
        s['timestamp'] = perf_counter() - t0
        s['loss'] = loss
        s['lr'] = tstate.opt_state.hyperparams['learning_rate'].item()
        if t % EVAL_EVERY == 0: s['eval_loss'] = sum(forward(tstate, batch) for batch in test_ds.as_numpy_iterator()) / len(test_ds)
        stats[t] = s
    
        # print if we gotta
        if t % PRINT_EVERY == 0 and t > 0:
            idxs = [stats[i] for i in range(t - PRINT_EVERY, t) if i in stats]
            avg_train_loss = np.mean([s['loss'] for s in idxs if 'loss' in s])
            avg_eval_loss = np.mean([s['eval_loss'] for s in idxs if 'eval_loss' in s])
            print(f'iters {t - PRINT_EVERY} - {t}')
            print(f'\tavg train loss: {avg_train_loss}')
            print(f'\tavg eval loss: {avg_eval_loss}')
        pbar.set_postfix({'loss': round(s['loss'].item(), 3)})

    return dict(stats)

# Run
Select the hyperparameters and the seeds to use for each trial.

In [8]:
# hyperparams
SEEDS = [18, 29, 69, 1] 
NUM_ITERS = 25000
EVAL_EVERY = 100
BATCH_SIZE = 1024
RESET_EVERY = 5000
PRINT_EVERY = int(1e10)

NAME = 'MNIST'
if 'DIR_PREFIX' not in globals(): DIR_PREFIX = '.'  # use this directory if unspecified

from jax.lib import xla_bridge
print('dataset:', NAME)
print('using', xla_bridge.get_backend().platform, 'for jax')
print(f'saving data at `{DIR_PREFIX}/data/`')

dataset: MNIST
using cpu for jax
saving data at `./data/`


In [9]:
# uncomment the ones to run
results = defaultdict(list)
# results = pkl.load(open(f'{DIR_PREFIX}/data/{NAME}_raw.pkl', 'rb'))

for s in SEEDS:
#     results['sgd'].append(train_standard_opt(s, NAME, optax.inject_hyperparams(optax.sgd)(learning_rate=0.2)))
#     results['adam'].append(train_standard_opt(s, NAME, optax.inject_hyperparams(optax.adam)(learning_rate=0.001)))
#     results['adam_0.0001'].append(train_standard_opt(s, NAME, optax.inject_hyperparams(optax.adam)(learning_rate=1e-4)))
#     results['adam_0.0005'].append(train_standard_opt(s, NAME, optax.inject_hyperparams(optax.adam)(learning_rate=5e-4)))
#     results['meta_scalar'].append(train_meta_opt(s, NAME, 'scalar', meta_lr=0.008, H=4, HH=2, initial_lr=0.1, use_adam=False))
    results['meta_scalar'].append(train_meta_opt(s, NAME, 'scalar', meta_lr=0.001, H=4, HH=2, initial_lr=0.1, use_adam=True))
#     results['meta_diagonal'].append(train_meta_opt(s, NAME, 'diagonal', meta_lr=1e-1, H=4, HH=2, initial_lr=0.2))
#     results['meta_GAPS'].append(train_gaps_meta_opt(s, NAME, 'scalar', meta_lr=0.005, H=6, B=6, initial_lr=0.2))
#     results['hgd'].append(train_hgd(s, NAME, initial_lr=0.4, hypergrad_lr=1e-4)
    
    with open(f'{DIR_PREFIX}/data/{NAME}_raw.pkl', 'wb') as f: 
        pkl.dump(results, f)
        print(f'Saved checkpoint for seed #{s}')

  0%|                                              | 7/25000 [00:01<1:08:13,  6.11it/s, loss=2.04]

(Array([0.4454494 , 0.46873635, 0.49411494, 0.50407493], dtype=float32),)
(Array([0.4454494 , 0.46873635, 0.49411494, 0.5       ], dtype=float32),)

(Array([0.47103786, 0.4997273 , 0.52434003, 0.5459984 ], dtype=float32),)
(Array([0.47103786, 0.4997273 , 0.5       , 0.5       ], dtype=float32),)

(Array([0.46872386, 0.5024389 , 0.52695173, 0.54680884], dtype=float32),)
(Array([0.46872386, 0.5       , 0.5       , 0.5       ], dtype=float32),)

(Array([0.5026192 , 0.520077  , 0.55486447, 0.59197557], dtype=float32),)
(Array([0.5, 0.5, 0.5, 0.5], dtype=float32),)



  0%|                                               | 15/25000 [00:01<30:48, 13.51it/s, loss=1.25]

(Array([0.5146063 , 0.55175084, 0.619205  , 0.67365754], dtype=float32),)
(Array([0.5, 0.5, 0.5, 0.5], dtype=float32),)

(Array([0.54788065, 0.56706995, 0.61313033, 0.70841634], dtype=float32),)
(Array([0.5, 0.5, 0.5, 0.5], dtype=float32),)

(Array([0.5428356 , 0.594669  , 0.68257445, 0.72299576], dtype=float32),)
(Array([0.5, 0.5, 0.5, 0.5], dtype=float32),)

(Array([0.60373306, 0.6716294 , 0.71478987, 0.787097  ], dtype=float32),)
(Array([0.5, 0.5, 0.5, 0.5], dtype=float32),)

(Array([0.6191437 , 0.62117314, 0.6979269 , 0.770548  ], dtype=float32),)
(Array([0.5, 0.5, 0.5, 0.5], dtype=float32),)

(Array([0.6318585, 0.6196428, 0.6629517, 0.7311485], dtype=float32),)
(Array([0.5, 0.5, 0.5, 0.5], dtype=float32),)

(Array([0.54927385, 0.5672685 , 0.5712669 , 0.6251303 ], dtype=float32),)
(Array([0.5, 0.5, 0.5, 0.5], dtype=float32),)



  0%|                                              | 22/25000 [00:01<21:30, 19.36it/s, loss=0.795]

(Array([0.44126576, 0.438985  , 0.4654047 , 0.5423515 ], dtype=float32),)
(Array([0.44126576, 0.438985  , 0.4654047 , 0.5       ], dtype=float32),)

(Array([0.37377763, 0.3568725 , 0.364892  , 0.3903361 ], dtype=float32),)
(Array([0.37377763, 0.3568725 , 0.364892  , 0.3903361 ], dtype=float32),)

(Array([0.3751942 , 0.1909868 , 0.20857275, 0.25685218], dtype=float32),)
(Array([0.3751942 , 0.1909868 , 0.20857275, 0.25685218], dtype=float32),)

(Array([ 0.04087885, -0.07246463,  0.06491113,  0.16193897], dtype=float32),)
(Array([ 0.04087885, -0.07246463,  0.06491113,  0.16193897], dtype=float32),)

(Array([ 0.14649631,  0.08510147, -0.19676977, -0.38358098], dtype=float32),)
(Array([ 0.14649631,  0.08510147, -0.19676977, -0.38358098], dtype=float32),)

(Array([-0.22437796, -0.7571046 , -1.3025393 , -1.6616168 ], dtype=float32),)
(Array([-0.22437796, -0.5       , -0.5       , -0.5       ], dtype=float32),)



  0%|                                               | 29/25000 [00:01<17:13, 24.16it/s, loss=4.04]

(Array([ 0.17857361, -0.19539401, -1.7656782 ,  0.51891124], dtype=float32),)
(Array([ 0.17857361, -0.19539401, -0.5       ,  0.5       ], dtype=float32),)

(Array([ 0.9764075,  1.2703236, -1.5324805, -4.426737 ], dtype=float32),)
(Array([ 0.5,  0.5, -0.5, -0.5], dtype=float32),)

(Array([ 2.9351792, -2.3851118, -6.983225 , -2.2054734], dtype=float32),)
(Array([ 0.5, -0.5, -0.5, -0.5], dtype=float32),)

(Array([-1.6022667, -6.840847 , -3.158029 ,  5.241557 ], dtype=float32),)
(Array([-0.5, -0.5, -0.5,  0.5], dtype=float32),)

(Array([ 2.6946247 , -0.21407562, -4.039051  , -3.988716  ], dtype=float32),)
(Array([ 0.5       , -0.21407562, -0.5       , -0.5       ], dtype=float32),)

(Array([ -6.6058707, -11.020954 ,  -7.8426223,  -5.216161 ], dtype=float32),)
(Array([-0.5, -0.5, -0.5, -0.5], dtype=float32),)

(Array([-3.9367752e+00, -5.6867571e+00, -6.9180260e+00, -6.0614347e-03],      dtype=float32),)
(Array([-0.5       , -0.5       , -0.5       , -0.00606143], dtype=float32),)



  0%|                                               | 35/25000 [00:02<15:54, 26.15it/s, loss=1.75]

(Array([-1.8548092, -2.9615812, -3.5108354, -4.5666037], dtype=float32),)
(Array([-0.5, -0.5, -0.5, -0.5], dtype=float32),)

(Array([-0.25534657, -1.4650351 , -2.5918756 , -5.565876  ], dtype=float32),)
(Array([-0.25534657, -0.5       , -0.5       , -0.5       ], dtype=float32),)

(Array([-0.9564257, -2.3367043, -2.1024556, -0.2643596], dtype=float32),)
(Array([-0.5      , -0.5      , -0.5      , -0.2643596], dtype=float32),)

(Array([-1.6123364, -1.6707859, -2.9781399, -2.3056192], dtype=float32),)
(Array([-0.5, -0.5, -0.5, -0.5], dtype=float32),)

(Array([-1.3482414 , -1.5757707 , -1.9302169 , -0.92254287], dtype=float32),)
(Array([-0.5, -0.5, -0.5, -0.5], dtype=float32),)

(Array([ 0.03503628, -0.3881889 , -0.79508555, -0.5758077 ], dtype=float32),)
(Array([ 0.03503628, -0.3881889 , -0.5       , -0.5       ], dtype=float32),)

(Array([ 0.14224537, -0.34697354, -1.3675772 , -1.8487247 ], dtype=float32),)
(Array([ 0.14224537, -0.34697354, -0.5       , -0.5       ], dtype=float32),)



  0%|                                               | 42/25000 [00:02<15:00, 27.71it/s, loss=0.93]

(Array([-0.20155998, -0.6174903 , -1.1810323 , -0.07816149], dtype=float32),)
(Array([-0.20155998, -0.5       , -0.5       , -0.07816149], dtype=float32),)

(Array([-0.23823024, -0.7851046 , -0.2843645 ,  0.6119088 ], dtype=float32),)
(Array([-0.23823024, -0.5       , -0.2843645 ,  0.5       ], dtype=float32),)

(Array([-0.4014547 , -1.0880688 , -0.95019954, -0.5640795 ], dtype=float32),)
(Array([-0.4014547, -0.5      , -0.5      , -0.5      ], dtype=float32),)

(Array([-0.7972581 , -1.4199607 , -1.6711242 , -0.57195234], dtype=float32),)
(Array([-0.5, -0.5, -0.5, -0.5], dtype=float32),)

(Array([-0.9590737 , -1.6849728 , -0.99993455,  0.1848037 ], dtype=float32),)
(Array([-0.5      , -0.5      , -0.5      ,  0.1848037], dtype=float32),)

(Array([-1.4308153 , -0.9524628 , -0.63422453, -0.2364482 ], dtype=float32),)
(Array([-0.5      , -0.5      , -0.5      , -0.2364482], dtype=float32),)

(Array([-0.36369377, -0.56495756, -0.89431286, -0.03888795], dtype=float32),)
(Array([-0.36369377,

  0%|                                              | 50/25000 [00:02<13:24, 31.02it/s, loss=0.664]

(Array([-0.5514799, -1.8102401, -1.452806 , -0.5103133], dtype=float32),)
(Array([-0.5, -0.5, -0.5, -0.5], dtype=float32),)

(Array([-1.8260975, -1.8689654, -1.3573266, -0.3899077], dtype=float32),)
(Array([-0.5      , -0.5      , -0.5      , -0.3899077], dtype=float32),)

(Array([-1.4641662 , -1.5217766 , -0.85822535,  0.5604944 ], dtype=float32),)
(Array([-0.5, -0.5, -0.5,  0.5], dtype=float32),)

(Array([-1.1702007 , -1.0962522 , -0.45213062,  0.30284333], dtype=float32),)
(Array([-0.5       , -0.5       , -0.45213062,  0.30284333], dtype=float32),)

(Array([-1.0595938 , -0.97277975, -0.63104343, -0.067964  ], dtype=float32),)
(Array([-0.5     , -0.5     , -0.5     , -0.067964], dtype=float32),)

(Array([-1.0246134 , -1.3083924 , -1.2220699 , -0.42438936], dtype=float32),)
(Array([-0.5       , -0.5       , -0.5       , -0.42438936], dtype=float32),)

(Array([-1.124754  , -1.4490352 , -0.9784023 , -0.18774101], dtype=float32),)
(Array([-0.5       , -0.5       , -0.5       , -0.187741

  0%|                                              | 58/25000 [00:02<13:36, 30.56it/s, loss=0.497]

(Array([-1.36939   , -1.1646316 , -0.589677  ,  0.03749371], dtype=float32),)
(Array([-0.5       , -0.5       , -0.5       ,  0.03749371], dtype=float32),)

(Array([-1.057817  , -0.45747298,  0.10229653,  0.5569551 ], dtype=float32),)
(Array([-0.5       , -0.45747298,  0.10229653,  0.5       ], dtype=float32),)

(Array([-0.30139822, -0.15568393,  0.10259989,  0.32576913], dtype=float32),)
(Array([-0.30139822, -0.15568393,  0.10259989,  0.32576913], dtype=float32),)

(Array([-0.24400462, -0.09306026,  0.10922074,  0.16227472], dtype=float32),)
(Array([-0.24400462, -0.09306026,  0.10922074,  0.16227472], dtype=float32),)

(Array([-0.2688333 ,  0.01426213,  0.09226145,  0.1148192 ], dtype=float32),)
(Array([-0.2688333 ,  0.01426213,  0.09226145,  0.1148192 ], dtype=float32),)

(Array([-0.11494298, -0.08172885, -0.06416655, -0.0589228 ], dtype=float32),)
(Array([-0.11494298, -0.08172885, -0.06416655, -0.0589228 ], dtype=float32),)

(Array([-0.00232851,  0.03956948,  0.04876806,  0.05085713

  0%|                                               | 62/25000 [00:03<12:43, 32.66it/s, loss=0.43]

(Array([0.06315476, 0.09034434, 0.09481227, 0.0518526 ], dtype=float32),)
(Array([0.06315476, 0.09034434, 0.09481227, 0.0518526 ], dtype=float32),)

(Array([-0.01272944,  0.04239465,  0.03909428,  0.04434862], dtype=float32),)
(Array([-0.01272944,  0.04239465,  0.03909428,  0.04434862], dtype=float32),)

(Array([0.0197218 , 0.09476557, 0.07061   , 0.0508773 ], dtype=float32),)
(Array([0.0197218 , 0.09476557, 0.07061   , 0.0508773 ], dtype=float32),)

(Array([0.04585098, 0.05427087, 0.03776676, 0.02852542], dtype=float32),)
(Array([0.04585098, 0.05427087, 0.03776676, 0.02852542], dtype=float32),)

(Array([0.04020827, 0.06701098, 0.09022183, 0.03414785], dtype=float32),)
(Array([0.04020827, 0.06701098, 0.09022183, 0.03414785], dtype=float32),)

(Array([0.04764838, 0.07881881, 0.05152886, 0.01811745], dtype=float32),)
(Array([0.04764838, 0.07881881, 0.05152886, 0.01811745], dtype=float32),)

(Array([0.03111909, 0.08296469, 0.03894127, 0.01989247], dtype=float32),)
(Array([0.03111909, 0.08

  0%|▏                                             | 75/25000 [00:03<11:12, 37.06it/s, loss=0.463]

(Array([ 0.06173757,  0.07803267,  0.08748515, -0.05265494], dtype=float32),)
(Array([ 0.06173757,  0.07803267,  0.08748515, -0.05265494], dtype=float32),)

(Array([0.0284503 , 0.05262851, 0.077574  , 0.01701695], dtype=float32),)
(Array([0.0284503 , 0.05262851, 0.077574  , 0.01701695], dtype=float32),)

(Array([-0.00447237,  0.16463129,  0.04318792, -0.17560816], dtype=float32),)
(Array([-0.00447237,  0.16463129,  0.04318792, -0.17560816], dtype=float32),)

(Array([ 0.00411225,  0.10513968,  0.09489487, -0.01006142], dtype=float32),)
(Array([ 0.00411225,  0.10513968,  0.09489487, -0.01006142], dtype=float32),)

(Array([ 0.00169775,  0.2163945 ,  0.02888954, -0.23772036], dtype=float32),)
(Array([ 0.00169775,  0.2163945 ,  0.02888954, -0.23772036], dtype=float32),)

(Array([ 0.02842901,  0.10697287,  0.10356911, -0.03418174], dtype=float32),)
(Array([ 0.02842901,  0.10697287,  0.10356911, -0.03418174], dtype=float32),)

(Array([ 0.05494977,  0.24009222,  0.0075954 , -0.13430038], dtype

  0%|▏                                             | 83/25000 [00:03<11:25, 36.35it/s, loss=0.499]

(Array([ 0.06094846,  0.17941283,  0.0362764 , -0.08950211], dtype=float32),)
(Array([ 0.06094846,  0.17941283,  0.0362764 , -0.08950211], dtype=float32),)

(Array([-0.05117543,  0.10804068,  0.16533107, -0.10897331], dtype=float32),)
(Array([-0.05117543,  0.10804068,  0.16533107, -0.10897331], dtype=float32),)

(Array([ 0.0309626 ,  0.13196483,  0.00350883, -0.04098428], dtype=float32),)
(Array([ 0.0309626 ,  0.13196483,  0.00350883, -0.04098428], dtype=float32),)

(Array([-0.12336871,  0.17909902,  0.20013168, -0.13281178], dtype=float32),)
(Array([-0.12336871,  0.17909902,  0.20013168, -0.13281178], dtype=float32),)

(Array([ 0.03179813,  0.06397124,  0.02003829, -0.01477861], dtype=float32),)
(Array([ 0.03179813,  0.06397124,  0.02003829, -0.01477861], dtype=float32),)

(Array([-0.10477601,  0.18947464,  0.14859226, -0.166574  ], dtype=float32),)
(Array([-0.10477601,  0.18947464,  0.14859226, -0.166574  ], dtype=float32),)

(Array([-0.01592578,  0.02917702,  0.09291771,  0.01515842

  0%|▏                                             | 91/25000 [00:03<11:38, 35.66it/s, loss=0.416]

(Array([-0.04154609,  0.02553611,  0.11747816, -0.00193886], dtype=float32),)
(Array([-0.04154609,  0.02553611,  0.11747816, -0.00193886], dtype=float32),)

(Array([-0.01317213,  0.23360427,  0.05283033, -0.172587  ], dtype=float32),)
(Array([-0.01317213,  0.23360427,  0.05283033, -0.172587  ], dtype=float32),)

(Array([-0.01799677,  0.04949827,  0.07336936,  0.01206668], dtype=float32),)
(Array([-0.01799677,  0.04949827,  0.07336936,  0.01206668], dtype=float32),)

(Array([ 0.06326646,  0.14535856, -0.01087808, -0.08336184], dtype=float32),)
(Array([ 0.06326646,  0.14535856, -0.01087808, -0.08336184], dtype=float32),)

(Array([ 0.03961493,  0.07043482,  0.01815709, -0.00305931], dtype=float32),)
(Array([ 0.03961493,  0.07043482,  0.01815709, -0.00305931], dtype=float32),)

(Array([ 0.02199432,  0.14187086,  0.03583474, -0.03816827], dtype=float32),)
(Array([ 0.02199432,  0.14187086,  0.03583474, -0.03816827], dtype=float32),)

(Array([-0.0253452 ,  0.04423784,  0.05812573,  0.01703257

  0%|▏                                             | 99/25000 [00:03<11:15, 36.84it/s, loss=0.435]

(Array([0.03825404, 0.03419912, 0.03124522, 0.0050992 ], dtype=float32),)
(Array([0.03825404, 0.03419912, 0.03124522, 0.0050992 ], dtype=float32),)

(Array([ 0.05112522, -0.00919772,  0.0562861 ,  0.03240674], dtype=float32),)
(Array([ 0.05112522, -0.00919772,  0.0562861 ,  0.03240674], dtype=float32),)

(Array([ 0.05948737,  0.04107622,  0.05971609, -0.0579222 ], dtype=float32),)
(Array([ 0.05948737,  0.04107622,  0.05971609, -0.0579222 ], dtype=float32),)

(Array([ 0.01707978,  0.0323408 ,  0.02068475, -0.00735885], dtype=float32),)
(Array([ 0.01707978,  0.0323408 ,  0.02068475, -0.00735885], dtype=float32),)

(Array([ 0.00184421,  0.07409068,  0.07906868, -0.06461409], dtype=float32),)
(Array([ 0.00184421,  0.07409068,  0.07906868, -0.06461409], dtype=float32),)

(Array([ 0.02934797,  0.06790504,  0.01263546, -0.00237412], dtype=float32),)
(Array([ 0.02934797,  0.06790504,  0.01263546, -0.00237412], dtype=float32),)

(Array([-0.02359412,  0.08614893,  0.08105236, -0.04304416], dtype

  0%|▏                                            | 103/25000 [00:04<22:10, 18.71it/s, loss=0.482]

(Array([ 0.037245  ,  0.04545451,  0.01266224, -0.01491503], dtype=float32),)
(Array([ 0.037245  ,  0.04545451,  0.01266224, -0.01491503], dtype=float32),)

(Array([0.01348783, 0.05327598, 0.04744334, 0.00956526], dtype=float32),)
(Array([0.01348783, 0.05327598, 0.04744334, 0.00956526], dtype=float32),)

(Array([0.05802905, 0.06037642, 0.01861489, 0.00124061], dtype=float32),)
(Array([0.05802905, 0.06037642, 0.01861489, 0.00124061], dtype=float32),)

(Array([-0.02659946,  0.03659729,  0.07638484,  0.02600467], dtype=float32),)
(Array([-0.02659946,  0.03659729,  0.07638484,  0.02600467], dtype=float32),)



  0%|▏                                            | 112/25000 [00:04<15:54, 26.08it/s, loss=0.408]

(Array([0.00812244, 0.01597295, 0.01965786, 0.03428541], dtype=float32),)
(Array([0.00812244, 0.01597295, 0.01965786, 0.03428541], dtype=float32),)

(Array([-0.00182125,  0.01476791,  0.03675835,  0.00818595], dtype=float32),)
(Array([-0.00182125,  0.01476791,  0.03675835,  0.00818595], dtype=float32),)

(Array([0.02836894, 0.00662699, 0.02091054, 0.00095721], dtype=float32),)
(Array([0.02836894, 0.00662699, 0.02091054, 0.00095721], dtype=float32),)

(Array([0.01938152, 0.03211627, 0.03558858, 0.00454169], dtype=float32),)
(Array([0.01938152, 0.03211627, 0.03558858, 0.00454169], dtype=float32),)

(Array([0.04451832, 0.03565741, 0.01381009, 0.00757728], dtype=float32),)
(Array([0.04451832, 0.03565741, 0.01381009, 0.00757728], dtype=float32),)

(Array([0.02350478, 0.03815209, 0.01027637, 0.00112493], dtype=float32),)
(Array([0.02350478, 0.03815209, 0.01027637, 0.00112493], dtype=float32),)

(Array([0.04283221, 0.03932623, 0.02244464, 0.02456516], dtype=float32),)
(Array([0.04283221, 0.03

  0%|▏                                             | 120/25000 [00:04<13:47, 30.05it/s, loss=0.47]

(Array([0.01140428, 0.00759947, 0.03338159, 0.01452711], dtype=float32),)
(Array([0.01140428, 0.00759947, 0.03338159, 0.01452711], dtype=float32),)

(Array([0.00937873, 0.00907322, 0.00828837, 0.02585808], dtype=float32),)
(Array([0.00937873, 0.00907322, 0.00828837, 0.02585808], dtype=float32),)

(Array([0.00176238, 0.01766238, 0.02765678, 0.01061765], dtype=float32),)
(Array([0.00176238, 0.01766238, 0.02765678, 0.01061765], dtype=float32),)

(Array([0.02313406, 0.01450141, 0.02173297, 0.0058008 ], dtype=float32),)
(Array([0.02313406, 0.01450141, 0.02173297, 0.0058008 ], dtype=float32),)

(Array([ 0.00835036,  0.01310409,  0.00821836, -0.02473724], dtype=float32),)
(Array([ 0.00835036,  0.01310409,  0.00821836, -0.02473724], dtype=float32),)

(Array([ 0.02188932,  0.03714126,  0.03338588, -0.0042748 ], dtype=float32),)
(Array([ 0.02188932,  0.03714126,  0.03338588, -0.0042748 ], dtype=float32),)

(Array([ 0.02723997,  0.03444086, -0.00467656, -0.02941744], dtype=float32),)
(Array([ 0.0

  1%|▏                                            | 128/25000 [00:05<12:36, 32.87it/s, loss=0.402]

(Array([ 0.02294179,  0.00469493, -0.00947637,  0.00014955], dtype=float32),)
(Array([ 0.02294179,  0.00469493, -0.00947637,  0.00014955], dtype=float32),)

(Array([ 0.01874356,  0.04507338,  0.04353289, -0.02942711], dtype=float32),)
(Array([ 0.01874356,  0.04507338,  0.04353289, -0.02942711], dtype=float32),)

(Array([0.01992521, 0.00740991, 0.01743058, 0.00078612], dtype=float32),)
(Array([0.01992521, 0.00740991, 0.01743058, 0.00078612], dtype=float32),)

(Array([-0.02530563,  0.05439708,  0.05432124, -0.00694436], dtype=float32),)
(Array([-0.02530563,  0.05439708,  0.05432124, -0.00694436], dtype=float32),)

(Array([-0.02945273,  0.009873  ,  0.03467482, -0.0043363 ], dtype=float32),)
(Array([-0.02945273,  0.009873  ,  0.03467482, -0.0043363 ], dtype=float32),)

(Array([-0.02130864, -0.00132316,  0.02083221, -0.02170694], dtype=float32),)
(Array([-0.02130864, -0.00132316,  0.02083221, -0.02170694], dtype=float32),)

(Array([ 0.01611913,  0.04013431,  0.06898443, -0.05631943], dtype

  1%|▏                                            | 136/25000 [00:05<11:40, 35.52it/s, loss=0.404]

(Array([ 0.04325295,  0.02051018, -0.01516041,  0.01348273], dtype=float32),)
(Array([ 0.04325295,  0.02051018, -0.01516041,  0.01348273], dtype=float32),)

(Array([ 0.01792145,  0.04421653,  0.01404486, -0.01722052], dtype=float32),)
(Array([ 0.01792145,  0.04421653,  0.01404486, -0.01722052], dtype=float32),)

(Array([ 0.04229591,  0.0327012 ,  0.01742286, -0.03326852], dtype=float32),)
(Array([ 0.04229591,  0.0327012 ,  0.01742286, -0.03326852], dtype=float32),)

(Array([0.02535493, 0.02303434, 0.00269672, 0.00866772], dtype=float32),)
(Array([0.02535493, 0.02303434, 0.00269672, 0.00866772], dtype=float32),)

(Array([-0.02019953,  0.00672367,  0.03482536,  0.00147589], dtype=float32),)
(Array([-0.02019953,  0.00672367,  0.03482536,  0.00147589], dtype=float32),)

(Array([ 0.05094157,  0.01761243, -0.00645813,  0.00336959], dtype=float32),)
(Array([ 0.05094157,  0.01761243, -0.00645813,  0.00336959], dtype=float32),)

(Array([ 0.02433881,  0.00935999, -0.00375328,  0.02554352], dtype

  1%|▎                                            | 144/25000 [00:05<11:29, 36.07it/s, loss=0.373]

(Array([0.01205735, 0.01693126, 0.02547831, 0.04386278], dtype=float32),)
(Array([0.01205735, 0.01693126, 0.02547831, 0.04386278], dtype=float32),)

(Array([-0.00130969,  0.00627864,  0.00407555,  0.00393125], dtype=float32),)
(Array([-0.00130969,  0.00627864,  0.00407555,  0.00393125], dtype=float32),)

(Array([0.01255288, 0.01792648, 0.00654381, 0.01071435], dtype=float32),)
(Array([0.01255288, 0.01792648, 0.00654381, 0.01071435], dtype=float32),)

(Array([ 0.00447143,  0.0101815 ,  0.01954514, -0.00319846], dtype=float32),)
(Array([ 0.00447143,  0.0101815 ,  0.01954514, -0.00319846], dtype=float32),)

(Array([-0.01028373,  0.00875663,  0.03039633,  0.01189018], dtype=float32),)
(Array([-0.01028373,  0.00875663,  0.03039633,  0.01189018], dtype=float32),)

(Array([ 0.01295654,  0.00044274, -0.01227301,  0.00558646], dtype=float32),)
(Array([ 0.01295654,  0.00044274, -0.01227301,  0.00558646], dtype=float32),)

(Array([0.01006397, 0.01059358, 0.01454841, 0.0028788 ], dtype=float32),)


  1%|▎                                            | 152/25000 [00:05<11:17, 36.70it/s, loss=0.371]

(Array([ 0.02847956,  0.01959155, -0.01331031, -0.00065159], dtype=float32),)
(Array([ 0.02847956,  0.01959155, -0.01331031, -0.00065159], dtype=float32),)

(Array([ 0.01240742, -0.01484431,  0.00632382, -0.03087572], dtype=float32),)
(Array([ 0.01240742, -0.01484431,  0.00632382, -0.03087572], dtype=float32),)

(Array([ 0.02745897,  0.00607721, -0.05752735, -0.04641236], dtype=float32),)
(Array([ 0.02745897,  0.00607721, -0.05752735, -0.04641236], dtype=float32),)

(Array([ 0.00790554,  0.01997384,  0.02535626, -0.03322611], dtype=float32),)
(Array([ 0.00790554,  0.01997384,  0.02535626, -0.03322611], dtype=float32),)

(Array([ 0.04569816,  0.03523612, -0.02302476, -0.05890051], dtype=float32),)
(Array([ 0.04569816,  0.03523612, -0.02302476, -0.05890051], dtype=float32),)

(Array([-0.03122098,  0.01696814,  0.0302664 , -0.00031512], dtype=float32),)
(Array([-0.03122098,  0.01696814,  0.0302664 , -0.00031512], dtype=float32),)

(Array([ 0.04095746,  0.00298963, -0.01465735, -0.01296955

  1%|▎                                            | 160/25000 [00:05<11:26, 36.18it/s, loss=0.322]

(Array([ 0.01412989,  0.01271217, -0.01248401, -0.02439548], dtype=float32),)
(Array([ 0.01412989,  0.01271217, -0.01248401, -0.02439548], dtype=float32),)

(Array([0.03262875, 0.01513032, 0.0322378 , 0.00639246], dtype=float32),)
(Array([0.03262875, 0.01513032, 0.0322378 , 0.00639246], dtype=float32),)

(Array([-0.00478004, -0.00385507,  0.05232198, -0.01002537], dtype=float32),)
(Array([-0.00478004, -0.00385507,  0.05232198, -0.01002537], dtype=float32),)

(Array([ 0.01345103, -0.02008298,  0.00627412, -0.05002826], dtype=float32),)
(Array([ 0.01345103, -0.02008298,  0.00627412, -0.05002826], dtype=float32),)

(Array([-0.06025014,  0.0103232 ,  0.02339523, -0.08904558], dtype=float32),)
(Array([-0.06025014,  0.0103232 ,  0.02339523, -0.08904558], dtype=float32),)

(Array([ 0.0015104 ,  0.05253044,  0.02581017, -0.03072437], dtype=float32),)
(Array([ 0.0015104 ,  0.05253044,  0.02581017, -0.03072437], dtype=float32),)

(Array([-0.01066728,  0.02211554,  0.02529704, -0.02139613], dtype

  1%|▎                                            | 164/25000 [00:06<11:35, 35.73it/s, loss=0.356]

(Array([-0.05613472,  0.02943935,  0.04955605,  0.00275423], dtype=float32),)
(Array([-0.05613472,  0.02943935,  0.04955605,  0.00275423], dtype=float32),)

(Array([-0.02716395,  0.02821107,  0.02254865, -0.00045311], dtype=float32),)
(Array([-0.02716395,  0.02821107,  0.02254865, -0.00045311], dtype=float32),)

(Array([ 0.00071979,  0.00199423,  0.01197052, -0.00381877], dtype=float32),)
(Array([ 0.00071979,  0.00199423,  0.01197052, -0.00381877], dtype=float32),)

(Array([-0.01033419, -0.02677033,  0.01321635, -0.00265364], dtype=float32),)
(Array([-0.01033419, -0.02677033,  0.01321635, -0.00265364], dtype=float32),)

(Array([-0.0017853 , -0.00825633, -0.00058114, -0.00029801], dtype=float32),)
(Array([-0.0017853 , -0.00825633, -0.00058114, -0.00029801], dtype=float32),)

(Array([ 0.00247103, -0.0406956 ,  0.00042041, -0.01052508], dtype=float32),)
(Array([ 0.00247103, -0.0406956 ,  0.00042041, -0.01052508], dtype=float32),)

(Array([ 0.00012211,  0.01028742, -0.01485624,  0.00549284

  1%|▎                                            | 172/25000 [00:06<11:50, 34.96it/s, loss=0.382]

(Array([ 0.01084576,  0.00965827, -0.00263872, -0.01657926], dtype=float32),)
(Array([ 0.01084576,  0.00965827, -0.00263872, -0.01657926], dtype=float32),)

(Array([ 0.0037045 , -0.02555285,  0.01559096,  0.03087796], dtype=float32),)
(Array([ 0.0037045 , -0.02555285,  0.01559096,  0.03087796], dtype=float32),)

(Array([3.9915638e-03, 3.0763557e-03, 3.2864602e-03, 4.4271932e-05],      dtype=float32),)
(Array([3.9915638e-03, 3.0763557e-03, 3.2864602e-03, 4.4271932e-05],      dtype=float32),)

(Array([-0.00291543, -0.01771883,  0.03379628, -0.02041735], dtype=float32),)
(Array([-0.00291543, -0.01771883,  0.03379628, -0.02041735], dtype=float32),)

(Array([-0.00466188,  0.02029776, -0.02593433, -0.06016943], dtype=float32),)
(Array([-0.00466188,  0.02029776, -0.02593433, -0.06016943], dtype=float32),)

(Array([-0.01367094,  0.00422182,  0.01360309, -0.03254407], dtype=float32),)
(Array([-0.01367094,  0.00422182,  0.01360309, -0.03254407], dtype=float32),)

(Array([-0.02608663,  0.00435047

  1%|▎                                            | 180/25000 [00:06<12:03, 34.32it/s, loss=0.342]

(Array([-0.05825104, -0.0376239 ,  0.05236513,  0.00429942], dtype=float32),)
(Array([-0.05825104, -0.0376239 ,  0.05236513,  0.00429942], dtype=float32),)

(Array([ 0.04378783, -0.02832589,  0.00840962, -0.01069267], dtype=float32),)
(Array([ 0.04378783, -0.02832589,  0.00840962, -0.01069267], dtype=float32),)

(Array([ 0.00283398,  0.01121383, -0.00773481, -0.03553664], dtype=float32),)
(Array([ 0.00283398,  0.01121383, -0.00773481, -0.03553664], dtype=float32),)

(Array([-0.02060121,  0.01625388, -0.00306558, -0.01404216], dtype=float32),)
(Array([-0.02060121,  0.01625388, -0.00306558, -0.01404216], dtype=float32),)

(Array([-0.01292383, -0.01391675, -0.00218424,  0.00200314], dtype=float32),)
(Array([-0.01292383, -0.01391675, -0.00218424,  0.00200314], dtype=float32),)

(Array([ 0.07534817,  0.00480089, -0.01383224, -0.02830691], dtype=float32),)
(Array([ 0.07534817,  0.00480089, -0.01383224, -0.02830691], dtype=float32),)

(Array([-0.01565504, -0.02313818,  0.0067764 ,  0.01400927

  1%|▎                                            | 188/25000 [00:06<11:38, 35.53it/s, loss=0.321]

(Array([ 0.00739407,  0.03043217,  0.02924934, -0.03437103], dtype=float32),)
(Array([ 0.00739407,  0.03043217,  0.02924934, -0.03437103], dtype=float32),)

(Array([-0.00486333, -0.01046973, -0.0053041 ,  0.00056651], dtype=float32),)
(Array([-0.00486333, -0.01046973, -0.0053041 ,  0.00056651], dtype=float32),)

(Array([-0.00109229,  0.01925346,  0.00429221, -0.01975241], dtype=float32),)
(Array([-0.00109229,  0.01925346,  0.00429221, -0.01975241], dtype=float32),)

(Array([-0.01477617, -0.02012207,  0.0059223 ,  0.0152525 ], dtype=float32),)
(Array([-0.01477617, -0.02012207,  0.0059223 ,  0.0152525 ], dtype=float32),)

(Array([ 0.02066464, -0.00277966, -0.01791223,  0.0087889 ], dtype=float32),)
(Array([ 0.02066464, -0.00277966, -0.01791223,  0.0087889 ], dtype=float32),)

(Array([0.03239708, 0.01773723, 0.00051603, 0.00264784], dtype=float32),)
(Array([0.03239708, 0.01773723, 0.00051603, 0.00264784], dtype=float32),)

(Array([ 0.01441086,  0.01630066,  0.02086461, -0.00789335], dtype

  1%|▎                                            | 196/25000 [00:06<11:04, 37.34it/s, loss=0.306]

(Array([-0.00520729,  0.02167954, -0.00250411, -0.0075949 ], dtype=float32),)
(Array([-0.00520729,  0.02167954, -0.00250411, -0.0075949 ], dtype=float32),)

(Array([-0.01274579, -0.02870417,  0.00407374,  0.0283695 ], dtype=float32),)
(Array([-0.01274579, -0.02870417,  0.00407374,  0.0283695 ], dtype=float32),)

(Array([-0.03045315, -0.01527232,  0.03094282,  0.04360756], dtype=float32),)
(Array([-0.03045315, -0.01527232,  0.03094282,  0.04360756], dtype=float32),)

(Array([ 0.00953567, -0.0262489 , -0.00487518, -0.04519109], dtype=float32),)
(Array([ 0.00953567, -0.0262489 , -0.00487518, -0.04519109], dtype=float32),)

(Array([-0.02178628, -0.025028  , -0.00862117, -0.04319485], dtype=float32),)
(Array([-0.02178628, -0.025028  , -0.00862117, -0.04319485], dtype=float32),)

(Array([ 0.02156236, -0.00394671,  0.01551728,  0.0073252 ], dtype=float32),)
(Array([ 0.02156236, -0.00394671,  0.01551728,  0.0073252 ], dtype=float32),)

(Array([ 0.01754138,  0.00667944, -0.01762026, -0.00822134

  1%|▎                                            | 200/25000 [00:07<14:48, 27.92it/s, loss=0.299]

(Array([-0.01655638, -0.01027978, -0.00677114,  0.00723043], dtype=float32),)
(Array([-0.01655638, -0.01027978, -0.00677114,  0.00723043], dtype=float32),)

(Array([-0.0361855 ,  0.00556207,  0.01070607, -0.00286819], dtype=float32),)
(Array([-0.0361855 ,  0.00556207,  0.01070607, -0.00286819], dtype=float32),)

(Array([ 0.02043517, -0.03088197, -0.02782365,  0.00322442], dtype=float32),)
(Array([ 0.02043517, -0.03088197, -0.02782365,  0.00322442], dtype=float32),)



  1%|▎                                            | 208/25000 [00:07<12:58, 31.85it/s, loss=0.275]

(Array([-0.00620568, -0.03344737, -0.02274173, -0.04531056], dtype=float32),)
(Array([-0.00620568, -0.03344737, -0.02274173, -0.04531056], dtype=float32),)

(Array([ 0.01558117, -0.00042129, -0.01680222, -0.01846076], dtype=float32),)
(Array([ 0.01558117, -0.00042129, -0.01680222, -0.01846076], dtype=float32),)

(Array([-0.01261512, -0.02727126, -0.01444248,  0.03434306], dtype=float32),)
(Array([-0.01261512, -0.02727126, -0.01444248,  0.03434306], dtype=float32),)

(Array([ 0.00257787, -0.00517062, -0.00926578, -0.00274538], dtype=float32),)
(Array([ 0.00257787, -0.00517062, -0.00926578, -0.00274538], dtype=float32),)

(Array([ 0.03740507, -0.07034831, -0.05677281, -0.08653496], dtype=float32),)
(Array([ 0.03740507, -0.07034831, -0.05677281, -0.08653496], dtype=float32),)

(Array([-0.03946202, -0.01857124, -0.02615815,  0.01398953], dtype=float32),)
(Array([-0.03946202, -0.01857124, -0.02615815,  0.01398953], dtype=float32),)

(Array([ 0.03205126,  0.02834951, -0.02188337, -0.05345973

  1%|▍                                            | 216/25000 [00:07<12:00, 34.39it/s, loss=0.324]

(Array([ 0.01191242, -0.0085957 , -0.01729689,  0.01789567], dtype=float32),)
(Array([ 0.01191242, -0.0085957 , -0.01729689,  0.01789567], dtype=float32),)

(Array([-0.0183435 ,  0.0068002 , -0.02147951, -0.04989855], dtype=float32),)
(Array([-0.0183435 ,  0.0068002 , -0.02147951, -0.04989855], dtype=float32),)

(Array([-0.02993506,  0.05221108,  0.05366007,  0.02041424], dtype=float32),)
(Array([-0.02993506,  0.05221108,  0.05366007,  0.02041424], dtype=float32),)

(Array([-0.02371725, -0.04135685, -0.00058968,  0.08369254], dtype=float32),)
(Array([-0.02371725, -0.04135685, -0.00058968,  0.08369254], dtype=float32),)

(Array([-0.04494043, -0.01847875,  0.0354197 , -0.01577751], dtype=float32),)
(Array([-0.04494043, -0.01847875,  0.0354197 , -0.01577751], dtype=float32),)

(Array([ 0.003849  , -0.07005119,  0.00562123, -0.03997434], dtype=float32),)
(Array([ 0.003849  , -0.07005119,  0.00562123, -0.03997434], dtype=float32),)

(Array([-0.02371205,  0.01209054, -0.00538641, -0.01656658

  1%|▍                                              | 224/25000 [00:07<11:40, 35.35it/s, loss=0.3]

(Array([ 0.03075469,  0.02983472, -0.02231256, -0.01959957], dtype=float32),)
(Array([ 0.03075469,  0.02983472, -0.02231256, -0.01959957], dtype=float32),)

(Array([ 0.02159525, -0.00289273, -0.01456275, -0.00625544], dtype=float32),)
(Array([ 0.02159525, -0.00289273, -0.01456275, -0.00625544], dtype=float32),)

(Array([-0.00293518, -0.00550898, -0.02414544, -0.00024506], dtype=float32),)
(Array([-0.00293518, -0.00550898, -0.02414544, -0.00024506], dtype=float32),)

(Array([ 0.01677223,  0.04630069, -0.00224056, -0.02490811], dtype=float32),)
(Array([ 0.01677223,  0.04630069, -0.00224056, -0.02490811], dtype=float32),)

(Array([ 0.0164578 ,  0.00135271, -0.01604562, -0.00133365], dtype=float32),)
(Array([ 0.0164578 ,  0.00135271, -0.01604562, -0.00133365], dtype=float32),)

(Array([ 0.01182535,  0.02412796,  0.01374552, -0.01556176], dtype=float32),)
(Array([ 0.01182535,  0.02412796,  0.01374552, -0.01556176], dtype=float32),)

(Array([ 0.03059523,  0.0109677 , -0.02235654, -0.00977184

  1%|▍                                            | 232/25000 [00:08<11:49, 34.91it/s, loss=0.282]

(Array([-0.00418156,  0.01422876,  0.01116344, -0.0117943 ], dtype=float32),)
(Array([-0.00418156,  0.01422876,  0.01116344, -0.0117943 ], dtype=float32),)

(Array([ 0.00174768,  0.01206258,  0.01926392, -0.00391965], dtype=float32),)
(Array([ 0.00174768,  0.01206258,  0.01926392, -0.00391965], dtype=float32),)

(Array([ 0.02587368,  0.02632316, -0.00782861, -0.0283472 ], dtype=float32),)
(Array([ 0.02587368,  0.02632316, -0.00782861, -0.0283472 ], dtype=float32),)

(Array([-0.00850113,  0.01310399,  0.02399201, -0.01474383], dtype=float32),)
(Array([-0.00850113,  0.01310399,  0.02399201, -0.01474383], dtype=float32),)

(Array([ 0.00887883,  0.01017972, -0.00241636, -0.01795115], dtype=float32),)
(Array([ 0.00887883,  0.01017972, -0.00241636, -0.01795115], dtype=float32),)

(Array([0.03318997, 0.00223088, 0.00044338, 0.01486807], dtype=float32),)
(Array([0.03318997, 0.00223088, 0.00044338, 0.01486807], dtype=float32),)

(Array([-0.00446421, -0.01238414,  0.01070333,  0.01218306], dtype

  1%|▍                                            | 240/25000 [00:08<11:27, 36.03it/s, loss=0.333]

(Array([-0.00974458,  0.00556266,  0.00471205,  0.00915562], dtype=float32),)
(Array([-0.00974458,  0.00556266,  0.00471205,  0.00915562], dtype=float32),)

(Array([-0.0185463 , -0.02519409, -0.01762133,  0.00643789], dtype=float32),)
(Array([-0.0185463 , -0.02519409, -0.01762133,  0.00643789], dtype=float32),)

(Array([ 0.00974915,  0.00493622, -0.00502105, -0.00053744], dtype=float32),)
(Array([ 0.00974915,  0.00493622, -0.00502105, -0.00053744], dtype=float32),)

(Array([ 0.01737641, -0.00201567, -0.00364776, -0.02897724], dtype=float32),)
(Array([ 0.01737641, -0.00201567, -0.00364776, -0.02897724], dtype=float32),)

(Array([ 0.00793461, -0.00175509,  0.01006106,  0.0111961 ], dtype=float32),)
(Array([ 0.00793461, -0.00175509,  0.01006106,  0.0111961 ], dtype=float32),)

(Array([-0.00078969, -0.00214724,  0.0089338 , -0.01397507], dtype=float32),)
(Array([-0.00078969, -0.00214724,  0.0089338 , -0.01397507], dtype=float32),)

(Array([ 0.00626166, -0.02118837, -0.00082466, -0.01951103

  1%|▍                                            | 248/25000 [00:08<11:10, 36.93it/s, loss=0.281]

(Array([ 0.00152663, -0.00261915,  0.00689335, -0.01814926], dtype=float32),)
(Array([ 0.00152663, -0.00261915,  0.00689335, -0.01814926], dtype=float32),)

(Array([ 0.01371676,  0.02239313, -0.0240428 , -0.02864706], dtype=float32),)
(Array([ 0.01371676,  0.02239313, -0.0240428 , -0.02864706], dtype=float32),)

(Array([ 0.00326898, -0.00459318,  0.00557662,  0.00441295], dtype=float32),)
(Array([ 0.00326898, -0.00459318,  0.00557662,  0.00441295], dtype=float32),)

(Array([-0.00814889, -0.01343534,  0.00580222,  0.01406319], dtype=float32),)
(Array([-0.00814889, -0.01343534,  0.00580222,  0.01406319], dtype=float32),)

(Array([-0.00471444,  0.00917851,  0.00670929, -0.02318126], dtype=float32),)
(Array([-0.00471444,  0.00917851,  0.00670929, -0.02318126], dtype=float32),)

(Array([-0.01374902, -0.01007558, -0.01843717, -0.02553248], dtype=float32),)
(Array([-0.01374902, -0.01007558, -0.01843717, -0.02553248], dtype=float32),)

(Array([-0.00357195,  0.00934329,  0.02089019, -0.01201136

  1%|▍                                            | 256/25000 [00:08<10:55, 37.77it/s, loss=0.299]

(Array([-0.01819409, -0.00584883,  0.05890614, -0.03796316], dtype=float32),)
(Array([-0.01819409, -0.00584883,  0.05890614, -0.03796316], dtype=float32),)

(Array([ 0.00738543, -0.02661771,  0.03579157,  0.00052224], dtype=float32),)
(Array([ 0.00738543, -0.02661771,  0.03579157,  0.00052224], dtype=float32),)

(Array([ 0.0054155 , -0.01138908,  0.01047421,  0.03701453], dtype=float32),)
(Array([ 0.0054155 , -0.01138908,  0.01047421,  0.03701453], dtype=float32),)

(Array([-0.00683602, -0.02491909, -0.00142351, -0.01752304], dtype=float32),)
(Array([-0.00683602, -0.02491909, -0.00142351, -0.01752304], dtype=float32),)

(Array([ 0.01611957, -0.07862701,  0.01610375, -0.0317023 ], dtype=float32),)
(Array([ 0.01611957, -0.07862701,  0.01610375, -0.0317023 ], dtype=float32),)

(Array([ 0.02263166, -0.00488809, -0.01877541, -0.00138534], dtype=float32),)
(Array([ 0.02263166, -0.00488809, -0.01877541, -0.00138534], dtype=float32),)

(Array([ 0.01645084,  0.01267621, -0.00472615,  0.01484532

  1%|▍                                              | 264/25000 [00:08<10:47, 38.18it/s, loss=0.3]

(Array([ 0.00313753, -0.01673514,  0.00678889, -0.0299629 ], dtype=float32),)
(Array([ 0.00313753, -0.01673514,  0.00678889, -0.0299629 ], dtype=float32),)

(Array([ 0.0220079 ,  0.00385193,  0.00471511, -0.00297973], dtype=float32),)
(Array([ 0.0220079 ,  0.00385193,  0.00471511, -0.00297973], dtype=float32),)

(Array([-0.00775318, -0.03149482, -0.01088701,  0.0163438 ], dtype=float32),)
(Array([-0.00775318, -0.03149482, -0.01088701,  0.0163438 ], dtype=float32),)

(Array([-0.00921036, -0.02516146, -0.00519696, -0.00289827], dtype=float32),)
(Array([-0.00921036, -0.02516146, -0.00519696, -0.00289827], dtype=float32),)

(Array([ 0.0079366 , -0.00518369,  0.01091971,  0.00304468], dtype=float32),)
(Array([ 0.0079366 , -0.00518369,  0.01091971,  0.00304468], dtype=float32),)

(Array([0.01123819, 0.00049234, 0.02908169, 0.02472462], dtype=float32),)
(Array([0.01123819, 0.00049234, 0.02908169, 0.02472462], dtype=float32),)

(Array([-0.02546947, -0.02655569, -0.02502274, -0.01914463], dtype

  1%|▌                                             | 272/25000 [00:09<11:09, 36.92it/s, loss=0.29]

(Array([-0.01442153, -0.01648508, -0.00850379,  0.02190433], dtype=float32),)
(Array([-0.01442153, -0.01648508, -0.00850379,  0.02190433], dtype=float32),)

(Array([ 0.05157037,  0.00596649, -0.04504896, -0.06559294], dtype=float32),)
(Array([ 0.05157037,  0.00596649, -0.04504896, -0.06559294], dtype=float32),)

(Array([ 0.02387185,  0.00927416, -0.00572353, -0.01278138], dtype=float32),)
(Array([ 0.02387185,  0.00927416, -0.00572353, -0.01278138], dtype=float32),)

(Array([-0.02364005,  0.01696377, -0.00862619,  0.01824309], dtype=float32),)
(Array([-0.02364005,  0.01696377, -0.00862619,  0.01824309], dtype=float32),)

(Array([-0.01147209, -0.00419588,  0.01233939,  0.00459562], dtype=float32),)
(Array([-0.01147209, -0.00419588,  0.01233939,  0.00459562], dtype=float32),)

(Array([ 0.0017709 , -0.01713502, -0.00660535, -0.00981248], dtype=float32),)
(Array([ 0.0017709 , -0.01713502, -0.00660535, -0.00981248], dtype=float32),)

(Array([ 0.00448311,  0.01150604, -0.00067428,  0.00768341

  1%|▌                                            | 280/25000 [00:09<11:22, 36.20it/s, loss=0.295]

(Array([-0.02089249, -0.02467621,  0.0076208 , -0.00787594], dtype=float32),)
(Array([-0.02089249, -0.02467621,  0.0076208 , -0.00787594], dtype=float32),)

(Array([-0.00490955,  0.01269809,  0.0010569 ,  0.00886001], dtype=float32),)
(Array([-0.00490955,  0.01269809,  0.0010569 ,  0.00886001], dtype=float32),)

(Array([-0.00171369,  0.01032074,  0.00435266,  0.00255416], dtype=float32),)
(Array([-0.00171369,  0.01032074,  0.00435266,  0.00255416], dtype=float32),)

(Array([-0.00745234,  0.00562512, -0.02216074,  0.00642467], dtype=float32),)
(Array([-0.00745234,  0.00562512, -0.02216074,  0.00642467], dtype=float32),)

(Array([-0.01039844, -0.01130211, -0.01177885,  0.02680226], dtype=float32),)
(Array([-0.01039844, -0.01130211, -0.01177885,  0.02680226], dtype=float32),)

(Array([ 0.01154139,  0.00719302, -0.00808685, -0.0152988 ], dtype=float32),)
(Array([ 0.01154139,  0.00719302, -0.00808685, -0.0152988 ], dtype=float32),)

(Array([-0.00959389,  0.01204425,  0.02306409,  0.0097778 

  1%|▌                                             | 288/25000 [00:09<11:11, 36.78it/s, loss=0.28]

(Array([ 0.03161713, -0.04113017,  0.03895163,  0.01807322], dtype=float32),)
(Array([ 0.03161713, -0.04113017,  0.03895163,  0.01807322], dtype=float32),)

(Array([ 0.00843006, -0.01842071, -0.00828693,  0.00950004], dtype=float32),)
(Array([ 0.00843006, -0.01842071, -0.00828693,  0.00950004], dtype=float32),)

(Array([-0.01840535,  0.00438943,  0.01059456, -0.02480782], dtype=float32),)
(Array([-0.01840535,  0.00438943,  0.01059456, -0.02480782], dtype=float32),)

(Array([ 0.00525141,  0.01150358, -0.04643721,  0.00664917], dtype=float32),)
(Array([ 0.00525141,  0.01150358, -0.04643721,  0.00664917], dtype=float32),)

(Array([-0.04251546,  0.07460679,  0.02359704, -0.02839182], dtype=float32),)
(Array([-0.04251546,  0.07460679,  0.02359704, -0.02839182], dtype=float32),)

(Array([-0.06142074, -0.04737058,  0.01316664,  0.02237551], dtype=float32),)
(Array([-0.06142074, -0.04737058,  0.01316664,  0.02237551], dtype=float32),)

(Array([ 0.0050436 , -0.00319712, -0.00382701, -0.00724481

  1%|▌                                            | 296/25000 [00:09<10:52, 37.87it/s, loss=0.268]

(Array([ 0.00016411, -0.02568041,  0.01984005,  0.01877444], dtype=float32),)
(Array([ 0.00016411, -0.02568041,  0.01984005,  0.01877444], dtype=float32),)

(Array([ 0.0115875 ,  0.00800242,  0.00317998, -0.00625664], dtype=float32),)
(Array([ 0.0115875 ,  0.00800242,  0.00317998, -0.00625664], dtype=float32),)

(Array([-0.00260729, -0.00301904,  0.00025966, -0.0248655 ], dtype=float32),)
(Array([-0.00260729, -0.00301904,  0.00025966, -0.0248655 ], dtype=float32),)

(Array([ 0.00429892, -0.01481764, -0.02195029, -0.0036426 ], dtype=float32),)
(Array([ 0.00429892, -0.01481764, -0.02195029, -0.0036426 ], dtype=float32),)

(Array([-0.01936794,  0.00442141,  0.00353664,  0.0061649 ], dtype=float32),)
(Array([-0.01936794,  0.00442141,  0.00353664,  0.0061649 ], dtype=float32),)

(Array([ 0.00672681,  0.00170354,  0.00383245, -0.00903918], dtype=float32),)
(Array([ 0.00672681,  0.00170354,  0.00383245, -0.00903918], dtype=float32),)

(Array([ 0.00720639, -0.01684771,  0.00527823, -0.01739549

  1%|▌                                            | 300/25000 [00:10<14:16, 28.85it/s, loss=0.261]

(Array([ 0.01541083,  0.01400772, -0.00667902, -0.03224083], dtype=float32),)
(Array([ 0.01541083,  0.01400772, -0.00667902, -0.03224083], dtype=float32),)

(Array([-0.01167191,  0.00870248,  0.03103659, -0.00094307], dtype=float32),)
(Array([-0.01167191,  0.00870248,  0.03103659, -0.00094307], dtype=float32),)

(Array([-0.02483102,  0.04181327,  0.01859936, -0.02288404], dtype=float32),)
(Array([-0.02483102,  0.04181327,  0.01859936, -0.02288404], dtype=float32),)

(Array([-0.02107096, -0.00828406,  0.0200906 , -0.00319632], dtype=float32),)
(Array([-0.02107096, -0.00828406,  0.0200906 , -0.00319632], dtype=float32),)



  1%|▌                                            | 309/25000 [00:10<12:41, 32.43it/s, loss=0.262]

(Array([ 0.00453013,  0.01137989, -0.00188692, -0.03357133], dtype=float32),)
(Array([ 0.00453013,  0.01137989, -0.00188692, -0.03357133], dtype=float32),)

(Array([ 0.00378246,  0.00328512,  0.02819441, -0.00049572], dtype=float32),)
(Array([ 0.00378246,  0.00328512,  0.02819441, -0.00049572], dtype=float32),)

(Array([ 0.0013815 ,  0.00039448, -0.01041334,  0.00417848], dtype=float32),)
(Array([ 0.0013815 ,  0.00039448, -0.01041334,  0.00417848], dtype=float32),)

(Array([-0.0066003 ,  0.01754826,  0.00686753, -0.00442069], dtype=float32),)
(Array([-0.0066003 ,  0.01754826,  0.00686753, -0.00442069], dtype=float32),)

(Array([ 0.01898921, -0.00032446, -0.00450764,  0.00569108], dtype=float32),)
(Array([ 0.01898921, -0.00032446, -0.00450764,  0.00569108], dtype=float32),)

(Array([-0.01472   , -0.00362152, -0.00946505, -0.00677882], dtype=float32),)
(Array([-0.01472   , -0.00362152, -0.00946505, -0.00677882], dtype=float32),)

(Array([ 0.01496891,  0.0039848 ,  0.01197489, -0.02804929

  1%|▌                                            | 313/25000 [00:10<12:00, 34.26it/s, loss=0.257]

(Array([ 0.00981424, -0.0202113 , -0.00427897,  0.02024966], dtype=float32),)
(Array([ 0.00981424, -0.0202113 , -0.00427897,  0.02024966], dtype=float32),)

(Array([ 0.01231542, -0.0043438 ,  0.01194779, -0.03063349], dtype=float32),)
(Array([ 0.01231542, -0.0043438 ,  0.01194779, -0.03063349], dtype=float32),)

(Array([ 0.01836622, -0.04077922,  0.00030241, -0.0128588 ], dtype=float32),)
(Array([ 0.01836622, -0.04077922,  0.00030241, -0.0128588 ], dtype=float32),)

(Array([-0.01286154,  0.00215467, -0.02675651, -0.01071597], dtype=float32),)
(Array([-0.01286154,  0.00215467, -0.02675651, -0.01071597], dtype=float32),)

(Array([ 0.00591327, -0.00814157,  0.01358914, -0.00746913], dtype=float32),)
(Array([ 0.00591327, -0.00814157,  0.01358914, -0.00746913], dtype=float32),)

(Array([-0.02489516, -0.00316607,  0.0038718 ,  0.02201012], dtype=float32),)
(Array([-0.02489516, -0.00316607,  0.0038718 ,  0.02201012], dtype=float32),)

(Array([-0.03483398, -0.03984204, -0.01375951,  0.01010742

  1%|▌                                            | 321/25000 [00:10<12:30, 32.87it/s, loss=0.297]

(Array([-0.03711651, -0.02654606, -0.02187706, -0.02732895], dtype=float32),)
(Array([-0.03711651, -0.02654606, -0.02187706, -0.02732895], dtype=float32),)

(Array([-0.00788818, -0.01528888, -0.01161627,  0.00723328], dtype=float32),)
(Array([-0.00788818, -0.01528888, -0.01161627,  0.00723328], dtype=float32),)

(Array([ 0.01851318,  0.00431244, -0.01016842, -0.0036223 ], dtype=float32),)
(Array([ 0.01851318,  0.00431244, -0.01016842, -0.0036223 ], dtype=float32),)

(Array([-0.03067156, -0.00457546, -0.02381047,  0.01819757], dtype=float32),)
(Array([-0.03067156, -0.00457546, -0.02381047,  0.01819757], dtype=float32),)

(Array([ 0.01430678,  0.02369994,  0.00357326, -0.0284356 ], dtype=float32),)
(Array([ 0.01430678,  0.02369994,  0.00357326, -0.0284356 ], dtype=float32),)

(Array([-0.00569995, -0.00292511, -0.01260672, -0.00976049], dtype=float32),)
(Array([-0.00569995, -0.00292511, -0.01260672, -0.00976049], dtype=float32),)

(Array([ 0.0108144 , -0.00552718,  0.00479425,  0.00817291

  1%|▌                                             | 329/25000 [00:10<12:03, 34.08it/s, loss=0.26]

(Array([-0.00069697, -0.02196881, -0.01366266,  0.00063399], dtype=float32),)
(Array([-0.00069697, -0.02196881, -0.01366266,  0.00063399], dtype=float32),)

(Array([-0.00657834, -0.00895426,  0.00065846, -0.01424834], dtype=float32),)
(Array([-0.00657834, -0.00895426,  0.00065846, -0.01424834], dtype=float32),)

(Array([-0.01820242, -0.01097167, -0.00584834,  0.01626286], dtype=float32),)
(Array([-0.01820242, -0.01097167, -0.00584834,  0.01626286], dtype=float32),)

(Array([ 0.00438309,  0.01458078,  0.01742935, -0.00203452], dtype=float32),)
(Array([ 0.00438309,  0.01458078,  0.01742935, -0.00203452], dtype=float32),)

(Array([-0.01144928, -0.00325316,  0.01580358,  0.01770778], dtype=float32),)
(Array([-0.01144928, -0.00325316,  0.01580358,  0.01770778], dtype=float32),)

(Array([ 0.00036352,  0.00838869,  0.00944359, -0.0030586 ], dtype=float32),)
(Array([ 0.00036352,  0.00838869,  0.00944359, -0.0030586 ], dtype=float32),)

(Array([-0.01118598, -0.01234133,  0.00373405, -0.01281322

  1%|▌                                            | 337/25000 [00:11<11:22, 36.14it/s, loss=0.297]

(Array([-0.00241819, -0.00518983,  0.004244  ,  0.0175852 ], dtype=float32),)
(Array([-0.00241819, -0.00518983,  0.004244  ,  0.0175852 ], dtype=float32),)

(Array([-0.00289615, -0.00051753,  0.01586543, -0.01485808], dtype=float32),)
(Array([-0.00289615, -0.00051753,  0.01586543, -0.01485808], dtype=float32),)

(Array([-0.00299281, -0.00092192,  0.01817265, -0.00863467], dtype=float32),)
(Array([-0.00299281, -0.00092192,  0.01817265, -0.00863467], dtype=float32),)

(Array([-0.02158586, -0.0068625 , -0.00652875, -0.00305022], dtype=float32),)
(Array([-0.02158586, -0.0068625 , -0.00652875, -0.00305022], dtype=float32),)

(Array([ 0.0071645 ,  0.00292129, -0.00071698, -0.00347322], dtype=float32),)
(Array([ 0.0071645 ,  0.00292129, -0.00071698, -0.00347322], dtype=float32),)

(Array([ 0.00258986, -0.00466863,  0.01233749, -0.01700708], dtype=float32),)
(Array([ 0.00258986, -0.00466863,  0.01233749, -0.01700708], dtype=float32),)

(Array([ 0.02194615, -0.01252232,  0.00174582, -0.00739105

  1%|▌                                            | 346/25000 [00:11<10:47, 38.09it/s, loss=0.227]

(Array([-0.00609068,  0.0204443 ,  0.02023922,  0.00097817], dtype=float32),)
(Array([-0.00609068,  0.0204443 ,  0.02023922,  0.00097817], dtype=float32),)

(Array([0.00162151, 0.00361584, 0.0289315 , 0.00424692], dtype=float32),)
(Array([0.00162151, 0.00361584, 0.0289315 , 0.00424692], dtype=float32),)

(Array([-0.01731557, -0.02670115,  0.01285428,  0.00037753], dtype=float32),)
(Array([-0.01731557, -0.02670115,  0.01285428,  0.00037753], dtype=float32),)

(Array([-0.0081867 , -0.02656835, -0.00380736, -0.00178332], dtype=float32),)
(Array([-0.0081867 , -0.02656835, -0.00380736, -0.00178332], dtype=float32),)

(Array([-0.00584748,  0.00117937,  0.02623443, -0.01936707], dtype=float32),)
(Array([-0.00584748,  0.00117937,  0.02623443, -0.01936707], dtype=float32),)

(Array([ 0.0059198 , -0.00153206,  0.00783919,  0.02133603], dtype=float32),)
(Array([ 0.0059198 , -0.00153206,  0.00783919,  0.02133603], dtype=float32),)

(Array([-0.01223496, -0.00772041, -0.00731517,  0.00091731], dtype

  1%|▋                                            | 355/25000 [00:11<10:29, 39.18it/s, loss=0.288]

(Array([-0.0174078 , -0.01963904, -0.01017793, -0.00266428], dtype=float32),)
(Array([-0.0174078 , -0.01963904, -0.01017793, -0.00266428], dtype=float32),)

(Array([-0.00587899, -0.00164421, -0.01155873,  0.01952974], dtype=float32),)
(Array([-0.00587899, -0.00164421, -0.01155873,  0.01952974], dtype=float32),)

(Array([-0.00324754,  0.01572626, -0.01650208, -0.00312651], dtype=float32),)
(Array([-0.00324754,  0.01572626, -0.01650208, -0.00312651], dtype=float32),)

(Array([ 0.03786386, -0.04291899,  0.03344206, -0.01489852], dtype=float32),)
(Array([ 0.03786386, -0.04291899,  0.03344206, -0.01489852], dtype=float32),)

(Array([ 0.07149912, -0.02853113,  0.02167061, -0.02640182], dtype=float32),)
(Array([ 0.07149912, -0.02853113,  0.02167061, -0.02640182], dtype=float32),)

(Array([ 0.018293  , -0.0224121 ,  0.02119228,  0.06036052], dtype=float32),)
(Array([ 0.018293  , -0.0224121 ,  0.02119228,  0.06036052], dtype=float32),)

(Array([ 0.00119243, -0.01173441, -0.01909638,  0.0329896 

  1%|▋                                             | 364/25000 [00:11<10:18, 39.85it/s, loss=0.26]

(Array([ 0.06202793, -0.04986091, -0.0024046 ,  0.01800845], dtype=float32),)
(Array([ 0.06202793, -0.04986091, -0.0024046 ,  0.01800845], dtype=float32),)

(Array([ 0.05338489, -0.01565171, -0.02235496, -0.0058541 ], dtype=float32),)
(Array([ 0.05338489, -0.01565171, -0.02235496, -0.0058541 ], dtype=float32),)

(Array([ 0.00830432,  0.00884594, -0.02362924, -0.00903436], dtype=float32),)
(Array([ 0.00830432,  0.00884594, -0.02362924, -0.00903436], dtype=float32),)

(Array([-0.01280086,  0.01073863, -0.01724285,  0.01508791], dtype=float32),)
(Array([-0.01280086,  0.01073863, -0.01724285,  0.01508791], dtype=float32),)

(Array([ 0.00599638,  0.00210158, -0.00010679,  0.00509634], dtype=float32),)
(Array([ 0.00599638,  0.00210158, -0.00010679,  0.00509634], dtype=float32),)

(Array([ 0.00427495, -0.00910689, -0.01277504, -0.01153504], dtype=float32),)
(Array([ 0.00427495, -0.00910689, -0.01277504, -0.01153504], dtype=float32),)

(Array([-0.01759919, -0.02046736, -0.0044461 , -0.00263338

  1%|▋                                            | 372/25000 [00:11<10:19, 39.77it/s, loss=0.235]

(Array([ 0.00326175,  0.01480093,  0.00825959, -0.0081725 ], dtype=float32),)
(Array([ 0.00326175,  0.01480093,  0.00825959, -0.0081725 ], dtype=float32),)

(Array([ 0.00886971,  0.01615575,  0.00255911, -0.00808135], dtype=float32),)
(Array([ 0.00886971,  0.01615575,  0.00255911, -0.00808135], dtype=float32),)

(Array([ 0.00237805, -0.00095567,  0.01533551, -0.00257541], dtype=float32),)
(Array([ 0.00237805, -0.00095567,  0.01533551, -0.00257541], dtype=float32),)

(Array([0.00782844, 0.00800482, 0.00828055, 0.0016581 ], dtype=float32),)
(Array([0.00782844, 0.00800482, 0.00828055, 0.0016581 ], dtype=float32),)

(Array([-0.00288943, -0.00090004,  0.00221204,  0.01627531], dtype=float32),)
(Array([-0.00288943, -0.00090004,  0.00221204,  0.01627531], dtype=float32),)

(Array([-0.00632463, -0.01183636,  0.01236867,  0.00719311], dtype=float32),)
(Array([-0.00632463, -0.01183636,  0.01236867,  0.00719311], dtype=float32),)

(Array([-0.01811966,  0.00975939, -0.00771204, -0.03468984], dtype

  2%|▋                                             | 381/25000 [00:12<10:20, 39.70it/s, loss=0.26]

(Array([ 0.01843296, -0.01869106, -0.0493873 ,  0.00515002], dtype=float32),)
(Array([ 0.01843296, -0.01869106, -0.0493873 ,  0.00515002], dtype=float32),)

(Array([0.00135071, 0.03107638, 0.00603916, 0.01448736], dtype=float32),)
(Array([0.00135071, 0.03107638, 0.00603916, 0.01448736], dtype=float32),)

(Array([ 0.00146925, -0.00434695, -0.00211713, -0.00019735], dtype=float32),)
(Array([ 0.00146925, -0.00434695, -0.00211713, -0.00019735], dtype=float32),)

(Array([-0.01176846, -0.00059027,  0.02852223, -0.02644858], dtype=float32),)
(Array([-0.01176846, -0.00059027,  0.02852223, -0.02644858], dtype=float32),)

(Array([-0.01329137,  0.03256723, -0.03254146, -0.01112831], dtype=float32),)
(Array([-0.01329137,  0.03256723, -0.03254146, -0.01112831], dtype=float32),)

(Array([-0.01628377,  0.00974668, -0.00936976, -0.01038926], dtype=float32),)
(Array([-0.01628377,  0.00974668, -0.00936976, -0.01038926], dtype=float32),)

(Array([ 0.00293627, -0.00208385,  0.00116084, -0.00573803], dtype

  2%|▋                                            | 390/25000 [00:12<10:18, 39.82it/s, loss=0.256]

(Array([ 0.01593954,  0.05278697, -0.00099019, -0.01989599], dtype=float32),)
(Array([ 0.01593954,  0.05278697, -0.00099019, -0.01989599], dtype=float32),)

(Array([-0.00640866, -0.00244468, -0.01411918,  0.00993454], dtype=float32),)
(Array([-0.00640866, -0.00244468, -0.01411918,  0.00993454], dtype=float32),)

(Array([0.00464476, 0.0160177 , 0.0158982 , 0.00276031], dtype=float32),)
(Array([0.00464476, 0.0160177 , 0.0158982 , 0.00276031], dtype=float32),)

(Array([ 0.00464646,  0.01854614, -0.00022139, -0.00243069], dtype=float32),)
(Array([ 0.00464646,  0.01854614, -0.00022139, -0.00243069], dtype=float32),)

(Array([-0.0398876 ,  0.01923274,  0.00163637, -0.00924004], dtype=float32),)
(Array([-0.0398876 ,  0.01923274,  0.00163637, -0.00924004], dtype=float32),)

(Array([-0.00068883,  0.01305163,  0.0041808 , -0.00253485], dtype=float32),)
(Array([-0.00068883,  0.01305163,  0.0041808 , -0.00253485], dtype=float32),)

(Array([-0.01383088,  0.01440795, -0.00046825,  0.02181154], dtype

  2%|▋                                            | 395/25000 [00:12<10:15, 39.95it/s, loss=0.267]

(Array([ 0.04262122, -0.01096539,  0.05372253, -0.02884105], dtype=float32),)
(Array([ 0.04262122, -0.01096539,  0.05372253, -0.02884105], dtype=float32),)

(Array([ 0.00160113, -0.0442327 ,  0.00575164,  0.05748297], dtype=float32),)
(Array([ 0.00160113, -0.0442327 ,  0.00575164,  0.05748297], dtype=float32),)

(Array([-0.00075136,  0.00304369,  0.00823222, -0.00704312], dtype=float32),)
(Array([-0.00075136,  0.00304369,  0.00823222, -0.00704312], dtype=float32),)

(Array([-0.0005837 , -0.0358405 ,  0.02408287, -0.01851583], dtype=float32),)
(Array([-0.0005837 , -0.0358405 ,  0.02408287, -0.01851583], dtype=float32),)

(Array([-0.00709633, -0.00167943,  0.00736803,  0.00565872], dtype=float32),)
(Array([-0.00709633, -0.00167943,  0.00736803,  0.00565872], dtype=float32),)

(Array([-0.01346871, -0.00446596,  0.00627056,  0.00018527], dtype=float32),)
(Array([-0.01346871, -0.00446596,  0.00627056,  0.00018527], dtype=float32),)

(Array([ 0.00176773, -0.00133819, -0.00203986,  0.00038421

  2%|▋                                            | 405/25000 [00:12<12:14, 33.47it/s, loss=0.279]

(Array([-0.00712128, -0.00046546, -0.00250711,  0.00749189], dtype=float32),)
(Array([-0.00712128, -0.00046546, -0.00250711,  0.00749189], dtype=float32),)

(Array([ 0.00986917,  0.01208205, -0.01973622, -0.00646193], dtype=float32),)
(Array([ 0.00986917,  0.01208205, -0.01973622, -0.00646193], dtype=float32),)

(Array([-0.00512622, -0.01329585,  0.02685563,  0.00017572], dtype=float32),)
(Array([-0.00512622, -0.01329585,  0.02685563,  0.00017572], dtype=float32),)

(Array([-0.00879179,  0.01637253, -0.02059724, -0.00146624], dtype=float32),)
(Array([-0.00879179,  0.01637253, -0.02059724, -0.00146624], dtype=float32),)

(Array([-0.00839136, -0.00345883,  0.0197899 ,  0.00319572], dtype=float32),)
(Array([-0.00839136, -0.00345883,  0.0197899 ,  0.00319572], dtype=float32),)

(Array([ 0.02631129,  0.02263379, -0.01835828, -0.03665624], dtype=float32),)
(Array([ 0.02631129,  0.02263379, -0.01835828, -0.03665624], dtype=float32),)

(Array([-0.00216679,  0.00919337,  0.00192929,  0.00343078

  2%|▋                                            | 413/25000 [00:13<11:30, 35.61it/s, loss=0.208]

(Array([ 0.018013  ,  0.0108106 , -0.02227356, -0.01662864], dtype=float32),)
(Array([ 0.018013  ,  0.0108106 , -0.02227356, -0.01662864], dtype=float32),)

(Array([-0.00483561, -0.02247593, -0.00966563, -0.02255436], dtype=float32),)
(Array([-0.00483561, -0.02247593, -0.00966563, -0.02255436], dtype=float32),)

(Array([-0.03026876,  0.00474763,  0.02132349, -0.01801617], dtype=float32),)
(Array([-0.03026876,  0.00474763,  0.02132349, -0.01801617], dtype=float32),)

(Array([ 0.00700983,  0.04706323, -0.00367313, -0.05388681], dtype=float32),)
(Array([ 0.00700983,  0.04706323, -0.00367313, -0.05388681], dtype=float32),)

(Array([-0.03562094,  0.03652883,  0.04704989, -0.01962253], dtype=float32),)
(Array([-0.03562094,  0.03652883,  0.04704989, -0.01962253], dtype=float32),)

(Array([-0.01161564, -0.01883931, -0.00413657,  0.01496997], dtype=float32),)
(Array([-0.01161564, -0.01883931, -0.00413657,  0.01496997], dtype=float32),)

(Array([-0.02180631,  0.02561392,  0.01271636, -0.04263546

  2%|▊                                            | 421/25000 [00:13<11:05, 36.94it/s, loss=0.228]

(Array([-0.03731417,  0.01820131,  0.00652368, -0.00150247], dtype=float32),)
(Array([-0.03731417,  0.01820131,  0.00652368, -0.00150247], dtype=float32),)

(Array([-0.01082607,  0.01019999, -0.00715712,  0.00207689], dtype=float32),)
(Array([-0.01082607,  0.01019999, -0.00715712,  0.00207689], dtype=float32),)

(Array([-0.00405095,  0.01833961, -0.00239027, -0.0099652 ], dtype=float32),)
(Array([-0.00405095,  0.01833961, -0.00239027, -0.0099652 ], dtype=float32),)

(Array([-0.02357836,  0.01205629, -0.00471068,  0.00330017], dtype=float32),)
(Array([-0.02357836,  0.01205629, -0.00471068,  0.00330017], dtype=float32),)

(Array([-0.00562186,  0.0030007 , -0.01095269,  0.0093206 ], dtype=float32),)
(Array([-0.00562186,  0.0030007 , -0.01095269,  0.0093206 ], dtype=float32),)

(Array([ 0.0138794 ,  0.01852803, -0.01749588,  0.00403618], dtype=float32),)
(Array([ 0.0138794 ,  0.01852803, -0.01749588,  0.00403618], dtype=float32),)

(Array([ 0.00030373, -0.02148938,  0.01652096,  0.0024651 

  2%|▊                                            | 429/25000 [00:13<10:59, 37.28it/s, loss=0.259]

(Array([-0.00994904,  0.00618085, -0.0064682 , -0.00163948], dtype=float32),)
(Array([-0.00994904,  0.00618085, -0.0064682 , -0.00163948], dtype=float32),)

(Array([ 0.00465471, -0.01791852, -0.00308102, -0.00128052], dtype=float32),)
(Array([ 0.00465471, -0.01791852, -0.00308102, -0.00128052], dtype=float32),)

(Array([0.00670056, 0.01916602, 0.01157617, 0.00035025], dtype=float32),)
(Array([0.00670056, 0.01916602, 0.01157617, 0.00035025], dtype=float32),)

(Array([-0.02191528,  0.00589801,  0.00577502,  0.02306184], dtype=float32),)
(Array([-0.02191528,  0.00589801,  0.00577502,  0.02306184], dtype=float32),)

(Array([-0.00837789,  0.00126767, -0.00398894, -0.02976952], dtype=float32),)
(Array([-0.00837789,  0.00126767, -0.00398894, -0.02976952], dtype=float32),)

(Array([-0.00831242,  0.01225318,  0.01159186,  0.01719654], dtype=float32),)
(Array([-0.00831242,  0.01225318,  0.01159186,  0.01719654], dtype=float32),)

(Array([ 0.00394208, -0.00509527,  0.00985124,  0.0171237 ], dtype

  2%|▊                                            | 437/25000 [00:13<11:34, 35.38it/s, loss=0.257]

(Array([ 0.0060775 , -0.01563309,  0.01797637,  0.01751271], dtype=float32),)
(Array([ 0.0060775 , -0.01563309,  0.01797637,  0.01751271], dtype=float32),)

(Array([ 0.02141052, -0.01219908, -0.00112668,  0.034915  ], dtype=float32),)
(Array([ 0.02141052, -0.01219908, -0.00112668,  0.034915  ], dtype=float32),)

(Array([ 0.00296948, -0.02026879, -0.01231623, -0.00133879], dtype=float32),)
(Array([ 0.00296948, -0.02026879, -0.01231623, -0.00133879], dtype=float32),)

(Array([ 0.01558266, -0.05615755,  0.01402097, -0.01750594], dtype=float32),)
(Array([ 0.01558266, -0.05615755,  0.01402097, -0.01750594], dtype=float32),)

(Array([ 0.05330862, -0.00389676,  0.00621667,  0.02251967], dtype=float32),)
(Array([ 0.05330862, -0.00389676,  0.00621667,  0.02251967], dtype=float32),)

(Array([ 0.00154618, -0.03348535, -0.02645467,  0.02235045], dtype=float32),)
(Array([ 0.00154618, -0.03348535, -0.02645467,  0.02235045], dtype=float32),)

(Array([ 0.00616222,  0.01379677, -0.00532222, -0.0091554 

  2%|▊                                            | 445/25000 [00:13<11:20, 36.06it/s, loss=0.283]

(Array([ 0.02613195, -0.00410429, -0.02204127,  0.00552809], dtype=float32),)
(Array([ 0.02613195, -0.00410429, -0.02204127,  0.00552809], dtype=float32),)

(Array([-0.01040647, -0.00367088, -0.01402866, -0.00353144], dtype=float32),)
(Array([-0.01040647, -0.00367088, -0.01402866, -0.00353144], dtype=float32),)

(Array([ 0.0019061 , -0.00746923, -0.00222235, -0.01251131], dtype=float32),)
(Array([ 0.0019061 , -0.00746923, -0.00222235, -0.01251131], dtype=float32),)

(Array([-0.01319255, -0.00368053,  0.00468691,  0.02748225], dtype=float32),)
(Array([-0.01319255, -0.00368053,  0.00468691,  0.02748225], dtype=float32),)

(Array([0.00482386, 0.00955415, 0.01200815, 0.00969383], dtype=float32),)
(Array([0.00482386, 0.00955415, 0.01200815, 0.00969383], dtype=float32),)

(Array([-0.00891959, -0.01404578, -0.00786942, -0.00041479], dtype=float32),)
(Array([-0.00891959, -0.01404578, -0.00786942, -0.00041479], dtype=float32),)

(Array([-0.00304527,  0.00551379,  0.00881778, -0.00574341], dtype

  2%|▊                                            | 453/25000 [00:14<11:34, 35.36it/s, loss=0.251]

(Array([-0.01237312, -0.00028241, -0.00045884, -0.01218339], dtype=float32),)
(Array([-0.01237312, -0.00028241, -0.00045884, -0.01218339], dtype=float32),)

(Array([ 0.00335483, -0.00838382,  0.00790178, -0.01192655], dtype=float32),)
(Array([ 0.00335483, -0.00838382,  0.00790178, -0.01192655], dtype=float32),)

(Array([ 0.00679498,  0.01125584, -0.00250358,  0.00146889], dtype=float32),)
(Array([ 0.00679498,  0.01125584, -0.00250358,  0.00146889], dtype=float32),)

(Array([ 0.01908054,  0.02516806,  0.01478187, -0.0231845 ], dtype=float32),)
(Array([ 0.01908054,  0.02516806,  0.01478187, -0.0231845 ], dtype=float32),)

(Array([-0.01245654,  0.0003699 ,  0.0271573 ,  0.00732004], dtype=float32),)
(Array([-0.01245654,  0.0003699 ,  0.0271573 ,  0.00732004], dtype=float32),)

(Array([ 0.00614537, -0.00472477, -0.00439719,  0.01321561], dtype=float32),)
(Array([ 0.00614537, -0.00472477, -0.00439719,  0.01321561], dtype=float32),)

(Array([-0.00526603, -0.00539759,  0.00527547,  0.00204218

  2%|▊                                            | 461/25000 [00:14<11:22, 35.94it/s, loss=0.212]

(Array([ 0.02683175, -0.00028793,  0.00748756, -0.00883032], dtype=float32),)
(Array([ 0.02683175, -0.00028793,  0.00748756, -0.00883032], dtype=float32),)

(Array([-0.0170998 ,  0.01745121,  0.01185799, -0.00064114], dtype=float32),)
(Array([-0.0170998 ,  0.01745121,  0.01185799, -0.00064114], dtype=float32),)

(Array([-0.00215379, -0.00505486,  0.00817706, -0.01359215], dtype=float32),)
(Array([-0.00215379, -0.00505486,  0.00817706, -0.01359215], dtype=float32),)

(Array([ 0.00709438,  0.01029727, -0.01132417, -0.00795392], dtype=float32),)
(Array([ 0.00709438,  0.01029727, -0.01132417, -0.00795392], dtype=float32),)

(Array([ 0.00682776, -0.01219483, -0.01238883,  0.00492772], dtype=float32),)
(Array([ 0.00682776, -0.01219483, -0.01238883,  0.00492772], dtype=float32),)

(Array([ 0.0115426 , -0.01735168,  0.0021075 , -0.00302869], dtype=float32),)
(Array([ 0.0115426 , -0.01735168,  0.0021075 , -0.00302869], dtype=float32),)

(Array([ 0.02258056, -0.00014832,  0.01536242, -0.02208973

  2%|▊                                            | 469/25000 [00:14<10:59, 37.17it/s, loss=0.223]

(Array([-0.02669666,  0.01448537,  0.02185839, -0.02473267], dtype=float32),)
(Array([-0.02669666,  0.01448537,  0.02185839, -0.02473267], dtype=float32),)

(Array([-0.00137018, -0.01056197,  0.00011594, -0.00173024], dtype=float32),)
(Array([-0.00137018, -0.01056197,  0.00011594, -0.00173024], dtype=float32),)

(Array([-0.01165714,  0.00156747, -0.01477291, -0.02178212], dtype=float32),)
(Array([-0.01165714,  0.00156747, -0.01477291, -0.02178212], dtype=float32),)

(Array([ 0.0063503 , -0.02809665, -0.00382225, -0.00191809], dtype=float32),)
(Array([ 0.0063503 , -0.02809665, -0.00382225, -0.00191809], dtype=float32),)

(Array([ 0.02174136, -0.00784401, -0.00561487,  0.00546045], dtype=float32),)
(Array([ 0.02174136, -0.00784401, -0.00561487,  0.00546045], dtype=float32),)

(Array([-0.02274694,  0.00153873, -0.01018068,  0.00744896], dtype=float32),)
(Array([-0.02274694,  0.00153873, -0.01018068,  0.00744896], dtype=float32),)

(Array([ 0.00472955, -0.00755044, -0.00332487,  0.040287  

  2%|▊                                            | 477/25000 [00:14<11:23, 35.88it/s, loss=0.208]

(Array([-0.01993534,  0.03802625, -0.01965083,  0.00016682], dtype=float32),)
(Array([-0.01993534,  0.03802625, -0.01965083,  0.00016682], dtype=float32),)

(Array([-0.03010931,  0.0646488 ,  0.00509271,  0.03552068], dtype=float32),)
(Array([-0.03010931,  0.0646488 ,  0.00509271,  0.03552068], dtype=float32),)

(Array([-0.08079412, -0.00664135, -0.01291395,  0.02675059], dtype=float32),)
(Array([-0.08079412, -0.00664135, -0.01291395,  0.02675059], dtype=float32),)

(Array([-0.01818313, -0.01464049, -0.00606309,  0.01555181], dtype=float32),)
(Array([-0.01818313, -0.01464049, -0.00606309,  0.01555181], dtype=float32),)

(Array([-0.03287055,  0.02934355, -0.00539583,  0.0106511 ], dtype=float32),)
(Array([-0.03287055,  0.02934355, -0.00539583,  0.0106511 ], dtype=float32),)

(Array([-0.0181475 ,  0.00265427,  0.00152532,  0.01743907], dtype=float32),)
(Array([-0.0181475 ,  0.00265427,  0.00152532,  0.01743907], dtype=float32),)

(Array([ 0.01221708,  0.00197029, -0.00713158, -0.04480524

  2%|▊                                            | 485/25000 [00:15<11:13, 36.41it/s, loss=0.219]

(Array([-0.02941224, -0.00983717,  0.00939121, -0.00091239], dtype=float32),)
(Array([-0.02941224, -0.00983717,  0.00939121, -0.00091239], dtype=float32),)

(Array([ 0.04869247, -0.0132694 ,  0.01209199,  0.00577482], dtype=float32),)
(Array([ 0.04869247, -0.0132694 ,  0.01209199,  0.00577482], dtype=float32),)

(Array([ 0.00069732,  0.00261923,  0.00266016, -0.00209412], dtype=float32),)
(Array([ 0.00069732,  0.00261923,  0.00266016, -0.00209412], dtype=float32),)

(Array([-0.00544178, -0.01237049, -0.00300404, -0.00875969], dtype=float32),)
(Array([-0.00544178, -0.01237049, -0.00300404, -0.00875969], dtype=float32),)

(Array([ 0.01103009,  0.00955865, -0.00113094, -0.00627315], dtype=float32),)
(Array([ 0.01103009,  0.00955865, -0.00113094, -0.00627315], dtype=float32),)

(Array([ 0.00961852, -0.00519732, -0.00242836,  0.00755765], dtype=float32),)
(Array([ 0.00961852, -0.00519732, -0.00242836,  0.00755765], dtype=float32),)

(Array([-2.2448601e-02, -9.1795810e-05,  1.5707524e-02, -2

  2%|▉                                            | 493/25000 [00:15<11:00, 37.12it/s, loss=0.226]

(Array([-0.00031474, -0.00546558, -0.01556408, -0.00476784], dtype=float32),)
(Array([-0.00031474, -0.00546558, -0.01556408, -0.00476784], dtype=float32),)

(Array([ 0.00285542, -0.00356921, -0.01121725, -0.01015433], dtype=float32),)
(Array([ 0.00285542, -0.00356921, -0.01121725, -0.01015433], dtype=float32),)

(Array([0.00484654, 0.01978124, 0.01901785, 0.0053165 ], dtype=float32),)
(Array([0.00484654, 0.01978124, 0.01901785, 0.0053165 ], dtype=float32),)

(Array([-0.01535783, -0.01686779,  0.00092575,  0.01307374], dtype=float32),)
(Array([-0.01535783, -0.01686779,  0.00092575,  0.01307374], dtype=float32),)

(Array([ 0.0045078 ,  0.00081605, -0.0039958 , -0.00372497], dtype=float32),)
(Array([ 0.0045078 ,  0.00081605, -0.0039958 , -0.00372497], dtype=float32),)

(Array([-0.00246241,  0.00366138,  0.01366028,  0.02670349], dtype=float32),)
(Array([-0.00246241,  0.00366138,  0.01366028,  0.02670349], dtype=float32),)

(Array([ 0.00963791, -0.00273788, -0.00584   , -0.01193149], dtype

  2%|▉                                            | 497/25000 [00:15<11:23, 35.82it/s, loss=0.266]

(Array([ 0.00565015, -0.01578384, -0.01270296,  0.00028016], dtype=float32),)
(Array([ 0.00565015, -0.01578384, -0.01270296,  0.00028016], dtype=float32),)

(Array([ 0.01286778, -0.00887535,  0.00434086, -0.0141456 ], dtype=float32),)
(Array([ 0.01286778, -0.00887535,  0.00434086, -0.0141456 ], dtype=float32),)

(Array([1.6874598e-02, 2.4940982e-03, 1.3730746e-02, 5.7828147e-05],      dtype=float32),)
(Array([1.6874598e-02, 2.4940982e-03, 1.3730746e-02, 5.7828147e-05],      dtype=float32),)

(Array([ 0.00485384,  0.01409919,  0.00839896, -0.01527021], dtype=float32),)
(Array([ 0.00485384,  0.01409919,  0.00839896, -0.01527021], dtype=float32),)

(Array([-0.00774863,  0.00426051, -0.02180328, -0.01149231], dtype=float32),)
(Array([-0.00774863,  0.00426051, -0.02180328, -0.01149231], dtype=float32),)

(Array([6.1891042e-06, 2.4336197e-03, 8.0328546e-03, 3.5851279e-03],      dtype=float32),)
(Array([6.1891042e-06, 2.4336197e-03, 8.0328546e-03, 3.5851279e-03],      dtype=float32),)



  2%|▉                                            | 505/25000 [00:15<13:17, 30.70it/s, loss=0.237]

(Array([-0.00441143, -0.02419504,  0.00997035, -0.00045374], dtype=float32),)
(Array([-0.00441143, -0.02419504,  0.00997035, -0.00045374], dtype=float32),)

(Array([ 0.01860354, -0.00818601, -0.00984006, -0.0050471 ], dtype=float32),)
(Array([ 0.01860354, -0.00818601, -0.00984006, -0.0050471 ], dtype=float32),)

(Array([0.01120037, 0.01436585, 0.00388021, 0.00665484], dtype=float32),)
(Array([0.01120037, 0.01436585, 0.00388021, 0.00665484], dtype=float32),)

(Array([-0.00903175, -0.00632603,  0.00964281,  0.01265988], dtype=float32),)
(Array([-0.00903175, -0.00632603,  0.00964281,  0.01265988], dtype=float32),)

(Array([-0.00683512, -0.01080251,  0.0175034 ,  0.00360106], dtype=float32),)
(Array([-0.00683512, -0.01080251,  0.0175034 ,  0.00360106], dtype=float32),)

(Array([ 0.02258028,  0.01137454, -0.00626848,  0.0185361 ], dtype=float32),)
(Array([ 0.02258028,  0.01137454, -0.00626848,  0.0185361 ], dtype=float32),)

(Array([-0.02223188,  0.00932013, -0.02540595,  0.00078716], dtype

  2%|▉                                            | 513/25000 [00:15<11:48, 34.58it/s, loss=0.258]

(Array([ 0.00113292,  0.0017529 , -0.01228675, -0.00992233], dtype=float32),)
(Array([ 0.00113292,  0.0017529 , -0.01228675, -0.00992233], dtype=float32),)

(Array([ 0.01820439, -0.00288495, -0.00431016, -0.00328533], dtype=float32),)
(Array([ 0.01820439, -0.00288495, -0.00431016, -0.00328533], dtype=float32),)

(Array([ 0.00776968,  0.00864513,  0.00425661, -0.00488268], dtype=float32),)
(Array([ 0.00776968,  0.00864513,  0.00425661, -0.00488268], dtype=float32),)

(Array([-0.00716044, -0.00957185,  0.00080014,  0.00524137], dtype=float32),)
(Array([-0.00716044, -0.00957185,  0.00080014,  0.00524137], dtype=float32),)

(Array([-4.9891882e-05, -7.8825094e-03, -7.0348489e-03,  5.2643376e-03],      dtype=float32),)
(Array([-4.9891882e-05, -7.8825094e-03, -7.0348489e-03,  5.2643376e-03],      dtype=float32),)

(Array([ 0.00298151, -0.0092543 ,  0.00298028, -0.00128439], dtype=float32),)
(Array([ 0.00298151, -0.0092543 ,  0.00298028, -0.00128439], dtype=float32),)

(Array([ 0.02623946,  0.

  2%|▉                                            | 518/25000 [00:16<12:36, 32.35it/s, loss=0.222]


(Array([ 0.00032461, -0.00424438, -0.0089822 ,  0.00819868], dtype=float32),)
(Array([ 0.00032461, -0.00424438, -0.0089822 ,  0.00819868], dtype=float32),)

(Array([ 0.00761648,  0.00534021, -0.00984259, -0.00262236], dtype=float32),)
(Array([ 0.00761648,  0.00534021, -0.00984259, -0.00262236], dtype=float32),)



KeyboardInterrupt: 

In [None]:
# clean the stats
aggregated = {}  # experiment name -> 'args' or timestamp -> stat key -> stat value
# gather stats
for k, v in results.items():  # for each experiment
    aggregated[k] = {'args': []}

    for n in range(len(SEEDS)):  # for each trial
        aggregated[k]['args'].append(v[n]['args'])
        
        for t in range(1, v[0]['args']['num_iters'] + 1):  # for each timestamp    
            for stat_key, value in v[n][t].items():  # for each stat recorded at that timestamp
                if stat_key not in aggregated[k]: aggregated[k][stat_key] = {}
                if t not in aggregated[k][stat_key]: aggregated[k][stat_key][t] = []
                aggregated[k][stat_key][t].append(value)

# aggregate stats
ret = defaultdict(dict)  # stat key -> experiment name -> 't' or 'avg' or 'std' -> 
args = {}
for k, v in aggregated.items():  # for experiment
    for stat_key in v.keys():  # for stat 
        if stat_key == 'args': 
            args[k] = v[stat_key]
            continue
        if k not in ret[stat_key]: ret[stat_key][k] = {}
        ret[stat_key][k]['t'] = list(v[stat_key].keys())
        arr = np.array(list(v[stat_key].values()))
        ret[stat_key][k]['avg'] = np.mean(arr, axis=1)
        ret[stat_key][k]['std'] = np.std(arr, axis=1)

with open(f'{DIR_PREFIX}/data/{NAME}_processed.pkl', 'wb') as f: 
    pkl.dump(ret, f)
    print('Saved processed results')

# Plot

In [None]:
# Plot
fig, ax = plt.subplots(len(ret), 1, figsize=(10, 24))

for i, stat_key in enumerate(ret.keys()):
    ax[i].set_title(stat_key)
    for experiment_name in ret[stat_key].keys():
        ts, avgs, stds = ret[stat_key][experiment_name]['t'], ret[stat_key][experiment_name]['avg'], ret[stat_key][experiment_name]['std']
        if avgs.ndim == 2:
            for j in range(avgs.shape[1]):
                ax[i].plot(ts, avgs[:, j], label=f'{experiment_name} {str(j)}')
                ax[i].fill_between(ts, avgs[:, j] - 1.96 * stds[:, j], avgs[:, j] + 1.96 * stds[:, j], alpha=0.2)
        else:
            if stat_key == 'loss':
                n = 3
                kernel = [1 / n,] * n
                avgs = np.convolve(avgs, kernel)[n // 2:n // 2 + avgs.shape[0]]
                stds = np.convolve(stds, kernel)[n // 2:n // 2 + stds.shape[0]]
            ax[i].plot(ts, avgs, label=experiment_name)
            ax[i].fill_between(ts, avgs - 1.96 * stds, avgs + 1.96 * stds, alpha=0.2)
    ax[i].legend()
    

# ax[1].set_ylim(-0.2, 0.2)
# ax[2].set_ylim(-0.2, 0.2)
# plt.savefig(f'figs/{NAME}.pdf')

#### 