In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import json, numpy as np, torch, re, io, contextlib, sys, warnings
from pathlib import Path
from tokenizers import ByteLevelBPETokenizer
from peft import PeftModel
from numbers import Real
import tempfile, shutil
import time

if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

BASE = Path("./for_llm_app/best_base.pt")
ADAPTER_CONFIG = Path("./for_llm_app/adapter_config.json")
ADAPTER_WEIGHTS = Path("./for_llm_app/adapter_model.safetensors")

CONF = Path("./for_llm_app/model_config_124M_sft.json")
VOCAB = Path("./for_llm_app/bpe_model-vocab.json")
MERGES = Path("./for_llm_app/bpe_model-merges.txt")

TEMPERATURE = 0.25
TOP_P = 0.95
TOP_K = 0
NUM_CANDIDATES = 3
TYPICAL_TAU = 0.9

REPETITION_PENALTY = 1.1
NO_REPEAT_NGRAM_SIZE = 3

NEG_CONT_PENALTY = 1.6
NEG_WINDOW_TOKENS = 12
POS_PIVOT_BONUS = 0.9

SENTENCE_LIMIT = 3
MIN_BEFORE_EOS = 14
EOS_BOOST = 0.2

ASSIST_PREFIX = " "
DISABLE_SHAPING = False

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    AMP_DTYPE = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
else:
    AMP_DTYPE = torch.float32

torch.set_grad_enabled(False)

bpe_tokenizer = ByteLevelBPETokenizer(str(VOCAB), str(MERGES), lowercase=False, add_prefix_space=True)

SPECIALS = [
    "<|PAD|>", "<|UNKNOWN|>", "<|START|>", "<|END|>",
    "<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>", "<|EOT|>",
    "<|INFOSTART|>", "<|INFOEND|>",
]
SID = {tok: bpe_tokenizer.token_to_id(tok) for tok in SPECIALS}

_missing = [t for t, i in SID.items() if i is None]
assert not _missing, f"Missing special tokens in vocab: {_missing}"

_EXPECTED = {
    "<|PAD|>": 0,
    "<|UNKNOWN|>": 1,
    "<|START|>": 2,
    "<|END|>": 3,
    "<|SYSTEM|>": 4,
    "<|USER|>": 5,
    "<|ASSISTANT|>": 6,
    "<|EOT|>": 7,
    "<|INFOSTART|>": 8,
    "<|INFOEND|>": 9
}

_MODEL_CACHE = {"model": None}
def get_model(merge: bool=False):
    m = _MODEL_CACHE["model"]
    if m is None:
        m = load_with_adapter_from_files(ADAPTER_CONFIG, ADAPTER_WEIGHTS, merge=merge)
        _MODEL_CACHE["model"] = m
    return m

_mismatch = {t: (SID[t], want) for t, want in _EXPECTED.items() if SID[t] != want}
assert not _mismatch, f"Special token ID mismatch: {_mismatch}"

(
    PAD_ID, UNKNOWN_ID, START_ID, END_ID,
    SYSTEM_ID, USER_ID, ASSISTANT_ID, EOT_ID,
    INFOSTART_ID, INFOEND_ID
) = (SID[t] for t in SPECIALS)

_CAP = re.compile(r"\b[A-Z][a-z]+(?:[-'][A-Z][a-z]+)*\b")

def _cap_words(s: str) -> set[str]:
    common = {"I","You","We","It","The","A","An"}
    months = {"January","February","March","April","May","June","July","August","September","October","November","December"}
    words = {w for w in _CAP.findall(s)}
    return {w for w in words if w not in common | months}

assert len({SID[t] for t in SPECIALS}) == len(SPECIALS), "Special token ID collision detected"
import sys
sys.path.append("./for_llm_app")
from Dummy_Model_sft import DummyModel

def load_state_dict_safely(model: torch.nn.Module, raw_state, strict: bool=False, verbose: bool=True):
    state = raw_state.get("model", raw_state.get("state_dict", raw_state))
    clean = {}
    for k, v in state.items():
        if k.startswith("module."): k = k[7:]
        if k.startswith("model."): k = k[6:]
        clean[k] = v
    missing, unexpected = model.load_state_dict(clean, strict=strict)
    if verbose:
        if missing: print(f"[load] missing ({len(missing)}): {missing[:5]} ...")
        if unexpected: print(f"[load] unexpected ({len(unexpected)}): {unexpected[:5]} ...")
    return model

cfg = json.loads(Path(CONF).read_text())


def _safe_torch_load(path: Path):
    try:
        return torch.load(path, map_location=DEVICE, weights_only=True)
    except TypeError:
        warnings.filterwarnings(
            "ignore",
            message="You are using `torch.load` with `weights_only=False`",
            category=FutureWarning,
        )
        return torch.load(path, map_location=DEVICE)


