# Алгоритм приоритетного поиска с глобальными ограничениями для контроля генерации LLM 

модели и общие функции.

In [2]:
import pandas as pd, re
from sklearn.model_selection import train_test_split
from typing import List, Dict, Generator
import time, statistics, torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import datasets
from profiling import time_avg, get_profile
import torch
import pandas as pd
from tqdm import tqdm
from profiling import time_avg, get_profile
import numpy as np

model_name = "Qwen/Qwen3-4B"
tok  = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype=torch.bfloat16, 
    trust_remote_code=True,
    #memory_efficient_attention=True
).eval()

BATCH_SIZE = 8

def batches(tasks: pd.Series, prompts: List[Dict[str, str]], 
            user_message_template: str, batch_size: int = BATCH_SIZE) -> Generator[torch.tensor, None, None]:
    all_prompts = [prompts + [{"role": "user", "content": user_message_template.format(task) }] for task in tasks]

    chatml = [tok.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True, 
                                       enable_thinking=False) for prompt in all_prompts]

    for i in range(0, len(chatml), batch_size):
        results = chatml[i:i+batch_size]
        enc = tok(results, return_tensors="pt", padding=True).to(model.device)
        yield enc, tasks.iloc[i:i+batch_size]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Создаем задачи для демонстрации

In [5]:
df = pd.read_csv("data/sudoku.csv")
train_sudoku, test_sudoku = train_test_split(df, test_size=10, random_state=42, shuffle=True)

sudoku_prompt = [
{
    "role": "system",
    "content": """
You are a Sudoku-solver.  Follow the rules exactly.

Input description:
    
- One string of 81 characters (row-major order).
- `0` = empty, your task to fill in them.
- `1-9` = fixed digits, you should not change them.
- Each 3x3 Block groups of neighbor rows and columns (nine digits) - is a **Block**
- To find the block for any position idx (0-80):

```
row    r = idx // 9         
column c = idx % 9   
block_row = r // 3   
block_col = c // 3
block_id  = block_row * 3 + block_col   
```

There is a mapping Input sequence (position shown in table cell) to Columns, Rows and Blocks:

      columns →
        0   1   2  │ 3   4   5  │ 6   7   8
      ┌────────────┼────────────┼────────────┐
rows 0│ 00  01  02 │ 03  04  05 │ 06  07  08 │
↓    1│ 09  10  11 │ 12  13  14 │ 15  16  17 │
     2│ 18  19  20 │ 21  22  23 │ 24  25  26 │
      ├────────────┼────────────┼────────────┤
     3│ 27  28  29 │ 30  31  32 │ 33  34  35 │
     4│ 36  37  38 │ 39  40  41 │ 42  43  44 │
     5│ 45  46  47 │ 48  49  50 │ 51  52  53 │
      ├────────────┼────────────┼────────────┤
     6│ 54  55  56 │ 57  58  59 │ 60  61  62 │
     7│ 63  64  65 │ 66  67  68 │ 69  70  71 │
     8│ 72  73  74 │ 75  76  77 │ 78  79  80 │
      └────────────┴────────────┴────────────┘


Rules to fill in 0-placeholders:

- **Each Row** must contain digits **1-9 exactly once**.  
- **Each Column** must contain digits **1-9 exactly once**.  
- **Each 3 x 3 Block** must contain digits **1-9 exactly once**.  
- **Fixed digits stay fixed** – every non-zero input digit must appear unchanged in the same position in your output.

Output format:

- Return **exactly one line of 81 digits**, no spaces or extra text.
- Order: row 1 (left→right) → row 2 → … → rdf = pd.DataFrame(datasets.load_dataset('allenai/common_gen', split='test'))

prompts = [
    tok.apply_chat_template(
        [{"role": "user",
          "content": f"write a sentence with the following words: [ {', '.join(c)} ]. Don't change words, tenses and include all words in the sentence."}],
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )
    for c in df["concepts"]
]ow 9 (the same as Input).
- Replace every `0` with the correct digit. 
- Keep all other digits unchanged!

Some examples:
""".strip()
  },

  {
    "role": "user",
    "content": "070000043040009610800634900094052000358460020000800530080070091902100005007040802"
  },
  {
    "role": "assistant",
    "content": "279518643543279618861634957694752318358461729712893534486927391932185465157346892"
  },

  {
    "role": "user",
    "content": "301086504046521070500000001400800002080347900009050038004090200008734090007208103"
  },
  {
    "role": "assistant",
    "content": "371986524846521379592473861463819752285347916719652438634195287128734695957268143"
  }
]
 
sdoku_batches = batches(test_sudoku['puzzle'], sudoku_prompt, "{}")

common_gen_df = pd.DataFrame(datasets.load_dataset('allenai/common_gen', split='validation'))
common_gen_batches = batches(common_gen_df['concepts'].map(lambda sl: ', '.join(sl)), [], 
                             "write a sentence with the following words: [ {} ]. Don't change words, tenses and include all words in the sentence.")

def is_include_all_words(str, concepts: list) -> bool:
    return all(word in str for word in concepts)

test_k_points = np.cumprod(np.arange(1, 6, 1))

## Наш сэмплер

In [6]:
import random
import heapq

