In [1]:
from model.LogTransformer import Transformer
from ml_collections import ConfigDict
import jax.numpy as jnp
import optax
import jax
from utils.single_gpu import Batch, TrainState, accumulate_gradients, print_metrics
import numpy as np

from typing import Any, Dict, Tuple
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]
import functools 

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
data_config = ConfigDict(
    dict(
        batch_size=64,
        seq_len=512,
        vocab_size=512//2,
    )
)
model_config = ConfigDict(
    dict(
        hidden_size=1024,
        dropout_rate=0.1,
        mlp_expansion=4,
        num_layers=12,
        head_dim=128,
        causal_mask=True,
        max_seq_len=data_config.seq_len,
        vocab_size=data_config.vocab_size,
        num_outputs=data_config.vocab_size,
        dtype=jnp.bfloat16,
        out_dtype=jnp.float32,
        softmax_dtype=jnp.float32,
        scan_layers=True,
        #remat=(""),
        remat=("MLP", "Attn"),
        determinant=True,
    )
)
model_config.num_heads = model_config.hidden_size // model_config.head_dim
optimizer_config = ConfigDict(
    dict(
        learning_rate=4e-4,
        num_minibatches=4,
    )
)
config = ConfigDict(
    dict(
        model=model_config,
        optimizer=optimizer_config,
        data=data_config,
        seed=42,
    )
)

In [10]:
model = Transformer(config=config.model)
optimizer = optax.adam(
    learning_rate=optax.warmup_exponential_decay_schedule(
        init_value=0,
        peak_value=config.optimizer.learning_rate,
        warmup_steps=10,
        transition_steps=1,
        decay_rate=0.99,
    )
)

In [11]:
tokens = jax.random.randint(
    jax.random.PRNGKey(0),
    (config.data.batch_size, config.data.seq_len),
    1,
    config.data.vocab_size,
)
batch_transformer = Batch(
    # 把第一个元素设成0，其他的后移，再把最后一个元素扔掉
    inputs=jnp.pad(tokens[:, :-1], ((0, 0), (1, 0)), constant_values=0),
    labels=tokens,
)

In [12]:
batch_transformer.inputs.shape

(64, 512)

In [13]:
model_rng, state_rng = jax.random.split(jax.random.PRNGKey(config.seed))
params = model.init(
    model_rng,
    batch_transformer.inputs[: config.data.batch_size // config.optimizer.num_minibatches],
    train=False,
)["params"]
state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimizer,
    rng=state_rng,
)

In [16]:
out = state.apply_fn({"params": state.params}, batch_transformer.inputs, train=False, rngs={"dropout": state.rng})

In [17]:
out.shape

(64,)

In [17]:
from netket import jax as nkjax

M = jax.random.normal(jax.random.PRNGKey(0), (4,4))
print(M.dtype)
A = nkjax.logdet_cmplx(M)

float64


In [6]:
def get_num_params(state: TrainState) -> int:
    return sum(np.prod(x.shape) for x in jax.tree_util.tree_leaves(state.params))


print(f"Number of parameters: {get_num_params(state):_}")

Number of parameters: 155_877_376


In [7]:
tokens

Array([[ 699,  693,  193, ...,  141, 1855,  430],
       [1930, 1630,  811, ..., 1884, 1227,  469],
       [ 521,   32, 1572, ..., 1088, 1466, 2023],
       ...,
       [1312, 1531,  948, ...,  114,  988, 1095],
       [1387,  462, 2036, ..., 1100, 1100,  700],
       [ 680, 1930,  898, ..., 1259, 1643, 1876]], dtype=int64)

In [7]:
def next_token_pred_loss(
    params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
) -> Tuple[PyTree, Metrics]:
    """Next token prediction loss function."""
    logits = apply_fn({"params": params}, batch.inputs, train=True, rngs={"dropout": rng})
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
    correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
    batch_size = np.prod(batch.labels.shape)
    step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
    loss = loss.mean()
    return loss, step_metrics

In [8]:
@functools.partial(
    jax.jit,
    donate_argnames=(
        "state",
        "metrics",
    ),
)
def train_step_transformer(
    state: TrainState,
    metrics: Metrics | None,
    batch: Batch,
) -> Tuple[TrainState, Metrics]:
    """Training step function.

    Executes a full training step with gradient accumulation for the next-token prediction task.

    Args:
        state: Current training state.
        metrics: Current metrics, accumulated from previous training steps.
        batch: Training batch.

    Returns:
        Tuple with updated training state (parameters, optimizer state, etc.) and metrics.
    """
    # Split the random number generator for the current step.
    rng, step_rng = jax.random.split(state.rng)
    # Determine gradients and metrics for the full batch.
    grads, step_metrics = accumulate_gradients(
        state,
        batch,
        step_rng,
        config.optimizer.num_minibatches,
        loss_fn=next_token_pred_loss,
        use_scan=True,
    )
    # Optimizer step.
    new_state = state.apply_gradients(grads=grads, rng=rng)
    # Accumulate metrics across training steps.
    if metrics is None:
        metrics = step_metrics
    else:
        metrics = jax.tree_map(jnp.add, metrics, step_metrics)
    return new_state, metrics

In [9]:
_, metric_shapes = jax.eval_shape(
    train_step_transformer,
    state,
    None,
    batch_transformer,
)
metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)

In [12]:
from tqdm.auto import tqdm

for _ in tqdm(range(4)):
    state, metrics = train_step_transformer(state, metrics, batch_transformer)
final_metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state, final_metrics = train_step_transformer(state, final_metrics, batch_transformer)
print_metrics(final_metrics, "Final metrics - Transformer")

100%|██████████| 4/4 [00:28<00:00,  7.02s/it]


 Final metrics - Transformer 
accuracy: 0.000641
loss: 7.778461


In [10]:
print(metric_shapes)
print(metrics["loss"])

{'accuracy': (ShapeDtypeStruct(shape=(), dtype=int64), ShapeDtypeStruct(shape=(), dtype=int64)), 'loss': (ShapeDtypeStruct(shape=(), dtype=float32), ShapeDtypeStruct(shape=(), dtype=int64))}
(Array(0., dtype=float32), Array(0, dtype=int64))


In [14]:
out = state.apply_fn({"params": state.params}, batch_transformer.inputs, train=False, rngs={"dropout": state.rng})

In [16]:
out.shape

(64, 512, 2048)

In [18]:
batch_transformer.inputs.shape

(64, 512)