In [1]:
import os
import torch

from collections import defaultdict
from tqdm import tqdm
from transformers import AutoTokenizer
from datasets import DatasetBuilder
from torch.cuda import nvtx

from modeling_fastdllm.modeling_llada import LLaDAModelLM
from fastdllm_get_loglikelihood import get_loglikelihood
from fastdllm_generate import generate, generate_with_dual_cache

from jinyu_utils.jinyu_dataset import jinyu_load_dataset
from jinyu_utils.jinyu_tokenizer import Tokenizer_
from jinyu_utils.jinyu_preprocess_wiki import parse_lines_with_index, merge_subdocs, PATTEN_REG_WIKI, simple_calculate_sim

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
id_model = 'GSAI-ML/LLaDA-8B-Instruct'
# path_cache_base = os.environ['HF_HUB_CACHE']
# folder_model = '--'.join(['models'] + id_model.split('/'))
# path_cache_model = os.path.join(path_cache_base, folder_model)
# path_snapshot_model = os.path.join(path_cache_model, 'snapshots')
# folder_snapshot_model_1 = [entity for entity in os.listdir(path_snapshot_model) if entity[0] != '.'][0]
# path_snapshot_model_1 = os.path.join(path_snapshot_model, folder_snapshot_model_1)
# print(path_snapshot_model_1)


In [3]:
'''load tokenizer'''
tokenizer = AutoTokenizer.from_pretrained(
    id_model,
    # local_files_only=True,
    trust_remote_code=True
)

if tokenizer.padding_side != 'left':
    tokenizer.padding_side = 'left'
# end

assert tokenizer.pad_token_id != 126336

In [4]:
'''load model'''
model_kwargs = {}
model = LLaDAModelLM.from_pretrained(
    id_model,
    # local_files_only=True,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    **model_kwargs
)

model = model.eval()
device_for_input = model.get_input_embeddings().weight.device

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|██████████| 6/6 [00:02<00:00,  2.32it/s]


In [5]:
'''load dataset'''
ds = jinyu_load_dataset(1, split='test')['text']


In [6]:
'''preprocess dataset'''
docs, _ = parse_lines_with_index(PATTEN_REG_WIKI, ds)
docs = docs['subdocs']

In [7]:
samples = []
for doc in docs:
    lines_1 = doc['texts']
    paragraph_1 = ' '.join(lines_1)
    lines_remain, titles = merge_subdocs(doc['subdocs'])
    paragraph_remain = ' '.join(lines_remain)
    prefix = paragraph_1
    target = paragraph_remain
    samples.append({'prefix': prefix, 'target': target})
# end


In [8]:
samples = samples[:5]
len(samples)

5

In [None]:
prompt = "Lily can run 12 kilometers per hour for 4 hours. After that, she runs 6 kilometers per hour. How many kilometers can she run in 8 hours?"

# Add special tokens for the Instruct model. The Base model does not require the following two lines.
m = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(m, add_generation_prompt=True, tokenize=False)

input_ids = tokenizer(prompt)['input_ids']
input_ids = torch.tensor(input_ids).to(device_for_input).unsqueeze(0)

with torch.inference_mode():
    nvtx.range_push("INFER")
    out = generate_with_dual_cache(model, input_ids, steps=64, gen_length=128, block_length=32, temperature=0., remasking='low_confidence')
    torch.cuda.synchronize()
    nvtx.range_pop()
# end


In [10]:
print(tokenizer.batch_decode(out[0][:, input_ids.shape[1]:], skip_special_tokens=True)[0])

Lily can run 12 kilometers per hour for 4 hours, so she runs a total of 12 kilometers per hour x 4 hours = 48 kilometers.
After that, she runs 6 kilometers per hour for the remaining 4 hours, so she runs a total of 6 kilometers per hour x 4 hours = 24 kilometers.
Therefore, Lily can run a total of 48 kilometers + 24 kilometers = 72 kilometers in 8 hours.
Conclusively: 72


In [19]:
out[0].shape

torch.Size([1, 201])

In [16]:
sims_target = []
for sample, predict in zip(samples, outputs):
    sims_target.append(simple_calculate_sim(sample['target'], predict))
# end
sims_target = [(idx, sim) for idx, sim in enumerate(sims_target)]
sims_target_sorted = sorted(sims_target, key=lambda copus: -copus[1])


