In [None]:
from md4.configs.hollow_md4.text8 import get_config
from collections.abc import Callable, Mapping, Sequence
import copy
import functools
from typing import Any

from absl import logging
from clu import metric_writers
from clu import metrics
from clu import parameter_overview
from clu import periodic_actions
from etils import epath
import flax
import flax.jax_utils as flax_utils
import flax.linen as nn
import grain.python as grain
import jax
from jax.experimental import checkify
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
from orbax import checkpoint as orbax_checkpoint

from md4 import input_pipeline
from md4 import input_pipeline_v2
from md4 import sampling
from md4 import utils
from md4.models import utils as model_utils

from md4.train import *
from md4.train import _get_checkpoint_manager

import pickle
from tqdm import tqdm

config = get_config()
workdir = "/root/md4/expt_contantlr_deeper__model_untie"

2025-04-04 23:26:07.458418: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743809167.471789   47256 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743809167.476108   47256 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1743809167.487916   47256 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1743809167.487930   47256 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1743809167.487932   47256 computation_placer.cc:177] computation placer alr

cv2 not found


In [None]:
workdir = epath.Path(workdir)
workdir.mkdir(parents=True, exist_ok=True)

rng = utils.get_rng(config.seed)
logging.info("Using random seed %s.", rng)
writer = metric_writers.create_default_writer(
    workdir, just_logging=jax.process_index() > 0
)

# Learning rate schedule.
assert config.batch_size % jax.device_count() == 0
per_device_batch_size = config.batch_size // jax.device_count()
num_train_steps = input_pipeline.get_num_train_steps(config)
steps_per_epoch = num_train_steps // config.num_epochs
logging.info("num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch)
schedule_fn = functools.partial(
    get_learning_rate,
    base_learning_rate=config.learning_rate,
    num_steps=num_train_steps,
    warmup_steps=config.warmup_steps,
    schedule_type=config.learning_rate_schedule,
)

# Build input pipeline.
rng, data_seed = jax.random.split(rng)
data_seed = int(
    jax.random.randint(data_seed, [], minval=0, maxval=np.iinfo(np.int32).max)
)
# The input pipeline runs on each process and loads data for local TPUs.
create_datasets = (
    input_pipeline_v2.create_datasets
    if config.get("use_v2_input_pipeline", None)
    else input_pipeline.create_datasets
)
train_loader, eval_loaders, dataset_info = create_datasets(config, data_seed)

train_iter = iter(train_loader)

# Initialize model.
rng, model_rng = jax.random.split(rng)
data_shape = input_pipeline.get_data_shape(config)
# Note: parameters are initialized in half precision if mixed_precision_training=True
# We could also try casting them to half precision here
model, optimizer, train_state, metrics_class = (
    create_train_state(  # pylint: disable=invalid-name
        config,
        model_rng,
        input_shape=(per_device_batch_size // config.num_microbatches,) + data_shape,
        schedule_fn=schedule_fn,
    )
)

# Set up checkpointing of the model and the input pipeline.
checkpoint_manager = _get_checkpoint_manager(config, workdir)

# Retrieve data from previous checkpoints if possible.
checkpointed_state = dict(train_state=train_state, train_iter=train_iter)
if checkpoint_manager.latest_step() is not None:
    checkpointed_state = checkpoint_manager.restore(
        checkpoint_manager.latest_step(), items=checkpointed_state
    )
train_state = checkpointed_state["train_state"]
train_iter = checkpointed_state["train_iter"]

I0000 00:00:1743809180.681189   47256 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 92886 MB memory:  -> device: 0, name: NVIDIA H100 NVL, pci bus id: 0000:03:00.0, compute capability: 9.0
I0000 00:00:1743809180.683373   47256 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 92898 MB memory:  -> device: 1, name: NVIDIA H100 NVL, pci bus id: 0000:04:00.0, compute capability: 9.0
I0000 00:00:1743809180.685348   47256 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 92898 MB memory:  -> device: 2, name: NVIDIA H100 NVL, pci bus id: 0000:63:00.0, compute capability: 9.0
I0000 00:00:1743809180.687302   47256 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 92898 MB memory:  -> device: 3, name: NVIDIA H100 NVL, pci bus id: 0000:64:00.0, compute capability: 9.0
I0000 00:00:1743809180.689231   47256 gpu_device.cc:2019] Created device /job:localhost/

In [None]:
# Distribute training.
train_state = flax_utils.replicate(train_state)
train_step_func = functools.partial(
    train_step,
    model=model,
    optimizer=optimizer,
    train_metrics_class=metrics_class,
    learning_rate_fn=schedule_fn,
    ema_rate=config.ema_rate,
    num_microbatches=config.get("num_microbatches", None),
)
if config.check_nans:
    train_step_func = checkify.checkify(train_step_func, errors=checkify.float_checks)
p_train_step = jax.pmap(train_step_func, axis_name="batch", donate_argnums=(0,))
p_eval_step = jax.pmap(
    functools.partial(
        eval_step,
        model=model,
        eval_metrics_class=metrics_class,
        ema_rate=config.ema_rate,
    ),
    axis_name="batch",
)

hooks = []
report_progress = periodic_actions.ReportProgress(
    num_train_steps=num_train_steps, writer=writer
)
if jax.process_index() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=workdir),
    ]
train_metrics = None

# Unreplicating from TPU is costly, so we only do it once at the start.
initial_step = int(flax.jax_utils.unreplicate(train_state.step))

In [4]:
batch = utils.reshape_batch(next(train_iter))

