In [1]:
import os
import torch
import numpy as np

import torch.nn.functional as F

from collections import defaultdict
from tqdm import tqdm
from modeling_llada.modeling_llada import LLaDAModelLM
from transformers import AutoTokenizer
from datasets import Dataset, load_dataset

from llada_get_loglikelihood import forward_process, get_log_likelihood
from llada_generate import get_num_transfer_tokens, add_gumbel_noise

from jinyu_utils.jinyu_tokenizer import Tokenizer_
from jinyu_utils.jinyu_preprocess_wiki import parse_lines_with_index, merge_subdocs, PATTEN_REG_WIKI, simple_calculate_sim
from jinyu_utils.jinyu_dataset import jinyu_load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
ID_TOKEN_MASK = 126336 # '|mdm_mask|'
ID_TOKEN_PADDING = 126081 # '|endoftext|'
ID_TOKEN_EOT = 126348 # '|eot_id|'

device = 'cuda:0'

id_model = 'GSAI-ML/LLaDA-8B-Base'


In [3]:
'''load tokenizer'''
tokenizer = AutoTokenizer.from_pretrained(
    id_model,
    # local_files_only=True,
    trust_remote_code=True
)

if tokenizer.padding_side != 'left':
    tokenizer.padding_side = 'left'
# end

assert tokenizer.pad_token_id != 126336

In [4]:
'''load model'''
model_kwargs = {}
model = LLaDAModelLM.from_pretrained(
    id_model,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    **model_kwargs
)

model = model.eval().to(device)
device_for_input = model.get_input_embeddings().weight.device

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
Loading checkpoint shards: 100%|██████████| 6/6 [00:00<00:00,  8.01it/s]


In [5]:
ds = jinyu_load_dataset(1, split='test')['text']

In [6]:
'''preprocess dataset'''
docs, _ = parse_lines_with_index(PATTEN_REG_WIKI, ds)
docs = docs['subdocs']

In [7]:
samples = []
for doc in docs:
    lines_1 = doc['texts']
    paragraph_1 = ' '.join(lines_1)
    lines_remain, titles = merge_subdocs(doc['subdocs'])
    paragraph_remain = ' '.join(lines_remain)
    prefix = paragraph_1
    target = paragraph_remain
    samples.append({'prefix': prefix, 'target': target})
# end


In [8]:
samples = samples[:1]
len(samples)

1

In [9]:
sample = samples[0]
prompt_ = sample['prefix']
target_ = sample['target']

with torch.no_grad():
    inputs_encoded = tokenizer(
        prompt_,
        add_special_tokens=False,
        return_tensors="pt"
    )

    targets_encoded = tokenizer(
        target_,
        add_special_tokens=False,
        return_tensors="pt"
    )

    ids_input = inputs_encoded['input_ids'].to(device_for_input)
    ids_target = targets_encoded['input_ids'].to(device_for_input)
    # generate_with_ppl(model, ids_input, ids_target[:, :128], steps=32, gen_length=128, block_length=32, remasking='confidence_top_k')

# end



In [None]:
with torch.no_grad():
    model = model
    prompt = ids_input
    target = ids_target[:,:128]
    attention_mask=None
    steps=128
    gen_length=128
    block_length=128
    temperature=0
    remasking='random'
    mask_id=126336
    logits_eos_inf=False
    confidence_eos_eot_inf=False
    
    # (batch, full_length)
    x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    y = torch.full(x.shape, mask_id, dtype=torch.long).to(model.device)
    y[:, -target.shape[-1]:] = target   # expand target into y
    ppl = torch.zeros(x.shape).to(model.device)

    # fill prompts
    x[:, :prompt.shape[1]] = prompt.clone()

    if attention_mask is not None:
        attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1)
    # end

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps_per_block = steps // num_blocks

    num_block = 0
    mask_block = (x[:, prompt.shape[1] + num_block*block_length : prompt.shape[1]+(num_block + 1)*block_length]) == mask_id
    nums_transfer_tokens = get_num_transfer_tokens(mask_block, steps_per_block)    # [[7,7,6],..] if steps_per_block = 3 and remainder = 2
    del mask_block

    step_per_block = 0
    
    logits = model(x, attention_mask=attention_mask).logits

    if logits_eos_inf:
        logits[:, :, ID_TOKEN_PADDING] = -torch.inf
    # end

    logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
    x0 = torch.argmax(logits_with_noise, dim=-1) # b, l -> [[id0_5, id1_3, ...],..] (this is the index of native max)

    if confidence_eos_eot_inf:
        logits_with_noise[:, :, ID_TOKEN_PADDING] = logits[:, :, ID_TOKEN_EOT] = -torch.inf
    # end

    p = F.softmax(logits, dim=-1)
    del logits

    index_p = None  # we are going to handle the index_p
    match remasking:
        case 'confidence_top_k' | 'random':
            index_p = x0.unsqueeze(-1)
        case 'truth_top_k':
            index_p = y.unsqueeze(-1)
        case _:
            raise NotImplementedError(remasking)
        # end
    # end match
    x0_p = torch.squeeze(p.gather(dim=-1, index=index_p), -1) # b, l [[0.9, 0.7],..]

    # set mask
    mask_current_full = torch.where(x==mask_id, True, False)    # set prompt to False
    mask_current_full[:, (prompt.shape[1]+(num_block+1)*block_length):] = False # set future block to False
    x0 = torch.where(mask_current_full, x0, x)  # restore non-current-block tokens
    x0_p_current = torch.where(mask_current_full, x0_p, -np.inf)

    # update x0, keep x0=x0 for the current & future blocks, override x0=x for prompt

    mask_transfered = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
    k=nums_transfer_tokens[:, step_per_block]

    if remasking == 'random':
        idx_mask_current = mask_current_full[0].nonzero(as_tuple=True)[0] # 1d
        idx_unmask_k = torch.randint(
            idx_mask_current[0], idx_mask_current[-1],
            (mask_current_full.shape[0], k),
            device=x0.device
        )
    else:
        _, idx_unmask_k = torch.topk(x0_p_current, k)
    # end if-else

    mask_transfered.scatter_(-1, idx_unmask_k, True) # VALID format of this mask_transfered[idx_unmask_k] = True
    x[mask_transfered] = x0[mask_transfered]
