In [1]:
import os
os.chdir("../")
print("Moved up")

Moved up


In [2]:
import importlib
import copy
import pickle
import typing as tp
from functools import partial

%env XLA_PYTHON_CLIENT_PREALLOCATE=false

import jax
import jax.numpy as jnp
import jax.flatten_util as fu
from flax import linen as nn  # Linen API
import numpy as np
import matplotlib.pyplot as plt
import optax
import math

from tqdm import tqdm
import time

import lib_data
import utils
import modules
import callbacks

# %env XLA_PYTHON_CLIENT_MEM_FRACTION=.9
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

print("devices", jax.devices())

shade_colours = plt.get_cmap('Set3')
dark_colours = plt.get_cmap('tab10')
all_colours = plt.get_cmap('tab20')

def light_colours(i):
    return all_colours(2*i+1)

env: XLA_PYTHON_CLIENT_PREALLOCATE=false


2025-05-05 12:24:57.460904: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746444297.474260 1406811 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746444297.478294 1406811 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


devices [CudaDevice(id=0)]


# Dataset

In [3]:
#-----------------------------------------------------------------------------------------------------------------------------
n_out = 1

n_train: int = 100000
n_eval: int = 10000
n_hess: int = 1

importlib.reload(lib_data)
def __get_datasets():
    datasets = lib_data.get_wikitext2_dataset(block_size=1024, max_train_samples=n_train, max_eval_samples=n_eval)

    print("Train:", len(datasets[0]), " Eval:", len(datasets[1]), " Hess:", len(datasets[2]))
    x, y = datasets[0][0]
    print("Input shape:", x.shape, "Target shape:", y.shape)

    data_name = "wiki2_"+str(n_out)+"cl_"+str(n_train) + "_" + str(n_eval) 
    # print(len(datasets))
    return datasets, data_name

datasets, data_name = __get_datasets()

⚙Tokenizing raw Wikitext-2...


Saving the dataset (0/1 shards):   0%|          | 0/4358 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/36718 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3760 [00:00<?, ? examples/s]

Saved tokenized dataset to: ./cached_wikitext2/tokenized_gpt2
Building LM datasets...
Flattening token sequences...
Total 2391884 tokens. Creating 2335 chunks with stride 1024
Flattening token sequences...
Total 283287 tokens. Creating 276 chunks with stride 1024
Train: 2335  Eval: 276  Hess: 2335
Input shape: (1023,) Target shape: (1023,)


# Architecture

In [4]:
#-----------------------------------------------------------------------------------------------------------------------------
importlib.reload(modules)
from ml_collections import ConfigDict

class MLPBlock(nn.Module):
    config: ConfigDict
    train: bool

    @nn.compact
    def __call__(self, x):
        features = x.shape[-1]
        x = nn.LayerNorm(dtype=self.config.dtype)(x)
        x = nn.Dense(self.config.mlp_expansion * features, dtype=self.config.dtype)(x)
        x = nn.gelu(x, approximate=True)
        x = nn.Dense(features, dtype=self.config.dtype)(x)
        x = nn.Dropout(rate=self.config.dropout_rate)(x, deterministic=not self.train)
        return x

