In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import easydel as ed

2025-09-02 11:04:19,151	INFO worker.py:1942 -- Started a local Ray instance. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


In [2]:
import jax
import jax.numpy as jnp
from gidd_easydel.model import GiddForDiffusionLM, GiddConfig

hidden_size = 512
num_layers = 8

model = GiddForDiffusionLM(
    config=GiddConfig(
        vocab_size=131072,
        hidden_size=hidden_size,
        intermediate_size=4*hidden_size,
        num_hidden_layers=num_layers,
        num_attention_heads=8,
        head_dim=64,
        is_causal=False,
        max_position_embeddings=512,
        resid_scale=4.0,
        init_scale=0.4 / hidden_size**0.5,
        emb_init_scale=0.02,
        head_init_scale=0.0,
        weight_scaling=1.0,
        head_scaling=512 / hidden_size,
        use_qk_norm=True,
        gradient_checkpointing=ed.EasyDeLGradientCheckPointers.NONE,
        attn_mechanism="vanilla",
        attn_dtype=jnp.bfloat16,
        attention_bias=True,
        mlp_bias=True,
        # scan_layers=True,
    ),
    dtype=jnp.bfloat16,
    param_dtype=jnp.bfloat16,
    precision=jax.lax.Precision.HIGH,
    rngs=ed.Rngs(0),
)
# model = model.shard_model()  # Shard the newly created model across devices.
model_state = model.to_state()

E0000 00:00:1756803864.254794 4063027 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1756803864.258629 4063027 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1756803864.269895 4063027 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1756803864.269908 4063027 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1756803864.269909 4063027 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1756803864.269910 4063027 computation_placer.cc:177] computation placer already registered. Please check linka

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

In [3]:
import optax
import typing as tp
import chex

from copy import deepcopy
from gidd_easydel.diffusion_trainer import DiffusionTrainer, DiffusionConfig
from gidd_easydel.optimizer import lapropw

total_steps = 100000
warmup_steps = 2000
cooldown_steps = 10000
lr = 0.1
aux_lr = 0.02
adam_eps = 1e-8
weight_decay = 0.0



def wsd_lr_schedule(total_steps: int, base_lr: float, warmup_steps: int = 0, cooldown_steps: int = 0) -> tp.Callable[[chex.Numeric], chex.Numeric]:
    """
    Implements a warmup-stable-decay learning rate schedule.
    
    Args:
        base_lr (float): Base learning rate.
        warmup_steps (int): Number of steps for warmup.
        cooldown_steps (int): Number of steps for decay.
        total_steps (int): Total number of training steps.
        curr_step (int): Current training step.
    
    Returns:
        float: Adjusted learning rate for the current step.
    """

    def sqrt1m_schedule(init_value: float, decay_steps: int) -> tp.Callable[[chex.Numeric], chex.Numeric]:
        def schedule(count: chex.Numeric) -> chex.Numeric:
            count = jnp.minimum(count, decay_steps)
            return init_value * (1 - (count / max(1, decay_steps))**0.5)
        return schedule

    return optax.schedules.join_schedules([
        optax.schedules.linear_schedule(0, base_lr, warmup_steps),
        optax.schedules.constant_schedule(base_lr),
        sqrt1m_schedule(base_lr, cooldown_steps),
    ], [warmup_steps, total_steps - cooldown_steps])


class CustomDiffusionConfig(DiffusionConfig):
        # Hacky: override the `get_optimizer_and_scheduler` method to implement per-layer learning rates
        def get_optimizer_and_scheduler(self, steps):
            optimizer_kwargs = deepcopy(self.optimizer_kwargs)
            clip_grad = optimizer_kwargs.pop("clip_grad", None)

            bulk_schedule = wsd_lr_schedule(
                total_steps=total_steps,
                base_lr=lr / hidden_size,
                warmup_steps=warmup_steps,
                cooldown_steps=cooldown_steps,
            )
            aux_schedule = wsd_lr_schedule(
                total_steps=total_steps,
                base_lr=aux_lr,
                warmup_steps=warmup_steps,
                cooldown_steps=cooldown_steps,
            )

            def param_label_fn(params: tp.Any) -> str:
                def label_leaf(path: str, param: chex.Array) -> str:
                    path = ''.join(str(k) for k in path)

                    if "norm" in path:
                        return "ln_params"
                    elif "embed_tokens" in path:
                        return "emb_unemb_params"
                    elif "bias" in path:
                        return "bias_params"
                    elif "lm_head" in path:
                        return "emb_unemb_params"
                    elif param.ndim > 1:
                        return "bulk_params"
                    else:
                        raise ValueError(f"Unknown parameter type: {path}")

                labels = jax.tree.map_with_path(label_leaf, params)
                return labels


            opt_kwargs = dict(b1=0.9, b2=0.99, eps=adam_eps / hidden_size / num_layers)
            optimizer = optax.multi_transform({
                "bulk_params": lapropw(learning_rate=bulk_schedule, weight_decay=weight_decay * hidden_size, **opt_kwargs),
                "ln_params": lapropw(learning_rate=aux_schedule, weight_decay=0.0, **opt_kwargs),
                "bias_params": lapropw(learning_rate=aux_schedule, weight_decay=0.0, **opt_kwargs),
                "emb_unemb_params": lapropw(learning_rate=aux_schedule, weight_decay=0.0, **opt_kwargs),
            }, param_label_fn)

            if clip_grad:
                tx = optax.chain(
                    optax.clip_by_global_norm(clip_grad),
                    optimizer,
                )
            else:
                tx = optimizer

            if optimizer_kwargs.get("gradient_accumulation_steps", 0) > 1:
                tx = optax.MultiSteps(tx, optimizer_kwargs["gradient_accumulation_steps"])
            
            # the LR schedule returned here is only used for logging purposes
            return optimizer, bulk_schedule

