In [1]:
import os
os.chdir("../")
print("Moved up")

Moved up


In [2]:
import importlib
import copy
import pickle
import typing as tp
from functools import partial

import jax
import jax.numpy as jnp
import jax.flatten_util as fu
from flax import linen as nn  # Linen API
import numpy as np
import matplotlib.pyplot as plt
import optax
import math

from tqdm import tqdm
import time

import lib_data
import utils
import modules
import callbacks

%env XLA_PYTHON_CLIENT_MEM_FRACTION=.9
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

print("devices", jax.devices())

shade_colours = plt.get_cmap('Set3')
dark_colours = plt.get_cmap('tab10')
all_colours = plt.get_cmap('tab20')

def light_colours(i):
    return all_colours(2*i+1)

2025-04-30 21:37:45.461095: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746045465.475055  351374 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746045465.479436  351374 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


env: XLA_PYTHON_CLIENT_MEM_FRACTION=.9
devices [CudaDevice(id=0)]


In [None]:
# %env XLA_FLAGS=--xla_gpu_deterministic_ops=true
# %env XLA_FLAGS=--xla_gpu_deterministic_reductions=true
%env XLA_FLAGS=--xla_gpu_deterministic_ops=true --xla_gpu_deterministic_reductions=true

# Dataset

In [4]:
#-----------------------------------------------------------------------------------------------------------------------------
n_out = 1

n_train: int = 20000
n_eval: int = 1000
n_hess: int = 10

importlib.reload(lib_data)
def __get_datasets():
    datasets = lib_data.get_wikitext2_dataset(block_size=128, max_train_samples=n_train, max_eval_samples=n_eval)

    print("Train:", len(datasets[0]), " Eval:", len(datasets[1]), " Hess:", len(datasets[2]))
    x, y = datasets[0][0]
    print("Input shape:", x.shape, "Target shape:", y.shape)

    data_name = "wiki2_"+str(n_out)+"cl_"+str(n_train) + "_" + str(n_eval) 
    # print(len(datasets))
    return datasets, data_name

datasets, data_name = __get_datasets()

Loading tokenized dataset from disk...
Building LM datasets...
Flattening token sequences...
Total 2391884 tokens. Creating 18686 chunks with stride 128
Flattening token sequences...
Total 283287 tokens. Creating 1000 chunks with stride 128
Train: 18686  Eval: 1000  Hess: 18686
Input shape: (127,) Target shape: (127,)


# Architecture

In [5]:
#-----------------------------------------------------------------------------------------------------------------------------
importlib.reload(modules)
from ml_collections import ConfigDict

class MLPBlock(nn.Module):
    config: ConfigDict
    train: bool

    @nn.compact
    def __call__(self, x):
        features = x.shape[-1]
        x = nn.LayerNorm(dtype=self.config.dtype)(x)
        x = nn.Dense(self.config.mlp_expansion * features, dtype=self.config.dtype)(x)
        x = nn.gelu(x)
        x = nn.Dense(features, dtype=self.config.dtype)(x)
        x = nn.Dropout(rate=self.config.dropout_rate)(x, deterministic=not self.train)
        return x

def dot_product_attention(query, key, value, mask, softmax_dtype=jnp.float32):
    scale = query.shape[-1] ** -0.5
    query = query.astype(softmax_dtype) * scale
    key = key.astype(softmax_dtype)
    weights = jnp.einsum("...qhd,...khd->...hqk", query, key)
    if mask is not None:
        weights = jnp.where(mask, weights, jnp.finfo(softmax_dtype).min)
    weights = nn.softmax(weights, axis=-1).astype(query.dtype)
    return jnp.einsum("...hqk,...khd->...qhd", weights, value)

class AttentionBlock(nn.Module):
    config: ConfigDict
    mask: tp.Optional[jax.Array]
    train: bool

    @nn.compact
    def __call__(self, x):
        features = x.shape[-1]
        x = nn.LayerNorm(dtype=self.config.dtype)(x)
        qkv = nn.DenseGeneral(
            features=(self.config.num_heads, self.config.head_dim * 3),
            axis=-1, dtype=self.config.dtype
        )(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)
        attn = dot_product_attention(q, k, v, mask=self.mask, softmax_dtype=self.config.softmax_dtype)
        x = nn.DenseGeneral(features=features, axis=(-2, -1), dtype=self.config.dtype)(attn)
        x = nn.Dropout(rate=self.config.dropout_rate)(x, deterministic=not self.train)
        return x