def _load_base_only() -> torch.nn.Module:
    base = DummyModel(cfg).to(DEVICE).eval()
    raw = _safe_torch_load(BASE)
    load_state_dict_safely(base, raw, strict=False, verbose=True)
    return base


def load_with_adapter_from_files(config_json: Path, weights_file: Path, merge: bool=False) -> torch.nn.Module:
    base = _load_base_only()

    with tempfile.TemporaryDirectory() as tmpdir:
        tdir = Path(tmpdir)
        cfg_dst = tdir / "adapter_config.json"
        wts_dst = tdir / "adapter_model.safetensors"

        shutil.copyfile(config_json, cfg_dst)
        shutil.copyfile(weights_file, wts_dst)

        model = PeftModel.from_pretrained(base, str(tdir), is_trainable=False).to(DEVICE).eval()
        if merge:
            model.merge_and_unload()
        return model

SYSTEM_PROMPT_DEFAULT = (
    "Be a helpful, concise assistant with a light, friendly tone. "
    "Answer directly in 1–3 sentences. Don’t use steps or bullet lists unless the user asks. "
    "Use the content between INFOSTART/INFOEND only as context. Do not mention it, 'memory', or any internal tags. "
    "Avoid speculation and say when unsure. Keep replies safe, accurate, and on-topic. "
    "Write in first person. Don’t mention being a language model, training data, or lack of browsing unless asked."
)

def build_prompt_ids(user_text: str, info_text: str, system_text: str = SYSTEM_PROMPT_DEFAULT, assist_prefix: str = " "):
    txt = (
        "<|START|><|SYSTEM|>" + system_text + "\n" +
        "<|INFOSTART|>" + (info_text if info_text else "") + "<|INFOEND|>\n" +
        "<|USER|>" + user_text.strip() + "\n" +
        "<|ASSISTANT|>" + assist_prefix
    )
    return bpe_tokenizer.encode(txt, add_special_tokens=False).ids

_STEP_PAT = re.compile(r"(?i)\bstep\s*\d+")
_SENT_PAT = re.compile(r"(?i)(identify the (given )?sentence|the sentence is)")
_SENT_END_RE = re.compile(r'[.!?](?=\s|$)')

def _count_sents(s: str) -> int:
    return len(_SENT_END_RE.findall(s))


def _trim_sents(s: str, limit: int) -> str:
    if limit <= 0: return ""
    if _count_sents(s) <= limit: return s.strip()
    idx = 0; kept = 0
    for m in _SENT_END_RE.finditer(s):
        kept += 1; idx = m.end()
        if kept >= limit: break
    return s[:idx].strip()

VOCAB_SIZE = int(cfg.get("vocab_size", bpe_tokenizer.get_vocab_size()))
TOKEN_STR = {i: bpe_tokenizer.decode([i]) for i in range(VOCAB_SIZE)}

def _has_emoji_or_symbol(s: str) -> bool:
    return any(
        (0x1F300 <= ord(ch) <= 0x1FAFF) or
        (0x2600 <= ord(ch) <= 0x26FF) or
        (0x2700 <= ord(ch) <= 0x27BF)
        for ch in s
    )

def _dedupe_close(texts, tol=0.92):
    out = []
    for t in texts:
        if not any(len(set(t.lower().split()) & set(u.lower().split())) /
                   max(1, len(set(t.lower().split()) | set(u.lower().split()))) > tol
                   for u in out):
            out.append(t)
    return out

def _build_ban_ids() -> set[int]:
    bad = set()
    literal_specials = {
        "<|PAD|>", "<|UNKNOWN|>", "<|START|>", "<|END|>",
        "<|SYSTEM|>", "<|USER|>", "<|ASSISTANT|>", "<|EOT|>",
        "<|INFOSTART|>", "<|INFOEND|>"
    }
    for i in range(VOCAB_SIZE):
        s = TOKEN_STR[i]
        if not s:
            continue
        if s in literal_specials:
            bad.add(i); continue
        if _has_emoji_or_symbol(s):
            bad.add(i); continue
        if s.strip() in {"<", ">"}:
            bad.add(i); continue
        
    bad.discard(END_ID)
    bad.discard(PAD_ID)
    return bad



def _violates_ngram(seq, cand, n):
    if n <= 1 or len(seq) < n-1: return False
    s = seq + [cand]; last = tuple(s[-n:])
    for i in range(len(s)-n):
        if tuple(s[i:i+n]) == last:
            return True
    return False

