In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F

device = "cuda:3"
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.1-8B-Instruct", token = '<token>').to(device)

print(model.lm_head)

# model.lm_head = Linear(
#     in_features=model.lm_head.in_features,
#     out_features=28,
#     bias=False
# ).to(device)

# print(model.lm_head)

model.eval()

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.65it/s]


Linear(in_features=4096, out_features=128256, bias=False)


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

In [None]:
# # Example input
text = '''You are a cryptic crossword expert. You are given a clue for a cryptic crossword. Output only the answer. 
clue:
Bend down to king in Chesterfield [6]
output: CRO'''


inputs = tokenizer(text, return_tensors="pt").to(device)

# Forward pass
with torch.no_grad():
    outputs = model(**inputs)

    
logs = outputs.logits[0][-1]
top_k = torch.topk(logs, 50)
print([tokenizer.decode(top_k.indices[i]) for i in range(len(top_k.indices))])
break

['UCH', 'OK', 'ES', 'SB', 'ON', 'WE', 'ONS', 'USE', 'UC', 'CH', 'PS', 'SSL', 'QUE', 'BE', 'COD', 'OL', 'KER', 'ST', 'ATS', 'CK', 'SI', 'SO', 'O', 'QU', 'FT', 'SER', 'B', 'SES', 'OLS', 'uch', 'YS', 'OPS', 'US', 'AT', 'PE', 'SL', 'SION', 'OME', 'AK', 'ESS', 'ISS', 'CS', 'CE', 'UCE', 'IK', 'USB', 'ONY', 'SW', 'BS', 'UK']


In [43]:
import pandas as pd
import os
import numpy as np

dir_path = "/root/Cryptic-Crosswords/Crossword Clues/test"
test_sets = []
for file in os.listdir(dir_path):
    test_set_partial = pd.read_csv(os.path.join(dir_path,file))
    test_sets.append(test_set_partial)

# Combine all three DataFrames
df = pd.read_csv("/root/Cryptic-Crosswords/Crossword Clues/test/unique_clues_test.csv")

clues = list(df["Clue"])
answers = list(df["Answer"])
lengths = list(df["Length"])

print("Test Set Size:",len(clues))

import json
data = json.load(open('OtherTestSet/naive_random.json'))
test_data = data['test']
clues = [test_data[i]['clue'] for i in range(len(test_data))]
answers = [test_data[i]['soln'] for i in range(len(test_data))]
lengths = [test_data[i]['lengths'] for i in range(len(test_data))]

Test Set Size: 16404


In [192]:
import json
data = json.load(open('OtherTestSet/naive_random.json'))
test_data = data['test']
clues = [test_data[i]['clue'] for i in range(len(test_data))]
answers = [test_data[i]['soln'] for i in range(len(test_data))]
lengths = [test_data[i]['lengths'] for i in range(len(test_data))]
print("Test Set Size:",len(clues))
print(clues[0])
print(answers[0])
print(lengths[0])

Test Set Size: 28476
Achy shaking stopped by iodine, salt and kaolin
chinaclay
[5, 4]


In [None]:
def prompt_generator(clue,length, list_of_answer = [], repeat = 0):
    prompt = f"""You are a cryptic crossword expert. You are given a clue for a cryptic crossword. Output only the answer. 
clue:
{clue} {length}
output:"""
    return prompt



In [178]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

