In [1]:
import torch
import torch.nn.functional as F

from transformers import AutoTokenizer
from modeling_fastdllm.modeling_llada import LLaDAModelLM

from fastdllm_generate import add_gumbel_noise, get_num_transfer_tokens

from jinyu_utils.jinyu_tokenizer import Tokenizer_
from jinyu_utils.jinyu_preprocess_wiki import parse_lines_with_index, merge_subdocs, PATTEN_REG_WIKI
from jinyu_utils.jinyu_dataset import jinyu_load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
ID_TOKEN_MASK = 126336 # '|mdm_mask|'
ID_TOKEN_PADDING = 126081 # '|endoftext|'
ID_TOKEN_EOT = 126348 # '|eot_id|'

TYPES_REMASKING = {'truth_top_k', 'random_top_k'}

device = 'cuda:1'
id_model = 'GSAI-ML/LLaDA-8B-Base'

In [3]:
'''load tokenizer'''
tokenizer = AutoTokenizer.from_pretrained(
    id_model,
    # local_files_only=True,
    trust_remote_code=True
)

if tokenizer.padding_side != 'left':
    tokenizer.padding_side = 'left'
# end

assert tokenizer.pad_token_id != 126336

In [4]:
'''load model'''
model_kwargs = {}
model = LLaDAModelLM.from_pretrained(
    id_model,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    **model_kwargs
)

model = model.eval().to(device)
device_for_input = model.get_input_embeddings().weight.device

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Loading checkpoint shards: 100%|██████████| 6/6 [00:01<00:00,  5.78it/s]


In [5]:
ds = jinyu_load_dataset(1, split='test')['text']

'''preprocess dataset'''
docs, _ = parse_lines_with_index(PATTEN_REG_WIKI, ds)
docs = docs['subdocs']

samples = []
for doc in docs:
    lines_1 = doc['texts']
    paragraph_1 = ' '.join(lines_1)
    lines_remain, titles = merge_subdocs(doc['subdocs'])
    paragraph_remain = ' '.join(lines_remain)
    prefix = paragraph_1
    target = paragraph_remain
    samples.append({'prefix': prefix, 'target': target})
# end


In [6]:
samples = samples[:1]
len(samples)

1

In [7]:
sample = samples[0]
prompt_ = sample['prefix']
target_ = sample['target']

with torch.no_grad():
    inputs_encoded = tokenizer(
        prompt_,
        add_special_tokens=False,
        return_tensors="pt"
    )

    targets_encoded = tokenizer(
        target_,
        add_special_tokens=False,
        return_tensors="pt"
    )

    ids_input = inputs_encoded['input_ids'].to(device_for_input)
    ids_target = targets_encoded['input_ids'].to(device_for_input)
    # generate_with_ppl(model, ids_input, ids_target[:, :128], steps=32, gen_length=128, block_length=32, remasking='confidence_top_k')

# end

In [None]:
def get_transfer_index(
    logits: torch.Tensor,
    temperature: float,
    remasking: str,
    mask_index: torch.Tensor,   # (B, L) bool
    x: torch.Tensor,            # (B, L) long
    y: torch.Tensor,            # (B, L) long
    num_transfer_tokens,        # (B,) or (B,1) long tensor, or None when threshold is used
    threshold: float = None,
):
    """
    Returns:
        x0: (B, L) long — proposed tokens
        transfer_index: (B, L) bool — which positions to update this step
    """
    # 1) Sample proposal x0
    # Gumbel-noise for exploration; if temperature==0, add_gumbel_noise should no-op
    logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
    x0 = torch.argmax(logits_with_noise, dim=-1)  # (B, L), long

    # 2) Confidence for chosen tokens (or random)
    p = F.softmax(logits.to(torch.float64), dim=-1)
    x0_p = torch.gather(p, dim=-1, index=y.unsqueeze(-1)).squeeze(-1)  # (B, L), float64
    # x0_p = torch.rand(x0.shape, device=x0.device, dtype=torch.float64)  # (B, L)  # removed by jinyu

    # Only modify masked spots; keep others as original x and set their confidence to -inf
    x0 = torch.where(mask_index, x0, x) # mask_index is only this block

    neg_inf = torch.tensor(torch.finfo(x0_p.dtype).min, device=x0.device, dtype=x0_p.dtype)
    confidence = torch.where(mask_index, x0_p, neg_inf)  # (B, L)   # so only the masked part has confidence

    # Ensure shape (B,) long    jinyu: re-calculate num_transfer_token every time(I think)
    if num_transfer_tokens.dim() == 2 and num_transfer_tokens.size(1) == 1:
        num_transfer_tokens = num_transfer_tokens.squeeze(1)
    # end

    num_transfer_tokens = num_transfer_tokens.to(dtype=torch.long, device=confidence.device)
    num_transfer_tokens = torch.clamp(num_transfer_tokens, min=0)   # jinyu: can it be negative???


    # Sort confidences descending (masked positions are valid; others are -inf)
    # idx: (B, L) gives positions in original sequence sorted by confidence
    if remasking == 'random_top_k':
        idx_sorted_random = torch.argsort(
            torch.where(
                mask_index,
                torch.rand(confidence.shape[0], confidence.shape[1], device=confidence.device),
                confidence
            ),
            dim=1,
            descending=True
        )
        idx_sorted = idx_sorted_random  # for your read
    elif remasking == 'truth_top_k':
        idx_sorted = torch.argsort(confidence, dim=1, descending=True)
    else:
        raise NotImplementedError()
    # end

    B, L = confidence.shape
    # Build a mask that is True for the first k[b] columns in each row (sorted order)
    cols = torch.arange(L, device=confidence.device).unsqueeze(0).expand(B, L)   # (B, L)
    k_expanded = num_transfer_tokens.unsqueeze(1).expand(B, L)                   # (B, L)
    select_sorted = cols < k_expanded                                            # (B, L) bool for top k

    # Scatter the sorted True/False back to original column order
    # Use integer scatter then cast to bool (scatter_ on bool can be finicky across versions)
    transfer_int = torch.zeros(B, L, device=confidence.device, dtype=torch.int8) # (B, L)
    transfer_int = transfer_int.scatter(1, idx_sorted, select_sorted.to(torch.int8))
    transfer_index = transfer_int.bool() & mask_index  # ensure we never select unmasked

    return x0, x0_p, transfer_index

