# Imports

In [1]:
import json, numpy as np, torch, re, io, contextlib, sys, warnings
from pathlib import Path
from tokenizers import ByteLevelBPETokenizer, AddedToken
from peft import PeftModel

# Paths

In [2]:
#RUN = Path("../fine_tuning_small/checkpoints") # if you're testing small one, comment out this one and comment the below one
RUN = Path("../fine_tuning/checkpoints")

BASE = Path("../pretrain_checkpoint/best_base.pt") # --base-ckpt used for SFT
#BEST = RUN / "lora_best" # (or lora_best_ema if you prefer)
BEST = RUN / "lora_best_ema"


CONF = Path("../configs/model_config_124M_sft.json")
VOCAB = Path("../bpe/bpe_model-vocab.json")
MERGES = Path("../bpe/bpe_model-merges.txt")


# for sanity check (we did this but i had a problem with generated output)
SFT_IDS = Path("../final_npy/train_input_ids.npy")
SFT_MSK = Path("../final_npy/train_loss_mask.npy")

# ENV configuration stuffs

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    AMP_DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
    AMP_DTYPE = torch.float32


torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x1e6567597f0>

# Tokenizer

In [4]:
bpe_decode = ByteLevelBPETokenizer(str(VOCAB), str(MERGES), lowercase=False, add_prefix_space=True)
bpe_prompt = ByteLevelBPETokenizer(str(VOCAB), str(MERGES), lowercase=False, add_prefix_space=True)

# read IDs from the tokenizer; do not add/append anything
SPECIALS = ["<|PAD|>", "<|UNKNOWN|>", "<|START|>", "<|END|>",
            "<|SYSTEM|>", "<|INFOSTART|>", "<|INFOEND|>", "<|USER|>", "<|ASSISTANT|>"]
SID = {t: bpe_decode.token_to_id(t) for t in SPECIALS}

# hard assertions to catch if there's anything wrong
assert SID["<|PAD|>"] == 0, f"PAD_ID mismatch: {SID['<|PAD|>']}"
assert SID["<|END|>"] == 3, f"END_ID mismatch: {SID['<|END|>']}"
assert SID["<|ASSISTANT|>"] == 6, "ASSISTANT_ID mismatch"

PAD_ID = SID["<|PAD|>"]
END_ID = SID["<|END|>"]
ASSISTANT_ID = SID["<|ASSISTANT|>"]
for _tag in ("<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>"):
    assert SID[_tag] is not None, f"Missing token in vocab: {_tag}"

# Model config & load

In [5]:
sys.path.append("../python_files")
from Dummy_Model_sft import DummyModel

# Accepts raw dict or {'model':...}/{ 'state_dict':... } and strips common prefixes
def load_state_dict_safely(model: torch.nn.Module, raw_state, strict: bool=False, verbose: bool=True):
    state = raw_state.get("model", raw_state.get("state_dict", raw_state))
    clean = {}
    for k, v in state.items():
        if k.startswith("module."): k = k[7:]
        if k.startswith("model."): k = k[6:]
        clean[k] = v
    missing, unexpected = model.load_state_dict(clean, strict=strict)
    if verbose:
        if missing: print(f"[load] missing ({len(missing)}): {missing[:5]} ...")
        if unexpected: print(f"[load] unexpected ({len(unexpected)}): {unexpected[:5]} ...")
        if not missing and not unexpected:
            print("[load] state dict loaded cleanly.")
    return model

cfg = json.loads(Path(CONF).read_text())


def _safe_torch_load(path: Path):
    try:
        return torch.load(path, map_location=DEVICE, weights_only=True)
    except TypeError:
        warnings.filterwarnings(
            "ignore",
            message="You are using `torch.load` with `weights_only=False`",
            category=FutureWarning,
        )
        return torch.load(path, map_location=DEVICE)


def _load_base_only() -> torch.nn.Module:
    base = DummyModel(cfg).to(DEVICE).eval()
    raw  = _safe_torch_load(BASE)
    load_state_dict_safely(base, raw, strict=False, verbose=True)
    return base


