In [1]:
import torch
import torch.nn.functional as F
from jinyu_utils import jinyu_dataset

from transformers import AutoTokenizer
from datasets import load_dataset
from abc import ABC, abstractmethod

from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
id_model = 'GSAI-ML/LLaDA-8B-Base'
tokenizer = AutoTokenizer.from_pretrained(
    id_model,
    trust_remote_code=True
)


In [14]:
name_dataset = jinyu_dataset.LIST_DATASET[1]
ds = load_dataset(*name_dataset, split='test')
ds = ds.filter(lambda x: x["text"] is not None and len(x["text"].strip()) > 0)


In [None]:
# remove empty lines
len_max = 4096
len_prompt = 128
len_target = len_max - len_prompt
size_block = 32
size_batch = 32
id_mask = 126336

class Tokenizer_(ABC):
    def __init__(self, tokenizer, len_max):
        self.tokenizer = tokenizer
        self.len_max = len_max
    # end

    @abstractmethod
    def _tokenize(self, ds_each):
        pass
    # end

    def __call__(self, ds_each):
        return self._tokenize(ds_each)
    # end
# end

class Tokenizer_wiki_simple(Tokenizer_):

    def _tokenize(self, ds_each):
        ids = tokenizer(
            ds_each['text'],
            add_special_tokens=False,               # avoids BOS/EOS being injected by tokenizer
            truncation=(self.len_max is not None),  # truncation and max_length is a pair
            max_length=self.len_max,
        )["input_ids"]

        return {
            'ids_input': ids,
            'length': len(ids)
        }
    # end tokenize
# end

ds = ds.map(Tokenizer_wiki_simple(tokenizer, len_max), remove_columns=ds.column_names)
ds = ds.filter(lambda x: x["length"] >= len_prompt + size_block * 1)  # need at least len_prompt and a first block generation
ds = ds.sort("length")  

In [25]:
class Collater_(ABC):
    def __init__(self, len_max, len_prompt, len_target, id_mask):
        self.len_max = len_max
        self.len_prompt = len_prompt
        self.len_target = len_target
        self.id_mask = id_mask
    # end

    @abstractmethod
    def _collate(self, ds_batch):
        pass
    # end

    def __call__(self, ds_batch):
        return self._collate(ds_batch)
    # end
# end


class Collater_wiki_simple(Collater_):

    def _collate(self, ds_batch):
        # batch: list of dicts with "input_ids" as python lists
        len_min = min(len(ds_each["ids_input"]) for ds_each in ds_batch)

        ids_input = torch.stack([torch.tensor(ds_each["ids_input"][:len_min], dtype=torch.long) for ds_each in ds_batch], dim=0) # [B, min_len]
        masks_input = torch.zeros_like(ids_input, dtype=bool)
        masks_input[:, len_prompt:] = True
        ids_target = torch.where(masks_input, ids_input, self.id_mask)
        ids_input[masks_input] = self.id_mask

        return {
            'ids_prompt_masked_full': ids_input,
            'ids_target_masked_full': ids_target,
            'masks_masked_full': masks_input
        }
    # end _collate
# end

collater = Collater_wiki_simple(len_max, len_prompt, len_target, id_mask)

loader = DataLoader(
        ds,
        batch_size=size_batch,
        shuffle=False,                 # keep sorted order
        collate_fn=collater,
        drop_last=False,
)

for batch in loader:
    ids_prompt_masked_full = batch['ids_prompt_masked_full']
    print(ids_prompt_masked_full.shape)
    ids_target_masked_full = batch['ids_target_masked_full']
    print(ids_prompt_masked_full.shape)
    masks_masked_full = batch['masks_masked_full']
    print(masks_masked_full.shape)

    print(masks_masked_full.sum() / masks_masked_full.shape[0])
    break
# end

torch.Size([32, 160])
torch.Size([32, 160])
torch.Size([32, 160])
tensor(32.)


In [None]:
# from torch.utils.data import DataLoader

# def eval_ppl(model, tokenizer, batch_size=8, device="cuda"):
#     ds = build_sorted_wikitext(tokenizer)

#     loader = DataLoader(
#         ds,
#         batch_size=batch_size,
#         shuffle=False,                 # keep sorted order
#         collate_fn=collate_truncate_to_min,
#         drop_last=False,
#     )

#     model.eval().to(device)

#     total_nll = 0.0
#     total_tokens = 0

#     for batch in loader:
#         input_ids = batch["input_ids"].to(device)   # [B, T]
#         # compute token-sum NLL for correct global aggregation:
#         out = model(input_ids=input_ids)
#         logits = out.logits

#         shift_logits = logits[:, :-1, :].contiguous()
#         shift_labels = input_ids[:, 1:].contiguous()

#         # sum loss over tokens (token-weighted), then normalize globally
#         loss_sum = F.cross_entropy(
#             shift_logits.view(-1, shift_logits.size(-1)),
#             shift_labels.view(-1),
#             reduction="sum",
#         )

#         total_nll += loss_sum.item()
#         total_tokens += shift_labels.numel()

#     ppl = float(torch.exp(torch.tensor(total_nll / max(total_tokens, 1))))
#     return ppl

In [None]:
# import torch
# import torch.nn.functional as F

# from transformers import AutoTokenizer
# from datasets import load_dataset

# def build_sorted_wikitext(tokenizer, name="wikitext-2-raw-v1", split="test", max_len=None):
#     ds = load_dataset("wikitext", name, split=split)

#     # remove empty lines
#     ds = ds.filter(lambda x: x["text"] is not None and len(x["text"].strip()) > 0)

#     def tok_fn(ex):
#         ids = tokenizer(
#             ex["text"],
#             add_special_tokens=False,   # avoids BOS/EOS being injected by tokenizer
#             truncation=(max_len is not None),
#             max_length=max_len,
#         )["input_ids"]
#         return {"input_ids": ids, "length": len(ids)}

#     ds = ds.map(tok_fn, remove_columns=ds.column_names)
#     ds = ds.filter(lambda x: x["length"] >= 2)  # need at least 2 tokens to score next-token
#     ds = ds.sort("length")                      # now sorted by length ascending
#     return ds

In [None]:
# @torch.no_grad()
# def batch_ppl_causal(model, input_ids):
#     # input_ids: [B, T]
#     out = model(input_ids=input_ids)
#     logits = out.logits  # [B, T, V]

#     # next-token prediction
#     shift_logits = logits[:, :-1, :].contiguous()   # [B, T-1, V]
#     shift_labels = input_ids[:, 1:].contiguous()    # [B, T-1]

#     # token-level NLL, averaged over all tokens in batch
#     loss = F.cross_entropy(
#         shift_logits.view(-1, shift_logits.size(-1)),
#         shift_labels.view(-1),
#         reduction="mean",
#     )
#     ppl = torch.exp(loss)
#     return ppl.item(), loss.item()