In [1]:
import sys
sys.path.append("/home/yixiuz/md4")

In [2]:
from md4.configs.md4.text8 import get_config
from collections.abc import Callable, Mapping, Sequence
import copy
import functools
from typing import Any

from absl import logging
from clu import metric_writers
from clu import metrics
from clu import parameter_overview
from clu import periodic_actions
from etils import epath
import flax
import flax.jax_utils as flax_utils
import flax.linen as nn
import grain.python as graåçin
import jax
from jax.experimental import checkify
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax
from orbax import checkpoint as orbax_checkpoint

from md4 import input_pipeline
from md4 import input_pipeline_v2
from md4 import sampling
from md4 import utils
from md4.models import utils as model_utils

from md4.train import *
from md4.train import _get_checkpoint_manager

import pickle
from tqdm import tqdm

config = get_config()
# workdir = "/root/md4/expt_contantlr_deeper__model_untie"

from clu import parameter_overview

2025-05-12 04:16:07.839417: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1747023367.897541  671941 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1747023367.915132  671941 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
  from .autonotebook import tqdm as notebook_tqdm


# Load checkpoint

In [3]:
# workdir = "gs://maskdiff/SIC/text8_sic_time_weighted"
# workdir = "gs://maskdiff/SIC/text8_sic_fixed"
# workdir = "gs://maskdiff/SIC/text8_sic_fixed_32_steps"
# workdir = "gs://maskdiff/SIC/text8_sic_zero"
workdir = "gs://maskdiff/SIC/text8_base"

In [4]:
with open(config.vocab_dir, "rb") as f:
    vocab = pickle.load(f)

workdir = epath.Path(workdir)
workdir.mkdir(parents=True, exist_ok=True)

rng = utils.get_rng(config.seed)
logging.info("Using random seed %s.", rng)
writer = metric_writers.create_default_writer(
    workdir, just_logging=jax.process_index() > 0
)

# Learning rate schedule.
assert config.batch_size % jax.device_count() == 0
per_device_batch_size = config.batch_size // jax.device_count()
num_train_steps = input_pipeline.get_num_train_steps(config)
steps_per_epoch = num_train_steps // config.num_epochs
logging.info(
    "num_train_steps=%d, steps_per_epoch=%d", num_train_steps, steps_per_epoch
)

schedule_fn = functools.partial(
    get_learning_rate,
    base_learning_rate=config.learning_rate,
    num_steps=num_train_steps,
    warmup_steps=config.warmup_steps,
    schedule_type=config.learning_rate_schedule,
)

# Build input pipeline.
rng, data_seed = jax.random.split(rng)
data_seed = int(
    jax.random.randint(data_seed, [], minval=0, maxval=np.iinfo(np.int32).max)
)
# The input pipeline runs on each process and loads data for local TPUs.
create_datasets = (
    input_pipeline_v2.create_datasets
    if config.get("use_v2_input_pipeline", None)
    else input_pipeline.create_datasets
)
train_loader, eval_loaders, dataset_info = create_datasets(config, data_seed)
train_iter = iter(train_loader)
# Initialize model.
rng, model_rng = jax.random.split(rng)
data_shape = input_pipeline.get_data_shape(config)
# Note: parameters are initialized in half precision if mixed_precision_training=True
# We could also try casting them to half precision here
model, optimizer, train_state, metrics_class = (
    create_train_state(  # pylint: disable=invalid-name
        config,
        model_rng,
        input_shape=(per_device_batch_size // config.num_microbatches,)
        + data_shape,
        schedule_fn=schedule_fn,
    )
)
# # Set up checkpointing of the model and the input pipeline.
# checkpoint_manager = _get_checkpoint_manager(config, workdir)
# # Retrieve data from previous checkpoints if possible.
checkpointed_state = dict(train_state=train_state, 
    step=0
    # train_iter=train_iter
)

from clu import checkpoint
checkpoint_dir = str(workdir / "checkpoints")
# The vdm code initalizes two checkpoints, one for loading and one for saving
# which I don't understand
# ckpt = checkpoint.MultihostCheckpoint(checkpoint_dir)
ckpt = checkpoint.Checkpoint(checkpoint_dir, max_to_keep=10)
checkpoint_to_restore = ckpt.get_latest_checkpoint_to_restore_from()

if checkpoint_to_restore:
    checkpointed_state = ckpt.restore_or_initialize(checkpointed_state)
# state_restore_dict = ckpt.restore_dict(checkpoint_to_restore)
# checkpointed_state = restore_partial(checkpointed_state, state_restore_dict)
train_state = checkpointed_state["train_state"]

2025-05-12 04:16:19.898351: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:152] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: UNKNOWN ERROR (303)