def load_with_adapter(adapter_dir: Path, merge: bool=False) -> torch.nn.Module:
    base = _load_base_only()
    model = PeftModel.from_pretrained(base, str(adapter_dir), is_trainable=False).to(DEVICE).eval()
    lp = [n for n,_ in model.named_parameters() if "lora_" in n]
    print(f"[PEFT] attached → LoRA tensors: {len(lp)}")
    if merge:
        model.merge_and_unload()
    return model


# Template

In [6]:
# should be exactly the same we used when training for SFT
SYSTEM_PROMPT_DEFAULT = (
    "Be a helpful, concise assistant with a light, friendly tone. "
    "Answer directly in 1–3 sentences. Don’t use steps or bullet lists unless the user asks. "
    "Use the content between INFOSTART/INFOEND only as context. Do not mention it, ‘memory,’ or any internal tags. "
    "Avoid speculation and say when unsure. Keep replies safe, accurate, and on-topic."
)

# autodetect whether training had leading space/newline after <|ASSISTANT|>
def _detect_assistant_prefix(ids_path: Path, sample_k: int = 256) -> str:
    if not ids_path.exists():
        return ""
    arr = np.load(ids_path, mmap_mode="r")
    spaces = newlines = letters = 0
    for i in range(min(sample_k, arr.shape[0])):
        row = arr[i].tolist()
        txt = bpe_decode.decode([t for t in row if t != PAD_ID])
        j = txt.find("<|ASSISTANT|>")
        if j == -1:
            continue
        j += len("<|ASSISTANT|>")
        ch = txt[j:j+1]
        if ch == " ": spaces += 1
        elif ch == "\n": newlines += 1
        elif ch: letters += 1
    if spaces >= max(newlines, letters): return " "
    if newlines > 0: return "\n"
    return ""

ASSIST_PREFIX = _detect_assistant_prefix(SFT_IDS)


def build_prompt_ids(user_text: str, info_text: str, system_text: str = SYSTEM_PROMPT_DEFAULT):
    """
    Builds:
      <|START|><|SYSTEM|>{SYSTEM}\n
      <|INFOSTART|>{INFO}<|INFOEND|>\n
      <|USER|>{USER}\n
      <|ASSISTANT|>{ASSIST_PREFIX}
    """
    txt = (
        "<|START|><|SYSTEM|>" + system_text + "\n"
        "<|INFOSTART|>" + (info_text if info_text else "") + "<|INFOEND|>\n"
        "<|USER|>" + user_text.strip() + "\n"
        "<|ASSISTANT|>" + ASSIST_PREFIX
    )
    return bpe_prompt.encode(txt, add_special_tokens=False).ids


# Prepare Sampling #1

In [7]:
# EOS stop + small sentence cap + strong tag-leak/style guards
SENTENCE_LIMIT = 3
MIN_BEFORE_EOS = 1
EOS_BOOST = 2.0

_STEP_PAT = re.compile(r"(?i)\bstep\s*\d+")
_SENT_PAT = re.compile(r"(?i)(identify the (given )?sentence|the sentence is)")

_SENT_END_RE = re.compile(r'[.!?](?=\s|$)')

def _count_sents(s: str) -> int:
    return len(_SENT_END_RE.findall(s))


def _trim_sents(s: str, limit: int) -> str:
    if limit <= 0: return ""
    if _count_sents(s) <= limit: return s.strip()
    idx = 0; kept = 0
    for m in _SENT_END_RE.finditer(s):
        kept += 1; idx = m.end()
        if kept >= limit: break
    return s[:idx].strip()

VOCAB_SIZE = int(cfg.get("vocab_size", bpe_decode.get_vocab_size()))
TOKEN_STR = {i: bpe_decode.decode([i]) for i in range(VOCAB_SIZE)}

# guard tag words tolerant to separators/mixed case
_TAG_PAT = re.compile(
    r"(?i)(i[^a-z0-9]?n[^a-z0-9]?f[^a-z0-9]?o[^a-z0-9]?s[^a-z0-9]?t[^a-z0-9]?a[^a-z0-9]?r[^a-z0-9]?t"
    r"|i[^a-z0-9]?n[^a-z0-9]?f[^a-z0-9]?o[^a-z0-9]?e[^a-z0-9]?n[^a-z0-9]?d"
    r"|s[^a-z0-9]?y[^a-z0-9]?s[^a-z0-9]?t[^a-z0-9]?e[^a-z0-9]?m"
    r"|u[^a-z0-9]?s[^a-z0-9]?e[^a-z0-9]?r"
    r"|a[^a-z0-9]?s[^a-z0-9]?s[^a-z0-9]?i[^a-z0-9]?s[^a-z0-9]?t[^a-z0-9]?a[^a-z0-9]?n[^a-z0-9]?t"
    r"|s[^a-z0-9]?t[^a-z0-9]?a[^a-z0-9]?r[^a-z0-9]?t"
    r"|e[^a-z0-9]?n[^a-z0-9]?d)"
)