In [17]:
sims_prefix = []
for sample, predict in zip(samples, outputs):
    sims_prefix.append(simple_calculate_sim(sample['prefix'], predict))
# end
sims_prefix = [(idx, sim) for idx, sim in enumerate(sims_prefix)]
sims_prefix_sorted = sorted(sims_prefix, key=lambda copus: -copus[1])
sims_prefix_sorted[:10]


[(16, 0.9876543209876543),
 (0, 0.9807692307692307),
 (15, 0.8),
 (37, 0.6071428571428571),
 (35, 0.578125),
 (48, 0.5),
 (31, 0.4583333333333333),
 (5, 0.36231884057971014),
 (23, 0.3333333333333333),
 (32, 0.288135593220339)]

In [18]:
import json

with open('wikiraw_base_50_32', 'w+') as file:
    idx = 0
    contents = []

    for sample, predict in zip(samples, outputs):
        content = {
            'prefix': sample['prefix'],
            'target': sample['target'],
            'predict': predict,
            'sim_prefix': sims_prefix[idx][1],
            'sim_target': sims_target[idx][1]
        }

        contents.append(content)

        idx += 1
    # end

    file.write(json.dumps(contents, indent=4))
# end

In [13]:
# '''one-by-one testing'''
# idx = 43
# prompts = [samples[idx]['prefix']]
# with torch.no_grad():
#     for prompt in tqdm(prompts, desc='starting to get outputs...'):
#         encoded_inputs = tokenizer(
#                 prompt,
#                 add_special_tokens=False,
#                 padding=True,
#                 return_tensors="pt"
#         )

#         input_ids = encoded_inputs['input_ids'].to(device_for_input)
#         attention_mask = encoded_inputs['attention_mask'].to(device_for_input)

#         out = generate(model, input_ids, attention_mask, steps=32, gen_length=128, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence')
#         output = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)
#         print(simple_calculate_sim(samples[idx]['prefix'], output[0]))
#         print(samples[idx]['prefix'])
#         print(output[0])
#     # end for
# # end with


In [14]:
'''get log likelihood parts'''

# samples = []
# for doc in docs:
#     lines_1 = doc['texts']
#     paragraph_1 = ' '.join(lines_1)
#     lines_remain, titles = merge_subdocs(doc)
#     paragraph_remain = ' '.join(lines_remain)
#     prefix = 'I will give you a general description of a person. I will also give you some subtitles and you need to give me the detail of them respectively . '
#     prefix += paragraph_1
#     prefix += " Titles are : "
#     prefix += ' , '.join(titles)

#     target = paragraph_remain
#     samples.append({'prefix': prefix, 'target': target})
# # end

# class Tokenizer_test(Tokenizer_):
#     def _tokenize(self, e):
#         prefix, target = self._encode_pair(e['prefix'], e['target'])
#         return {
#             'prefix_text': e['prefix'],
#             'target_text': e['target'],
#             'prefix': prefix,
#             'target': target
#         }
#     # end
# # end


# ds =  Dataset.from_list(samples)
# ds = ds.map(Tokenizer_test(tokenizer))
# ds = ds.with_format("torch")
# ds = ds.filter(lambda x: len(x["prefix"]) + len(x['target']) <= 4096)

# out = []
# with torch.no_grad():
#     for elem in tqdm(ds, desc="Computing likelihood..."):
#         prefix = elem["prefix"]
#         target = elem["target"]

#         ll = get_log_likelihood(model, prefix, target)
#         out.append(ll)
#     # end
# # end


'get log likelihood parts'

In [None]:
# prompts = [sample['prefix'] for sample in samples]
# outputs = []
# with torch.no_grad():
#     for prompt in tqdm(prompts, desc='starting to get outputs...'):
#         encoded_inputs = tokenizer(
#                 prompt,
#                 add_special_tokens=False,
#                 padding=True,
#                 return_tensors="pt"
#         )

#         input_ids = encoded_inputs['input_ids'].to(device_for_input)
#         attention_mask = encoded_inputs['attention_mask'].to(device_for_input)

#         out = generate(model, input_ids, attention_mask, steps=32, gen_length=128, block_length=32, temperature=0., cfg_scale=0., remasking='low_confidence')
#         output = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
#         outputs.append(output)
#     # end for
# # end with
