In [None]:
import os
import torch
import numpy as np

import torch.nn.functional as F

from collections import defaultdict
from tqdm import tqdm
from modeling_llada.modeling_llada import LLaDAModelLM
from transformers import AutoTokenizer
from datasets import Dataset, load_dataset

from llada_get_loglikelihood import forward_process, get_log_likelihood
from llada_generate import generate

from jinyu_utils.jinyu_tokenizer import Tokenizer_
from jinyu_utils.jinyu_preprocess_wiki import parse_lines_with_index, merge_subdocs, PATTEN_REG_WIKI, simple_calculate_sim
from jinyu_utils.jinyu_dataset import jinyu_load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
ID_TOKEN_MASK = 126336 # '|mdm_mask|'
ID_TOKEN_PADDING = 126081 # '|endoftext|'
ID_TOKEN_EOT = 126348 # '|eot_id|'

id_model = 'GSAI-ML/LLaDA-8B-Base'


In [3]:
'''load tokenizer'''
tokenizer = AutoTokenizer.from_pretrained(
    id_model,
    # local_files_only=True,
    trust_remote_code=True
)

if tokenizer.padding_side != 'left':
    tokenizer.padding_side = 'left'
# end

assert tokenizer.pad_token_id != 126336

In [None]:
'''load model'''
model_kwargs = {}
model = LLaDAModelLM.from_pretrained(
    id_model,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    **model_kwargs
)

model = model.eval()
device_for_input = model.get_input_embeddings().weight.device

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|██████████| 6/6 [00:01<00:00,  3.02it/s]


In [None]:
ds = jinyu_load_dataset(1, split='test')['text']


In [6]:
'''preprocess dataset'''
docs, _ = parse_lines_with_index(PATTEN_REG_WIKI, ds)
docs = docs['subdocs']

In [7]:
samples = []
for doc in docs:
    lines_1 = doc['texts']
    paragraph_1 = ' '.join(lines_1)
    lines_remain, titles = merge_subdocs(doc['subdocs'])
    paragraph_remain = ' '.join(lines_remain)
    prefix = paragraph_1
    target = paragraph_remain
    samples.append({'prefix': prefix, 'target': target})
# end


In [8]:
samples = samples[:50]
len(samples)

50