EOS_IDS = {
    tok.eos_token_id,
    tok.convert_tokens_to_ids("<|im_end|>"),
    tok.convert_tokens_to_ids("<|endoftext|>")
}

def tiny_noise(eps: float = 1e-7) -> float:
    """
    Возвращает случайное число в диапазоне (0, eps).
    По умолчанию eps = 1e-7, что на несколько порядков
    меньше типичных разниц -log P.
    """
    return random.random() * eps


def is_solution(ids: torch.Tensor) -> bool:
    return ids[-1].item() in EOS_IDS                

def canon_key(ids: torch.Tensor) -> int:              # анти-дубликатор
    return hash(tok.decode(ids, skip_special_tokens=False).strip())

def variants(
    prefix: torch.Tensor,
    batch_eval: int = 8,
    max_new: int = 200
):
    beam_width = batch_eval
    base_len   = prefix.size(0)

    frontier: list[tuple[int, float, torch.Tensor]] = [ (-base_len, 0.0, prefix.cpu()) ]
    visited = {canon_key(prefix.cpu())}

    while frontier:
        batch = [heapq.heappop(frontier) for _ in range(min(batch_eval, len(frontier)))]

        seqs_cpu, scores = [], []
        for neg_len, score, ids in batch:
            if is_solution(ids) or (ids.size(0) - base_len) >= max_new:
                yield ids, -score
            else:
                seqs_cpu.append(ids)
                scores.append(score)
        if not seqs_cpu:
            continue

        seqs_gpu = [s.to(model.device, non_blocking=True) for s in seqs_cpu]
        with torch.inference_mode():
            padded = torch.nn.utils.rnn.pad_sequence(
                seqs_gpu, batch_first=True, padding_value=tok.pad_token_id
            )
            last_logits = model(padded, use_cache=False).logits[
                torch.arange(len(seqs_gpu), device=padded.device),
                [len(s) - 1 for s in seqs_gpu]
            ]
        top_lp, top_tid = torch.topk(torch.log_softmax(last_logits, -1),
                                     beam_width, dim=-1)

        # ── 3. расширяем: greedy + альтернативы ──────────────────────────
        for ids_cpu, base_sc, lps, tids in zip(seqs_cpu, scores, top_lp, top_tid):
            for rank, (lp, tid) in enumerate(zip(lps, tids)):      # 0-й = greedy
                child_ids = torch.cat([ids_cpu, tid.view(1).cpu()])
                key       = canon_key(child_ids)
                if key in visited:
                    continue

                visited.add(key)
                child_score = base_sc - lp.item() + tiny_noise()

                # приоритет: сначала длина, потом лог-скор
                heapq.heappush(frontier, (-child_ids.size(0), child_score, child_ids))

## Генерация нашим методом

In [8]:
def our_gen(k, common_gen_df):
    print('testing K=', k)
    data = list(batches(common_gen_df['concepts'].map(lambda sl: ', '.join(sl)), [], 
                                "write a sentence with the following words: [ {} ]. Don't change words, tenses and include all words in the sentence."))

    @time_avg
    def test_our_gen(data, steps=k):
        results = []
        for batch, concepts_batch in tqdm(data):
            for seq, concepts in zip(batch.input_ids, concepts_batch):
                for i, (res_tok, neg_logp) in enumerate(variants(seq, max_new=steps)):
                    text = tok.decode(res_tok, skip_special_tokens=True)
                    text = text.split('</think>')[1].strip()
                    if i > steps:
                        results.append("")
                        break
                    if is_include_all_words(text, concepts):
                        results.append(text)
                        break
        return results

    df_key_suffix = f'_our_gen_{k}'
    common_gen_df['result' + df_key_suffix] = test_our_gen(data)

    is_ok = []
    for _, row in common_gen_df.iterrows():
        concepts = row['concepts']
        is_ok.append(is_include_all_words(row['result' + df_key_suffix], concepts))

    common_gen_df['is_ok' + df_key_suffix] = is_ok

    print(f'Total time {k}:', get_profile())
    print(f'ACC ({k}):', common_gen_df['is_ok' + df_key_suffix].mean())

for k in test_k_points:
    our_gen(k, common_gen_df)

testing K= 1


100%|██████████| 503/503 [01:48<00:00,  4.63it/s]


Total time 1: [('our_gen.<locals>.test_our_gen', 2, 109809.639001498)]
ACC (1): 0.0
testing K= 2


100%|██████████| 503/503 [06:46<00:00,  1.24it/s]


Total time 2: [('our_gen.<locals>.test_our_gen', 1, 406763.78514198586)]
ACC (2): 0.0
testing K= 6


100%|██████████| 503/503 [27:53<00:00,  3.33s/it]


Total time 6: [('our_gen.<locals>.test_our_gen', 1, 1673665.8450830146)]
ACC (6): 0.004728720756595321
testing K= 24


100%|██████████| 503/503 [1:43:19<00:00, 12.33s/it]


Total time 24: [('our_gen.<locals>.test_our_gen', 1, 6199684.3787840335)]
ACC (24): 0.5373320059731209
testing K= 120


100%|██████████| 503/503 [2:03:25<00:00, 14.72s/it]  

Total time 120: [('our_gen.<locals>.test_our_gen', 1, 7405134.013422998)]
ACC (120): 0.6070184171229467





## Жадный алгоритм CommonGen