_NEG_HIT_RE = re.compile(r"\b(?:not|no|never|cannot|can['’]t|won['’]t|don['’]t|doesn['’]t|didn['’]t|isn['’]t|aren['’]t|wasn['’]t|weren['’]t)\b", re.I)
_NEG_CONT_PHRASES = [
    " I can't", " I cannot", " can't", " cannot", " unable to", " I won't", " I do not",
    " I don't", " I am not able", " unfortunately", " sorry", " I cannot help",
    " I don't have the ability", " I don't have the ability to", " I do not have the ability to",
    " I don't have access to the latest", " I do not have access to the latest",
    " I don't have access to real-time", " I do not have access to real-time",
    " I cannot browse", " I can't browse"
]

STOPWORDS = {
    "a","an","the","and","or","but","if","then","so","to","of","in","on","for","with",
    "is","are","am","was","were","be","been","being",
    "i","you","he","she","it","we","they","me","him","her","us","them",
    "my","your","his","hers","our","their",
    "what","which","who","whom","whose","when","where","why","how",
    "do","does","did","can","could","should","would","may","might","will","shall"
}

_POS_PIVOTS = [
    " but", " however", " still", " instead", " here's", " here is",
    " I can", " we can", " let's", " try", " you can",
]

def _ids_for_phrases(phrases):
    out = []
    for p in phrases:
        ids = bpe_tokenizer.encode(p, add_special_tokens=False).ids
        if ids:
            out.append(ids)
    return out

NEG_CONT_IDS = _ids_for_phrases(_NEG_CONT_PHRASES)
POS_PIVOT_IDS = _ids_for_phrases(_POS_PIVOTS)
MAX_NEG_CONT = max((len(s) for s in NEG_CONT_IDS), default=0)
MAX_POS_PVT  = max((len(s) for s in POS_PIVOT_IDS), default=0)
COLON_TIDS = {tid for tid, s in TOKEN_STR.items() if s and s.strip() == ':'}
TUTORY = {"here's", "here", "how", "let", "explain"}
TUTORY_TIDS = {tid for tid, s in TOKEN_STR.items() if s and s.strip().lower() in TUTORY}
PERIOD_TIDS = {tid for tid, s in TOKEN_STR.items() if s and s.strip() in {".", "!", "?"}}
PHRASEY = {"identify", "key", "elements", "use", "information"}
PHRASEY_TIDS = {tid for tid, s in TOKEN_STR.items() if s and s.strip().lower() in PHRASEY}
DRIFT_WORDS = {"human","body","biology","animal","animals","cell","cells"}
DRIFT_TIDS  = {tid for tid, s in TOKEN_STR.items() if s and s.strip().lower() in DRIFT_WORDS}

def typical_filter(logits, tau=0.9):
    p = torch.softmax(logits, dim=-1)
    H = -(p * (p+1e-12).log()).sum(dim=-1, keepdim=True)
    self_info = -torch.log(p + 1e-12)
    dev = torch.abs(self_info - H)
    idx = torch.argsort(dev, dim=-1)
    cum = torch.cumsum(p.gather(-1, idx), dim=-1)
    keep_mask = cum <= tau
    keep_mask[..., 0] = True
    kept = torch.full_like(p, float("-inf"))
    kept[0, idx[0, keep_mask[0]]] = torch.log(p[0, idx[0, keep_mask[0]]] + 1e-12)
    return kept

def _would_complete_any(prev_ids: list[int], cand_id: int, seqs: list[list[int]], max_len: int) -> bool:
    if not seqs:
        return False
    ctx = prev_ids[-(max_len-1):] if max_len > 1 else []
    test = ctx + [cand_id]
    for seq in seqs:
        L = len(seq)
        if len(test) >= L and test[-L:] == seq:
            return True
    return False

def _soft_tail_trim(s: str) -> str:
    s = s.strip()
    cut = re.search(r'[.!?](?!.*[.!?])', s)
    return s if not cut else s[:cut.end()]

TAG_TOKEN_RE = re.compile(
    r"<\|?(PAD|UNKNOWN|START|END|SYSTEM|USER|ASSISTANT|EOT|INFOSTART|INFOEND)\|?>"
)

def _unsafe_tail_would_appear(seq, candidate_id):
    test = bpe_tokenizer.decode(seq + [candidate_id])
    tail = test[-128:]
    return bool(TAG_TOKEN_RE.search(tail))