_TAG_WORDS = {"info", "start", "end", "system", "user", "assistant", "infost", "fost"}


def _would_form_taglike(prev_tail: str, cand_str) -> bool:
    cand = cand_str if isinstance(cand_str, str) else str(cand_str)
    s = (prev_tail + cand)[-96:]
    if "<" in s or ">" in s or "|" in s:
        return True
    return bool(_TAG_PAT.search(s))


def _has_emoji_or_symbol(s: str) -> bool:
    return any(
        (0x1F300 <= ord(ch) <= 0x1FAFF) or
        (0x2600  <= ord(ch) <= 0x26FF) or
        (0x2700  <= ord(ch) <= 0x27BF)
        for ch in s
    )


# ban tokens likely to spark tag leakage or junk style
def _build_ban_ids() -> set[int]:
    bad = set()
    for i in range(VOCAB_SIZE):
        s  = TOKEN_STR[i]
        sl = s.lower()
        if ("<" in s) or (">" in s) or ("|" in s):
            bad.add(i); continue
        if _has_emoji_or_symbol(s):
            bad.add(i); continue
        ns = s.replace(" ", "")
        if ("#" in ns) or ("http" in sl) or ("www" in sl) or (".com" in sl) or (".net" in sl) or (".org" in sl):
            bad.add(i); continue
        
        # also ban tokens that *by themselves* look like tag words or their stubs
        if _TAG_PAT.fullmatch(sl):
            bad.add(i); continue
    bad.discard(END_ID) # never ban EOS
    bad.discard(PAD_ID) # never ban PAD
    return bad


def _violates_ngram(seq, cand, n):
    if n <= 1 or len(seq) < n-1: return False
    s = seq + [cand]; last = tuple(s[-n:])
    for i in range(len(s)-n):
        if tuple(s[i:i+n]) == last:
            return True
    return False

# negation-aware steering
NEG_WINDOW_TOKENS = 6
NEG_CONT_PENALTY  = 0.8
POS_PIVOT_BONUS   = 0.6

_NEG_HIT_RE = re.compile(r"\b(?:not|no|never|cannot|can['’]t|won['’]t|don['’]t|doesn['’]t|didn['’]t|isn['’]t|aren['’]t|wasn['’]t|weren['’]t)\b", re.I)

_NEG_CONT_PHRASES = [
    " I can't", " I cannot", " can't", " cannot", " unable to", " I won't", " I do not",
    " I don't", " I am not able", " unfortunately", " sorry", " I cannot help",
]
_POS_PIVOTS = [
    " but", " however", " still", " instead", " here's", " here is",
    " I can", " we can", " let's", " try", " you can",
]

def _ids_for_phrases(phrases):
    out = []
    for p in phrases:
        ids = bpe_decode.encode(p, add_special_tokens=False).ids
        if ids:
            out.append(ids)
    return out

NEG_CONT_IDS = _ids_for_phrases(_NEG_CONT_PHRASES)
POS_PIVOT_IDS = _ids_for_phrases(_POS_PIVOTS)
MAX_NEG_CONT = max((len(s) for s in NEG_CONT_IDS), default=0)
MAX_POS_PVT  = max((len(s) for s in POS_PIVOT_IDS), default=0)

def _would_complete_any(prev_ids: list[int], cand_id: int, seqs: list[list[int]], max_len: int) -> bool:
    if not seqs:
        return False
    ctx = prev_ids[-(max_len-1):] if max_len > 1 else []
    test = ctx + [cand_id]
    for seq in seqs:
        L = len(seq)
        if len(test) >= L and test[-L:] == seq:
            return True
    return False

# Prepare Sampling #2