In [None]:
results = []
data = list(batches(common_gen_df['concepts'].map(lambda sl: ', '.join(sl)), [], 
                             "write a sentence with the following words: [ {} ]. Don't change words, tenses and include all words in the sentence."))

@time_avg
def test_free_gen(data):
    results = []
    is_oks = []
    for batch, concepts in tqdm(data):
        with torch.no_grad():
            gen = model.generate(batch.input_ids, attention_mask=batch.attention_mask, do_sample=False, top_p=None, top_k=None, temperature=None)
        texts = tok.batch_decode(gen, skip_special_tokens=True)
        texts = [text.split('</think>')[1].strip() for text in texts]
        results.extend(texts)
        is_oks.append([is_include_all_words(text, concept) for text, concept in zip(texts, concepts)])
    return results, is_oks

results, is_oks = test_free_gen(data)

common_gen_df['free_gen_result'] = results
common_gen_df['is_ok_free_gen'] = [x for xs in is_oks for x in xs]

print(f'Total time: {get_profile()}')
print('ACC:', common_gen_df['is_ok_free_gen'].mean())

100%|██████████| 503/503 [04:33<00:00,  1.84it/s]

Total time: [('test_free_gen', 1, 273475.7408319856)]
ACC: 0.35739173718267797





## Nuclius sampling + проверка

In [6]:
results = []
data = list(batches(common_gen_df['concepts'].map(lambda sl: ', '.join(sl)), [], 
                             "write a sentence with the following words: [ {} ]. Don't change words, tenses and include all words in the sentence."))

@time_avg
@torch.inference_mode()
def test_free_gen_nucleus(
    data_loader,
    max_tries: int = 40,
    top_p: float = 0.95,
    temperature: float = 0.5,
    max_new_tokens: int | None = None,
):
    all_results = []

    for batch, concepts in tqdm(data_loader):
        # 👉 превратим Series в список с 0-based позициями
        concepts = list(concepts)               # или concepts.reset_index(drop=True)

        batch_size     = batch.input_ids.size(0)
        unfinished     = torch.ones(batch_size, dtype=torch.bool, device=batch.input_ids.device)
        batch_results  = [None] * batch_size

        for _ in range(max_tries):
            if not unfinished.any():
                break

            idx = unfinished.nonzero(as_tuple=False).flatten()

            gen = model.generate(
                batch.input_ids[idx],
                attention_mask=batch.attention_mask[idx],
                do_sample=True,
                top_p=top_p,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                pad_token_id=tok.pad_token_id,
            )

            decoded = tok.batch_decode(gen, skip_special_tokens=True)

            for i, text in zip(idx.tolist(), decoded):
                text = text.split('</think>')[1].strip() if '</think>' in text else text
                if is_include_all_words(text, concepts[i]):
                    batch_results[i] = text
                    unfinished[i]   = False

        # те, что не прошли 40 попыток
        batch_results = [s or "" for s in batch_results]
        all_results.extend(batch_results)

    return all_results


common_gen_df['nuc_gen_result'] = test_free_gen_nucleus(data)

is_ok = []
for _, row in common_gen_df.iterrows():
    concepts = row['concepts']
    is_ok.append(is_include_all_words(row['nuc_gen_result'], concepts))

common_gen_df['is_ok_nuc_gen'] = is_ok

print(f'Total time (nucleus): {get_profile()}')
print('ACC (nucles):', common_gen_df['is_ok_nuc_gen'].mean())

100%|██████████| 503/503 [2:02:36<00:00, 14.62s/it]  

Total time (nucleus): [('test_free_gen_nucleus', 1, 7356091.35289496)]
ACC (nucles): 0.226231956197113





K==100, p=0

In [8]:
results = []
data = list(batches(common_gen_df['concepts'].map(lambda sl: ', '.join(sl)), [], 
                             "write a sentence with the following words: [ {} ]. Don't change words, tenses and include all words in the sentence."))

@time_avg
@torch.inference_mode()
def test_free_gen_nucleus(
    data_loader,
    max_tries: int = 100,
    top_p: float = 1.0,
    temperature: float = None,
    max_new_tokens: int | None = None,
):
    all_results = []

    for batch, concepts in tqdm(data_loader):
        # 👉 превратим Series в список с 0-based позициями
        concepts = list(concepts)               # или concepts.reset_index(drop=True)

        batch_size     = batch.input_ids.size(0)
        unfinished     = torch.ones(batch_size, dtype=torch.bool, device=batch.input_ids.device)
        batch_results  = [None] * batch_size

        for _ in range(max_tries):
            if not unfinished.any():
                break

            idx = unfinished.nonzero(as_tuple=False).flatten()

            gen = model.generate(
                batch.input_ids[idx],
                attention_mask=batch.attention_mask[idx],
                do_sample=True,
                top_p=top_p,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                pad_token_id=tok.pad_token_id,
            )

            decoded = tok.batch_decode(gen, skip_special_tokens=True)

            for i, text in zip(idx.tolist(), decoded):
                text = text.split('</think>')[1].strip() if '</think>' in text else text
                if is_include_all_words(text, concepts[i]):
                    batch_results[i] = text
                    unfinished[i]   = False

        # те, что не прошли 40 попыток
        batch_results = [s or "" for s in batch_results]
        all_results.extend(batch_results)

    return all_results