@torch.no_grad()
def generate_ids(model, prompt_ids, user_words: set[str] | None = None, min_tokens: int = 40,
                 max_new_tokens=120, temperature=0.0, top_p=1.0, top_k=0,
                 repetition_penalty=1.07, no_repeat_ngram_size=3, seed=None,
                 ban_token_ids: set[int] | None = None,
                 eos_token_id=None, pad_token_id=None, generator=None, 
                 known_caps: set[str] | None = None):

    dev  = next(model.parameters()).device
    STOP = int(eos_token_id if eos_token_id is not None else END_ID)
    PAD  = int(pad_token_id if pad_token_id is not None else PAD_ID)
    g    = generator if generator is not None else (
        torch.Generator(device=dev).manual_seed(int(seed)) if seed is not None else None
    )

    x = torch.tensor([prompt_ids], dtype=torch.long, device=dev)
    gen: list[int] = []
    neg_window = 0
    min_before_eos = MIN_BEFORE_EOS
    eos_boost = EOS_BOOST

    CORE_WORDS_HARD = {"step", "Step", "•", "-", "Identify", "keywords", "phrases", "Use", "Here's", "how"}
    SOFT_WORDS      = {":", "1", "2", "3", "4", "5"}
    CORE_TIDS = {tid for tid,s in TOKEN_STR.items() if s and s.strip() in CORE_WORDS_HARD}
    SOFT_TIDS = {tid for tid,s in TOKEN_STR.items() if s and s.strip() in SOFT_WORDS}
    QA_SCAFF_WORDS = {"Q","Q:","Question","Question:","A","A:","Answer","Answer:"}
    QA_SCAFF_TIDS  = {tid for tid, s in TOKEN_STR.items() if s and s.strip() in QA_SCAFF_WORDS}

    want_steps = False
    if user_words:
        hints = {"steps","step","bullet","list","instructions","how-to","procedure","outline"}
        want_steps = any(h in user_words for h in hints)

    for _ in range(max_new_tokens):
        with torch.inference_mode(), torch.autocast(device_type="cuda", dtype=AMP_DTYPE, enabled=(dev.type=="cuda")):
            logits, _ = model(input_ids=x)
        next_logits = logits[:, -1, :].float()

        if not want_steps:
            if CORE_TIDS:
                next_logits[0, list(CORE_TIDS)] -= 0.15
            if len(gen) < 18 and SOFT_TIDS:
                next_logits[0, list(SOFT_TIDS)] -= 0.20

            recent = (bpe_tokenizer.decode(gen[-64:]) if gen else "").lower()
            looks_listy = (
                "step" in recent[-24:]
                or "•" in recent[-24:]
                or re.search(r"(^|\s)-\s?$", recent[-24:])
                or re.search(r"\bhere\s+(are|is|\'?s)\b", recent)
                or "keyword" in recent[-32:] or "phrase" in recent[-32:]
            )
            if looks_listy and SOFT_TIDS:
                next_logits[0, list(SOFT_TIDS)] -= 0.30

        partial_text = bpe_tokenizer.decode(gen)
        sents_done = _count_sents(partial_text) >= 1

        if not sents_done and known_caps is not None:
            ctx_words_lower = {w.lower() for w in re.findall(r"[A-Za-z][\w'-]*", bpe_tokenizer.decode(prompt_ids) or "")}
            with torch.no_grad():
                _, cand = torch.topk(next_logits, k=min(64, next_logits.size(-1)), dim=-1)
            for tid in cand[0].tolist():
                if tid in (END_ID, PAD_ID):
                    continue
                ts = (TOKEN_STR.get(tid) or "").strip().lower()
                if ts.isalpha() and ts in ctx_words_lower:
                    next_logits[0, tid] -= 0.35

        if not sents_done:
            tail = (partial_text[-40:] if partial_text else "").lower()
            if re.search(r"\bit\s+is\s+a\s+$", tail):
                BOILER = {"that","which","uses","language","words"}
                BOILER_TIDS = {tid for tid,s in TOKEN_STR.items() if s and s.strip().lower() in BOILER}
                if BOILER_TIDS:
                    next_logits[0, list(BOILER_TIDS)] -= 0.4

        if sents_done:
            if COLON_TIDS:
                next_logits[0, list(COLON_TIDS)] -= 1.2
            if TUTORY_TIDS:
                next_logits[0, list(TUTORY_TIDS)] -= 0.8
            if not want_steps and PHRASEY_TIDS:
                next_logits[0, list(PHRASEY_TIDS)] -= 0.15

            mem_words = set(re.findall(r"[A-Za-z][\w'-]*", (bpe_tokenizer.decode(prompt_ids) or "")))
            if mem_words:
                with torch.no_grad():
                    _, cand = torch.topk(next_logits, k=min(48, next_logits.size(-1)), dim=-1)
                mem_lower = {w.lower() for w in mem_words}
                for tid in cand[0].tolist():
                    if tid in (END_ID, PAD_ID):
                        continue
                    ts = (TOKEN_STR.get(tid) or "").strip().lower()
                    if ts and ts in mem_lower:
                        next_logits[0, tid] -= 0.12

            APOLOGY = {"sorry","unfortunately"}
            APO_TIDS = {tid for tid,s in TOKEN_STR.items() if s and s.strip().lower() in APOLOGY}
            if APO_TIDS:
                next_logits[0, list(APO_TIDS)] -= 0.2
        
        if _count_sents(bpe_tokenizer.decode(gen)) >= 2 and PERIOD_TIDS:
            next_logits[0, list(PERIOD_TIDS)] += 0.10

        if user_words:
            with torch.no_grad():
                _, cand = torch.topk(next_logits, k=min(64, next_logits.size(-1)), dim=-1)
            uw_lower = {w.lower() for w in user_words}
            for tid in cand[0].tolist():
                ts = (TOKEN_STR.get(tid) or "").strip().lower()
                if len(ts) >= 3 and ts.isalpha() and ts in uw_lower:
                    next_logits[0, tid] += 0.2

        if DRIFT_TIDS:
            if not any(w in user_words for w in {"human","biology","body","animal","animals","cell","cells"}):
                penalty = 0.6
                if "humanity" in user_words:
                    penalty += 0.25
                next_logits[0, list(DRIFT_TIDS)] -= penalty

        if len(gen) < 12:
            REFUSAL_SURF = {"don't", "cannot", "can't", "unable", "sorry", "unfortunately"}
            REFUSAL_TIDS = {tid for tid,s in TOKEN_STR.items() if s and s.strip().lower() in REFUSAL_SURF}
            if REFUSAL_TIDS:
                next_logits[0, list(REFUSAL_TIDS)] -= 1.2

        next_logits[0, PAD] = float("-inf")
        if ban_token_ids:
            idx = [t for t in ban_token_ids if t != STOP]
            if idx:
                next_logits[0, idx] = float("-inf")

        if len(gen) < max(min_before_eos, min_tokens):
            next_logits[0, STOP] = float("-inf")
        else:
            next_logits[0, STOP] += eos_boost

        if len(gen) < 3 and QA_SCAFF_TIDS:
            next_logits[0, list(QA_SCAFF_TIDS)] -= 1.5

        if gen and repetition_penalty and repetition_penalty != 1.0:
            recent_ids = gen[-64:]
            for tid in set(recent_ids):
                val = next_logits[0, tid]
                next_logits[0, tid] = torch.where(val > 0, val / repetition_penalty, val * repetition_penalty)

        if len(gen) >= 8:
            window = gen[-8:]
            uniq, counts = np.unique(window, return_counts=True)
            for tid_i, cnt in zip(uniq.tolist(), counts.tolist()):
                if tid_i in (STOP, PAD): 
                    continue
                next_logits[0, tid_i] -= 0.03 * cnt

        if len(gen) >= 2:
            last1, last2 = gen[-1], gen[-2]
            with torch.no_grad():
                _, cand = torch.topk(next_logits, k=min(64, next_logits.size(-1)), dim=-1)
            for tid in cand[0].tolist():
                if tid == last1 == last2:
                    next_logits[0, tid] -= 0.6
                if tid == last1:
                    next_logits[0, tid] -= 0.12

        do_sample = (temperature and temperature > 0) or (top_p and 0 < top_p < 1.0) or (top_k and top_k > 0)

        if temperature and temperature > 0:
            next_logits /= float(temperature)

        if known_caps and _count_sents(partial_text) >= 2:
            with torch.no_grad():
                _, cand = torch.topk(next_logits, k=min(64, next_logits.size(-1)), dim=-1)
            for tid in cand[0].tolist():
                ts = TOKEN_STR.get(tid) or ""
                if ts[:1].isupper() and ts.isalpha() and ts not in known_caps and tid not in (STOP, PAD):
                    next_logits[0, tid] -= 0.35
        
        use_typical = (len(gen) >= 12)
                
        if use_typical:
            filt = typical_filter(next_logits, tau=TYPICAL_TAU)
            if not torch.isneginf(filt).all():
                next_logits = filt
        else:
            if top_k is not None and top_k > 0:
                v, _ = torch.topk(next_logits, k=min(top_k, next_logits.size(-1)))
                next_logits[next_logits < v[:, [-1]]] = float("-inf")
            if top_p is not None and 0 < top_p < 1.0:
                probs = torch.softmax(next_logits, dim=-1)
                sp, si = torch.sort(probs, descending=True)
                csum = torch.cumsum(sp, dim=-1)
                keep = csum <= top_p
                keep[..., 0] = True
                filt = torch.full_like(sp, float("-inf"))
                filt[keep] = torch.log(sp[keep] + 1e-12)
                next_logits = torch.full_like(next_logits, float("-inf"))
                next_logits[0, si[0]] = filt[0]
        
        if sents_done:
            _, top8 = torch.topk(next_logits, k=8, dim=-1)
            top8_ids = top8[0].tolist()
            if any((TOKEN_STR.get(t) or '').strip() == ':' for t in top8_ids):
                if len(gen) >= max(min_before_eos, min_tokens):
                    next_logits[0, STOP] += 0.9

        loop12 = (len(gen) >= 12 and len(set(gen[-12:])) <= 4)
        loop10 = (len(gen) >= 10 and len(set(gen[-10:])) <= 3)

        if loop12:
            next_logits[0, STOP] += 1.6
        elif loop10:
            next_logits[0, STOP] += 0.8
        else:
            with torch.no_grad():
                p = torch.softmax(next_logits, dim=-1)
                ent = float(-(p * torch.log(p + 1e-12)).sum())
            if len(gen) >= 6 and ent < 3.4:
                next_logits[0, STOP] += 0.8

        if len(gen) >= max(min_before_eos, min_tokens) - 4 and PERIOD_TIDS:
            next_logits[0, list(PERIOD_TIDS)] += 0.45

        running_tail = bpe_tokenizer.decode(gen[-64:]) if gen else ""
        if _NEG_HIT_RE.search((running_tail or "").lower()):
            neg_window = max(neg_window, NEG_WINDOW_TOKENS)

        probs = torch.softmax(next_logits, dim=-1)
        _, cand_idx = torch.sort(probs, descending=True)
        tmp_logits = next_logits.clone()
        max_scan = min(256, cand_idx.size(-1))
        for j in range(max_scan):
            cid = int(cand_idx[0, j])
            if cid == STOP:
                continue

            cand_str = TOKEN_STR.get(cid) or bpe_tokenizer.decode([cid])

            if ANGLE_NOISE_PREFIX_RE.search(running_tail) and re.match(r"[A-Za-z0-9/|>]", cand_str):
                tmp_logits[0, cid] = float("-inf")
                continue

            if _unsafe_tail_would_appear(gen, cid) or (
                no_repeat_ngram_size and _violates_ngram(gen, cid, no_repeat_ngram_size)
            ):
                tmp_logits[0, cid] = float("-inf")
                continue

            if _would_complete_any(gen, cid, NEG_CONT_IDS, MAX_NEG_CONT):
                penalty = NEG_CONT_PENALTY * (1.25 if neg_window > 0 else 1.0)
                tmp_logits[0, cid] -= penalty
            if _would_complete_any(gen, cid, POS_PIVOT_IDS, MAX_POS_PVT):
                bonus = POS_PIVOT_BONUS * (1.25 if neg_window > 0 else 1.0)
                tmp_logits[0, cid] += bonus

        if torch.isneginf(tmp_logits).all():
            fallback = next_logits.clone()
            fallback[0, PAD] = float("-inf")
            if len(gen) < min_before_eos:
                fallback[0, STOP] = float("-inf")
            next_id = int(torch.argmax(fallback, dim=-1))
        else:
            if do_sample:
                samp_p = torch.softmax(tmp_logits, dim=-1)
                next_id = int(torch.multinomial(samp_p[0], num_samples=1, generator=g))
            else:
                next_id = int(torch.argmax(tmp_logits, dim=-1))

        if _unsafe_tail_would_appear(gen, next_id):
            ok = False
            for j2 in range(min(32, cand_idx.size(-1))):
                alt = int(cand_idx[0, j2])
                if alt == next_id or alt == PAD:
                    continue
                if len(gen) < MIN_BEFORE_EOS and alt == STOP:
                    continue
                if not _unsafe_tail_would_appear(gen, alt):
                    next_id = alt; ok = True; break
            if not ok:
                fb = next_logits.clone()
                fb[0, PAD] = float("-inf")
                if len(gen) < min_before_eos:
                    fb[0, STOP] = float("-inf")
                next_id = int(torch.argmax(fb, dim=-1))

        gen.append(next_id)
        x = torch.cat([x, torch.tensor([[next_id]], device=dev)], dim=1)

        text_so_far = bpe_tokenizer.decode(gen)
        if len(gen) >= min_tokens:
            if next_id == STOP: break
            if "<|END|>" in text_so_far or "<|EOT|>" in text_so_far: break
            if _count_sents(text_so_far) >= SENTENCE_LIMIT: break
            if _STEP_PAT.search(text_so_far) or _SENT_PAT.search(text_so_far):
                min_before_eos = 0
                eos_boost = 3.0
        if len(gen) >= max(min_tokens, 10) and len(set(gen[-10:])) <= 3:
            break
        if neg_window > 0:
            neg_window -= 1

    return gen