def beam_search_decoder(prompt, length_budget, model, tokenizer, initial_beam_width=50, beam_decay=0.7, max_steps=5, top_k=50):

    if len(length_budget) == 1:
        length_budget = length_budget[0]
        model.to(device)
        model.eval()

        BeamEntry = tuple[str, float]  # (generated_text, cumulative_logprob)
        beam = [("", 0.0)]
        answer_list = []

        for step in range(max_steps):
            current_beam_width = max(1, int(initial_beam_width * (beam_decay ** step)))
            new_beam = []

            for generated_text, cum_logprob in beam:
                input_text = prompt + generated_text
                input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

                with torch.no_grad():
                    outputs = model(input_ids)
                    logits = outputs.logits

                # Get top_k tokens from the last position
                probs = torch.nn.functional.log_softmax(logits[0, -1], dim=-1)
                topk_probs, topk_indices = torch.topk(probs, top_k)

                for i in range(top_k):
                    token_id = topk_indices[i].item()
                    token_logprob = topk_probs[i].item()
                    token_str = tokenizer.decode([token_id], skip_special_tokens=True)

                    if not token_str.strip().isalpha():
                        continue  # only alphanumeric tokens

                    # Check if adding this token exceeds the length budget
                    new_text = generated_text + token_str.strip()
                    if len(new_text.strip()) > length_budget:
                        continue
                    elif len(new_text.strip()) == length_budget:
                        answer_list.append((new_text, cum_logprob + token_logprob))
                    else:
                        new_beam.append((new_text, cum_logprob + token_logprob))

            # Select top-k from new_beam
            beam = sorted(new_beam, key=lambda x: x[1], reverse=True)[:current_beam_width]
            # print("Step", step, "Beam:", beam)
            if not beam:
                break

        if answer_list:
            final_answers = list(set([text.lower() for text, _ in sorted(answer_list, key=lambda x: x[1], reverse=True)]))
            final_answers.sort()
            return final_answers
        elif beam:
            return [beam[0][0]]
        else:
            return []

    else:
        model.to(device)
        model.eval()

        BeamEntry = tuple[list[str], int, float]  # ([generated_words], current_word_idx, cumulative_logprob)
        beam = [([], 0, 0.0)]
        answer_list = []

        for step in range(max_steps):
            current_beam_width = max(1, int(initial_beam_width * (beam_decay ** step)))
            new_beam = []

            for generated_words, word_idx, cum_logprob in beam:
                if word_idx >= len(length_budget):
                    continue  # Already done

                current_word = "" if word_idx >= len(generated_words) else generated_words[word_idx]
                input_text = prompt + " ".join(generated_words + ([current_word] if current_word else []))
                input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

                with torch.no_grad():
                    outputs = model(input_ids)
                    logits = outputs.logits

                probs = torch.nn.functional.log_softmax(logits[0, -1], dim=-1)
                topk_probs, topk_indices = torch.topk(probs, top_k)

                for i in range(top_k):
                    token_id = topk_indices[i].item()
                    token_logprob = topk_probs[i].item()
                    token_str = tokenizer.decode([token_id], skip_special_tokens=True).strip()

                    if not token_str.isalpha():
                        continue

                    new_word = current_word + token_str
                    if len(new_word) > length_budget[word_idx]:
                        continue

                    new_generated_words = generated_words.copy()
                    if len(new_word) == length_budget[word_idx]:
                        # Completed a word
                        if word_idx < len(generated_words):
                            new_generated_words[word_idx] = new_word
                        else:
                            new_generated_words.append(new_word)

                        if word_idx + 1 == len(length_budget):
                            # Full sequence complete
                            answer_list.append((" ".join(new_generated_words), cum_logprob + token_logprob))
                        else:
                            new_beam.append((new_generated_words, word_idx + 1, cum_logprob + token_logprob))
                    else:
                        # Still building current word
                        if word_idx < len(generated_words):
                            new_generated_words[word_idx] = new_word
                        else:
                            new_generated_words.append(new_word)
                        new_beam.append((new_generated_words, word_idx, cum_logprob + token_logprob))

            beam = sorted(new_beam, key=lambda x: x[2], reverse=True)[:current_beam_width]
            # print("Step", step, "Beam:", beam)

            if not beam:
                break

        if answer_list:
            final_answers = list(set([text.lower() for text, _ in sorted(answer_list, key=lambda x: x[1], reverse=True)]))
            final_answers.sort()
            return final_answers
        else:
            return []

In [193]:
from tqdm import tqdm
acc = 0
for idx in tqdm(range(len(clues[:1000]))):
    # idx = 7
    clue = clues[idx]
    answer = answers[idx]
    length = lengths[idx]
    # print(answer)
    # print(clue)
    # print(length)
    prompt = prompt_generator(clue,length)

    decoded = beam_search_decoder(
        prompt=prompt+" ",
        length_budget=length,
        model=model,
        tokenizer=tokenizer,
        initial_beam_width=20,
        top_k=10,
    )
    acc += answer in decoded
    # print("Generated:", decoded)
    # print(answer in decoded)
    # print(len(decoded))

    if idx%100==0:
        print(idx, acc*100/(idx+1))

  0%|          | 1/1000 [00:01<21:59,  1.32s/it]

0 0.0


 10%|█         | 101/1000 [01:55<13:34,  1.10it/s]

100 14.851485148514852


 20%|██        | 201/1000 [03:53<15:11,  1.14s/it]

200 12.437810945273633


 30%|███       | 301/1000 [05:52<13:35,  1.17s/it]

300 10.299003322259136


 40%|████      | 401/1000 [07:47<11:27,  1.15s/it]

400 10.972568578553616


 50%|█████     | 501/1000 [09:42<10:43,  1.29s/it]

500 10.578842315369261


 60%|██████    | 601/1000 [11:38<08:19,  1.25s/it]

600 10.8153078202995


 70%|███████   | 701/1000 [13:32<05:51,  1.18s/it]

700 11.126961483594865


 80%|████████  | 801/1000 [15:23<03:48,  1.15s/it]

800 10.986267166042447


 90%|█████████ | 901/1000 [17:24<02:03,  1.25s/it]

900 11.320754716981131


100%|██████████| 1000/1000 [19:19<00:00,  1.16s/it]


In [182]:
print(acc/1000)

0.305
