In [1]:
import os
import torch
import numpy as np

import torch.nn.functional as F

from collections import defaultdict
from tqdm import tqdm
from modeling_llada.modeling_llada_with_kv_cache import LLaDAModelLM
from transformers import AutoTokenizer
from datasets import Dataset, load_dataset

from llada_get_loglikelihood import forward_process, get_log_likelihood
from llada_generate import get_num_transfer_tokens, add_gumbel_noise

from jinyu_utils.jinyu_tokenizer import Tokenizer_
from jinyu_utils.jinyu_preprocess_wiki import parse_lines_with_index, merge_subdocs, PATTEN_REG_WIKI
from jinyu_utils.jinyu_dataset import jinyu_load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ID_TOKEN_MASK = 126336 # '|mdm_mask|'
ID_TOKEN_PADDING = 126081 # '|endoftext|'
ID_TOKEN_EOT = 126348 # '|eot_id|'

device = 'cuda:0'

id_model = 'GSAI-ML/LLaDA-8B-Base'


In [3]:
'''load tokenizer'''
tokenizer = AutoTokenizer.from_pretrained(
    id_model,
    # local_files_only=True,
    trust_remote_code=True
)

if tokenizer.padding_side != 'left':
    tokenizer.padding_side = 'left'
# end

assert tokenizer.pad_token_id != 126336

In [4]:
'''load model'''
model_kwargs = {}
model = LLaDAModelLM.from_pretrained(
    id_model,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    **model_kwargs
)

model = model.eval().to(device)
device_for_input = model.get_input_embeddings().weight.device

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00, 13.22it/s]


In [5]:
ds = jinyu_load_dataset(1, split='test')['text']

In [6]:
'''preprocess dataset'''
docs, _ = parse_lines_with_index(PATTEN_REG_WIKI, ds)
docs = docs['subdocs']

In [7]:
samples = []
for doc in docs:
    lines_1 = doc['texts']
    paragraph_1 = ' '.join(lines_1)
    lines_remain, titles = merge_subdocs(doc['subdocs'])
    paragraph_remain = ' '.join(lines_remain)
    prefix = paragraph_1
    target = paragraph_remain
    samples.append({'prefix': prefix, 'target': target})
# end


In [8]:
samples = samples[:1]
len(samples)

1

In [9]:
sample = samples[0]
prompt_ = sample['prefix']
target_ = sample['target']

with torch.no_grad():
    inputs_encoded = tokenizer(
        prompt_,
        add_special_tokens=False,
        return_tensors="pt"
    )

    targets_encoded = tokenizer(
        target_,
        add_special_tokens=False,
        return_tensors="pt"
    )

    ids_input = inputs_encoded['input_ids'].to(device_for_input)
    ids_target = targets_encoded['input_ids'].to(device_for_input)
    # generate_with_ppl(model, ids_input, ids_target[:, :128], steps=32, gen_length=128, block_length=32, remasking='confidence_top_k')

# end



