In [None]:
import os
import torch

from collections import defaultdict
from tqdm import tqdm
from modeling.modeling_llada import LLaDAModelLM
from transformers import AutoTokenizer
from datasets import Dataset, load_dataset

from llada_get_loglikelihood import forward_process, get_log_likelihood
from llada_generate import generate

from jinyu_utils.jinyu_tokenizer import Tokenizer_
from jinyu_utils.jinyu_preprocess_wiki import parse_lines_with_index, merge_subdocs, PATTEN_REG_WIKI, simple_calculate_sim
from jinyu_utils.jinyu_dataset import jinyu_load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
id_model = 'GSAI-ML/LLaDA-8B-Base'

In [3]:
'''load tokenizer'''
tokenizer = AutoTokenizer.from_pretrained(
    id_model,
    trust_remote_code=True
)

if tokenizer.padding_side != 'left':
    tokenizer.padding_side = 'left'
# end

assert tokenizer.pad_token_id != 126336

In [4]:
'''load model'''
model_kwargs = {}
model = LLaDAModelLM.from_pretrained(
    id_model,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    **model_kwargs
)

model = model.eval()
device_for_input = model.get_input_embeddings().weight.device

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|██████████| 6/6 [00:02<00:00,  2.96it/s]


In [5]:
'''load dataset'''
ds = jinyu_load_dataset(2)['test']

In [6]:
ds[0]['question']

"Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"

In [7]:
samples = []
for data in ds:
    samples.append({'prefix': data['question'], 'target': data['answer']})
# end
samples = samples[:50]
print(samples[0])

{'prefix': "Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", 'target': 'Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer’s market.\n#### 18'}


In [11]:
prompts = [sample['prefix'] for sample in samples]
outputs = []
with torch.no_grad():
    for prompt in tqdm(prompts, desc='starting to get outputs...'):
        encoded_inputs = tokenizer(
                prompt,
                add_special_tokens=False,
                padding=True,
                return_tensors="pt"
        )

        input_ids = encoded_inputs['input_ids'].to(device_for_input)
        attention_mask = encoded_inputs['attention_mask'].to(device_for_input)

        out = generate(model, input_ids, attention_mask, steps=32, gen_length=128, block_length=128, temperature=0., cfg_scale=0., remasking='low_confidence')
        output = tokenizer.batch_decode(out[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
        outputs.append(output)
    # end for
# end with


starting to get outputs...: 100%|██████████| 50/50 [00:43<00:00,  1.16it/s]


In [12]:
import json

with open('gsm8k_base_50_32', 'w+') as file:
    idx = 0
    contents = []

    for sample, predict in zip(samples, outputs):
        content = {
            'prefix': sample['prefix'],
            'target': sample['target'],
            'predict': predict
        }

        contents.append(content)

        idx += 1
    # end

    file.write(json.dumps(contents, indent=4))
# end

In [10]:
# '''get log likelihood parts'''


# class Tokenizer_test(Tokenizer_):
#     def _tokenize(self, e):
#         prefix, target = self._encode_pair(e['prefix'], e['target'])
#         return {
#             'prefix_text': e['prefix'],
#             'target_text': e['target'],
#             'prefix': prefix,
#             'target': target
#         }
#     # end
# # end


# ds =  Dataset.from_list(samples)
# ds = ds.map(Tokenizer_test(tokenizer))
# ds = ds.with_format("torch")
# ds = ds.filter(lambda x: len(x["prefix"]) + len(x['target']) <= 4096)

# out = []
# with torch.no_grad():
#     for elem in tqdm(ds, desc="Computing likelihood..."):
#         prefix = elem["prefix"]
#         target = elem["target"]

#         ll = get_log_likelihood(model, prefix, target)
#         out.append(ll)
#     # end
# # end