In [4]:
arguments = CustomDiffusionConfig(
    ## Trainer arguments
    model_name="gidd",  # for wandb project name
    use_wandb=True,
    num_train_epochs=1,
    total_batch_size=8,
    do_last_save=True,
    max_sequence_length=512,
    # This is MANDATORY for streaming datasets. It tells the trainer how many
    # steps constitute one "epoch". Should be ~ (total_dataset_size // total_batch_size).
    per_epoch_training_steps=98_000_000,
    max_training_steps=total_steps,
    learning_rate=lr / hidden_size,
    optimizer=ed.EasyDeLOptimizers.ADAMW,
    scheduler=ed.EasyDeLSchedulers.COSINE,
    warmup_steps=warmup_steps,
    weight_decay=0.02,
    save_optimizer_state=True,
    clip_grad=1.0,
    report_steps=50,
    log_steps=100,
    metrics_aggregation="mean",
    # progress_bar_type="json",
    use_grain=False,
)

In [5]:
from transformers import AutoTokenizer
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("dvruette/nemotron-cc-bpe")

train_dataset = load_dataset("dvruette/nemotron-cc-65btok", split="train", streaming=True)

trainer = DiffusionTrainer(
        arguments=arguments,
        model_state=model_state,
        tokenizer=tokenizer,
        train_dataset=train_dataset,
        eval_dataset=None,
        seed=0,
        dtype=jnp.bfloat16,
    )

Resolving data files:   0%|          | 0/300 [00:00<?, ?it/s]