In [8]:
# helper used mid-loop to block unsafe tails before they fully appear
def _unsafe_tail_would_appear(seq, candidate_id):
    test = bpe_decode.decode(seq + [candidate_id])
    tail = test[-128:].lower()

    if ("<" in tail) or (">" in tail) or ("|" in tail):
        return True

    # also block tag-like substrings without brackets (e.g., "fostart", "infostart")
    if _TAG_PAT.search(tail):
        return True
    if any(w in tail for w in _TAG_WORDS | {"fostart", "infostart", "infoend"}):
        return True

    return False

# generating token ids
@torch.no_grad()
def generate_ids(model, prompt_ids, user_words: set[str] | None = None,
                 max_new_tokens=120, temperature=0.0, top_p=1.0, top_k=0,
                 repetition_penalty=1.07, no_repeat_ngram_size=3, seed=None,
                 ban_token_ids: set[int] | None = None,
                 eos_token_id=None, pad_token_id=None, generator=None,
                 first_token_boost: dict[int, float] | None = None):

    dev = next(model.parameters()).device
    STOP = int(eos_token_id if eos_token_id is not None else END_ID)
    PAD  = int(pad_token_id if pad_token_id is not None else PAD_ID)
    g = generator if generator is not None else (torch.Generator(device=dev).manual_seed(int(seed)) if seed is not None else None)

    x = torch.tensor([prompt_ids], dtype=torch.long, device=dev)
    gen = []
    neg_window = 0
    min_before_eos = MIN_BEFORE_EOS
    eos_boost      = EOS_BOOST
    user_word_tids = None
    if user_words:
        user_word_tids = {
            tid for tid, s in TOKEN_STR.items()
            if s and (ss := s.strip().lower()) in user_words and 3 <= len(ss) <= 12
        }

    # discourage scaffolds unless the prompt explicitly asks
    want_steps = False
    if user_words:
        hints = {"steps", "step", "bullet", "list", "instructions", "how-to", "procedure", "outline"}
        want_steps = any(h in user_words for h in hints)

    step_tids = None
    if not want_steps:
        # catches "Step 1:" and simple bullets
        candidates = {"step", "Step", ":", "1", "2", "3", "-", "•", "Identify", "sentence"}
        step_tids = {tid for tid, s in TOKEN_STR.items() if s and s.strip() in candidates}

    # nudge away from "the sentence is" rut
    tsi_tids = {tid for tid, s in TOKEN_STR.items() if s and s.strip().lower() in {"the", "sentence", "is"}}

    for _ in range(max_new_tokens):
        with torch.autocast(device_type="cuda", dtype=AMP_DTYPE, enabled=(dev.type=="cuda")):
            logits, _ = model(input_ids=x)
        next_logits = logits[:, -1, :].float()
        if first_token_boost and len(gen) == 0:
            for tid, bonus in first_token_boost.items():
                next_logits[0, tid] += float(bonus)

        # light scaffold nudges
        if step_tids:
            next_logits[0, list(step_tids)] -= 0.40 # stronger push from "Step 1:" unless asked
        if tsi_tids:
            next_logits[0, list(tsi_tids)] -= 0.10 # nudge away from "the sentence is"

        # hard bans (keep EOS allowed)
        next_logits[0, PAD] = float("-inf")
        if ban_token_ids:
            idx = [t for t in ban_token_ids if t != STOP]
            if idx:
                next_logits[0, idx] = float("-inf")

        # EOS shaping
        if len(gen) < min_before_eos:
            next_logits[0, STOP] = float("-inf")
        else:
            next_logits[0, STOP] += eos_boost

        if len(gen) < 3:
            scaffold_words = {"Q", "Q:", "Question", "Question:", "A", "A:", "Answer", "Answer:"}
            scaffold_tids = {tid for tid,s in TOKEN_STR.items() if s and s.strip() in scaffold_words}
            if scaffold_tids:
                next_logits[0, list(scaffold_tids)] -= 1.5 # stronger push away at start

        # repetition penalty
        if gen and repetition_penalty and repetition_penalty != 1.0:
            recent = gen[-64:]
            for tid in set(recent):
                val = next_logits[0, tid]
                next_logits[0, tid] = torch.where(val > 0, val / repetition_penalty, val * repetition_penalty)

        # temperature / top-k / top-p
        do_sample = (temperature and temperature > 0) or (top_p and 0 < top_p < 1.0) or (top_k and top_k > 0)

        if temperature and temperature > 0:
            next_logits /= float(temperature)

        # top-k
        if top_k is not None and top_k > 0:
            v, _ = torch.topk(next_logits, k=min(top_k, next_logits.size(-1)))
            next_logits[next_logits < v[:, [-1]]] = float("-inf")

        # top-p
        if top_p is not None and 0 < top_p < 1.0:
            probs = torch.softmax(next_logits, dim=-1)
            sp, si = torch.sort(probs, descending=True)
            csum = torch.cumsum(sp, dim=-1)
            keep = csum <= top_p
            keep[..., 0] = True
            filt = torch.full_like(sp, float("-inf"))

            # store filtered logits (log-probs) back into next_logits in original index order
            filt[keep] = torch.log(sp[keep] + 1e-12)
            next_logits = torch.full_like(next_logits, float("-inf"))
            next_logits[0, si[0]] = filt[0]

        # entropy-based nudge toward EOS when degenerate
        with torch.no_grad():
            p = torch.softmax(next_logits, dim=-1)
            ent = float(-(p * torch.log(p + 1e-12)).sum())
        if len(gen) >= 6 and ent < 3.4:
            next_logits[0, STOP] += 1.5

        # guard n-gram pass over top candidates
        running_tail = bpe_decode.decode(gen[-64:]) if gen else ""
        if _NEG_HIT_RE.search((running_tail or "").lower()):
            neg_window = max(neg_window, NEG_WINDOW_TOKENS)
        probs = torch.softmax(next_logits, dim=-1)
        _, cand_idx = torch.sort(probs, descending=True)
        tmp_logits = next_logits.clone()
        max_scan = min(256, cand_idx.size(-1))
        for j in range(max_scan):
            cid = int(cand_idx[0, j])
            if cid == STOP:
                continue # let EOS through
            cand_str = TOKEN_STR.get(cid) or bpe_decode.decode([cid])
            if _would_form_taglike(running_tail, cand_str) or (
                no_repeat_ngram_size and _violates_ngram(gen, cid, no_repeat_ngram_size)
            ):
                tmp_logits[0, cid] = float("-inf")

            # steer away from continuing refusals; nudge toward constructive pivots
            if neg_window > 0:
                if _would_complete_any(gen, cid, NEG_CONT_IDS, MAX_NEG_CONT):
                    tmp_logits[0, cid] -= NEG_CONT_PENALTY
                if _would_complete_any(gen, cid, POS_PIVOT_IDS, MAX_POS_PVT):
                    tmp_logits[0, cid] += POS_PIVOT_BONUS


        # choose next token
        if torch.isneginf(tmp_logits).all():
            # soft fallback if we banned everything
            fallback = next_logits.clone()
            fallback[0, PAD] = float("-inf")
            if len(gen) < min_before_eos:
                fallback[0, STOP] = float("-inf")
            next_id = int(torch.argmax(fallback, dim=-1))
        else:
            if do_sample:
                # proper sampling from filtered distribution
                samp_p = torch.softmax(tmp_logits, dim=-1)
                next_id = int(torch.multinomial(samp_p[0], num_samples=1, generator=g))
            else:
                # fall back to greedy
                next_id = int(torch.argmax(tmp_logits, dim=-1))

        # if choosing the argmax still forms unsafe tail, pick the next safe
        if _unsafe_tail_would_appear(gen, next_id):
            ok = False
            for j2 in range(min(32, cand_idx.size(-1))):
                alt = int(cand_idx[0, j2])
                if alt == next_id or alt == PAD:
                    continue
                if len(gen) < MIN_BEFORE_EOS and alt == STOP:
                    continue
                if not _unsafe_tail_would_appear(gen, alt):
                    next_id = alt
                    ok = True
                    break
            if not ok:
                fb = next_logits.clone()
                fb[0, PAD] = float("-inf")
                if len(gen) < min_before_eos:
                    fb[0, STOP] = float("-inf")
                next_id = int(torch.argmax(fb, dim=-1))

        gen.append(next_id)
        x = torch.cat([x, torch.tensor([[next_id]], device=dev)], dim=1)

        # secondary stops
        text_so_far = bpe_decode.decode(gen)
        if next_id == STOP: break
        if "<|END|>" in text_so_far: break
        if _count_sents(text_so_far) >= SENTENCE_LIMIT: break
        if _STEP_PAT.search(text_so_far) or _SENT_PAT.search(text_so_far):
            min_before_eos = 0
            eos_boost = 3.0

        if len(gen) >= 10 and len(set(gen[-10:])) <= 3: break
        if neg_window > 0:
            neg_window -= 1


    return gen