# end with

In [None]:
# # [[8,8,7]] when steps_per_block = 3 and remainder = 2, split remainder to early steps
# def get_num_transfer_tokens(mask_block, steps_per_block):
#     '''
#     In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
#     Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
#     the expected number of tokens transitioned at each step should be consistent.

#     This function is designed to precompute the number of tokens that need to be transitioned at each step.
#     '''
#     num_mask = mask_block.sum(dim=1, keepdim=True)

#     # number of unmask per step
#     base = num_mask // steps_per_block

#     # number of remainder for the last shot
#     remainder = num_mask % steps_per_block


#     num_transfer_tokens = torch.zeros(num_mask.size(0), steps_per_block, device=mask_block.device, dtype=torch.int64) + base

#     for i in range(num_mask.size(0)):
#         num_transfer_tokens[i, :remainder[i]] += 1

#     return num_transfer_tokens
# # end


# def add_gumbel_noise(logits, temperature):
#     '''
#     The Gumbel max is a method for sampling categorical distributions.
#     According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
#     Thus, we use float64.
#     '''
#     if temperature == 0:
#         return logits
#     logits = logits.to(torch.float64)
#     noise = torch.rand_like(logits, dtype=torch.float64)
#     gumbel_noise = (- torch.log(noise)) ** temperature
#     return logits.exp() / gumbel_noise
# end


In [None]:
# @ torch.no_grad()
# def generate_with_ppl(model, prompt, target, attention_mask=None, steps=128, gen_length=128, block_length=128, temperature=0,
#              remasking='low_confidence', mask_id=126336, logits_eos_inf=False, confidence_eos_eot_inf=False):

#     # (batch, full_length)
#     x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
#     y = torch.full(x.shape, mask_id, dtype=torch.long).to(model.device)
#     y[:, -target.shape[-1]:] = target   # expand target into y
#     ppl = torch.zeros(x.shape).to(model.device)

#     # fill prompts
#     x[:, :prompt.shape[1]] = prompt.clone()

#     if attention_mask is not None:
#         attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1)
#     # end

#     assert gen_length % block_length == 0
#     num_blocks = gen_length // block_length

#     assert steps % num_blocks == 0
#     steps_per_block = steps // num_blocks

#     for num_block in range(num_blocks):

#         mask_block = (x[:, prompt.shape[1] + num_block*block_length : prompt.shape[1]+(num_block + 1)*block_length]) == mask_id
#         nums_transfer_tokens = get_num_transfer_tokens(mask_block, steps_per_block)    # [[7,7,6],..] if steps_per_block = 3 and remainder = 2
#         del mask_block

#         for step_per_block in range(steps_per_block):
            
#             logits = model(x, attention_mask=attention_mask).logits

#             if logits_eos_inf:
#                 logits[:, :, ID_TOKEN_PADDING] = -torch.inf
#             # end

#             logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
#             x0 = torch.argmax(logits_with_noise, dim=-1) # b, l -> [[id0_5, id1_3, ...],..] (this is the index of native max)

#             if confidence_eos_eot_inf:
#                 logits_with_noise[:, :, ID_TOKEN_PADDING] = logits[:, :, ID_TOKEN_EOT] = -torch.inf
#             # end

#             p = F.softmax(logits, dim=-1)
#             del logits

#             index_p = None  # we are going to handle the index_p
#             match remasking:
#                 case 'confidence_top_k' | 'random':
#                     index_p = x0.unsqueeze(-1)
#                 case 'truth_top_k':
#                     index_p = y.unsqueeze(-1)
#                 case _:
#                     raise NotImplementedError(remasking)
#                 # end
#             # end match
#             x0_p = torch.squeeze(p.gather(dim=-1, index=index_p), -1) # b, l [[0.9, 0.7],..]

#             # set mask
#             mask_current_full = torch.where(x==mask_id, True, False)    # set prompt to False
#             mask_current_full[:, (prompt.shape[1]+(num_block+1)*block_length):] = False # set future block to False
#             x0 = torch.where(mask_current_full, x0, x)  # restore non-current-block tokens
#             x0_p_current = torch.where(mask_current_full, x0_p, -np.inf)

#             # update x0, keep x0=x0 for the current & future blocks, override x0=x for prompt

#             mask_transfered = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)

#             if remasking == 'random':
#                 idx_mask_current = mask_current_full[0].nonzero(as_tuple=True)[0] # 1d
#                 idx_unmask_k = torch.randint(
#                     idx_mask_current[0], idx_mask_current[-1],
#                     (mask_current_full.shape[0], k)
#                 )
#             else:
#                 _, idx_unmask_k = torch.topk(x0_p_current, k=nums_transfer_tokens[:, step_per_block])
#             # end if-else

#             mask_transfered.scatter_(-1, idx_unmask_k, True) # VALID format of this mask_transfered[idx_unmask_k] = True
#             print(mask_transfered)
#             break
#             x[mask_transfered] = x0[mask_transfered]
#         # end steps_per_block
#     # end num_blocks
#     return x
# # end