common_gen_df['nuc_gen_result'] = test_free_gen_nucleus(data)

is_ok = []
for _, row in common_gen_df.iterrows():
    concepts = row['concepts']
    is_ok.append(is_include_all_words(row['nuc_gen_result'], concepts))

common_gen_df['is_ok_nuc_gen'] = is_ok

print(f'Total time (nucleus): {get_profile()}')
print('ACC (nucles):', common_gen_df['is_ok_nuc_gen'].mean())

100%|██████████| 503/503 [4:00:44<00:00, 28.72s/it]  

Total time (nucleus): [('test_free_gen_nucleus', 1, 14444126.722427027)]
ACC (nucles): 0.35714285714285715





K=1000

In [None]:
results = []
data = list(batches(common_gen_df['concepts'].map(lambda sl: ', '.join(sl)), [], 
                             "write a sentence with the following words: [ {} ]. Don't change words, tenses and include all words in the sentence."))

@time_avg
@torch.inference_mode()
def test_free_gen_nucleus(
    data_loader,
    max_tries: int = 1000,
    top_p: float = 1.0,
    temperature: float = None,
    max_new_tokens: int | None = None,
):
    all_results = []

    for batch, concepts in tqdm(data_loader):
        # 👉 превратим Series в список с 0-based позициями
        concepts = list(concepts)               # или concepts.reset_index(drop=True)

        batch_size     = batch.input_ids.size(0)
        unfinished     = torch.ones(batch_size, dtype=torch.bool, device=batch.input_ids.device)
        batch_results  = [None] * batch_size

        for _ in range(max_tries):
            if not unfinished.any():
                break

            idx = unfinished.nonzero(as_tuple=False).flatten()

            gen = model.generate(
                batch.input_ids[idx],
                attention_mask=batch.attention_mask[idx],
                do_sample=True,
                top_p=top_p,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                pad_token_id=tok.pad_token_id,
            )

            decoded = tok.batch_decode(gen, skip_special_tokens=True)

            for i, text in zip(idx.tolist(), decoded):
                text = text.split('</think>')[1].strip() if '</think>' in text else text
                if is_include_all_words(text, concepts[i]):
                    batch_results[i] = text
                    unfinished[i]   = False

        # те, что не прошли 40 попыток
        batch_results = [s or "" for s in batch_results]
        all_results.extend(batch_results)

    return all_results


common_gen_df['nuc_gen_result'] = test_free_gen_nucleus(data)

is_ok = []
for _, row in common_gen_df.iterrows():
    concepts = row['concepts']
    is_ok.append(is_include_all_words(row['nuc_gen_result'], concepts))

common_gen_df['is_ok_nuc_gen'] = is_ok

print(f'Total time (nucleus): {get_profile()}')
print('ACC (nucles):', common_gen_df['is_ok_nuc_gen'].mean())

 15%|█▌        | 77/503 [5:17:56<16:53:53, 142.80s/it]

# Анасамбль слабых правил

## Голосование простым большинством

Правила

In [None]:
import random
from enum import Enum, auto


class Verdicts(Enum):
    ACCEPT = auto()
    REJECT = auto()
    ABSTAIN = auto()


def is_ok(text: str, word1: str, word2: str, word3: str, banned_word: str) -> bool:
    # Правила 1–3: проверка наличия обязательных слов
    def check_word(word: str, text: str) -> Verdicts:
        """Возвращает вердикт для правила, проверяющего наличие слова в тексте."""
        if word == "" or word is None or not (word in text):
            return Verdicts.ABSTAIN      # Слово не задано
        return Verdicts.ACCEPT
    
    verdicts = []
    # Проверяем первые три слова по правилам 1-3
    verdicts.append(check_word(word1, text))
    verdicts.append(check_word(word2, text))
    verdicts.append(check_word(word3, text))
    
    def random_verdict(_: str, text: str) -> Verdicts:
        rand_value = random.random()  # получаем число [0.0, 1.0)
        if rand_value < 0.3:
            return Verdicts.ACCEPT
        elif rand_value < 0.7:
            return Verdicts.REJECT
        else:
            return Verdicts.ABSTAIN
    
    verdicts.append(random_verdict("", text))
    
    def is_banned(banned_word: str, text: str) -> Verdicts:
        if banned_word is not None and banned_word != "" and banned_word in text:
            return Verdicts.REJECT
        else:
            return Verdicts.ABSTAIN
    

    verdicts.append(is_banned(banned_word, text))
    
    # Подсчёт голосов "принять" и "отказать"
    accept_count = verdicts.count(Verdicts.ACCEPT)
    reject_count = verdicts.count(Verdicts.REJECT)
 
    if reject_count > accept_count:
        return False
    return True


In [None]:
from typing import List


