In [7]:
import torch
import torch.nn.functional as F
from jinyu_utils import jinyu_dataset

from transformers import AutoTokenizer
from datasets import load_dataset

In [14]:
id_model = 'GSAI-ML/LLaDA-8B-Base'
tokenizer = AutoTokenizer.from_pretrained(
    id_model,
    trust_remote_code=True
)


In [None]:
name_dataset = jinyu_dataset.LIST_DATASET[1]
ds = load_dataset(*name_dataset, split='test')

In [13]:
ds = ds.filter(lambda x: x["text"] is not None and len(x["text"].strip()) > 0)

Filter: 100%|██████████| 4358/4358 [00:00<00:00, 254756.47 examples/s]


In [None]:
# remove empty lines
max_len = 4096

def tok_fn(ex):
    ids = tokenizer(
        ex["text"],
        add_special_tokens=False,   # avoids BOS/EOS being injected by tokenizer
        truncation=(max_len is not None),
        max_length=max_len,
    )["input_ids"]
    return {"input_ids": ids, "length": len(ids)}

ds = ds.map(tok_fn, remove_columns=ds.column_names)
ds = ds.filter(lambda x: x["length"] >= 2)  # need at least 2 tokens to score next-token
ds = ds.sort("length")                      # now sorted by length ascending

In [None]:
import torch

def collate_truncate_to_min(batch):
    # batch: list of dicts with "input_ids" as python lists
    B = len(batch)
    min_len = min(len(x["input_ids"]) for x in batch)

    input_ids = torch.stack([
        torch.tensor(x["input_ids"][:min_len], dtype=torch.long)
        for x in batch
    ], dim=0)  # [B, min_len]

    return {"input_ids": input_ids}

In [None]:


@torch.no_grad()
def batch_ppl_causal(model, input_ids):
    # input_ids: [B, T]
    out = model(input_ids=input_ids)
    logits = out.logits  # [B, T, V]

    # next-token prediction
    shift_logits = logits[:, :-1, :].contiguous()   # [B, T-1, V]
    shift_labels = input_ids[:, 1:].contiguous()    # [B, T-1]

    # token-level NLL, averaged over all tokens in batch
    loss = F.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
        reduction="mean",
    )
    ppl = torch.exp(loss)
    return ppl.item(), loss.item()

In [None]:
import torch
import torch.nn.functional as F

from transformers import AutoTokenizer
from datasets import load_dataset

def build_sorted_wikitext(tokenizer, name="wikitext-2-raw-v1", split="test", max_len=None):
    ds = load_dataset("wikitext", name, split=split)

    # remove empty lines
    ds = ds.filter(lambda x: x["text"] is not None and len(x["text"].strip()) > 0)

    def tok_fn(ex):
        ids = tokenizer(
            ex["text"],
            add_special_tokens=False,   # avoids BOS/EOS being injected by tokenizer
            truncation=(max_len is not None),
            max_length=max_len,
        )["input_ids"]
        return {"input_ids": ids, "length": len(ids)}

    ds = ds.map(tok_fn, remove_columns=ds.column_names)
    ds = ds.filter(lambda x: x["length"] >= 2)  # need at least 2 tokens to score next-token
    ds = ds.sort("length")                      # now sorted by length ascending
    return ds