# Strip for clean text

In [11]:
def trim_to_last_period(s: str) -> str:
    s = s.rstrip()
    if s.endswith('.'):
        return s
    i = s.rfind('.')
    return "" if i == -1 else s[:i+1]

# API

In [12]:
import time

def _fallback_answer(info_text: str, user_text: str) -> str:
    # try to honor name if provided in INFO block
    m = re.search(r"(?i)\bmy name is\s+([\w .'-]{2,64})", info_text or "")
    name = m.group(1).strip() if m else None
    if re.search(r"(?i)\bwho am i\b|\bwhat'?s my name\b|\bwho am i to you\b", user_text or ""):
        if name:
            return f"You're {name}."
        return "You're the user chatting with me."
    # generic safe fallback
    return "I'm a small chat model. How can I help in a sentence or two?"

def _first_person_bias() -> dict[int, float]:
    # include leading-space variants (because we're using ByteLevel BPE)
    seeds = [" I", " I'm", " I’m", " I can", "I"]
    bias = {}
    for s in seeds:
        tid = bpe_decode.token_to_id(s)
        if tid is not None and tid >= 0:
            bias[tid] = 0.8 # small, not domineering
    return bias

def _score_answer(user_text: str, out: str) -> float:
    score = 0.0
    sents = _count_sents(out)
    if 1 <= sents <= 3: score += 2.0
    elif sents == 0:    score -= 2.0
    else:               score -= 1.0

    ut = user_text.lower(); ol = out.lower()
    # encourage first-person & capability verbs for openers like "Who are you / what can you do?"
    if re.search(r"(?i)\bwho\s+are\s+you\b|\bwhat\s+can\s+you\s+do\b|\bhow\s+can\s+you\s+help\b", user_text):
        if " i " in f" {ol} " or "i'm" in ol or "i’m" in ol: score += 1.2
        if " can " in f" {ol} ": score += 0.8

    # discourage common boilerplate
    for bp in ("i don't have access to the latest news", "news or news", "business business"):
        if bp in ol: score -= 1.5

    # numbered list unless explicitly requested
    if not re.search(r"(?i)\b(list|steps|bullets|outline|numbered)\b", user_text):
        if re.search(r"(?m)^\s*\d+\.\s+", out): score -= 1.0

    # very long tail
    if len(out) > 400: score -= 0.5
    return score