def trim_to_last_period(s: str) -> str:
    s = s.strip()
    if not s:
        return s
    if s.endswith((".", "!", "?")):
        return s
    i = s.rfind(".")
    return s if i == -1 else s[:i+1]

ANGLE_NOISE_PREFIX_RE = re.compile(r"<[A-Za-z/|]{0,24}$")

_REFUSAL_RE = re.compile(
    r"(?i)\b(i\s+(?:don['’]t|cannot|can['’]?t(?!\s+wait)|won['’]?t|am\s+not\s+able))\b|"
    r"\b(unfortunately|sorry|unable to)\b|"
    r"\b(as an?\s+(?:ai|language model))\b|"
    r"access to the latest|real[-\s]?time|browsing"
)

def _score_answer(user_text: str, out: str, info_text: str = "") -> float:
    score = 0.0
    ol = out.strip()
    ll = ol.lower()

    if _REFUSAL_RE.search(ll): score -= 3.5

    sents = len(re.findall(r'[.!?](?=\s|$)', ol))
    if   1 <= sents <= 3: score += 2.0
    elif sents == 0: score -= 2.0
    elif sents >= 6: score -= 0.8

    nchar = len(ol)
    if 40 <= nchar <= 320: score += 0.6
    elif nchar < 25: score -= 0.6
    elif nchar > 600: score -= 0.6

    if re.search(r':\s*\d+\.\s*', ol): score -= 0.6
    if re.search(r"(?m)^\s*\d+\.\s+", ol): score -= 1.0
    if re.search(r"\b(\w+)(?:\s+\1){2,}\b", ol): score -= 0.7
    if re.search(r"\b(i\s+can\s+(help|do|explain|summarize|guide))\b", ll): score += 0.6
    if re.search(r'(?i)\b(identify|keyword|phrase|here(?:\'|)s how)\b', ol): score -= 1.0
    if re.search(r'(?i)\b(step\s*\d+|^\s*[-•]\s+)', ol, flags=re.M): score -= 1.0
    if re.search(r"(?i)\bi am\b|\bi’m\b|\bi can\b", ol): score += 0.4
    if re.search(r":\s*$", ol): score -= 0.8
    if "<|END|>" in ol or "<|EOT|>" in ol: score -= 1.0
    if re.search(r'(?i)\b(use this information|identify (the )?key (points|elements)|here(?:\'|)s how)\b', ol):
        score -= 1.2
    if re.search(r'(?i)\b(language|words)\b.*\b(use[s]?\b|\bused\b)', ol):
        score -= 0.9

    u_tokens = set(re.findall(r"[a-zA-Z][\w'-]*", user_text))
    i_tokens = set(re.findall(r"[a-zA-Z][\w'-]*", info_text))
    a_tokens = set(re.findall(r"[a-zA-Z][\w'-]*", ol))
    overlap = len((u_tokens | i_tokens) & a_tokens)
    score += 0.02 * overlap
    known_caps = _cap_words(user_text) | _cap_words(info_text)
    out_caps   = _cap_words(ol)
    novel_caps = [w for w in out_caps if w not in known_caps]
    score -= 0.6 * len(novel_caps)

    return float(score)