In [11]:
# train_state.params
# Count the number of parameters
def count_parameters(params):
  total = 0
  for param in jax.tree_util.tree_leaves(params):
    total += param.size
  return total
    
num_params = count_parameters(flax_utils.unreplicate(train_state.params))
print("Number of parameters: ", num_params)

Number of parameters:  153259


In [12]:
# overview = parameter_overview.get_parameter_overview(flax_utils.unreplicate(train_state.params))
# print(overview)

In [13]:
# We can't do flax serialization so long as we're using grain for the data loader
# train_iter = checkpointed_state["train_iter"]
# Distribute training.
train_state = flax_utils.replicate(train_state)
train_step_func = functools.partial(
    train_step,
    model=model,
    optimizer=optimizer,
    train_metrics_class=metrics_class,
    learning_rate_fn=schedule_fn,
    ema_rate=config.ema_rate,
    num_microbatches=config.get("num_microbatches", None),
)
if config.check_nans:
    train_step_func = checkify.checkify(
        train_step_func, errors=checkify.float_checks
    )
p_train_step = jax.pmap(train_step_func, axis_name="batch", donate_argnums=(0,))
p_eval_step = jax.pmap(
    functools.partial(
        eval_step,
        model=model,
        eval_metrics_class=metrics_class,
        ema_rate=config.ema_rate,
    ),
    axis_name="batch",
)
hooks = []
report_progress = periodic_actions.ReportProgress(
    num_train_steps=num_train_steps, writer=writer
)
if jax.process_index() == 0:
    hooks += [
        report_progress,
        periodic_actions.Profile(num_profile_steps=5, logdir=workdir),
    ]
train_metrics = None
# Unreplicating from TPU is costly, so we only do it once at the start.
# initial_step = int(flax.jax_utils.unreplicate(train_state.step))
initial_step = checkpointed_state["step"]
logging.info("Initial step is %d", initial_step)

In [14]:
batch = utils.reshape_batch(next(train_iter))

# if config.check_nans:
#     errs, (train_state, metrics_update) = p_train_step(
#         train_state=train_state, batch=batch
#     )
#     errs.throw()
# else:
#     train_state, metrics_update = p_train_step(train_state=train_state, batch=batch)
# metric_update = flax_utils.unreplicate(metrics_update)

# train_metrics = (
#     metric_update if train_metrics is None else train_metrics.merge(metric_update)
# )

In [None]:
import jax.random as jr
from jax.scipy.special import logsumexp
import tensorflow_probability.substrates.jax as tfp

tfd = tfp.distributions