class AttentionBlock(nn.Module):
    config: ConfigDict
    mask: tp.Optional[jax.Array]
    train: bool

    @nn.compact
    def __call__(self, x):
        features = x.shape[-1]
        x = nn.LayerNorm(dtype=self.config.dtype)(x)
        qkv = nn.DenseGeneral(
            features=(self.config.num_heads, self.config.head_dim * 3),
            axis=-1, dtype=self.config.dtype
        )(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)

        scale = q.shape[-1] ** -0.5
        q = q.astype(self.config.softmax_dtype) * scale
        k = k.astype(self.config.softmax_dtype)

        q = q.transpose(0, 2, 1, 3) # [B T H D] to [B H T D]
        k = k.transpose(0, 2, 1, 3) # [B T H D] to [B H T D]
        v = v.transpose(0, 2, 1, 3) # [B T H D] to [B H T D]
    
        attn = q @ k.swapaxes(-2, -1) # [B H T D] @ [B H D T] -> [B H T T]
    
        if self.mask is not None:
            attn = jnp.where(self.mask, attn, jnp.finfo(self.config.softmax_dtype).min)
    
        attn = nn.softmax(attn, axis=-1).astype(self.config.dtype)
        attn = nn.Dropout(rate=self.config.dropout_rate)(attn, deterministic=not self.train)
        y = attn @ v # [B H T T] @ [B H T D] -> [B H T D]
        y = y.transpose(0, 2, 1, 3) # [B H T D] -> [B T H D]
        y = y.reshape(x.shape)  # [B T H D] -> [B T C(H*D)]
        y = nn.Dense(features, dtype=self.config.dtype)(y)
        y = nn.Dropout(rate=self.config.dropout_rate)(y, deterministic=not self.train)
        return y

class TransformerBlock(nn.Module):
    config: ConfigDict
    mask: tp.Optional[jax.Array]
    train: bool

    @nn.compact
    def __call__(self, x):
        mlp = MLPBlock
        if "MLP" in self.config.remat:
            mlp = nn.remat(mlp, prevent_cse=False)
        attn = AttentionBlock
        if "Attn" in self.config.remat:
            attn = nn.remat(attn, prevent_cse=False)

        x = x + attn(config=self.config, mask=self.mask, train=self.train)(x)
        x = x + mlp(config=self.config, train=self.train)(x)
        return x