In [39]:
@ torch.no_grad()
def run_with_prefix_cache(
    model, prompt, target,
    steps=128,
    gen_length=128,
    block_length=128,
    temperature=0.,
    remasking='truth_top_k',
    mask_id=126336,
    is_eval=True
):
    
    target = target[:,:gen_length]

    x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()

    y = torch.full(x.shape, mask_id, dtype=torch.long).to(model.device)
    y[:, -target.shape[-1]:] = target   # expand target into y

    probs_all = torch.zeros(x.shape, dtype=torch.float64).to(model.device)

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps = steps // num_blocks

    nfe = 0
    for num_block in range(num_blocks):
        current_block_start = prompt.shape[1] + num_block * block_length
        current_block_end = current_block_start + block_length

        block_mask_index = (x[:, current_block_start:current_block_end] == mask_id) # block_mask_index is mask_block
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)

        output = model(x, use_cache=True)
        past_key_values = output.past_key_values

        mask_index = (x == mask_id)
        mask_index[:, current_block_end:] = 0

        x0, x0_p, transfer_index = get_transfer_index(
            output.logits,
            temperature,
            remasking,
            mask_index,
            x,
            y,
            num_transfer_tokens[:, 0]
        )

        if is_eval:
            x[transfer_index] = y[transfer_index]
            probs_all[transfer_index] = x0_p[transfer_index]
        else:
            x[transfer_index] = x0[transfer_index]
        # end
        

        new_past_key_values = []
        for i in range(len(past_key_values)):
            new_past_key_values.append(())
            for j in range(len(past_key_values[i])):
                new_past_key_values[i] += (past_key_values[i][j][:, :, :current_block_start],)
            # end for j
        # end for i
        
        past_key_values = new_past_key_values
        nfe += 1

        i = 1
        while True:
            if (x[:, current_block_start:current_block_end] == mask_id).sum() == 0:
                break
                # end
            nfe += 1
            mask_index = (x[:, current_block_start:] == mask_id)
            mask_index[:, block_length:] = 0

            # jinyu: assuming logits is L - len(cached)
            logits = model(x[:, current_block_start:], past_key_values=past_key_values, use_cache=True).logits

            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1) # b, l

            x0, x0_p, transfer_index = get_transfer_index(
                logits,
                temperature,
                remasking,
                mask_index, 
                x[:, current_block_start:],
                y[:, current_block_start:],
                num_transfer_tokens[:, i]
            )

            if is_eval:
                x[:, current_block_start:][transfer_index] = y[:, current_block_start:][transfer_index]
                probs_all[:, current_block_start:][transfer_index] = x0_p[transfer_index]
            else:
                x[:, current_block_start:][transfer_index] = x0[transfer_index]
            # end

            i += 1
        # end for block

    return probs_all, y != mask_id
# end

In [40]:
def calculate_ppl_and_conf(probs_all, mask_target, eps=1e-12):
    probs_collected = probs_all[mask_target].reshape(mask_target.shape[0], -1)  # [B, K]

    # Arithmetic mean confidence (what you currently call mean_prob)
    mean_prob = probs_collected.mean(dim=-1)  # [B]

    # Per-token NLL and per-row PPL (geometric-mean based)
    nll_collected = -torch.log(probs_collected + eps)   # [B, K]
    nll_per = nll_collected.mean(dim=-1)                 # [B]
    ppl_per = torch.exp(nll_per)                        # [B]

    # Geometric mean confidence (this one is directly tied to PPL)
    # geo_prob = torch.exp(torch.log(probs_collected + eps).mean(dim=1))  # [B]
    # And ppl_per == 1 / geo_prob (up to eps effects)

    return ppl_per.item(), mean_prob.item()

In [None]:
probs_truth, mask_current_full = run_with_prefix_cache(model, ids_input, ids_target, remasking='truth_top_k', block_length=64, steps=128)
print(calculate_ppl_and_conf(probs_truth, mask_current_full))

(4.0182959601783335, 0.5236314106856976)


: 