def test_ancestral_sample_step_informed(self, rng, i, timesteps, zt, conditioning=None):

    B, D = zt.shape[:2]

    rng_body = jax.random.fold_in(rng, i)
    s, t = self.get_sampling_grid(i, timesteps)
    cond = self.get_cond_embedding(conditioning)

    alpha_t = self.noise_schedule.alpha(t)
    alpha_s = self.noise_schedule.alpha(s)

    rng_pstep, rng_cstep = jr.split(rng_body, 2)

    # Predictor (ancestral)
    logits, _ = self.predict_x(zt, t, cond=cond)
    mean_preds = jax.nn.softmax(logits, axis=-1)

    unmask_prob = (alpha_s - alpha_t) / (1 - alpha_t)
    probs_vocab = unmask_prob * mean_preds

    probs_mask = jnp.ones(list(zt.shape) + [1]) * (1 - unmask_prob)
    probs = jnp.concatenate([probs_vocab, probs_mask], axis=-1)

    rng_pstep_1, rng_pstep_2 = jax.random.split(rng_pstep)

    to_unmask = tfd.Categorical(probs=probs).sample(seed=rng_pstep_1)
    is_mask_zt = zt == self.vocab_size
    zs = jnp.where(is_mask_zt, to_unmask, zt)
    is_mask_zs = zs == self.vocab_size

    if self.loss_type == 'sic_zero':
      # Also sample the other positions
      denoising_pred = tfd.Categorical(probs=mean_preds).sample(seed=rng_pstep_2)
      zs = jnp.where(is_mask_zs, denoising_pred, zs)

    zs_orig = zs

    if self.k == 0:
        return zs, zs

    # Corrector (gibbs)
    rng_cstep_1, rng_cstep_2 = jr.split(rng_cstep, 2)
    logits, _ = self.predict_x(zs, s, cond=cond)
    logits -= logsumexp(logits, axis=-1, keepdims=True)
    mean_preds = jax.nn.softmax(logits, axis=-1)

    jump_target = tfd.Categorical(probs=mean_preds).sample(seed=rng_cstep_1)
    # Figure out locations with the lowest score
    # Since the score is proportional to the denoising prob anyways, we're just gonna use the logits again
    b_idx, d_idx = jnp.indices((B, D))
    scores = logits[b_idx, d_idx, zs]
    # Add temperature annealing
    # This is minus since conventionally we add noise and take max
    scores -= self.gibbs_temp * jr.gumbel(rng_cstep_2, shape=(B, D))
    scores = jnp.where(is_mask_zs, jnp.inf, scores)
    
    # Trick: sort and then find the kth smallest
    k = self.k
    # k = jnp.ceil(D * alpha_s).astype(jnp.int32)
    # k = jnp.clip(k, 1, D)
    thres = jnp.sort(scores, axis=-1)[:, k-1:k]
    # thres = jax.lax.dynamic_slice(jnp.sort(scores, axis=-1), (0, k-1), (B, 1))
    zs_corr = jnp.where((scores <= thres) & (zs != self.vocab_size)
                        , 
                        jump_target, 
                        zs
                        # self.vocab_size
                        )

    if self.loss_type == 'sic_zero':
      # Re-mask
      zs_corr = jnp.where(is_mask_zs, self.vocab_size, zs_corr)

    return zs_orig, zs_corr

def remdm_sample_step(self, rng, i, timesteps, zt, conditioning=None):

    # Maximum remasking rate (?)
    # TODO: we need to put this into the MD4 class
    # The paper had sigma_cap between 0.02 and 0.04
    # In othr words remasking is highly sensitive
    sigma_cap = 0.02 #self.sigma_cap

    MASK = self.vocab_size

    rng_body = jax.random.fold_in(rng, i)
    s, t = self.get_sampling_grid(i, timesteps)
    cond = self.get_cond_embedding(conditioning)

    alpha_t = self.noise_schedule.alpha(t)
    alpha_s = self.noise_schedule.alpha(s)

    logits, _ = self.predict_x(zt, t, cond=cond)
    mean_preds = jax.nn.softmax(logits, axis=-1)

    sigma_t = jnp.minimum((1 - alpha_s) / alpha_t,  sigma_cap)

    unmask_prob = (alpha_s - (1 - sigma_t) * alpha_t) / (1 - alpha_t)
    probs_vocab = unmask_prob * mean_preds

    probs_mask = jnp.ones(list(zt.shape) + [1]) * (1 - unmask_prob)
    probs = jnp.concatenate([probs_vocab, probs_mask], axis=-1)

    probs = jnp.where(zt[...,None] != MASK, 
        # Non-mask, mask with probability sigma_t
        # Shape: [B, D, S+1]
        jnp.concatenate([jax.nn.one_hot(zt, self.vocab_size) * (1 - sigma_t),
                         jnp.ones(list(zt.shape) + [1]) * sigma_t], axis=-1),
        # Backward rate = - dalphat / (1 - alphat) * denoising_probs
        # Shape: [B, D, S+1]
        probs
    )

    to_unmask = tfd.Categorical(probs=probs).sample(seed=rng_body)
    # is_mask = zt == self.vocab_size
    # zs = jnp.where(is_mask, to_unmask, zt)
    zs = to_unmask

    return zs, zs

