In [1]:
import torch
import torch.nn.functional as F

from transformers import AutoTokenizer
from modeling_fastdllm.modeling_llada import LLaDAModelLM

from fastdllm_generate import add_gumbel_noise, get_num_transfer_tokens

from jinyu_utils.jinyu_tokenizer import Tokenizer_
from jinyu_utils.jinyu_preprocess_wiki import parse_lines_with_index, merge_subdocs, PATTEN_REG_WIKI
from jinyu_utils.jinyu_dataset import jinyu_load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ID_TOKEN_MASK = 126336 # '|mdm_mask|'
ID_TOKEN_PADDING = 126081 # '|endoftext|'
ID_TOKEN_EOT = 126348 # '|eot_id|'

TYPES_REMASKING = {'truth_top_k', 'random_top_k'}

device = 'cuda:1'
id_model = 'GSAI-ML/LLaDA-8B-Base'

In [3]:
'''load tokenizer'''
tokenizer = AutoTokenizer.from_pretrained(
    id_model,
    # local_files_only=True,
    trust_remote_code=True
)

if tokenizer.padding_side != 'left':
    tokenizer.padding_side = 'left'
# end

assert tokenizer.pad_token_id != 126336

In [4]:
'''load model'''
model_kwargs = {}
model = LLaDAModelLM.from_pretrained(
    id_model,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    **model_kwargs
)

model = model.eval().to(device)
device_for_input = model.get_input_embeddings().weight.device

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Loading checkpoint shards: 100%|██████████| 6/6 [00:01<00:00,  4.50it/s]


In [5]:
def get_transfer_index(
    logits: torch.Tensor,
    temperature: float,
    remasking: str,
    mask_index: torch.Tensor,   # (B, L) bool
    x: torch.Tensor,            # (B, L) long
    y: torch.Tensor,            # (B, L) long
    num_transfer_tokens,        # (B,) or (B,1) long tensor, or None when threshold is used
    threshold: float = None,
):
    """
    Returns:
        x0: (B, L) long — proposed tokens
        transfer_index: (B, L) bool — which positions to update this step
    """
    # 1) Sample proposal x0
    # Gumbel-noise for exploration; if temperature==0, add_gumbel_noise should no-op
    logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
    x0 = torch.argmax(logits_with_noise, dim=-1)  # (B, L), long

    # 2) Confidence for chosen tokens (or random)
    p = F.softmax(logits.to(torch.float64), dim=-1)
    x0_p = torch.gather(p, dim=-1, index=y.unsqueeze(-1)).squeeze(-1)  # (B, L), float64
    # x0_p = torch.rand(x0.shape, device=x0.device, dtype=torch.float64)  # (B, L)  # removed by jinyu

    # Only modify masked spots; keep others as original x and set their confidence to -inf
    x0 = torch.where(mask_index, x0, x) # mask_index is only this block

    neg_inf = torch.tensor(torch.finfo(x0_p.dtype).min, device=x0.device, dtype=x0_p.dtype)
    confidence = torch.where(mask_index, x0_p, neg_inf)  # (B, L)   # so only the masked part has confidence

    # Ensure shape (B,) long    jinyu: re-calculate num_transfer_token every time(I think)
    if num_transfer_tokens.dim() == 2 and num_transfer_tokens.size(1) == 1:
        num_transfer_tokens = num_transfer_tokens.squeeze(1)
    # end

    num_transfer_tokens = num_transfer_tokens.to(dtype=torch.long, device=confidence.device)
    num_transfer_tokens = torch.clamp(num_transfer_tokens, min=0)   # jinyu: can it be negative???


    # Sort confidences descending (masked positions are valid; others are -inf)
    # idx: (B, L) gives positions in original sequence sorted by confidence
    if remasking == 'random_top_k':
        idx_sorted_random = torch.argsort(
            torch.where(
                mask_index,
                torch.rand(confidence.shape[0], confidence.shape[1], device=confidence.device),
                confidence
            ),
            dim=1,
            descending=True
        )
        idx_sorted = idx_sorted_random  # for your read
    elif remasking == 'truth_top_k':
        idx_sorted = torch.argsort(confidence, dim=1, descending=True)
    else:
        raise NotImplementedError()
    # end

    B, L = confidence.shape
    # Build a mask that is True for the first k[b] columns in each row (sorted order)
    cols = torch.arange(L, device=confidence.device).unsqueeze(0).expand(B, L)   # (B, L)
    k_expanded = num_transfer_tokens.unsqueeze(1).expand(B, L)                   # (B, L)
    select_sorted = cols < k_expanded                                            # (B, L) bool for top k

    # Scatter the sorted True/False back to original column order
    # Use integer scatter then cast to bool (scatter_ on bool can be finicky across versions)
    transfer_int = torch.zeros(B, L, device=confidence.device, dtype=torch.int8) # (B, L)
    transfer_int = transfer_int.scatter(1, idx_sorted, select_sorted.to(torch.int8))
    transfer_index = transfer_int.bool() & mask_index  # ensure we never select unmasked

    return x0, x0_p, transfer_index