In [15]:
prompts = [sample['prefix'] for sample in samples]
outputs = []
with torch.no_grad():
    for prompt in tqdm(prompts, desc='starting to get outputs...'):
        encoded_inputs = tokenizer(
                prompt,
                add_special_tokens=False,
                padding=True,
                return_tensors="pt"
        )

        input_ids = encoded_inputs['input_ids'].to(device_for_input)
        attention_mask = encoded_inputs['attention_mask'].to(device_for_input)

        out = generate(model, input_ids, attention_mask, steps=32, gen_length=128, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence')
        output = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
        outputs.append(output)
    # end for
# end with


starting to get outputs...: 100%|██████████| 50/50 [01:19<00:00,  1.59s/it]


In [16]:
sims_target = []
for sample, predict in zip(samples, outputs):
    sims_target.append(simple_calculate_sim(sample['target'], predict))
# end
sims_target = [(idx, sim) for idx, sim in enumerate(sims_target)]
sims_target_sorted = sorted(sims_target, key=lambda copus: -copus[1])


In [17]:
sims_prefix = []
for sample, predict in zip(samples, outputs):
    sims_prefix.append(simple_calculate_sim(sample['prefix'], predict))
# end
sims_prefix = [(idx, sim) for idx, sim in enumerate(sims_prefix)]
sims_prefix_sorted = sorted(sims_prefix, key=lambda copus: -copus[1])
sims_prefix_sorted[:10]


[(16, 0.9876543209876543),
 (0, 0.9807692307692307),
 (15, 0.8),
 (37, 0.6071428571428571),
 (35, 0.578125),
 (48, 0.5),
 (31, 0.4583333333333333),
 (5, 0.36231884057971014),
 (23, 0.3333333333333333),
 (32, 0.288135593220339)]

In [18]:
import json

with open('wikiraw_base_50_32', 'w+') as file:
    idx = 0
    contents = []

    for sample, predict in zip(samples, outputs):
        content = {
            'prefix': sample['prefix'],
            'target': sample['target'],
            'predict': predict,
            'sim_prefix': sims_prefix[idx][1],
            'sim_target': sims_target[idx][1]
        }

        contents.append(content)

        idx += 1
    # end

    file.write(json.dumps(contents, indent=4))
# end

In [None]:
# [[8,8,7]] when steps_per_block = 3 and remainder = 2, split remainder to early steps
def get_num_transfer_tokens(mask_block, steps_per_block):
    '''
    In the reverse process, the interval [0, 1] is uniformly discretized into steps intervals.
    Furthermore, because LLaDA employs a linear noise schedule (as defined in Eq. (8)),
    the expected number of tokens transitioned at each step should be consistent.

    This function is designed to precompute the number of tokens that need to be transitioned at each step.
    '''
    num_mask = mask_block.sum(dim=1, keepdim=True)

    # number of unmask per step
    base = num_mask // steps_per_block

    # number of remainder for the last shot
    remainder = num_mask % steps_per_block


    num_transfer_tokens = torch.zeros(num_mask.size(0), steps_per_block, device=mask_block.device, dtype=torch.int64) + base

    for i in range(num_mask.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1

    return num_transfer_tokens
# end


def add_gumbel_noise(logits, temperature):
    '''
    The Gumbel max is a method for sampling categorical distributions.
    According to arXiv:2409.02908, for MDM, low-precision Gumbel Max improves perplexity score but reduces generation quality.
    Thus, we use float64.
    '''
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise
# end


In [None]:
@ torch.no_grad()
def generate_with_ppl(model, prompt, target, attention_mask=None, steps=128, gen_length=128, block_length=128, temperature=0.,
             remasking='low_confidence', mask_id=126336, logits_eos_inf=False, confidence_eos_eot_inf=False):

    # (batch, full_length)
    x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    y = torch.full(x.shape).to(model.device)
    y[:, -target.shape[-1]:] = target   # expand target into y
    ppl = torch.zeros(x.shape).to(model.device)

    # fill prompts
    x[:, :prompt.shape[1]] = prompt.clone()

    if attention_mask is not None:
        attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1)
    # end

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps_per_block = steps // num_blocks

    for num_block in range(num_blocks):
        
        mask_block = (x[:, prompt.shape[1] + num_block*block_length : prompt.shape[1]+(num_block + 1)*block_length]) == mask_id
        
        nums_transfer_tokens = get_num_transfer_tokens(mask_block, steps_per_block)    # [[7,7,6],..] if steps_per_block = 3 and remainder = 2

        for step_per_block in range(steps_per_block):
            mask_masked_full = (x == mask_id)
            
            logits = model(x, attention_mask=attention_mask).logits

            if logits_eos_inf:
                logits[:, :, ID_TOKEN_PADDING] = -torch.inf
            # end

            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1) # b, l -> [[id0_5, id1_3, ...],..] (this is the index of native max)
            # x0_y = logits_with_noise.gather(-1, y.unsqueeze(-1))
            
            if confidence_eos_eot_inf:
                logits_with_noise[:, :, ID_TOKEN_PADDING] = logits[:, :, ID_TOKEN_EOT] = -torch.inf
            # end

            p = F.softmax(logits, dim=-1)

            index_p = None  # we are going to handle the index_p
            match remasking:
                case 'confidence_top_k':
                    index_p = x0.unsqueeze(-1)
                case 'truth_top_k':
                    index_p = y.unsqueeze(-1)
                case 'random':
                    # perturbation code
                    index_p = torch.arange()
                    x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device) # NOTE: BAD!! break the ppl calculation
                case _:
                    raise NotImplementedError(remasking)
                # end
            # end match
            x0_p = torch.squeeze(p.gather(dim=-1, index=index_p), -1) # b, l [[0.9, 0.7],..]

            x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf    # set future blocks to be -inf

            x0 = torch.where(mask_masked_full, x0, x)
            confidence = torch.where(mask_masked_full, x0_p, -np.inf)

            mask_transfered = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            for batchid in range(confidence.shape[0]):
                _, select_index = torch.topk(confidence[batchid], k=nums_transfer_tokens[batchid, step_per_block])
                mask_transfered[batchid, select_index] = True
            x[mask_transfered] = x0[mask_transfered]
        # end steps_per_block
    # end num_blocks
    return x