# generate(info_text, user_text) -> str
# args:
#     info_text: anything you'd like to pass inside <|INFOSTART|>...<|INFOEND|>
#     user_text: the user's prompt
#     greedy: kept for API compatibility
#     clean_style: ban emoji/links/hashtags tokens during decode
# returns:
#     string, trimmed to a few sentences; never includes tags

def generate(info_text: str, user_text: str, greedy: bool=False, clean_style: bool=True) -> str:
    model = load_with_adapter(BEST, merge=False)
    prompt_ids = build_prompt_ids(user_text=user_text, info_text=info_text, system_text=SYSTEM_PROMPT_DEFAULT)

    user_words = {w for w in re.findall(r"\w+", user_text.lower()) if len(w) >= 3}

    # decode settings
    decode = dict(
        max_new_tokens=40, # maximum tokens
        temperature=0.60, # crisp wording level (0-1)
        top_p=1.0, # disabled nucleus (we're completely using top_k since it's a small model)
        top_k=2, # tiny randomness (so, like choose 1 between 2 top probabilities)
        repetition_penalty=1.16, # penalty for repetition
        no_repeat_ngram_size=4, # reduce short phrase echoes
        seed=None, # randomness (but g = generator if generator is not None else (torch.Generator(device=dev).manual_seed(int(seed)) if seed is not None else None) handles that)
        eos_token_id=END_ID, # END token
        pad_token_id=PAD_ID, # PAD token
    )



    # tiny budget for very short prompts
    if len(user_text.strip()) <= 20:
        decode["max_new_tokens"] = min(decode["max_new_tokens"], 80)

    ban_ids = _build_ban_ids() if clean_style else set()
    first_boost = _first_person_bias()

    cands = []
    for _ in range(3):
        g = torch.Generator(device=DEVICE).manual_seed(int(time.time_ns()) & 0xFFFFFFFF)
        with contextlib.redirect_stdout(io.StringIO()):
            gen_ids = generate_ids(
                model, prompt_ids, user_words,
                ban_token_ids=ban_ids, generator=g, first_token_boost=first_boost, **decode
            )
        cands.append(bpe_decode.decode(gen_ids).split("<|END|>")[0])

    best_raw = max(cands, key=lambda t: _score_answer(user_text, t))
    out = best_raw  # already a string; do not re-decode last gen_ids

    # # detokenize + cleanup (clip at END, strip leaked tags)
    out = re.sub(r"\s+", " ", out)
    # strip any tag tokens that leaked as text
    out = re.sub(r"<\|[^>]+?\|>", "", out).strip().strip('"')
    out = re.sub(r"\b(?:START|END|SYSTEM|USER|ASSISTANT|INFOSTART|INFOEND)\b[:\"]?", "", out, flags=re.I)
    # remove obvious scaffold headers if any slipped through
    out = re.sub(r"(?i)\bstep\s*\d+\s*:\s*", "", out)
    out = re.sub(r"(?i)(identify the (given )?sentence|the sentence is)\.?", "", out)
    # dedupe repeated words
    out = re.sub(r"\b(\w+)(?:\s+\1){2,}\b", r"\1", out)
    out = re.sub(r"\s{2,}", " ", out).strip()
    # strip numbered lists unless explicitly requested
    if not re.search(r"(?i)\b(list|steps|bullets|outline|numbered)\b", user_text):
        out = re.sub(r"(?m)^\s*\d+\.\s+", "", out)


    # final trim & fallback if too weird/empty
    out = _trim_sents(out, SENTENCE_LIMIT)
    if (not out) or len(out) < 3 or _STEP_PAT.search(out) or _SENT_PAT.search(out) or any(t in out for t in ("FOST", "Pplanator")):
        out = _fallback_answer(info_text, user_text)

    out = trim_to_last_period(out)
    return out