In [None]:
def run(model, ids_input, ids_target, remasking):
    with torch.no_grad():
        model = model
        prompt = ids_input
        target = ids_target[:,:128]
        attention_mask=None
        steps=128
        gen_length=128
        block_length=128
        temperature=0
        remasking=remasking
        mask_id=126336
        logits_eos_inf=False
        confidence_eos_eot_inf=False

        _tuples_log_unmask = []
        
        # (batch, full_length)
        x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
        x[:, :prompt.shape[1]] = prompt.clone()

        y = torch.full(x.shape, mask_id, dtype=torch.long).to(model.device)
        y[:, -target.shape[-1]:] = target   # expand target into y

        probs_all = torch.zeros(x.shape, dtype=torch.bfloat16).to(model.device)


        # fill prompts

        if attention_mask is not None:
            attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1)
        # end

        assert gen_length % block_length == 0
        num_blocks = gen_length // block_length

        assert steps % num_blocks == 0
        steps_per_block = steps // num_blocks

        for num_block in range(num_blocks):
            mask_block = (x[:, prompt.shape[1] + num_block*block_length : prompt.shape[1]+(num_block + 1)*block_length]) == mask_id
            nums_transfer_tokens = get_num_transfer_tokens(mask_block, steps_per_block)    # [[7,7,6],..] if steps_per_block = 3 and remainder = 2
            del mask_block

            for step_per_block in range(steps_per_block):   # TODO: 1 -> steps_per_block
            
                logits = model(x, attention_mask=attention_mask, cache_kv=True).logits

                if logits_eos_inf:
                    logits[:, :, ID_TOKEN_PADDING] = -torch.inf
                # end

                logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
                x0 = torch.argmax(logits_with_noise, dim=-1) # b, l -> [[id0_5, id1_3, ...],..] (this is the index of native max)

                if confidence_eos_eot_inf:
                    logits_with_noise[:, :, ID_TOKEN_PADDING] = logits[:, :, ID_TOKEN_EOT] = -torch.inf
                # end

                p = F.softmax(logits, dim=-1)
                del logits

                index_p = None  # we are going to handle the index_p
                match remasking:
                    case 'generate_top_k':
                        index_p = x0.unsqueeze(-1)    # ALERT: original code
                        # index_p = y.unsqueeze(-1)
                    case 'truth_top_k' | 'random_top_k':
                        index_p = y.unsqueeze(-1)
                    case _:
                        raise NotImplementedError(remasking)
                    # end
                # end match
                x0_p = torch.squeeze(p.gather(dim=-1, index=index_p), -1) # b, l [[0.9, 0.7],..]

                # set mask
                mask_current_full = torch.where(x==mask_id, True, False)    # set prompt to False
                mask_current_full[:, (prompt.shape[1]+(num_block+1)*block_length):] = False # set future block to False
                # print((mask_current_full == True).sum())  # DEBUG Random issue here
                # print(x0_p[mask_current_full])

                x0 = torch.where(mask_current_full, x0, x)  # restore non-current-block tokens
                x0_p_current_full = torch.where(mask_current_full, x0_p, -np.inf)

                # update x0, keep x0=x0 for the current & future blocks, override x0=x for prompt
                mask_transfered = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
                k=nums_transfer_tokens[:, step_per_block]

                if remasking == 'random_top_k':
                    idx_mask_current = mask_current_full.nonzero(as_tuple=True)[-1].reshape(mask_current_full.shape[0], -1) # 1d
                    perm = torch.argsort(torch.rand(idx_mask_current.shape[0], idx_mask_current.shape[1], device=mask_current_full.device), dim=-1)
                    idx_unmask_k = idx_mask_current.gather(-1, perm)[:,:k]
                else:
                    _, idx_unmask_k = torch.topk(x0_p_current_full, k)
                # end if-else

                mask_transfered.scatter_(-1, idx_unmask_k, True)            # VALID format of this mask_transfered[idx_unmask_k] = True

                x[mask_transfered] = y[mask_transfered]                     # original code:   x[mask_transfered] = x0[mask_transfered]
                probs_all[mask_transfered] = x0_p_current_full[mask_transfered]   #
                _tuples_log_unmask.append((y[mask_transfered], idx_unmask_k))
                # end
            # end for steps
        # end for blocks
    # end with
    return probs_all, y != mask_id, _tuples_log_unmask
# end

In [11]:
# def calculate_ppl_and_conf(probs_all, mask_target):
#     probs_collected = probs_all[mask_target].reshape(mask_target.shape[0], -1)    # (b, collected)
#     mean_prob = probs_collected.mean(dim=-1)
#     nll_collected = -torch.log(probs_collected + 1e-12)
#     print(nll_collected)
#     ppl = torch.exp(nll_collected.mean())

#     return ppl.item(), mean_prob.item()
# # end

In [12]:
def calculate_ppl_and_conf(probs_all, mask_target, eps=1e-12):
    probs_collected = probs_all[mask_target].reshape(mask_target.shape[0], -1)  # [B, K]

    # Arithmetic mean confidence (what you currently call mean_prob)
    mean_prob = probs_collected.mean(dim=-1)  # [B]

    # Per-token NLL and per-row PPL (geometric-mean based)
    nll_collected = -torch.log(probs_collected + eps)   # [B, K]
    nll_per = nll_collected.mean(dim=-1)                 # [B]
    ppl_per = torch.exp(nll_per)                        # [B]

    # Geometric mean confidence (this one is directly tied to PPL)
    # geo_prob = torch.exp(torch.log(probs_collected + eps).mean(dim=1))  # [B]
    # And ppl_per == 1 / geo_prob (up to eps effects)

    return ppl_per.item(), mean_prob.item()

In [None]:
probs_truth, mask_current_full, tuples_log_unmask = run(model, ids_input, ids_target, 'truth_top_k')
# probs_random, mask_current_full = run(model, ids_input, ids_target, 'random_top_k')

In [14]:
print(calculate_ppl_and_conf(probs_truth, mask_current_full))
# print(calculate_ppl_and_conf(probs_random, mask_current_full))

(3.203125, 0.56640625)


In [20]:
model.model.transformer.blocks[-1]._k_previous.shape

torch.Size([1, 558, 4096])

In [16]:
model

LLaDAModelLM(
  (model): LLaDAModel(
    (transformer): ModuleDict(
      (wte): Embedding(126464, 4096)
      (emb_drop): Dropout(p=0.0, inplace=False)
      (ln_f): RMSLayerNorm()
      (blocks): ModuleList(
        (0-31): 32 x LLaDALlamaBlock(
          (dropout): Dropout(p=0.0, inplace=False)
          (act): SiLU()
          (attn_out): Linear(in_features=4096, out_features=4096, bias=False)
          (ff_out): Linear(in_features=12288, out_features=4096, bias=False)
          (rotary_emb): RotaryEmbedding()
          (attn_norm): RMSLayerNorm()
          (ff_norm): RMSLayerNorm()
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (ff_proj): Linear(in_features=4096, out_features=12288, bias=False)
          (up_proj): Linear(in_features=4096, out_features=12288, bias=False)
        )
      )
    