# end

In [None]:
@ torch.no_grad()
def generate(model, prompt, attention_mask=None, steps=128, gen_length=128, block_length=128, temperature=0.,
             remasking='low_confidence', mask_id=126336, logits_eos_inf=False, confidence_eos_eot_inf=False):
    '''
    Args:
        model: Mask predictor.
        prompt: A tensor of shape (1, L).
        steps: Sampling steps, less than or equal to gen_length.
        gen_length: Generated answer length.
        block_length: Block length, less than or equal to gen_length. If less than gen_length, it means using semi_autoregressive remasking.
        temperature: Categorical distribution sampling temperature.
        cfg_scale: Unsupervised classifier-free guidance scale.
        remasking: Remasking strategy. 'low_confidence' or 'random'.
        mask_id: The toke id of [MASK] is 126336.
        logits_eos_inf: Whether to set the logits of EOS token to -inf. See Appendix B.4 of LLaDA for details
        confidence_eos_eot_inf: Whether to set the confidence of EOS and EoT token to -inf. See Appendix B.4 of LLaDA for details
    '''
    # (batch, full_length)
    x = torch.full((prompt.shape[0], prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)

    # fill prompts
    x[:, :prompt.shape[1]] = prompt.clone()

    if attention_mask is not None:
        attention_mask = torch.cat([attention_mask, torch.ones((prompt.shape[0], gen_length), dtype=attention_mask.dtype, device=model.device)], dim=-1)
    # end

    prompt_index = (x != mask_id)

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps_per_block = steps // num_blocks

    for num_block in range(num_blocks):
        block_mask_index = (x[:, prompt.shape[1] + num_block * block_length: prompt.shape[1] + (num_block + 1) * block_length:] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps_per_block)    # [[7,7,6],..] if steps_per_block = 3 and remainder = 2
        
        for i in range(steps_per_block):
            mask_index = (x == mask_id)
            logits = model(x, attention_mask=attention_mask).logits

            if logits_eos_inf:
                logits[:, :, ID_TOKEN_PADDING] = -torch.inf
            # end

            logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
            x0 = torch.argmax(logits_with_noise, dim=-1) # b, l -> [[id0_5, id1_3, ...],..]
            
            if confidence_eos_eot_inf:
                logits_with_noise[:, :, ID_TOKEN_PADDING] = logits[:, :, ID_TOKEN_EOT] = -torch.inf
            # end

            if remasking == 'low_confidence':
                p = F.softmax(logits, dim=-1)
                x0_p = torch.squeeze(
                    torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l [[0.9, 0.7],..]
            elif remasking == 'random':
                x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
            else:
                raise NotImplementedError(remasking)
            # end

            x0_p[:, prompt.shape[1] + (num_block + 1) * block_length:] = -np.inf

            x0 = torch.where(mask_index, x0, x)
            confidence = torch.where(mask_index, x0_p, -np.inf)

            mask_transfered = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
            for batchid in range(confidence.shape[0]):
                _, select_index = torch.topk(confidence[batchid], k=num_transfer_tokens[batchid, i])
                mask_transfered[batchid, select_index] = True
            x[mask_transfered] = x0[mask_transfered]
        # end steps_per_block
    # end num_blocks
    return x
# end