def ancestral_sample_step(self, rng, i, timesteps, zt, conditioning=None):
    rng_body = jax.random.fold_in(rng, i)
    s, t = self.get_sampling_grid(i, timesteps)
    cond = self.get_cond_embedding(conditioning)

    alpha_t = self.noise_schedule.alpha(t)
    alpha_s = self.noise_schedule.alpha(s)

    logits, _ = self.predict_x(zt, t, cond=cond)
    mean_preds = jax.nn.softmax(logits, axis=-1)

    unmask_prob = (alpha_s - alpha_t) / (1 - alpha_t)
    probs_vocab = unmask_prob * mean_preds

    probs_mask = jnp.ones(list(zt.shape) + [1]) * (1 - unmask_prob)
    probs = jnp.concatenate([probs_vocab, probs_mask], axis=-1)

    to_unmask = tfd.Categorical(probs=probs).sample(seed=rng_body)
    is_mask = zt == self.vocab_size
    zs = jnp.where(is_mask, to_unmask, zt)
    return zs, zs

def ancestral_sample_step_uninformed(self, rng, i, timesteps, zt, conditioning=None):

    rng_body = jax.random.fold_in(rng, i)
    s, t = self.get_sampling_grid(i, timesteps)
    cond = self.get_cond_embedding(conditioning)

    alpha_t = self.noise_schedule.alpha(t)
    alpha_s = self.noise_schedule.alpha(s)

    rng_pstep, rng_cstep = jr.split(rng_body, 2)

    # Predictor (ancestral)
    logits, _ = self.predict_x(zt, t, cond=cond)
    mean_preds = jax.nn.softmax(logits, axis=-1)

    unmask_prob = (alpha_s - alpha_t) / (1 - alpha_t)
    probs_vocab = unmask_prob * mean_preds

    probs_mask = jnp.ones(list(zt.shape) + [1]) * (1 - unmask_prob)
    probs = jnp.concatenate([probs_vocab, probs_mask], axis=-1)

    to_unmask = tfd.Categorical(probs=probs).sample(seed=rng_pstep)
    is_mask = zt == self.vocab_size
    zs = jnp.where(is_mask, to_unmask, zt)

    # Corrector (uninformed)
    # Need to compute the backward rates from the logits
    # then sample with euler step...
    MASK = self.vocab_size

    logits, _ = self.predict_x(zs, s, cond=cond)
    logits -= logsumexp(logits, axis=-1, keepdims=True)
    # Shape: [B, D, S]
    mean_preds = jax.nn.softmax(logits, axis=-1)

    B, D, S = logits.shape
    b_idx, d_idx = jnp.indices((B, D))

    def _euler_update(key, x, rates):
        eps = 1e-8
        # Mask out the self transitions
        rates = rates.at[b_idx, d_idx, x].set(0.0)
        sum_rates = jnp.sum(rates, axis=-1)
        # transition_logit = jnp.log(-jnp.expm1(-rates)) # Prob = 1 - exp(-rate)
        transition_logit = jnp.log(-jnp.expm1(-sum_rates))[...,None] + jnp.log(rates) - jnp.log(sum_rates + eps)[...,None]
        transition_logit = transition_logit.at[b_idx, d_idx, x].set(-sum_rates)
        
        out = jr.categorical(key, transition_logit).astype(jnp.int32)
        return out

    dalpha_s = self.noise_schedule.dalpha(s)
    # Compute the rate matrix
    # Shape: [B, D, S+1]
    rates = jnp.where(zs[...,None] != MASK, 
        # Forward rate = - dalphat / alphat
        # Shape: [1, 1, S+1]
        jnp.concatenate([jnp.zeros((S,)), jnp.array((-dalpha_s / alpha_s,))])[None, None],
        # Backward rate = - dalphat / (1 - alphat) * denoising_probs
        # Shape: [B, D, S+1]
        jnp.concatenate([mean_preds * (-dalpha_s / (1-alpha_s)), jnp.zeros((B, D, 1))], axis=-1)
    )
    
    # The forward_backward_corrector shouldn't be used when s is 0
    zs_corr = jax.lax.cond(s == 0, lambda x: x, 
        lambda x: _euler_update(rng_cstep, x, rates * self.uninformed_step_size * (t-s)), zs)

    return zs, zs_corr