In [6]:
def calculate_ppl_and_conf(probs_all, mask_target, eps=1e-12):
    probs_collected = probs_all[mask_target].reshape(mask_target.shape[0], -1)  # [B, K]

    # Arithmetic mean confidence (what you currently call mean_prob)
    mean_prob = probs_collected.mean(dim=-1)  # [B]

    # Per-token NLL and per-row PPL (geometric-mean based)
    nll_collected = -torch.log(probs_collected + eps)   # [B, K]
    nll_per = nll_collected.mean(dim=-1)                 # [B]
    ppl_per = torch.exp(nll_per)                        # [B]

    # Geometric mean confidence (this one is directly tied to PPL)
    # geo_prob = torch.exp(torch.log(probs_collected + eps).mean(dim=1))  # [B]
    # And ppl_per == 1 / geo_prob (up to eps effects)

    return ppl_per.item(), mean_prob.item()

In [7]:
@torch.no_grad()
def run_with_dual_cache(
    model, prompt, target,
    steps=128,
    gen_length=128,
    block_length=128,
    temperature=0.,
    remasking="low_confidence",
    mask_id=126336,
    is_eval=True
):

    B = prompt.shape[0]
    length_prompt = int(prompt.shape[1])  # Python int, not Tensor

    target = target[:, :gen_length]

    # x: (B,length_prompt + gen_length)
    x = torch.full((B, length_prompt + gen_length), mask_id, dtype=torch.long, device=model.device)
    x[:, :length_prompt] = prompt

    y = torch.full(x.shape, mask_id, dtype=torch.long).to(model.device)
    y[:, -target.shape[-1]:] = target   # expand target into y

    probs_all = torch.zeros(x.shape, dtype=torch.float64).to(model.device)

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps_per_block = steps // num_blocks


    nfe = 0

    for id_block in range(num_blocks):
        position_start = length_prompt + id_block * block_length
        position_end = position_start + block_length

        # Masks/indices for the current block
        block_mask_index = (x[:, position_start:position_end] == mask_id)  # (B, block_length)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)  # (B, steps_per_block)

        # 1) Warm KV-cache on the full prefix once per block
        out_full = model(x, use_cache=True)
        past_key_values = out_full.past_key_values
        nfe += 1

        # Build a replace_position tensor indicating the block range (static slice)
        mask_position_replace = torch.zeros_like(x, dtype=torch.bool)
        mask_position_replace[:, position_start:position_end] = True  # boolean mask (not a dynamic slice bound)

        # Step 0: do an initial transfer on the full logits
        mask_masked_full = (x == mask_id)
        # Do not touch beyond current block in this phase
        mask_masked_full[:, position_end:] = False

        quota0 = num_transfer_tokens[:, 0]  # (B,)
        x0, x0_p, transfer_index = get_transfer_index(
            out_full.logits,
            temperature,
            remasking,
            mask_masked_full,
            x,
            y,
            quota0
        )

        # In-place update via torch.where (no tensor-slice assignment with mask)
        # x = torch.where(transfer_index, x0, x)   # -> replace by jinyu
        if is_eval:
            x[transfer_index] = y[transfer_index]
            probs_all[transfer_index] = x0_p[transfer_index]
        else:
            x[transfer_index] = x0[transfer_index]
        # end

        # 2) Semi-autoregressive refinement, fixed number of steps (graph-friendly)
        #    Each iteration runs on the current block with KV-cache and replace_position
        for step_in_block in range(1, steps_per_block):
            # Evaluate logits only for current block with cache
            if (x[:, position_start:position_end] == mask_id).sum() == 0:
                break
            # end

            logits_blk = model(
                x[:, position_start:position_end],
                past_key_values=past_key_values,
                use_cache=True,
                replace_position=mask_position_replace
            ).logits  # shape expected by get_transfer_index*

            # Mask and quota for this step (all tensor ops)
            mask_blk = (x[:, position_start:position_end] == mask_id)  # (B, block_length)
            blk_x = x[:, position_start:position_end]
            blk_y = y[:, position_start:position_end]
            blk_prob = probs_all[:, position_start:position_end]

            quota_i = num_transfer_tokens[:, step_in_block]  # (B,)
            blk_x0, blk_x0_p, transfer_idx_blk = get_transfer_index(
                logits_blk,
                temperature,
                remasking,
                mask_blk,
                blk_x,
                blk_y,
                quota_i
            )

            if is_eval:
                blk_x[transfer_idx_blk] = blk_y[transfer_idx_blk]
                blk_prob[transfer_idx_blk] = blk_x0_p[transfer_idx_blk]
            else:
                blk_x[transfer_idx_blk] = blk_x0[transfer_idx_blk]
            # end

            nfe += 1

    return probs_all, y != mask_id