if config.check_nans:
    errs, (train_state, metrics_update) = p_train_step(
        train_state=train_state, batch=batch
    )
    errs.throw()
else:
    train_state, metrics_update = p_train_step(train_state=train_state, batch=batch)
metric_update = flax_utils.unreplicate(metrics_update)

train_metrics = (
    metric_update if train_metrics is None else train_metrics.merge(metric_update)
)

2025-04-04 23:27:35.897089: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743809255.910521   49398 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743809255.914822   49398 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1743809255.926160   49398 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1743809255.926178   49398 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1743809255.926180   49398 computation_placer.cc:177] computation placer alr

cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
batch_size 64
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found


In [23]:
model.timesteps = 100

In [None]:
if hasattr(model, "sample_step"):
    # with report_progress.timed("sample"):
    _, sample_rng = jax.random.split(rng)
    dummy_loader = train_loader
    dummy_batch = utils.reshape_batch(next(iter(dummy_loader)))
    dummy_inputs = dummy_batch[config.task_type]
    if "label" in dummy_batch:
        conditioning = dummy_batch["label"].astype("int32")
    else:
        conditioning = None

    samples = sampling.generate(
        model,
        train_state,
        flax_utils.replicate(sample_rng),
        dummy_inputs,
        conditioning=conditioning,
    )

    all_samples = jax.pmap(lambda x: jax.lax.all_gather(x, "batch"), axis_name="batch")(
        samples
    )
    all_samples = flax_utils.unreplicate(all_samples)
    all_samples = all_samples.reshape(-1, *data_shape)
    if config.task_type == "image":
        sample_grid = utils.generate_image_grids(all_samples)
        writer.write_images(step, {"samples": sample_grid})
        del all_samples, sample_grid
    elif config.task_type == "text":
        # pass
        tokenizer = dataset_info["tokenizer"]
        texts = utils.detokenize_texts(all_samples, tokenizer)
        # writer.write_texts(step, {"samples": texts})

2025-04-04 23:42:36.318397: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1743810156.331331   64180 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1743810156.335586   64180 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-04 23:42:36.338298: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
W0000 00:00:1743810156.346623   64180 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1743810156.346642   64180 computation_pl

cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found
cv2 not found


In [25]:
tokenizer = dataset_info["tokenizer"]
texts = utils.detokenize_texts(all_samples, tokenizer)
concated = " ".join(texts)
generated_words = concated.split()
len(texts), len(generated_words)

(512, 22859)

In [9]:
with open("/root/md4/data_dir/text8/text8_vocab.pkl", "rb") as f:
    vocab = pickle.load(f)

In [None]:
# check if generated words are in vocab
s = 0
for word in tqdm(generated_words):
    if word in vocab:
        s += 1
print(s / len(generated_words))

  1%|          | 233/22859 [00:00<00:19, 1173.70it/s]

100%|██████████| 22859/22859 [00:18<00:00, 1229.59it/s]

0.8886652959447044





In [None]:
# # Distribute training.
# train_state = flax_utils.replicate(train_state)
# train_step_func = functools.partial(
#     train_step,
#     model=model,
#     optimizer=optimizer,
#     train_metrics_class=metrics_class,
#     learning_rate_fn=schedule_fn,
#     ema_rate=config.ema_rate,
#     num_microbatches=config.get("num_microbatches", None),
# )
# if config.check_nans:
#     train_step_func = checkify.checkify(
#         train_step_func, errors=checkify.float_checks
#     )
# p_train_step = jax.pmap(train_step_func, axis_name="batch", donate_argnums=(0,))
# p_eval_step = jax.pmap(
#     functools.partial(
#         eval_step,
#         model=model,
#         eval_metrics_class=metrics_class,
#         ema_rate=config.ema_rate,
#     ),
#     axis_name="batch",
# )

# hooks = []
# report_progress = periodic_actions.ReportProgress(
#     num_train_steps=num_train_steps, writer=writer
# )
# if jax.process_index() == 0:
#     hooks += [
#         report_progress,
#         periodic_actions.Profile(num_profile_steps=5, logdir=workdir),
#     ]
# train_metrics = None

# # Unreplicating from TPU is costly, so we only do it once at the start.
# initial_step = int(flax.jax_utils.unreplicate(train_state.step))

In [None]:
# Ignore sample step for now
if hasattr(model, "sample_step"):
    # with report_progress.timed("sample"):
    _, sample_rng = jax.random.split(rng)
    dummy_loader = train_loader
    dummy_batch = utils.reshape_batch(next(iter(dummy_loader)))
    dummy_inputs = dummy_batch[config.task_type]
    if "label" in dummy_batch:
        conditioning = dummy_batch["label"].astype("int32")
    else:
        conditioning = None

    samples = sampling.generate(
        model,
        train_state,
        flax_utils.replicate(sample_rng),
        dummy_inputs,
        conditioning=conditioning,
    )

    all_samples = jax.pmap(lambda x: jax.lax.all_gather(x, "batch"), axis_name="batch")(
        samples
    )
    all_samples = flax_utils.unreplicate(all_samples)
    all_samples = all_samples.reshape(-1, *data_shape)
    if config.task_type == "image":
        sample_grid = utils.generate_image_grids(all_samples)
        writer.write_images(step, {"samples": sample_grid})
        del all_samples, sample_grid
    elif config.task_type == "text":
        # pass
        # tokenizer = dataset_info["tokenizer"]
        texts = utils.detokenize_texts(all_samples, tokenizer)
        # writer.write_texts(step, {"samples": texts})