model.test_sample_step = test_ancestral_sample_step_informed
model.ancestral_sample_step = ancestral_sample_step
model.ancestral_sample_step_uninformed = ancestral_sample_step_uninformed
model.remdm_sample_step = remdm_sample_step

In [74]:
jax.nn.one_hot(jnp.array([1, 2, 3]), 4)

Array([[0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.]], dtype=float32)

In [75]:
# Copied and modified from sampling.py
def get_attr(train_state, key):
  if hasattr(train_state, key):
    return getattr(train_state, key)
  else:
    return train_state[key]

@functools.partial(jax.pmap, axis_name='batch', static_broadcasted_argnums=0)
def generate(model, train_state, rng, dummy_inputs, conditioning=None):
  """Generate samples from the diffusion model."""
  rng = jax.random.fold_in(rng, jax.lax.axis_index('batch'))
  variables = {
      'params': get_attr(train_state, 'ema_params'),
      **get_attr(train_state, 'state'),
  }
  rng, sub_rng = jax.random.split(rng)
  zt = model.apply(
      variables,
      dummy_inputs.shape[0],
      method=model.prior_sample,
      rngs={'sample': sub_rng},
  )
  rng, sub_rng = jax.random.split(rng)

  timesteps = model.timesteps

  print("Using sampler:" + model.sampler)

  if model.sampler == 'informed' or model.sampler == 'uninformed':
    timesteps //= 2

  print("Timesteps: " + str(timesteps))

  # def body_fn(i, zt):
  #   return model.apply(
  #       variables,
  #       sub_rng,
  #       i,
  #       timesteps,
  #       zt,
  #       conditioning=conditioning,
  #       # We are hard coding the sample step to be our test_sample_step
  #       method=model.test_sample_step,
  #   )

  # z0 = jax.lax.fori_loop(
  #     lower=0, upper=timesteps, body_fun=body_fn, init_val=zt
  # )

  if model.sampler == 'informed':
    method = model.test_sample_step
  elif model.sampler == 'uninformed':
    method = model.ancestral_sample_step_uninformed
  elif model.sampler == 'ancestral':
    method = model.ancestral_sample_step
  elif model.sampler == 'remdm':
    method = model.remdm_sample_step

  def step_fn(zt_corr, i):

    zt_pred, zt_corr = model.apply(
        variables,
        sub_rng,
        i,
        timesteps,
        zt_corr,
        conditioning=conditioning,
        method=method#model.test_sample_step,
    )

    return zt_corr, (zt_pred, zt_corr)

  # Use scan instead
  z0, out = jax.lax.scan(
      step_fn,
      init=zt,
      xs=jnp.arange(timesteps),
  )

  sample = model.apply(
      variables,
      z0,
      conditioning=conditioning,
      method=model.decode,
      rngs={'sample': rng},
  )

  return sample, out

In [76]:
# model.sampler = "uninformed"
model.sampler = "remdm"

results_dict = {}

k = 4
tem = .5

model.k = k
model.gibbs_temp = tem

model.uninformed_step_size = .1

model.sigma_cap = 0.3

# for k in k_range:
#     for tem in tem_range:
#         results_dict[k, tem] = {}

timesteps = 256
model.timesteps = timesteps # Informed sampler will use half the timesteps


_, sample_rng = jax.random.split(rng)
dummy_loader = train_loader
dummy_batch = utils.reshape_batch(next(iter(dummy_loader)))
dummy_inputs = dummy_batch[config.task_type]

# # Only 1 sample per device
# dummy_inputs = dummy_inputs[:,:1]

if "label" in dummy_batch:
    conditioning = dummy_batch["label"].astype("int32")