[34m[1mwandb[0m: Currently logged in as: [33mdvruette[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[38;5;99m(11:04:37 easydel.utils.helpers)[0m time took for configure dataloaders : [92m3.2410 sec[0m
[38;5;99m(11:04:37 easydel.utils.helpers)[0m time took for configure Model, Optimizer, Scheduler and Config : [94m0.1419 ms[0m
[38;5;99m(11:04:37 easydel.utils.helpers)[0m time took for configure sharded state : [94m333.0500 ms[0m
[38;5;99m(11:04:37 gidd_easydel.diffusion_trainer.diffusion_trainer)[0m Configuring functions for DiffusionTrainer...
[38;5;99m(11:04:37 gidd_easydel.diffusion_trainer.diffusion_trainer)[0m Functions configured successfully.
[38;5;99m(11:04:37 easydel.utils.helpers)[0m time took for configure functions and sharding them : [94m29.9489 ms[0m
[38;5;99m(11:04:37 gidd_easydel.diffusion_trainer.diffusion_trainer)[0m Initialized DiffusionTrainer


In [6]:
trainer.model_state.opt_state.__class__

optax.transforms._combining.PartitionState

In [33]:
from jax.experimental.array_serialization import serialization as array_serialization

In [39]:
trainer.model_state.opt_state.inner_states['bias_params'].inner_state[1]


EmptyState()

In [42]:
import optax
trainer.model_state.opt_state.inner_states['bias_params'].inner_state[1] == optax.EmptyState()

True

In [48]:
import jax

path_value_pairs, treedef = jax.tree_util.tree_flatten_with_path(trainer.model_state.opt_state)

path_strs = ["".join([str(p) for p in path]) for path, _ in path_value_pairs]

In [62]:
path_value_pairs[0][0]

(GetAttrKey(name='inner_states'),
 DictKey(key='bias_params'),
 GetAttrKey(name='inner_state'),
 SequenceKey(idx=0),
 GetAttrKey(name='count'))

In [50]:
treedef

PyTreeDef(CustomNode(namedtuple[PartitionState], [{'bias_params': CustomNode(namedtuple[MaskedState], [(CustomNode(namedtuple[ScaleByLapropState], [*, CustomNode(State[('lm_head', 'model')], [{'kernel': CustomNode(VariableState[(<class 'flax.nnx.variablelib.Param'>, ())], [CustomNode(namedtuple[MaskedNode], [])])}, {'embed_tokens': {'embedding': CustomNode(VariableState[(<class 'flax.nnx.variablelib.Param'>, ())], [CustomNode(namedtuple[MaskedNode], [])])}, 'layers': {0: {'attn_layernorm': {'kernel': CustomNode(VariableState[(<class 'flax.nnx.variablelib.Param'>, ())], [CustomNode(namedtuple[MaskedNode], [])])}, 'mlp': {'down_proj': {'bias': CustomNode(VariableState[(<class 'flax.nnx.variablelib.Param'>, ())], [*]), 'kernel': CustomNode(VariableState[(<class 'flax.nnx.variablelib.Param'>, ())], [CustomNode(namedtuple[MaskedNode], [])])}, 'up_proj': {'bias': CustomNode(VariableState[(<class 'flax.nnx.variablelib.Param'>, ())], [*]), 'kernel': CustomNode(VariableState[(<class 'flax.nnx.v

In [49]:
import random

# random.shuffle(path_value_pairs)

jax.tree_util.tree_unflatten(treedef, [val for _, val in path_value_pairs])

PartitionState(inner_states={'bias_params': MaskedState(inner_state=(ScaleByLapropState(count=Array(0, dtype=int32), mu=State({
  'lm_head': {
    'kernel': VariableState(
      type=Param,
      value=()
    )
  },
  'model': {
    'embed_tokens': {
      'embedding': VariableState(
        type=Param,
        value=()
      )
    },
    'layers': {
      0: {
        'attn_layernorm': {
          'kernel': VariableState(
            type=Param,
            value=()
          )
        },
        'mlp': {
          'down_proj': {
            'bias': VariableState( # 512 (2.0 KB)
              type=Param,
              value=Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                     0., 

In [37]:
path_value_pairs.__class__

list

In [30]:

# print("filtered path/vals: ", [(path_str, val) for path_str, (path, val) in zip(path_strs, path_value_pairs) if "['inner_state']['1']" in path_str])
path_strs

[".inner_states['bias_params'].inner_state[0].count",
 ".inner_states['bias_params'].inner_state[0].mu['model']['layers'][0]['mlp']['down_proj']['bias'].value",
 ".inner_states['bias_params'].inner_state[0].mu['model']['layers'][0]['mlp']['up_proj']['bias'].value",
 ".inner_states['bias_params'].inner_state[0].mu['model']['layers'][0]['self_attn']['k_bias'].value",
 ".inner_states['bias_params'].inner_state[0].mu['model']['layers'][0]['self_attn']['v_bias'].value",
 ".inner_states['bias_params'].inner_state[0].mu['model']['layers'][1]['mlp']['down_proj']['bias'].value",
 ".inner_states['bias_params'].inner_state[0].mu['model']['layers'][1]['mlp']['up_proj']['bias'].value",
 ".inner_states['bias_params'].inner_state[0].mu['model']['layers'][1]['self_attn']['k_bias'].value",
 ".inner_states['bias_params'].inner_state[0].mu['model']['layers'][1]['self_attn']['v_bias'].value",
 ".inner_states['bias_params'].inner_state[0].mu['model']['layers'][2]['mlp']['down_proj']['bias'].value",
 ".inne

In [11]:
import importlib
from eformer.pytree import PyTree, is_flatten, serialization, int_key_to_string
import eformer.pytree._tree_util
importlib.reload(eformer.pytree._tree_util)
from eformer.pytree._tree_util import flatten_dict, unflatten_dict

serialized = serialization.to_state_dict(trainer.model_state.opt_state)
# from_state_dict(trainer.model_state.opt_state, serialized)

In [8]:
serialized["inner_states"]["bulk_params"]["inner_state"]['2']

{'count': Array(0, dtype=int32)}

In [9]:
int_key_to_string(serialized["inner_states"]["bulk_params"]["inner_state"]['1'])

entering: keep_empty_nodes=True
testing instance: keep_empty_nodes=True
inside: keep_empty_nodes=True
empty: () True
keep empty: ()


{}

In [13]:
flat_dict = flatten_dict(serialized, keep_empty_nodes=True, sep=".")

In [17]:
unflat_dict = unflatten_dict(flat_dict, sep=".")

In [19]:
unflat_dict["inner_states"]["bulk_params"]["inner_state"]['1']

{}

In [29]:
trainer.model_state.opt_state.inner_states["bulk_params"][0][2]

ScaleByScheduleState(count=Array(0, dtype=int32))

In [51]:
trainer.model_state.opt_state.__class__, trainer.model_state.graphstate.__class__

(optax.transforms._combining.PartitionState, flax.nnx.statelib.State)

In [50]:
import gidd_easydel.model

import jax
import jax.numpy as jnp

path = "outputs/diffusion_trainer/testing/2025-09-01/gidd-L8-D512-H8-N512-bs=8-lr=0.1-testing/16-34-29/gidd/run-1000"

model_state = ed.EasyDeLState.load_state(
    path,
    dtype=jnp.bfloat16,
    param_dtype=jnp.bfloat16,
    precision=jax.lax.Precision.HIGH,
)

RuntimeError: await wasn't used with future

In [7]:
ckpt_dir = "/pub/hofmann-scratch/dvruette/gidd-checkpoints/gidd-checkpoints_us-east5/cooldown_exp_1/2025-08-30/gidd-L12-D768-H12-N2048-bs=128-T=16k-lr=0.5-cd=0.2/07-18-15"

In [9]:
from gidd_easydel.train import get_latest_checkpoint

ckpt_path, _ = get_latest_checkpoint(os.path.join(ckpt_dir, "gidd"))
ckpt_path

LocalPath('/pub/hofmann-scratch/dvruette/gidd-checkpoints/gidd-checkpoints_us-east5/cooldown_exp_1/2025-08-30/gidd-L12-D768-H12-N2048-bs=128-T=16k-lr=0.5-cd=0.2/07-18-15/gidd/run-16000')

In [None]:
import jax
import jax.numpy as jnp

model_state = ed.EasyDeLState.load_state(
    ckpt_path,
    dtype=jnp.bfloat16,
    param_dtype=jnp.bfloat16,
    precision=jax.lax.Precision.HIGH,
    
)

Loading shards:   0%|          | 0/171 [00:00<?, ?it/s]

[38;5;99m(16:08:05 easydel.infra.mixins.bridge)[0m Generation config file not found, using a generation config created from the model config.
[38;5;99m(16:08:08 easydel.infra.base_state)[0m Optimizer state loaded from /pub/hofmann-scratch/dvruette/gidd-checkpoints/gidd-checkpoints_us-east5/cooldown_exp_1/2025-08-30/gidd-L12-D768-H12-N2048-bs=128-T=16k-lr=0.5-cd=0.2/07-18-15/gidd/run-16000


In [15]:
model_state.mesh

Mesh(axis_sizes=(1, 1, 1, 1, 1), axis_names=('dp', 'fsdp', 'ep', 'tp', 'sp'), axis_types=(Auto, Auto, Auto, Auto, Auto))

In [22]:
model_state.model.model.layers[0].attn_layernorm

GiddRMSNorm( # Param: 6,144 (12.3 KB)
  config="GiddConfig(\n  vocab_size=131072,\n  hidden_size=768,\n  intermediate_size=3072,\n  num_hidden_layers=12,\n  num_attention_heads=12,\n  head_dim=64,\n  is_causal=False,\n  attn_soft_cap=30.0,\n  max_position_embeddings=2048,\n  resid_scale=4.0,\n  rms_norm_eps=1e-06,\n  use_qk_norm=True,\n  init_scale=0.014433756729740645,\n  emb_init_scale=0.02,\n  head_init_scale=0.0,\n  weight_scaling=1.0,\n  head_scaling=0.6666666666666666,\n  bos_token_id=0,\n  eos_token_id=1,\n  rope_theta=10000.0,\n  tie_word_embeddings=False,\n  gradient_checkpointing='',\n  rope_scaling=None,\n  scan_mlp_chunk_size=1024,\n  bits=None,\n  pretraining_tp=1,\n  attention_bias=True,\n  mlp_bias=True,\n  scan_layers=False,\n)",
  dtype=bfloat16,
  epsilon=1e-06,
  kernel=Param( # 6,144 (12.3 KB)
    value=Array(shape=(6144,), dtype=dtype(bfloat16))
  ),
  param_dtype=float32
)

In [16]:
input_ids = jnp.ones((1, 2048), dtype=jnp.int32) * 3

with model_state.mesh:
    model_state.model(input_ids=input_ids)

ValueError: Incompatible shapes for broadcasting: shapes=[(6144,), (1, 2048, 768)]