class Transformer(nn.Module):
    config: ConfigDict

    @nn.compact
    def __call__(self, x, mask=None, train=True):
        if mask is None and self.config.causal_mask:
            mask = nn.make_causal_mask(x, dtype=jnp.bool_)

        embed = nn.Embed(self.config.vocab_size, self.config.hidden_size, dtype=self.config.dtype, name='token_embed')
        x = embed(x) 
        pos_emb = self.param("pos_emb", nn.initializers.normal(0.02),
                             (self.config.max_seq_len, self.config.hidden_size)).astype(self.config.dtype)
        
        x += pos_emb[None, :x.shape[1]]

        block_fn = functools.partial(TransformerBlock, config=self.config, mask=mask, train=train)

        if self.config.scan_layers:
            block = block_fn(name="block")
            x, _ = nn.scan(
                lambda module, carry, _: (module(carry), None),
                variable_axes={"params": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.num_layers
            )(block, x, ())
        else:
            for i in range(self.config.num_layers):
                x = block_fn(name=f"block_{i}")(x)

        x = nn.LayerNorm(dtype=self.config.dtype)(x)

        # weight tying
        logits = x @ embed.embedding.T
        return logits.astype(jnp.float32)

In [5]:
config = ConfigDict()
config.vocab_size = 50257
config.hidden_size = 384
config.num_layers = 6
config.num_heads = 6
config.head_dim = 64
config.mlp_expansion = 4
config.dropout_rate = 0.1
config.max_seq_len = 1024
config.num_outputs = 50257
config.dtype = jnp.float32
config.causal_mask = True
config.softmax_dtype = jnp.float32
config.remat = [],  #["MLP", "Attn"]
config.scan_layers = False

def __get_arch__():

    model = Transformer(config)
    model_name = f"Transformer_L{config.num_layers}_H{config.hidden_size}_Heads{config.num_heads}"
    return model, model_name

model_arch, model_name = __get_arch__()
print(model_name)

Transformer_L6_H384_Heads6


In [8]:
# config = ConfigDict()
# config.vocab_size = 50257
# config.hidden_size = 768
# config.num_layers = 12
# config.num_heads = 12
# config.head_dim = 64
# config.mlp_expansion = 4
# config.dropout_rate = 0.1
# config.max_seq_len = 1024
# config.num_outputs = 50257
# config.dtype = jnp.float32
# config.causal_mask = True
# config.softmax_dtype = jnp.float32
# config.remat = ["MLP", "Attn"]
# config.scan_layers = False

# def __get_arch__():

#     model = Transformer(config)
#     model_name = f"Transformer_L{config.num_layers}_H{config.hidden_size}_Heads{config.num_heads}"
#     return model, model_name

# model_arch, model_name = __get_arch__()
# print(model_name)

# Optimizer

In [6]:
#-----------------------------------------------------------------------------------------------------------------------------
importlib.reload(modules)
from optax import contrib

def __get_optim__(warmup_steps, lr, b1, b2, b3, option="", rho=None, sync_period=1):
    # warmup_steps, lr, b1, b2, b3 = hyps['warmup_steps'], hyps['lr'], hyps['b1'], hyps['b2'], hyps['b3']
    if option == 'sam':
        assert rho is not None
        warmup_scheduler = optax.linear_schedule(init_value=0.0, end_value=lr,
                                                transition_steps=warmup_steps,
                                                transition_begin=0,)
        constant_scheduler = optax.constant_schedule(lr)
        lr_scheduler = optax.join_schedules([warmup_scheduler, constant_scheduler], boundaries=[warmup_steps])
        base_opt = modules.get_sgd_optimizer(lr_scheduler, b1, b2, b3, verbose=False)
        adv_opt = modules.get_sgd_optimizer(rho, b1, b2, b3, verbose=False)
        optimizer = contrib.sam(base_opt, adv_opt, sync_period=sync_period, opaque_mode=True) # sam opt
        optim_name = f"sgdFam-SAM_1b{b1}_2b{b2}_3b{b3}_lr{lr}_warmup{warmup_steps}_rho{rho}_syncT{sync_period}"

    else:
        warmup_scheduler = optax.linear_schedule(init_value=0.0, end_value=lr,
                                                transition_steps=warmup_steps,
                                                transition_begin=0,)
        constant_scheduler = optax.constant_schedule(lr)
        lr_scheduler = optax.join_schedules([warmup_scheduler, constant_scheduler], boundaries=[warmup_steps])
        optimizer = modules.get_sgd_optimizer(lr_scheduler, b1, b2, b3, verbose=False)
        optim_name = f"sgdFam_1b{b1}_2b{b2}_3b{b3}_lr{lr}_warmup{warmup_steps}"
        
    return optimizer, optim_name

optimizer, optim_name = __get_optim__(2, 0.1, 0, 0, 0)

# Model Params (Fixed and Tuned)

In [13]:
#-----------------------------------------------------------------------------------------------------------------------------
import itertools
### FIXED
warmup_steps = 2
bs = 1
eval_bs = 1
n_epochs = 10
loss_fn = optax.softmax_cross_entropy_with_integer_labels

# FLEXIBLE
# lr = 0.005
# beta_list = [(0., 0., 0.), (0., 0.99, 0.)]
# optim_hp = LR, B1, B2, B3, SAM, Rho, sync_period
optim_hp_list = [
    (5e-3, 0.9, 0.99, 0., False, 0., 1), 
]
seed_list = [x for x in range(1)]
# sam_list = [False, True]
s = [optim_hp_list, seed_list]
hyp_list = list(itertools.product(*s))
print(len(hyp_list))

1


# Callbacks

In [7]:
#-----------------------------------------------------------------------------------------------------------------------------
sws = 5
cb_freq = 1
hess_freq = int(1e8) # really large
importlib.reload(callbacks)
def __get_cbs__(state, compute_hessian=False):
    cbs = []
    cbs.append(callbacks.saveWeightsCB(sws, grad=True))
    # cbs.append(callbacks.thinCB(thin_freq=cb_freq))

    if compute_hessian:
        hvpCB = callbacks.hvpCB(loss_fn=loss_fn, batches=(datasets[2].data[:n_hess], datasets[2].targets[:n_hess]), 
                            save_freq=hess_freq, hess_bs=n_hess, state=state, bn=False)
        cbs.append(hvpCB)   
        specCB = callbacks.spectrumCB(n_eigs=20, n_evecs=10, 
                    loss_fn=loss_fn, seed=seed, hvpCB=hvpCB, save_freq=hess_freq, verbose=False)
        cbs.append(specCB)

        esCB = callbacks.earlyStopCB(acc_threshold=0.999, cbs=None, min_eps=sws, max_eps=n_epochs,conseq_eps=3,
                                 final_cbs=[hvpCB, specCB], verbose=False, low_eps=max(sws, 100), low_thresh=0.11, )
    else:
        esCB = callbacks.earlyStopCB(acc_threshold=0.999, cbs=None, min_eps=sws, max_eps=n_epochs, conseq_eps=5,
                                 verbose=False, low_eps=max(sws, 100), low_thresh=0., )
    cbs.append(esCB)
    return cbs


# Train State

In [9]:
from flax import struct                # Flax dataclasses
from clu import metrics
from flax.training import train_state  # Useful dataclass to keep train state
from perplexity import Perplexity
importlib.reload(modules)

@struct.dataclass
class Metrics(metrics.Collection):
    accuracy: metrics.Accuracy
    perplexity: Perplexity
    loss: metrics.Average.from_output('loss')

class TrainState(train_state.TrainState):
    metrics: Metrics
    rng: jax.Array

class TrainStateBN(train_state.TrainState):
    metrics: Metrics
    batch_stats: tp.Any
    rng: jax.Array

class TrainStateSAM(modules.TrainStateSAM):
    metrics: Metrics
    batch_stats: tp.Any
    rng: jax.Array

def create_train_state(model, optimizer, inputs, rng, option=""):
    """Creates an initial `TrainState`."""
    rng, model_rng = jax.random.split(rng)
    if option == "":
        params = model.init(model_rng, jnp.ones_like(inputs[0][jnp.newaxis, :]))['params'] # initialize parameters by passing a template image
        
        tx = optimizer
        return TrainState.create(
          apply_fn=model.apply, params=params, tx=tx, metrics=Metrics.empty(), rng=rng)
        
    elif option == "bn":
        variables = model.init(model_rng, jnp.ones_like(inputs[0][jnp.newaxis, :])) # initialize parameters by passing a template image
        params = variables['params']
        batch_stats = variables['batch_stats']
        
        tx = optimizer
        return TrainStateBN.create(
          apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats, 
          metrics=Metrics.empty(), rng=rng)
        
    elif option == "sam":
        variables = model.init(model_rng, jnp.ones_like(inputs[0][jnp.newaxis, :])) # initialize parameters by passing a template image
        params = variables['params']
        batch_stats = variables['batch_stats']
        
        tx = optimizer
        return TrainStateSAM.create(
          apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats, 
          metrics=Metrics.empty(), rng=rng)
    else:
        raise NotImplementedError
        
        

# Training

In [10]:
import functools
# from flax.linen import tabulate
# tabulated_fn = tabulate(model, rngs={"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1)},
#                            console_kwargs={"width": 200, "force_jupyter": False}) # Avoid output clipping in notebooks)

# print(tabulated_fn(sample_batch[0], train=False))
# # with open("gpt2_summary.txt", "w") as f:
#     # f.write(tab_fn(dummy_input, train=False))
#     # f.write(tabulated_fn(sample_batch[0], train=False))

In [11]:
@jax.jit
def _compute_metrics(*, state, batch):
    preds = state.apply_fn({'params': state.params}, batch[0], train=False)
    loss = loss_fn(preds, batch[1]).mean()
    metric_updates = state.metrics.single_from_model_output(
        logits=preds, labels=batch[1], loss=loss)
    metrics = state.metrics.merge(metric_updates)
    state = state.replace(metrics=metrics)
    return state

In [14]:
#-----------------------------------------------------------------------------------------------------------------------------
import training
importlib.reload(training)

load_files = False
compute_hessian = False
force_train = True

all_mh = []
all_exp_names = []

for hyp in hyp_list:
    
    metrics_history = {'train_loss': [],
                   'train_accuracy': [],
                   'train_perplexity': [],
                   'test_loss': [],
                   'test_accuracy': [],
                   'test_perplexity': [],
                      }

    lr, b1, b2, b3, sam, sam_rho, sync_T = hyp[0]
    seed = hyp[1]
    option = 'sam' if sam else ""
    
    if datasets is None:
        datasets, data_name = __get_datasets__()
    
    train_loader = lib_data.NumpyLoader(datasets[0], batch_size=bs, shuffle=False)
    for sample_batch in train_loader:
        break
    
    test_loader = lib_data.NumpyLoader(datasets[1], batch_size=eval_bs)
    dataloaders = [train_loader, test_loader]
    
    model, model_name = __get_arch__()
    model_name += "_seed"+str(seed)

    optim, optim_name = __get_optim__(warmup_steps, lr, b1, b2, b3, option=option, rho=sam_rho, sync_period=sync_T)
    optim_name += f"_epochs{n_epochs}_bs{bs}"

    init_rng = jax.random.PRNGKey(seed)
    state = create_train_state(model, optim, sample_batch[0], init_rng, option=option)
    del init_rng  # Must not be used anymore.
    num_params = utils.count_params(state.params)
    print("num params", num_params)
    
    # state = load_params(state, config)
    
    sample_out = state.apply_fn({'params': state.params,}, sample_batch[0], train=False)
    print("output dim", sample_out.shape)

    break

    # evaluate perplexity

    # train_bar = tqdm(train_loader, desc='train', total=len(train_loader))
    # test_bar = tqdm(test_loader, desc="validation", total=len(test_loader))

    # for batch in train_bar:
    #     state = _compute_metrics(state=state, batch=batch)
    # for metric, value in state.metrics.compute().items():  # compute metrics
    #     metrics_history[f'train_{metric}'].append(value)  # record metrics
    # utils.reset_metrics(state)
    
    # for batch in test_bar:
    #     state = _compute_metrics(state=state, batch=batch)
    # for metric, value in state.metrics.compute().items():  # compute metrics
    #     metrics_history[f'test_{metric}'].append(value)  # record metrics
    # utils.reset_metrics(state)
    # print(metrics_history)
    
    # train model
    cbs = __get_cbs__(state, compute_hessian=compute_hessian)
    cb_name_str = utils.get_callback_name_str(cbs)
    cb_name_list = utils.get_callback_name_list(cbs)
    # break

    experiment_name = utils.get_now() + "_" + data_name + "_" + model_name + "_" + optim_name
    
    try:
        if force_train:
            raise FileNotFoundError
        experiment_name, lse = utils.find_latest_exp(experiment_name, n_epochs, save_freq=cb_freq, 
                                                   cbs=cb_name_list, unknown_lse=True, verbose=False)
        metrics_history = utils.load_thing("traj/" + experiment_name + "/metrics.pkl")
        print(f"tr_acc: {metrics_history['train_accuracy'][-1]:0%}, te_acc: {metrics_history['test_accuracy'][-1]:0%}")
        metrics_history['lse'] = [lse]
        if compute_hessian:
            eigvals = utils.load_thing("traj/" + experiment_name + "/eigvals.pkl")
            metrics_history['eigvals'] = eigvals
            print(f"sharp: {metrics_history['eigvals'][-1][0]}")

    except FileNotFoundError:
        metrics_history = training.train_model(state, model, loss_fn, metrics_history, n_epochs, dataloaders, \
                                                   experiment_name, cbs, option=option, force_fb=False, tqdm_over_epochs=1, 
                                              eval_freq=1, gradient_accumulation=32, tqdm_over_batch=False)         
        
    all_mh.append(metrics_history)
    all_exp_names.append(experiment_name)
    
    print(experiment_name, "complete")
    print("\n ---------------------------------------------------------------------------------------------------------\n")
# Training: datasets, hps, arch_func, optim_func, cb_func, -> train model


num params 30339456
output dim (1, 1023, 50257)


In [14]:
#-----------------------------------------------------------------------------------------------------------------------------
import training
importlib.reload(training)

load_files = False
compute_hessian = False
force_train = True

all_mh = []
all_exp_names = []

for hyp in hyp_list:
    
    metrics_history = {'train_loss': [],
                   'train_accuracy': [],
                   'train_perplexity': [],
                   'test_loss': [],
                   'test_accuracy': [],
                   'test_perplexity': [],
                      }

    lr, b1, b2, b3, sam, sam_rho, sync_T = hyp[0]
    seed = hyp[1]
    option = 'sam' if sam else ""
    
    if datasets is None:
        datasets, data_name = __get_datasets__()
    
    train_loader = lib_data.NumpyLoader(datasets[0], batch_size=bs, shuffle=True)
    for sample_batch in train_loader:
        break
    
    test_loader = lib_data.NumpyLoader(datasets[1], batch_size=eval_bs)
    dataloaders = [train_loader, test_loader]
    
    model, model_name = __get_arch__()
    model_name += "_seed"+str(seed)

    optim, optim_name = __get_optim__(warmup_steps, lr, b1, b2, b3, option=option, rho=sam_rho, sync_period=sync_T)
    optim_name += f"_epochs{n_epochs}_bs{bs}"

    init_rng = jax.random.PRNGKey(seed)
    state = create_train_state(model, optim, sample_batch[0], init_rng, option=option)
    del init_rng  # Must not be used anymore.

    state = load_params(state)
    
    sample_out = state.apply_fn({'params': state.params,}, sample_batch[0], train=False)
    print("output dim", sample_out.shape)
    # break
    cbs = __get_cbs__(state, compute_hessian=compute_hessian)
    cb_name_str = utils.get_callback_name_str(cbs)
    cb_name_list = utils.get_callback_name_list(cbs)
    # break
    num_params = utils.count_params(state.params)
    print("num params", num_params)

    experiment_name = utils.get_now() + "_" + data_name + "_" + model_name + "_" + optim_name
    
    try:
        if force_train:
            raise FileNotFoundError
        experiment_name, lse = utils.find_latest_exp(experiment_name, n_epochs, save_freq=cb_freq, 
                                                   cbs=cb_name_list, unknown_lse=True, verbose=False)
        metrics_history = utils.load_thing("traj/" + experiment_name + "/metrics.pkl")
        print(f"tr_acc: {metrics_history['train_accuracy'][-1]:0%}, te_acc: {metrics_history['test_accuracy'][-1]:0%}")
        metrics_history['lse'] = [lse]
        if compute_hessian:
            eigvals = utils.load_thing("traj/" + experiment_name + "/eigvals.pkl")
            metrics_history['eigvals'] = eigvals
            print(f"sharp: {metrics_history['eigvals'][-1][0]}")

    except FileNotFoundError:
        metrics_history = training.train_model(state, model, loss_fn, metrics_history, n_epochs, dataloaders, \
                                                   experiment_name, cbs, option=option, force_fb=False, tqdm_over_epochs=1, 
                                              eval_freq=1, gradient_accumulation=1)         
        
    all_mh.append(metrics_history)
    all_exp_names.append(experiment_name)
    
    print(experiment_name, "complete")
    print("\n ---------------------------------------------------------------------------------------------------------\n")
# Training: datasets, hps, arch_func, optim_func, cb_func, -> train model


output dim (1, 127, 50257)
num params 124439808
Training model 250501-1848_wiki2_1cl_10_10_Transformer_L12_H768_Heads12_seed0_sgdFam_1b0.9_2b0.99_3b0.0_lr0.005_warmup2_epochs0_bs1


epochs: 0it [00:00, ?it/s]

Training complete 250501-1848_wiki2_1cl_10_10_Transformer_L12_H768_Heads12_seed0_sgdFam_1b0.9_2b0.99_3b0.0_lr0.005_warmup2_epochs0_bs1
250501-1848_wiki2_1cl_10_10_Transformer_L12_H768_Heads12_seed0_sgdFam_1b0.9_2b0.99_3b0.0_lr0.005_warmup2_epochs0_bs1 complete

 ---------------------------------------------------------------------------------------------------------



In [15]:
print(all_mh[0])

{'train_loss': [Array(4.0963697, dtype=float32)], 'train_accuracy': [Array(0.28031495, dtype=float32)], 'train_perplexity': [Array(60.12163, dtype=float32)], 'test_loss': [Array(4.162267, dtype=float32)], 'test_accuracy': [Array(0.32283464, dtype=float32)], 'test_perplexity': [Array(64.21695, dtype=float32)], 'lse': 0}


In [17]:
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
input_ids = tokenizer("The meaning of life is", return_tensors="jax").input_ids
out = state.apply_fn({'params': state.params}, input_ids, train=False)
decoded = tokenizer.decode(jnp.argmax(out, axis=-1)[0])
print(decoded)


 of the is not


In [18]:
# print(sample_out.shape)
token_ids = sample_batch[0]
# token_ids = jnp.argmax(nn.softmax(sample_out), axis=-1)  # shape: [batch_size, seq_len]
print(token_ids.shape)
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
decoded = tokenizer.batch_decode(token_ids, skip_special_tokens=True)
for text in decoded:
    print(text)


(1, 127)
 game and follows the " Nameless " , a penal military unit serving the nation of Gallia during the Second Europan War who perform secret black operations and are pitted against the Imperial unit " Calamaty Raven " . 
 The game began development in 2010 , carrying over a large portion of the work done on Valkyria Chronicles II . While it retained the standard features of the series , it also underwent multiple adjustments , such as making the game more forgiving for series newcomers . Character designer Raita Honjou and composer Hitoshi Sakimoto both returned from previous entries , along with Valkyria Chronicles II director Takeshi


In [76]:
# token_ids = sample_batch[0]
token_ids_in = sample_batch[0]
token_ids_out = jnp.argmax(nn.softmax(sample_out), axis=-1)  # shape: [batch_size, seq_len]
from transformers import GPT2Tokenizer

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

decoded_in = tokenizer.batch_decode(token_ids_in, skip_special_tokens=True)
decoded_out = tokenizer.batch_decode(token_ids_out, skip_special_tokens=True)
# print(decoded_in)
# print(decoded_out)

from itertools import zip_longest

for in_str, out_str in zip(decoded_in, decoded_out):
    in_words = in_str.split()
    out_words = out_str.split()
    for in_word, out_word in zip_longest(in_words, out_words, fillvalue=""):
        print(f"{in_word:<15} | {out_word}")
    print("-" * 40)


=               | �---
Valkyria        | polls--
Chronicles      | Democrats
III             | ------
=               | Democrats-------------
Senjō           | Democratsn----
no              | Democrats--
Valkyria        | Democrats
3               | Democratsn-------
:               | -------
Unrecorded      | --
Chronicles      | --
(               | -
Japanese        | b
:               | ----
戦場のヴァルキュリア3     | -
,               | -
lit             | -
.               | --
Valkyria        | --
of              | -
the             | -
Battlefield     | -
3               | --
)               | -
,               | -
commonly        | --
referred        | we
to              | --
as              | we-b
Valkyria        | -
Chronicles      | --
III             | �
outside         | -
Japan           | -
,               | �
is              | �
a               | -
tactical        | �
role            | we
@-@             | I
playing         | we
video           | -
game            | �
develop

In [19]:
print(len(optim_hp_list), len(all_mh))

13 39


In [24]:
optim_names = ['SGD', 'SGD-SAM', 'ADAM', 'ADAM-SAM-R0', 'ADAM-SAM', 'ADAM-UB-1e0', 'ADAM-UB-1e0-SAM', 'ADAM-UB-5e-1','ADAM-UB-1e-1', 'ADAM-UB-5e-2', 'ADAM-UB-1e-2', 'ADAM-UB-5e-3','ADAM-UB-1e-3']
stat_names = ['train_accuracy', 'test_accuracy', 'lse']
for i in range(len(optim_hp_list)):
    stats = [ 0 for j in range(len(stat_names))]
    for j in range(len(stats)):
        for k in range(3*i, 3*i+3):
            stats[j] += all_mh[k][stat_names[j]][-1]
    out = f"{optim_names[i]}"
    for j in range(len(stats)):
        out += f", {stat_names[j]}:{stats[j]/3}"
    print(out)
    # print(optim_hp_list[i], np.mean(all_mh[3*i:3*i+3]['train_accuracy'][-1]), np.mean(all_mh[3*i:3*i+3]['test_accuracy'][-1]))

SGD, train_accuracy:0.9998698234558105, test_accuracy:0.5381667017936707, lse:60.333333333333336
SGD-SAM, train_accuracy:0.9996744990348816, test_accuracy:0.5221666693687439, lse:62.0
ADAM, train_accuracy:0.9994141459465027, test_accuracy:0.5755000114440918, lse:1326.3333333333333
ADAM-SAM-R0, train_accuracy:0.9992188215255737, test_accuracy:0.5898333787918091, lse:1417.3333333333333
ADAM-SAM, train_accuracy:0.9994140863418579, test_accuracy:0.5566667318344116, lse:1307.0
ADAM-UB-1e0, train_accuracy:0.9991536140441895, test_accuracy:0.5726667642593384, lse:1045.6666666666667
ADAM-UB-1e0-SAM, train_accuracy:0.9995443224906921, test_accuracy:0.5693333745002747, lse:1139.0
ADAM-UB-5e-1, train_accuracy:0.9994140863418579, test_accuracy:0.5721666812896729, lse:1001.0
ADAM-UB-1e-1, train_accuracy:0.9996744990348816, test_accuracy:0.5730000734329224, lse:1139.6666666666667
ADAM-UB-5e-2, train_accuracy:0.9992188215255737, test_accuracy:0.561333417892456, lse:889.0
ADAM-UB-1e-2, train_accuracy:

In [None]:

for i in range(len(all_mh)):
    print(hyp_list[i][0], all_mh[i]['train_accuracy'][-1], all_mh[i]['test_accuracy'][-1])

In [16]:

for i in range(len(all_mh)):
    print(hyp_list[i][0], all_mh[i]['train_accuracy'][-1], all_mh[i]['test_accuracy'][-1])

(0.1, 0.0, 0.0, 0.0, False, 0.0, 1) 0.99921876 0.508
(0.1, 0.0, 0.0, 0.0, True, 0.1, 1) 0.99902344 0.5245
(0.005, 0.9, 0.99, 0.0, False, 0.0, 1) 0.99921876 0.5705
(0.005, 0.9, 0.99, 0.0, True, 0.0, 1) 0.99902344 0.583
(0.005, 0.9, 0.99, 0.0, True, 0.001, 1) 0.99902344 0.57750005
(0.005, 0.9, 0.99, -1.0, False, 0.0, 1) 0.9996094 0.586
(0.005, 0.9, 0.99, -1.0, True, 0.001, 1) 0.9996094 0.56450003
(0.005, 0.9, 0.99, -0.5, False, 0.0, 1) 0.9996094 0.5755
(0.005, 0.9, 0.99, -0.1, False, 0.0, 1) 0.9998047 0.573
(0.005, 0.9, 0.99, -0.05, False, 0.0, 1) 0.9994141 0.58100003
(0.005, 0.9, 0.99, -0.01, False, 0.0, 1) 0.99902344 0.54700005
(0.005, 0.9, 0.99, -0.005, False, 0.0, 1) 0.9142578 0.3535
(0.005, 0.9, 0.99, -0.001, False, 0.0, 1) 0.09980469 0.108500004


# Analysis