else:
    conditioning = None

samples = generate(
    model,
    train_state,
    flax_utils.replicate(sample_rng),
    dummy_inputs,
    conditioning=conditioning,
)

2025-05-12 05:04:49.230901: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-12 05:04:49.250707: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-12 05:04:49.267263: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-12 05:04:49.267359: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-05-12 05:04:49.288562: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for 

Using sampler:remdm
Timesteps: 256


In [77]:
all_samples = jax.pmap(
    lambda x: jax.lax.all_gather(x, "batch"), axis_name="batch"
)(samples)
all_samples = flax_utils.unreplicate(all_samples)
z0, (zt_pred, zt_corr) = all_samples
z0 = z0.reshape(-1, *data_shape)

predictor_steps = timesteps if model.sampler == "ancestral" else timesteps // 2

# zt_pred has shape (num_devices, num_steps, batch_size, seq_len)
# We want to swap axes and reshape to (num_steps, batch_size, seq_len)
zt_pred = jnp.swapaxes(zt_pred, 0, 1)
zt_pred = zt_pred.reshape(predictor_steps, -1, *data_shape)
zt_corr = jnp.swapaxes(zt_corr, 0, 1)
zt_corr = zt_corr.reshape(predictor_steps, -1, *data_shape)

tokenizer = dataset_info["tokenizer"]
# foo = lambda z: utils.detokenize_texts(z, tokenizer)
corr_text = np.apply_along_axis(tokenizer.decode, -1, zt_corr[:,0])
pred_text = np.apply_along_axis(tokenizer.decode, -1, zt_pred[:,0])

texts = utils.detokenize_texts(z0, tokenizer)

# zt_pred = jax.vmap(utils.detokenize_texts, in_axes=(0, None))(np.asarray(zt_pred), tokenizer)
# zt_corr = jax.vmap(foo)(np.array(zt_corr))

# all_samples = all_samples.reshape(-1, *data_shape)
# tokenizer = dataset_info["tokenizer"]
# texts = utils.detokenize_texts(all_samples, tokenizer)
# results_dict[k, tem] = texts

In [78]:
import numpy as np

def visualize_predictor_corrector(z_pred: np.ndarray, z_corr: np.ndarray,
                                  predictor_color: str = '\033[94m',  # Blue
                                  corrector_color: str = '\033[92m',  # Green
                                  reset_color: str = '\033[0m'):
    """
    Visualizes alternating predictor/corrector updates.

    Args:
        z_pred: Numpy array of shape (timesteps, seq_len), predictor outputs.
        z_corr: Numpy array of shape (timesteps, seq_len), corrector outputs.
        predictor_color: Color for predictor changes.
        corrector_color: Color for corrector changes.
    """

    alt_color = '\033[90m'  # Grey

    T = z_pred.shape[0]
    assert z_corr.shape[0] == T, "z_pred and z_corr must have same number of timesteps"
    # assert z_pred.shape[1] == z_corr.shape[1], "Sequence lengths must match"

    def highlight_diff(prev, curr, color, reset):
        return ''.join(
            f"{color}{c}{reset}" if c != p else f"{reset}{c}{reset}"
            for p, c in zip(prev, curr)
        )

    print(f"{''.join(z_pred[0])}")
    corr_highlight = highlight_diff(z_pred[0], z_corr[0], corrector_color, reset=alt_color)
    print(f"{corr_highlight}")
    for t in range(1, T):
        # Predictor step
        pred_highlight = highlight_diff(z_corr[t-1], z_pred[t], predictor_color, reset=reset_color)
        print(f"{pred_highlight}")

        # Corrector step
        corr_highlight = highlight_diff(z_pred[t], z_corr[t], corrector_color, reset=alt_color)
        print(f"{corr_highlight}")

# z_pred = np.array([
#     list("__________"),
#     list("h_________"),
#     list("he________"),
#     list("hel_______"),
# ])

# z_corr = np.array([
#     list("__________"),
#     list("t_________"),
#     list("de________"),
#     list("hgl_______"),
# ])

count_to_visualize = 100