@time_avg
def gen_batch(model_generate, puzzles: List[str]) -> List[str]:
    """
    puzzles : список строк-судоку
    model_generate : функция-обёртка вокруг model.generate (т.к. вы
                     передаёте lambda input: model.generate(...))
    ↩︎  список строк-ответов (только генерация)
    """
    # 1) строим список промптов (один на каждый пазл)
    prompt_texts = [
        tok.apply_chat_template(
            prompt + [{"role": "user", "content": p}],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False
        )
        for p in puzzles
    ]

    # 2) токенизируем как настоящий batch
    enc = tok(prompt_texts, return_tensors="pt", padding=True).to(model.device)

    # 3) вызываем генератор (он уже внутри lambda)
    with torch.no_grad():
        out_ids = model_generate(enc)          # shape: [B, L_gen]

    out_ids = out_ids.cpu()
    answers = []
    for i, prompt_len in enumerate((enc.input_ids != tok.pad_token_id).sum(1)):
        gen_part = out_ids[i, prompt_len:]     # обрезаем префикс-промпт
        answers.append(tok.decode(gen_part, skip_special_tokens=True).strip())
    return answers


# ── обёртка для удобного вызова в pandas-цикле  ─────────────────────
def calc_all(df, model):
    BATCH = 2
    results = []
    reset_profile()

    generator = lambda enc: model.generate(
        enc.input_ids,
        attention_mask=enc.attention_mask,
        top_p=0.9,
        temperature=0.8,
        max_new_tokens=300
    )

    for i in tqdm(range(0, len(df), BATCH)):
        puzzles = df["puzzle"].iloc[i:i+BATCH].tolist()
        results.extend(gen_batch(generator, puzzles))

    return results


# ── запуск  ─────────────────────────────────────────────────────────
test["free_gen_solution"] = calc_all(test, model)
print_profile()
test[test["solution"] == test["free_gen_solution"]]


100%|██████████| 5/5 [05:40<00:00, 68.06s/it]

function                                             calls     avg ms
gen_batch                                                5  68057.332





Unnamed: 0,puzzle,solution,cls_gen_solution,free_gen_solution


In [None]:
import re, torch, torch.nn.functional as F
from enum   import Enum, auto
from typing import List, Dict, Tuple, Callable

# ---- Verdict enum -------------------------------------------------
class Verdict(Enum):
    REJECT   = auto()
    CONTINUE = auto()
    ACCEPT   = auto()

_puzzle_re = re.compile(r"[0-9]{81}")   # ищем строку-пазл
_digit_re  = re.compile(r"[0-9]")       # одиночные цифры

def sudoku_rule(text: str, history: List[Dict[str, str]]) -> Verdict:
    """
    • text    – текущий вывод LLM (префикс).
    • history – диалог; в последнем user-сообщении ищем 81-символьный пазл.

    Правило:
    1. В text (после strip()) должны быть **только цифры 0-9**.  
       Любой другой символ – сразу REJECT.
    2. Кол-во цифр ≤ 81.  >81 → REJECT.
    3. Заданные цифры из puzzle неизменны.
    4. Нет дубликатов в строках, столбцах, блоках.
    5. Если все 81 ячейка заполнена и п. 4 выполнен → ACCEPT,
       иначе (цифр <81) → CONTINUE.
    """
    puzzle = None
    for msg in reversed(history):
        if msg["role"] == "user":
            hit = _puzzle_re.findall(msg.get("content", ""))
            if hit:
                puzzle = hit[-1]
                break
    if puzzle is None:
        return Verdict.REJECT

    stripped = text.strip()
    if not stripped.isdigit():          # любой не-digit → reject
        return Verdict.REJECT

    digits = list(stripped)
    if len(digits) > 81:                # длиннее сетки
        return Verdict.REJECT

    for p, d in zip(puzzle, digits):
        if p != "0" and p != d:
            return Verdict.REJECT

    digits += ["0"] * (81 - len(digits))
    rows = [digits[i*9:(i+1)*9] for i in range(9)]
    cols = list(zip(*rows))
    blks = [digits[r*27+c*3:r*27+c*3+27:9] +
            digits[r*27+c*3+1:r*27+c*3+28:9] +
            digits[r*27+c*3+2:r*27+c*3+29:9]
            for r in range(3) for c in range(3)]

    for unit in rows + [list(c) for c in cols] + blks:
        nums = [x for x in unit if x != "0"]
        if len(nums) != len(set(nums)):
            return Verdict.REJECT

    return Verdict.ACCEPT if "0" not in digits else Verdict.CONTINUE


# ---- helpers ------------------------------------------------------
def _pad(batch: List[List[int]], pad_id, device):
    L = max(map(len, batch))
    t = torch.full((len(batch), L), pad_id, dtype=torch.long, device=device)
    for i, seq in enumerate(batch):
        t[i, :len(seq)] = torch.tensor(seq, device=device)
    return t

def in_think(txt: str) -> bool:
    o = txt.rfind("<think>"); c = txt.rfind("</think>")
    return o != -1 and (c == -1 or c < o)