class TransformerBlock(nn.Module):
    config: ConfigDict
    mask: tp.Optional[jax.Array]
    train: bool

    @nn.compact
    def __call__(self, x):
        mlp = MLPBlock
        if "MLP" in self.config.remat:
            mlp = nn.remat(mlp, prevent_cse=False)
        attn = AttentionBlock
        if "Attn" in self.config.remat:
            attn = nn.remat(attn, prevent_cse=False)

        x = x + attn(config=self.config, mask=self.mask, train=self.train)(x)
        x = x + mlp(config=self.config, train=self.train)(x)
        return x

class Transformer(nn.Module):
    config: ConfigDict

    @nn.compact
    def __call__(self, x, mask=None, train=True):
        if mask is None and self.config.causal_mask:
            mask = nn.make_causal_mask(x, dtype=jnp.bool_)

        x = nn.Embed(self.config.vocab_size, self.config.hidden_size, dtype=self.config.dtype)(x)
        pos_emb = self.param("pos_emb", nn.initializers.normal(0.02),
                             (self.config.max_seq_len, self.config.hidden_size)).astype(self.config.dtype)
        x += pos_emb[None, :x.shape[1]]

        block_fn = functools.partial(TransformerBlock, config=self.config, mask=mask, train=train)

        if self.config.scan_layers:
            block = block_fn(name="block")
            x, _ = nn.scan(
                lambda module, carry, _: (module(carry), None),
                variable_axes={"params": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.num_layers
            )(block, x, ())
        else:
            for i in range(self.config.num_layers):
                x = block_fn(name=f"block_{i}")(x)

        x = nn.LayerNorm(dtype=self.config.dtype)(x)
        x = nn.Dense(self.config.num_outputs, dtype=self.config.dtype)(x)
        return x.astype(jnp.float32)

In [6]:
def __get_arch__():

    config = ConfigDict()
    config.vocab_size = 50257
    config.hidden_size = 256
    config.num_layers = 4
    config.num_heads = 4
    config.head_dim = 64
    config.mlp_expansion = 4
    config.dropout_rate = 0.1
    config.max_seq_len = 512
    config.num_outputs = 50257
    config.dtype = jnp.float32
    config.causal_mask = True
    config.softmax_dtype = jnp.float32
    config.remat = ["MLP", "Attn"]
    config.scan_layers = False

    model = Transformer(config)
    model_name = f"Transformer_L{config.num_layers}_H{config.hidden_size}_Heads{config.num_heads}"
    return model, model_name

model_arch, model_name = __get_arch__()
print(model_name)

Transformer_L4_H256_Heads4


# Optimizer

In [7]:
#-----------------------------------------------------------------------------------------------------------------------------
importlib.reload(modules)
from optax import contrib

def __get_optim__(warmup_steps, lr, b1, b2, b3, option="", rho=None, sync_period=1):
    # warmup_steps, lr, b1, b2, b3 = hyps['warmup_steps'], hyps['lr'], hyps['b1'], hyps['b2'], hyps['b3']
    if option == 'sam':
        assert rho is not None
        warmup_scheduler = optax.linear_schedule(init_value=0.0, end_value=lr,
                                                transition_steps=warmup_steps,
                                                transition_begin=0,)
        constant_scheduler = optax.constant_schedule(lr)
        lr_scheduler = optax.join_schedules([warmup_scheduler, constant_scheduler], boundaries=[warmup_steps])
        base_opt = modules.get_sgd_optimizer(lr_scheduler, b1, b2, b3, verbose=False)
        adv_opt = modules.get_sgd_optimizer(rho, b1, b2, b3, verbose=False)
        optimizer = contrib.sam(base_opt, adv_opt, sync_period=sync_period, opaque_mode=True) # sam opt
        optim_name = f"sgdFam-SAM_1b{b1}_2b{b2}_3b{b3}_lr{lr}_warmup{warmup_steps}_rho{rho}_syncT{sync_period}"

    else:
        warmup_scheduler = optax.linear_schedule(init_value=0.0, end_value=lr,
                                                transition_steps=warmup_steps,
                                                transition_begin=0,)
        constant_scheduler = optax.constant_schedule(lr)
        lr_scheduler = optax.join_schedules([warmup_scheduler, constant_scheduler], boundaries=[warmup_steps])
        optimizer = modules.get_sgd_optimizer(lr_scheduler, b1, b2, b3, verbose=False)
        optim_name = f"sgdFam_1b{b1}_2b{b2}_3b{b3}_lr{lr}_warmup{warmup_steps}"
        
    return optimizer, optim_name

optimizer, optim_name = __get_optim__(2, 0.1, 0, 0, 0)

# Model Params (Fixed and Tuned)

In [8]:
#-----------------------------------------------------------------------------------------------------------------------------
import itertools
### FIXED
warmup_steps = 2
bs = 32
eval_bs = 32
n_epochs = 2000
loss_fn = optax.softmax_cross_entropy_with_integer_labels

# FLEXIBLE
# lr = 0.005
# beta_list = [(0., 0., 0.), (0., 0.99, 0.)]
# optim_hp = LR, B1, B2, B3, SAM, Rho, sync_period
optim_hp_list = [
    (5e-3, 0.9, 0.99, 0., False, 0., 1), 
]
seed_list = [x for x in range(1)]
# sam_list = [False, True]
s = [optim_hp_list, seed_list]
hyp_list = list(itertools.product(*s))
print(len(hyp_list))

1


# Callbacks

In [9]:
#-----------------------------------------------------------------------------------------------------------------------------
sws = 5
cb_freq = 1
hess_freq = int(1e8) # really large
importlib.reload(callbacks)
def __get_cbs__(state, compute_hessian=False):
    cbs = []
    cbs.append(callbacks.saveWeightsCB(sws, grad=True))
    # cbs.append(callbacks.thinCB(thin_freq=cb_freq))

    if compute_hessian:
        hvpCB = callbacks.hvpCB(loss_fn=loss_fn, batches=(datasets[2].data[:n_hess], datasets[2].targets[:n_hess]), 
                            save_freq=hess_freq, hess_bs=n_hess, state=state, bn=False)
        cbs.append(hvpCB)   
        specCB = callbacks.spectrumCB(n_eigs=20, n_evecs=10, 
                    loss_fn=loss_fn, seed=seed, hvpCB=hvpCB, save_freq=hess_freq, verbose=False)
        cbs.append(specCB)

        esCB = callbacks.earlyStopCB(acc_threshold=0.999, cbs=None, min_eps=sws, max_eps=n_epochs,conseq_eps=3,
                                 final_cbs=[hvpCB, specCB], verbose=False, low_eps=max(sws, 100), low_thresh=0.11, )
    else:
        esCB = callbacks.earlyStopCB(acc_threshold=0.999, cbs=None, min_eps=sws, max_eps=n_epochs, conseq_eps=5,
                                 verbose=False, low_eps=max(sws, 100), low_thresh=0., )
    cbs.append(esCB)
    return cbs


# Train State

In [16]:
import flax
import jax.numpy as jnp
from clu.metrics import Metric
import optax


@flax.struct.dataclass
class Perplexity(Metric):
  """Computes perplexity from logits and integer labels.

  This assumes logits have shape [..., vocab_size] and labels have shape [...].
  """

  total_log_likelihood: jnp.ndarray
  total_tokens: jnp.ndarray

  @classmethod
  def empty(cls) -> Perplexity:
    return cls(
        total_log_likelihood=jnp.array(0.0, dtype=jnp.float32),
        total_tokens=jnp.array(0, dtype=jnp.int32)
    )

  @classmethod
  def from_model_output(
      cls, *, logits: jnp.ndarray, labels: jnp.ndarray, **kwargs
  ) -> Perplexity:
    if logits.ndim != labels.ndim + 1 or labels.dtype != jnp.int32:
      raise ValueError(
          f"Expected labels.dtype==jnp.int32 and logits.ndim={logits.ndim}=="
          f"labels.ndim+1={labels.ndim + 1}"
      )

    # Flatten for token-level cross entropy
    vocab_size = logits.shape[-1]
    logits_flat = logits.reshape(-1, vocab_size)
    labels_flat = labels.reshape(-1)

    # Negative log-likelihood per token
    nll = optax.softmax_cross_entropy_with_integer_labels(logits_flat, labels_flat)

    return cls(
        total_log_likelihood=jnp.sum(nll),
        total_tokens=labels_flat.size
    )

  def merge(self, other: Perplexity) -> Perplexity:
    return Perplexity(
        total_log_likelihood=self.total_log_likelihood + other.total_log_likelihood,
        total_tokens=self.total_tokens + other.total_tokens
    )

  def compute(self) -> jnp.ndarray:
    avg_nll = self.total_log_likelihood / self.total_tokens
    return jnp.exp(avg_nll)


In [17]:
from flax import struct                # Flax dataclasses
from clu import metrics
from flax.training import train_state  # Useful dataclass to keep train state
importlib.reload(modules)

@struct.dataclass
class Metrics(metrics.Collection):
    accuracy: metrics.Accuracy
    perplexity: Perplexity
    loss: metrics.Average.from_output('loss')

class TrainState(train_state.TrainState):
    metrics: Metrics
    rng: jax.Array

class TrainStateBN(train_state.TrainState):
    metrics: Metrics
    batch_stats: tp.Any
    rng: jax.Array

class TrainStateSAM(modules.TrainStateSAM):
    metrics: Metrics
    batch_stats: tp.Any
    rng: jax.Array

def create_train_state(model, optimizer, inputs, rng, option=""):
    """Creates an initial `TrainState`."""
    rng, model_rng = jax.random.split(rng)
    if option == "":
        params = model.init(model_rng, jnp.ones_like(inputs[0][jnp.newaxis, :]))['params'] # initialize parameters by passing a template image
        
        tx = optimizer
        return TrainState.create(
          apply_fn=model.apply, params=params, tx=tx, metrics=Metrics.empty(), rng=rng)
        
    elif option == "bn":
        variables = model.init(model_rng, jnp.ones_like(inputs[0][jnp.newaxis, :])) # initialize parameters by passing a template image
        params = variables['params']
        batch_stats = variables['batch_stats']
        
        tx = optimizer
        return TrainStateBN.create(
          apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats, 
          metrics=Metrics.empty(), rng=rng)
        
    elif option == "sam":
        variables = model.init(model_rng, jnp.ones_like(inputs[0][jnp.newaxis, :])) # initialize parameters by passing a template image
        params = variables['params']
        batch_stats = variables['batch_stats']
        
        tx = optimizer
        return TrainStateSAM.create(
          apply_fn=model.apply, params=params, tx=tx, batch_stats=batch_stats, 
          metrics=Metrics.empty(), rng=rng)
    else:
        raise NotImplementedError
        
        

# Training

In [18]:
import functools

In [None]:
#-----------------------------------------------------------------------------------------------------------------------------
import training
importlib.reload(training)

load_files = False
compute_hessian = False
force_train = True

all_mh = []
all_exp_names = []

for hyp in hyp_list:
    
    metrics_history = {'train_loss': [],
                   'train_accuracy': [],
                   'train_perplexity': [],
                   'test_loss': [],
                   'test_accuracy': [],
                   'test_perplexity': [],
                      }

    lr, b1, b2, b3, sam, sam_rho, sync_T = hyp[0]
    seed = hyp[1]
    option = 'sam' if sam else ""
    
    if datasets is None:
        datasets, data_name = __get_datasets__()
    
    train_loader = lib_data.NumpyLoader(datasets[0], batch_size=bs, shuffle=True)
    for sample_batch in train_loader:
        break
    
    test_loader = lib_data.NumpyLoader(datasets[1], batch_size=eval_bs)
    dataloaders = [train_loader, test_loader]
    
    model, model_name = __get_arch__()
    model_name += "_seed"+str(seed)

    optim, optim_name = __get_optim__(warmup_steps, lr, b1, b2, b3, option=option, rho=sam_rho, sync_period=sync_T)
    optim_name += f"_epochs{n_epochs}_bs{bs}"

    init_rng = jax.random.PRNGKey(seed)
    state = create_train_state(model, optim, sample_batch[0], init_rng, option=option)
    del init_rng  # Must not be used anymore.
    
    cbs = __get_cbs__(state, compute_hessian=compute_hessian)
    cb_name_str = utils.get_callback_name_str(cbs)
    cb_name_list = utils.get_callback_name_list(cbs)
    # break
    num_params = utils.count_params(state.params)
    print("num params", num_params)

    experiment_name = utils.get_now() + "_" + data_name + "_" + model_name + "_" + optim_name
    
    try:
        if force_train:
            raise FileNotFoundError
        experiment_name, lse = utils.find_latest_exp(experiment_name, n_epochs, save_freq=cb_freq, 
                                                   cbs=cb_name_list, unknown_lse=True, verbose=False)
        metrics_history = utils.load_thing("traj/" + experiment_name + "/metrics.pkl")
        print(f"tr_acc: {metrics_history['train_accuracy'][-1]:0%}, te_acc: {metrics_history['test_accuracy'][-1]:0%}")
        metrics_history['lse'] = [lse]
        if compute_hessian:
            eigvals = utils.load_thing("traj/" + experiment_name + "/eigvals.pkl")
            metrics_history['eigvals'] = eigvals
            print(f"sharp: {metrics_history['eigvals'][-1][0]}")

    except FileNotFoundError:
        metrics_history = training.train_model(state, model, loss_fn, metrics_history, n_epochs, dataloaders, \
                                                   experiment_name, cbs, option=option, force_fb=False, tqdm_over_epochs=1, 
                                              eval_freq=1, gradient_accumulation=16)         
        
    all_mh.append(metrics_history)
    all_exp_names.append(experiment_name)
    
    print(experiment_name, "complete")
    print("\n ---------------------------------------------------------------------------------------------------------\n")
# Training: datasets, hps, arch_func, optim_func, cb_func, -> train model


In [19]:
print(len(optim_hp_list), len(all_mh))

13 39


In [24]:
optim_names = ['SGD', 'SGD-SAM', 'ADAM', 'ADAM-SAM-R0', 'ADAM-SAM', 'ADAM-UB-1e0', 'ADAM-UB-1e0-SAM', 'ADAM-UB-5e-1','ADAM-UB-1e-1', 'ADAM-UB-5e-2', 'ADAM-UB-1e-2', 'ADAM-UB-5e-3','ADAM-UB-1e-3']
stat_names = ['train_accuracy', 'test_accuracy', 'lse']
for i in range(len(optim_hp_list)):
    stats = [ 0 for j in range(len(stat_names))]
    for j in range(len(stats)):
        for k in range(3*i, 3*i+3):
            stats[j] += all_mh[k][stat_names[j]][-1]
    out = f"{optim_names[i]}"
    for j in range(len(stats)):
        out += f", {stat_names[j]}:{stats[j]/3}"
    print(out)
    # print(optim_hp_list[i], np.mean(all_mh[3*i:3*i+3]['train_accuracy'][-1]), np.mean(all_mh[3*i:3*i+3]['test_accuracy'][-1]))

SGD, train_accuracy:0.9998698234558105, test_accuracy:0.5381667017936707, lse:60.333333333333336
SGD-SAM, train_accuracy:0.9996744990348816, test_accuracy:0.5221666693687439, lse:62.0
ADAM, train_accuracy:0.9994141459465027, test_accuracy:0.5755000114440918, lse:1326.3333333333333
ADAM-SAM-R0, train_accuracy:0.9992188215255737, test_accuracy:0.5898333787918091, lse:1417.3333333333333
ADAM-SAM, train_accuracy:0.9994140863418579, test_accuracy:0.5566667318344116, lse:1307.0
ADAM-UB-1e0, train_accuracy:0.9991536140441895, test_accuracy:0.5726667642593384, lse:1045.6666666666667
ADAM-UB-1e0-SAM, train_accuracy:0.9995443224906921, test_accuracy:0.5693333745002747, lse:1139.0
ADAM-UB-5e-1, train_accuracy:0.9994140863418579, test_accuracy:0.5721666812896729, lse:1001.0
ADAM-UB-1e-1, train_accuracy:0.9996744990348816, test_accuracy:0.5730000734329224, lse:1139.6666666666667
ADAM-UB-5e-2, train_accuracy:0.9992188215255737, test_accuracy:0.561333417892456, lse:889.0
ADAM-UB-1e-2, train_accuracy:

In [None]:

for i in range(len(all_mh)):
    print(hyp_list[i][0], all_mh[i]['train_accuracy'][-1], all_mh[i]['test_accuracy'][-1])

In [16]:

for i in range(len(all_mh)):
    print(hyp_list[i][0], all_mh[i]['train_accuracy'][-1], all_mh[i]['test_accuracy'][-1])

(0.1, 0.0, 0.0, 0.0, False, 0.0, 1) 0.99921876 0.508
(0.1, 0.0, 0.0, 0.0, True, 0.1, 1) 0.99902344 0.5245
(0.005, 0.9, 0.99, 0.0, False, 0.0, 1) 0.99921876 0.5705
(0.005, 0.9, 0.99, 0.0, True, 0.0, 1) 0.99902344 0.583
(0.005, 0.9, 0.99, 0.0, True, 0.001, 1) 0.99902344 0.57750005
(0.005, 0.9, 0.99, -1.0, False, 0.0, 1) 0.9996094 0.586
(0.005, 0.9, 0.99, -1.0, True, 0.001, 1) 0.9996094 0.56450003
(0.005, 0.9, 0.99, -0.5, False, 0.0, 1) 0.9996094 0.5755
(0.005, 0.9, 0.99, -0.1, False, 0.0, 1) 0.9998047 0.573
(0.005, 0.9, 0.99, -0.05, False, 0.0, 1) 0.9994141 0.58100003
(0.005, 0.9, 0.99, -0.01, False, 0.0, 1) 0.99902344 0.54700005
(0.005, 0.9, 0.99, -0.005, False, 0.0, 1) 0.9142578 0.3535
(0.005, 0.9, 0.99, -0.001, False, 0.0, 1) 0.09980469 0.108500004


# Analysis