visualize_predictor_corrector(pred_text[-count_to_visualize:], corr_text[-count_to_visualize:])


||||||||a|||||||||||r|t||||t||||||||||||||||||||||||||||||||||a|||||||||||y ||||||||||| |||||| ||||||||||||||||||||||| ||||||||||||||| |||||||e|||ha|||||||||||||||||||||||||||||||f|||||||||||||| |||||m||||||||||||o| |||||||h|||s|||||||||||t|||||||||||| |||
[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90ma[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90mr[90m[90m|[90m[90mt[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90mt[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90ma[90m[90m|[90m[90m|[90m[90m|[90m[90m|[90m[90m|

In [79]:
# concated = " ".join(texts)
    
# generated_words = concated.split()
# # len(texts), len(generated_words)

# # check if generated words are in vocab
# s = 0
# for word in tqdm(generated_words):
#     if word in vocab:
#         s += 1

# print(s / len(generated_words))

In [80]:
from multiprocessing import Pool

concated = " ".join(texts)
    
generated_words = concated.split()

cpus=16
chunk_size = len(generated_words) // cpus + 1

def check_word_in_vocab(i):
    s = 0
    for word in tqdm(generated_words[i:i + chunk_size]):
        if word == " ":
            continue
        if word in vocab:
            s += len(word) + 1
    return s
# Use a pool of workers to check the words in parallel

with Pool(cpus) as p:
    results = p.map(check_word_in_vocab,  range(0, len(generated_words), chunk_size))

# print(sum(results) / len(generated_words))
print(sum(results) / len(concated))

  self.pid = os.fork()
  self.pid = os.fork()
100%|██████████| 1497/1497 [00:06<00:00, 241.78it/s]
100%|██████████| 1497/1497 [00:06<00:00, 238.58it/s]
 96%|█████████▋| 1441/1497 [00:06<00:00, 219.96it/s]
100%|██████████| 1497/1497 [00:06<00:00, 228.68it/s]
100%|██████████| 1483/1483 [00:06<00:00, 226.12it/s]
100%|██████████| 1497/1497 [00:06<00:00, 228.06it/s]
 97%|█████████▋| 1453/1497 [00:06<00:00, 240.14it/s]
100%|██████████| 1497/1497 [00:06<00:00, 224.55it/s]
100%|██████████| 1497/1497 [00:06<00:00, 224.25it/s]
100%|██████████| 1497/1497 [00:06<00:00, 213.99it/s]
100%|██████████| 1497/1497 [00:06<00:00, 223.60it/s]
100%|██████████| 1497/1497 [00:06<00:00, 222.35it/s]
100%|██████████| 1497/1497 [00:06<00:00, 222.04it/s]
100%|██████████| 1497/1497 [00:06<00:00, 222.06it/s]
100%|██████████| 1497/1497 [00:06<00:00, 221.46it/s]
100%|██████████| 1497/1497 [00:06<00:00, 218.54it/s]
https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=55de35cc14b8,55de35b

0.8048380109892616


In [None]:
# SIC_zero
# Informed (256 steps): 0.94
# Informed (256 steps, k=4, temp=0.5): 0.9861151182919792
# Informed (32 steps): 0.760611
# Informed (32 steps， k=16): 0.8057085628442664
# Informed (32 steps， k=16, temp=0.5): 0.8898350020273853

In [None]:
# workdir = "gs://maskdiff/SIC/text8_sic_fixed_32_steps"
# Informed: 0.8630749014454665
# Informed (256 steps): 0.980857811674092
# Ancestral (256 steps): ~0.85

In [None]:
# workdir = "gs://maskdiff/SIC/text8_sic_fixed"
# Ancestral: 0.7536640360766629
# Uninformed: 0.7263929743786502
# Informed: 0.9155978150920494

In [None]:
# import pickle 
# with open('/home/yixiuz/baseline_acc.pkl', 'rb') as f:
#     baseline_acc = pickle.load(f)
# baseline_acc

{16: 0.7461044912923923,
 32: 0.8284363094977449,
 64: 0.8732092941998603,
 128: 0.89815259641001,
 256: 0.9017050711547782}