# ---- beam-search --------------------------------------------------
def beam_generate(
    model,
    tok,
    prompts: List[str],
    *,
    beam_size: int = 16,
    max_new_tokens: int = 120,
    sub_bs: int = 2,
    rule: Callable[[str, List[Dict[str,str]]], Verdict] = sudoku_rule,
) -> List[str]:

    dev   = next(model.parameters()).device
    eos   = tok.eos_token_id
    pad   = tok.pad_token_id
    vocab = model.config.vocab_size             # ← фикс

    p_ids = [tok.encode(p, add_special_tokens=False) for p in prompts]
    batch = len(prompts)

    # histories per prompt
    histories = [[{"role":"user","content":_puzzle_re.findall(p)[-1]}]
                 for p in prompts]

    # digit mask on GPU
    mask = torch.full((vocab,), -1e9, device=dev)
    for d in "123456789":
        mask[tok.convert_tokens_to_ids(d)] = 0.0
    mask[eos] = 0.0

    def score(ids_lists: List[List[int]]) -> torch.Tensor:
        outs = []
        for i in range(0, len(ids_lists), sub_bs):
            ids = _pad(ids_lists[i:i+sub_bs], pad, dev)
            att = (ids != pad).long()
            with torch.no_grad():
                logits = model(ids, attention_mask=att,
                               logits_to_keep=1, use_cache=False).logits[:, -1]
            outs.append(torch.log_softmax(logits + mask, -1))
        return torch.cat(outs, 0)

    first_lp = score(p_ids)
    beams  = [[] for _ in range(batch)]   # (ids, txt, score)
    closed = [[] for _ in range(batch)]

    for qi, lp_row in enumerate(first_lp):
        lp, idx = torch.topk(lp_row, beam_size)
        for sc, tid in zip(lp.tolist(), idx.tolist()):
            s = tok.convert_ids_to_tokens([tid])[0]
            if tid == eos:
                if rule("", histories[qi]) is Verdict.ACCEPT:
                    closed[qi].append(([], "", sc))
            else:
                v = Verdict.CONTINUE if in_think(s) else rule(s, histories[qi])
                if v is not Verdict.REJECT:
                    beams[qi].append(([tid], s, sc))

    for _ in range(max_new_tokens - 1):
        active, back = [], []
        for qi, bl in enumerate(beams):
            for bi, (ids, txt, sc) in enumerate(bl):
                active.append(p_ids[qi] + ids)
                back.append((qi, bi))
        if not active: break

        lp_all = score(active)
        new_pool = [[] for _ in range(batch)]

        for row, (qi, bi) in enumerate(back):
            base_ids, base_txt, base_sc = beams[qi][bi]
            lp, idx = torch.topk(lp_all[row], beam_size)

            for add, tid in zip(lp.tolist(), idx.tolist()):
                s     = tok.convert_ids_to_tokens([tid])[0]
                ids2  = base_ids + [tid]
                txt2  = base_txt + s
                sc2   = base_sc + add

                if tid == eos:
                    if in_think(base_txt) or \
                       rule(base_txt, histories[qi]) is Verdict.ACCEPT:
                        closed[qi].append((base_ids, base_txt, sc2))
                    continue

                v = Verdict.CONTINUE if in_think(txt2) else rule(txt2, histories[qi])
                if v is Verdict.REJECT:
                    continue
                new_pool[qi].append((ids2, txt2, sc2))

        beams = []
        for pool in new_pool:
            pool.sort(key=lambda x: x[2], reverse=True)
            beams.append(pool[:beam_size])
        if all(not bl for bl in beams): break

    outs = []
    for qi in range(batch):
        cand = closed[qi] if closed[qi] else beams[qi]
        best_txt = max(cand, key=lambda x: x[2])[1] if cand else ""
        outs.append(best_txt)
    return outs

def calc_all(df):
    BATCH = 2
    res = []
    for i in tqdm(range(0, len(df), BATCH)):
        puzzles = df["puzzle"].iloc[i:i+BATCH].tolist()
        prompts = tok.apply_chat_template(
            [prompt + [{"role":"user","content":p}] for p in puzzles],
            tokenize=False, add_generation_prompt=True, enable_thinking=False
        )
        res.extend(beam_generate(model, tok, prompts))
    return res
    
prompt_texts = [
        tok.apply_chat_template(
            prompt + [{"role": "user", "content": p}],
            tokenize=False,
        )
        for p in puzzles
    ]

    # 2) токенизируем как настоящий batch
    enc = tok(prompt_texts, return_tensors="pt", padding=True).to(model.device)_gen_solution"] = calc_all(test)

100%|██████████| 5/5 [1:06:41<00:00, 800.29s/it]


In [None]:
EOS_IDS   = {tok.eos_token_id, tok.convert_tokens_to_ids("<|im_end|>"),tok.convert_tokens_to_ids("<|endoftext|>") } 

import heapq, itertools, torch
from typing import Generator, List, Tuple, Callable
from enum import Enum, auto

def tiny_noise(ids: torch.Tensor, eps: float = 1e-7) -> float:
    h = hash(ids.numpy().tobytes()) & 0xFFFF    # 16 бит
    return eps * h / 65535.0

# ──  условие «решение» (по умолчанию: любой EOS)  ────────────────────────
def is_solution(ids: torch.Tensor) -> bool:
    return ids[-1].item() in EOS_IDS