def _salvage_if_all_refusals(cands: list[str]) -> str | None:
    bad = [c for c in cands if _REFUSAL_RE.search(c)]
    if len(bad) == len(cands):
        return min(cands, key=lambda t: len(_REFUSAL_RE.findall(t)))
    return None

def generate(info_text: str, user_text: str, greedy: bool=False, clean_style: bool=True, seed: int | None = None) -> str:

    model = get_model(merge=True)
    prompt_ids = build_prompt_ids(
        user_text=user_text,
        info_text=info_text,
        system_text=SYSTEM_PROMPT_DEFAULT,
        assist_prefix=ASSIST_PREFIX
    )


    user_words = {w for w in re.findall(r"\w+", user_text.lower()) if len(w) >= 3 and w not in STOPWORDS}

    decode = dict(
        max_new_tokens=124,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        top_k=TOP_K,
        repetition_penalty=REPETITION_PENALTY,
        no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
        seed=seed,
        eos_token_id=END_ID,
        pad_token_id=PAD_ID,
    )

    if len(user_text.strip()) <= 20:
        decode["max_new_tokens"] = min(decode["max_new_tokens"], 80)

    ban_ids = _build_ban_ids() if clean_style else set()

    cands: list[str] = []
    for i in range(NUM_CANDIDATES):
        if seed is None:
            g = torch.Generator(device=DEVICE).manual_seed(
                (int(time.time_ns() & 0xFFFFFFFF) ^ 0xA5A5A5A5 ^ i) & 0xFFFFFFFF
            )
        else:
            g = torch.Generator(device=DEVICE).manual_seed(int(seed) + i)

        with contextlib.redirect_stdout(io.StringIO()):
            known_caps = _cap_words(bpe_tokenizer.decode(prompt_ids)) | _cap_words(
                f"{info_text or ''} {user_text or ''}"
            )

            gen_ids = generate_ids(
                model, prompt_ids, user_words,
                min_tokens=16,
                ban_token_ids=ban_ids,
                generator=g,
                known_caps=known_caps,
                **decode
            )
        cands.append(bpe_tokenizer.decode(gen_ids).split("<|END|>")[0])

    cands = _dedupe_close(cands)
    fallback = _salvage_if_all_refusals(cands)
    if fallback is not None:
        out = fallback
    else:
        out = max(cands, key=lambda t: _score_answer(user_text, t, info_text))

    out = out.split("<|EOT|>")[0]
    out = re.sub(r"\s+", " ", out)
    out = re.sub(r"[\x00-\x1F\x7F]", "", out)
    out = re.sub(r"<\|[^>]+?\|>", "", out).strip().strip('"')
    out = re.sub(
    r"<\|?(?:PAD|UNKNOWN|START|END|SYSTEM|USER|ASSISTANT|EOT|INFOSTART|INFOEND)\|?>",
        "",
        out
    )
    out = re.sub(r'\b(\w+)\s+and\s+\1\b', r'\1', out, flags=re.I)
    out = re.sub(r"<[A-Za-z0-9/|<>]{1,40}>?", "", out)
    out = re.sub(r":\s*(?:\d+\.\s*)+", ": ", out)
    out = re.sub(r"(?i)\s*here(?:'|)s\s+how.*?:\s*", "", out)
    out = re.sub(r"(?i)^\s*(the (answer|sentence) is)\s*[:\-]?\s*", "", out)
    out = re.sub(r"\b(\w+)(?:\s+\1){2,}\b", r"\1", out)
    out = re.sub(r"\s{2,}", " ", out).strip()
    out = re.sub(r"(?i)\s*(?:let me explain.*)$", "", out).strip()
    out = re.sub(r"(?i)\s*(?:i hope this (?:information )?helps!?).*$", "", out).strip()
    out = re.sub(r'(?i)\b(?:identify|use)\s+the\s+(?:key\s*)?(?:word|phrase)s?\b.*?:\s*', '', out)
    out = re.sub(r":\s*$", ".", out)
    out = re.sub(r':\s*(?:\d+\.\s*){1,6}', ': ', out)
    out = re.sub(r'(?i)\s*(?:i hope that helps!?|let me know if you have (?:any )?other questions.*)$', '', out)
    out = re.sub(r'(?m)^\s*(?:[-•]|\d+\.)\s+.*$', '', out)
    out = re.sub(r'\b(\w+)(?:\s*,\s*\1){1,}\b', r'\1', out)
    out = re.sub(r'\b(\w+)(?:\s+\1){1,}\b', r'\1', out)

    if not re.search(r"(?i)\b(list|steps|bullets|outline|numbered)\b", user_text):
        out = re.sub(r"(?m)^\s*\d+\.\s+", "", out)
    
    out = _soft_tail_trim(out)
    out = _trim_sents(out, SENTENCE_LIMIT)
    return out