# end

In [18]:
ds = jinyu_load_dataset(1, split='test')['text']

'''preprocess dataset'''
docs, _ = parse_lines_with_index(PATTEN_REG_WIKI, ds)
docs = docs['subdocs']

samples = []
for doc in docs:
    lines_1 = doc['texts']
    paragraph_1 = ' '.join(lines_1)
    lines_remain, titles = merge_subdocs(doc['subdocs'])
    paragraph_remain = ' '.join(lines_remain)
    prefix = paragraph_1
    target = paragraph_remain
    samples.append(' '.join([paragraph_1, paragraph_remain]))
# end


len_prompt=64
gen_length=128

with torch.no_grad():
    for sample in samples[10:30]:
        sample_encoded = tokenizer(
            sample,
            add_special_tokens=False,
            return_tensors='pt'
        )

        ids_input = sample_encoded['input_ids'].squeeze(0)[:len_prompt].unsqueeze(0).to(device_for_input)
        ids_target = sample_encoded['input_ids'].squeeze(0)[len_prompt:len_prompt+gen_length].unsqueeze(0).to(device_for_input)

        probs_truth, mask_current_full = run_with_dual_cache(
            model, ids_input, ids_target,
            remasking='truth_top_k',
            block_length=32,
            steps=128,
            gen_length=gen_length,
            is_eval=True
        )

        print(calculate_ppl_and_conf(probs_truth, mask_current_full))

(7.114917882077981, 0.34863197475043933)
(5.1450709933436105, 0.38914112049785976)
(5.195045714652265, 0.39853902838583954)
(8.178791229822112, 0.3829883445662564)
(2.8754892782883608, 0.530275950735745)
(4.727720551167838, 0.43572830153074643)
(8.123064685423431, 0.35045478573543776)
(16.68613102213148, 0.21855766465620974)
(10.832268503524181, 0.2807977635031875)
(11.751660851799864, 0.23327153821161195)
(10.759011736385872, 0.27951908115644686)
(6.562285940917342, 0.35676243357806986)
(9.379813445599101, 0.33921640805823805)
(10.126998727695428, 0.28967419049822557)
(9.148136902144175, 0.306092517489537)
(7.680092016156881, 0.34737660388375474)
(5.4763812031289625, 0.4054353434282099)
(11.677416297867197, 0.19300162764524814)


RuntimeError: The expanded size of the tensor (160) must match the existing size (0) at non-singleton dimension 1.  Target sizes: [1, 160].  Tensor sizes: [0]