# ──  глубина-сначала + best-first  ───────────────────────────────────────
def variants(
    prefix: torch.Tensor,
    beam_width: int = 8,
    batch_eval: int = 4,
    max_new:   int = 128,
) -> Generator[Tuple[torch.Tensor, float], None, None]:
    """
    Генератор (ids, log_prob):
      • приоритет (-len, score)  →  всегда раскрываем самую длинную ветку;
      • visited-set предотвращает дубли;
      • поиск полон, пока max_new конечен.
    """
    base_len = prefix.size(0)                   # длина исходного префикса
    visited = {hash(prefix.cpu().numpy().tobytes())}

    # frontier на CPU: (-len, score+noise, ids_CPU)
    frontier: List[Tuple[int, float, torch.Tensor]] = [
        (-base_len, 0.0, prefix.cpu())
    ]
    heapq.heapify(frontier)

    while frontier:
        # ── 1. формируем батч самых длинных гипотез ──
        batch = [heapq.heappop(frontier)
                 for _ in range(min(batch_eval, len(frontier)))]

        # распаковываем
        seqs_cpu, scores = [], []
        for neg_len, score, ids in batch:
            if is_solution(ids) or (ids.size(0) - base_len) >= max_new:
                yield ids, -score
            else:
                seqs_cpu.append(ids)
                scores.append(score)

        if not seqs_cpu:         # весь батч уже дал решения
            continue

        # ── 2. один forward на GPU  ─────────────────────────────
        seqs_gpu = [s.to(model.device, non_blocking=True) for s in seqs_cpu]
        with torch.inference_mode():
            padded = torch.nn.utils.rnn.pad_sequence(
                seqs_gpu, batch_first=True, padding_value=tok.pad_token_id
            )
            logits = model(padded, use_cache=False).logits
            last = logits[
                torch.arange(len(seqs_gpu), device=logits.device),
                [len(s) - 1 for s in seqs_gpu]
            ]
        logp    = torch.log_softmax(last, -1)
        top_lp, top_tid = torch.topk(logp, beam_width, dim=-1)

        # ── 3. расширяем: greedy + альтернативы ─────────────────
        for ids_cpu, base_score, best_lp, best_tid, alt_lps, alt_tids in zip(
                seqs_cpu, scores,
                top_lp[:, 0], top_tid[:, 0],
                top_lp[:, 1:], top_tid[:, 1:]):

            # (а) альтернативы
            for lp, tid in zip(alt_lps, alt_tids):
                child = torch.cat([ids_cpu, tid.view(1).cpu()])
                key = canon_key(child)
                if key not in visited:
                    visited.add(key)
                    child_score = base_score - lp.item() + tiny_noise(child)
                    heapq.heappush(frontier,
                                   (-child.size(0), child_score, child))

            # (б) greedy-первый (рассматривается раньше из-за DFS)
            best = torch.cat([ids_cpu, best_tid.view(1).cpu()])
            key = hash(best.numpy().tobytes())
            if key not in visited:
                visited.add(key)
                best_score = base_score - best_lp.item() + tiny_noise(best)
                heapq.heappush(frontier, (-best.size(0), best_score, best))



class Verdict(Enum):
    REJECT   = auto()
    CONTINUE = auto()
    ACCEPT   = auto()

_puzzle_re = re.compile(r"[0-9]{81}")   # ищем строку-пазл
_digit_re  = re.compile(r"[0-9]")       # одиночные цифры

def sudoku_rule(text: str) -> Verdict:
    stripped = text.strip()
    if not stripped.isdigit():          # любой не-digit → reject
        return Verdict.REJECT

    digits = list(stripped)
    if len(digits) > 81:                # длиннее сетки
        return Verdict.REJECT

    digits += ["0"] * (81 - len(digits))
    rows = [digits[i*9:(i+1)*9] for i in range(9)]
    cols = list(zip(*rows))
    blks = [digits[r*27+c*3:r*27+c*3+27:9] +
            digits[r*27+c*3+1:r*27+c*3+28:9] +
            digits[r*27+c*3+2:r*27+c*3+29:9]
            for r in range(3) for c in range(3)]

    for unit in rows + [list(c) for c in cols] + blks:
        nums = [x for x in unit if x != "0"]
        if len(nums) != len(set(nums)):
            return Verdict.REJECT

    return Verdict.ACCEPT if "0" not in digits else Verdict.CONTINUE



enc = tok([tok.apply_chat_template(
            prompt + [{"role": "user", "content": p}],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False
        ) for p in test['puzzle'].tolist()], return_tensors="pt", padding=True).to(model.device).input_ids

res = []
for pr, p, sol in zip(enc, test['puzzle'].tolist(), test['solution'].tolist()): 
    print('Puzzle', p)
    print('Soolution', sol)
    for seq, neg_logp in variants(pr):
        text = tok.decode(seq, skip_special_tokens=True)#.split('</think>')[1].strip()
        if '</think>' in text:
            text = text.split('</think>')[1].strip()
        else:
            print('not finished thoths', text)
            print()
        if sudoku_rule(text) == Verdict.ACCEPT:
            res.append(text)
            print("\tAccepted solution:", text)
            break
        else:
            print('\tRejected solution:', text)

Puzzle 910700000065800000078004695020401007081000209006000400003670908000108030040509176
Soolution 914756823265893714378214695529481367481367259736925481153672948697148532842539176
	Rejected solution: 912753864865324197378915624234687519751469382693172458182596734476238915549813267
	Rejected solution: 912753864365814297783269514249317685817645329634928751591476238428531976176893452
	Rejected solution: 912753864365814297783269514249317685817645329654928731136879452528431976971586342
	Rejected solution: 912753864365814297783269514249317685817645329654928731136879452528431976971586243
	Rejected solution: 9127538648653241973789156242346875197514693826931724581825967344762389155498132676
	Rejected solution: 9127538643658142977832695142493176858176453296349287515914762384285319761768934523
	Rejected solution: 912753864865324197378915624234687519751469382693172458182596734476238915549813267
	Rejected solution: 912753864365814297783269514249317685817645329634928751591476238428531976176893452
	