In [None]:
user_question = "Who are you and what can you do?"
extra_information = "You are a supportive artificial intelligence."
print(f"User typed: {user_question}")
print(f"Memory has: {extra_information}")
model_answer = generate(info_text=extra_information, user_text=user_question)
print(f"Model answer: {model_answer}\n")

user_question = "Hello, I made you. It's been a month and half since I start this project making small LLM, and I'm so glad that I finished it!"
extra_information = ""
print(f"User typed: {user_question}")
print(f"Memory has: {extra_information}")
model_answer = generate(info_text=extra_information, user_text=user_question)
print(f"Model answer: {model_answer}\n")

user_question = "What's your thought on the humanity?"
extra_information = ""
print(f"User typed: {user_question}")
print(f"Memory has: {extra_information}")
model_answer = generate(info_text=extra_information, user_text=user_question)
print(f"Model answer: {model_answer}\n")

user_question = "Hi. What's my name?"
extra_information = "User's name is Insoo Son."
print(f"User typed: {user_question}")
print(f"Memory has: {extra_information}")
model_answer = generate(info_text=extra_information, user_text=user_question)
print(f"Model answer: {model_answer}")

User typed: Who are you and what can you do?
Memory has: You are a supportive artificial intelligence.
Model answer: I am a supportive artificial intelligence. I can help you learn about the language of the people who are in need. Please provide me with information on what you are doing.

User typed: Hello, I made you. It's been a month and half since I start this project making small LLM, and I'm so glad that I finished it!
Memory has: 
Model answer: I'm glad you found it.

User typed: What's your thought on the humanity?
Memory has: 
Model answer: "I am a man." The humanity is a group of people who have been working for years in the world. They are not like people who work together and do not work together. They are like people who are different, but they are different.

User typed: Hi. What's my name?
Memory has: User's name is Insoo Son.
Model answer: Insoo Son. Insoo Son is a name that is used to describe a person's actions. It is a name that is used to describe a person's actions

NameError: name 's' is not defined