# Sanity check & generation

In [14]:
if __name__ == "__main__":
    if SFT_IDS.exists() and SFT_MSK.exists():
        ids = np.load(SFT_IDS, mmap_mode="r"); msk = np.load(SFT_MSK, mmap_mode="r")
        ones = float(msk.sum())/msk.size if msk.size else 0.0
        print(f"[DATA] ids: {ids.shape} | mask ones%: {ones:.3%}")
        # Peek what actually follows <|ASSISTANT|> in row 0:
        txt0 = bpe_decode.decode([t for t in ids[0].tolist() if t != PAD_ID])
        j0 = txt0.find("<|ASSISTANT|>"); post = txt0[j0+len('<|ASSISTANT|>'): j0+len('<|ASSISTANT|>')+12] if j0!=-1 else ""
        print(f"[ASSIST_PREFIX detected: {repr(ASSIST_PREFIX)}] peek row0 after tag: {repr(post)}")
        if "<|END|>" in post:
            print("[warn] Row0 assistant begins near EOS; many rows might be ultra-short. Decoding will lean on fallbacks.")

    for i in range(2):
        print(f"\n\nGenerated #{i+1}:")
        print(generate(info_text="You are helpful artificial intelligence", user_text="Who are you and what can you do?"))


[DATA] ids: (1161960, 512) | mask ones%: 29.003%
[ASSIST_PREFIX detected: ' '] peek row0 after tag: '  Sure, here'


Generated #1:
[load] state dict loaded cleanly.
[PEFT] attached → LoRA tensors: 96
I am a professional player who has been a prominent player in the field of tennis. He is known for his skills and skills, and he has also played an important role in winning tennis.


Generated #2:
[load] state dict loaded cleanly.
[PEFT] attached → LoRA tensors: 96
I am a person who is very interested in learning about the world. You can learn from a book by the author "The World of Life" by George G. D.