KeyboardInterrupt: 

In [None]:
test['solution'].iloc[0]

'914756823265893714378214695529481367481367259736925481153672948697148532842539176'

In [2]:
import datasets
import torch


df = pd.DataFrame(datasets.load_dataset('allenai/common_gen', split='test'))

prompts = [
    tok.apply_chat_template(
        [{"role": "user",
          "content": f"write a sentence with the following words: [ {', '.join(c)} ]. Don't change words, tenses and include all words in the sentence."}],
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=False
    )
    for c in df["concepts"]
]

batch_size = 32          # подбери под свою видеокарту (8-32 обычно ок)
results = []

for start in tqdm(range(0, len(prompts), batch_size)):
    batch_prompts = prompts[start : start + batch_size]

    enc = tok(
        batch_prompts,
        padding=True,
        return_tensors="pt"
    ).to(model.device)                            # токены на GPU

    with torch.no_grad():                   # экономим память, убираем градиенты
        gen = model.generate(
            input_ids=enc.input_ids,
            attention_mask=enc.attention_mask,
            top_p=0.9,
            temperature=0.8,
            max_new_tokens=300
        )

    results.extend(tok.batch_decode(gen, skip_special_tokens=True))

    # подчистить, чтобы освободить VRAM
    del enc, gen
    torch.cuda.empty_cache()

for s in [s.split('</think>')[1] for s in results]:
    print(s)




100%|██████████| 47/47 [01:17<00:00,  1.66s/it]



The team practiced a new drill on the field.


The player took a shot at the goal.


I threw the frisbee to my dog, and he caught it right away.


She sat at the front of the table to eat the food.


She sat at the front of the stage, holding a guitar and pointing toward the microphone.


I used a metal tool to cut the piece of metal.


The dog walked on the sidewalk with a leash.


She stepped onto the stage, ready to perform her carefully rehearsed routine, the music filling the air with energy and rhythm.


I will use the machine to sew and demonstrate how to operate it.


I used the stove to cook food in the pan.


The player wore the jersey on the field.


The door was open, but the refrigerator was closed.


I push the lawn mower to mow the grass on the lawn.


He held the marshmallow over the fire, carefully roasting it until it was golden and sticky, then carefully stuck it on a skewer to finish cooking.


The couple sat at the table for dinner.


She put on her lipstick and 




In [3]:
df['free_gen_result'] = [s.split('</think>')[1] for s in results]

is_ok = []

def is_include_all_words(str, concepts: list) -> bool:
    return all(word in str for word in concepts)

for _, row in df.iterrows():
    ok = is_include_all_words(row['free_gen_result'], row['concepts'])
    is_ok.append(ok)

df['is_free_gen_ok'] = is_ok

df

Unnamed: 0,concept_set_idx,concepts,target,free_gen_result,is_free_gen_ok
0,0,"[team, run, drill, field]",,\n\nThe team practiced a new drill on the field.,False
1,1,"[goal, player, take, shot]",,\n\nThe player took a shot at the goal.,False
2,2,"[dog, frisbee, throw, catch]",,"\n\nI threw the frisbee to my dog, and he caug...",False
3,3,"[food, front, table, sit]",,\n\nShe sat at the front of the table to eat t...,False
4,4,"[front, sit, guitar, microphone]",,"\n\nShe sat at the front of the stage, holding...",False
...,...,...,...,...,...
1492,1492,"[worker, street, brick, stand, attempt]",,"\n\nThe worker stood on the street, trying to ...",True
1493,1493,"[pant, wear, golfer, club, jacket]",,\n\nThe golfer wore a stylish jacket and a pai...,False
1494,1494,"[mirror, gear, hold, picture, take]",,"\n\nI will take a picture of the mirror, makin...",True
1495,1495,"[tie, exercise, rope, wall, wave]",,\n\nShe tied a rope to the wall and used it to...,True


In [4]:
df['is_free_gen_ok'].mean()

np.float64(0.541750167000668)

In [None]:
import random
import heapq, itertools, torch






enc = tok(prompts, return_tensors="pt", padding=True).to(model.device).input_ids

results = []
for input_ids, c in tqdm(zip(enc, df['concepts'].tolist()), total=len(df)): 
    for i, (seq, neg_logp) in enumerate(variants(input_ids)):
        text = tok.decode(seq, skip_special_tokens=True)
        if '</think>' in text:
            text = text.split('</think>')[1].strip()
        else:
            continue
        if is_include_all_words(text, c):
            results.append(text)
            break
        if i > 40:
            results.append("")
            break
            

100%|██████████| 1497/1497 [56:56<00:00,  2.28s/it]  


In [22]:
df['contrained_gen_result'] = results

is_ok = []

for _, row in df.iterrows():
    ok = is_include_all_words(row['contrained_gen_result'], row['concepts'])
    is_ok.append(ok)

df['is_contrained_gen_ok'] = is_ok

In [23]:
df['is_contrained_gen_ok'].mean()

np.float64(0.9799599198396793)