# Query generation

This task is about generating a list of queries to maximise recall of the specific documents cited by each original article.


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pprint import pprint
from joblib import Memory
import jsonlines
import os
import re
import wandb

cwd = os.getcwd()
os.sys.path.append(cwd)

from inference import ModelInference
import gym
from tqdm import tqdm
import logging
import util

logging.basicConfig(level=logging.INFO)
memory = Memory("data/cache", verbose=0)
base_mistral_model = "mistralai/Mistral-7B-Instruct-v0.2"

  from .autonotebook import tqdm as notebook_tqdm


## Generate SFT data

We can get our training off the ground by generating a set of data for supervised fine-tuning. We will generate seed samples, evaluate them, and augment them with some heuristics.


### Generate seed samples


In [3]:
max_tokens = 512
inference = ModelInference(
    model_path=base_mistral_model,
)

INFO:root:Loading model...
Loading checkpoint shards: 100%|██████████| 3/3 [00:12<00:00,  4.13s/it]


In [4]:
@memory.cache
def generate(*args, **kwargs):
    response = inference.generate_response(*args, **kwargs)
    return response

In [7]:
"""
Generate samples of query sets for literature review titles
"""

logging.info("Loading documents...")
doc_ids = gym.db.get_doc_ids_with_citations()
doc_ids = sorted(doc_ids)
logging.info(f"Loaded {len(doc_ids)} document IDs")

samples_path = "data/query_generation_samples.jsonl"

doc_ids = doc_ids[:1000]
logging.info(f"Loading {len(doc_ids)} documents...")
docs = gym.db.get_documents(doc_ids)
docs = util.dedup_list(docs, key=lambda x: x["doc_id"])
print(f"Loaded {len(docs)} documents")

samples_per_title = 1
logging.info("Generating samples...")
samples = []
for doc in tqdm(docs):
    doc_id, title = doc["doc_id"], doc["title"]
    logging.info(f"Title: {title}, Doc ID: {doc_id}")
    prompt = f"Generate a set of short keyword queries to find papers related to the following title: {title}\n\nQueries:\n\n"
    for _ in range(samples_per_title):
        # print(f"\nPrompt: {prompt}")
        response = generate(
            prompt=prompt,
            max_tokens=max_tokens,
            temperature=0.9,
        )
        response = response.split(prompt)[-1].strip()
        # print(f"Response: \n{response}")
        samples.append(
            {
                "task": "query_generation",
                "doc_id": doc_id,
                "title": title,
                "prompt": prompt,
                "response": response,
            }
        )

# Save results
with open(samples_path, "w") as f:
    writer = jsonlines.Writer(f)
    writer.write_all(samples)
    writer.close()

INFO:root:Loading documents...


INFO:root:Loaded 1684 document IDs
INFO:root:Loading 1000 documents...
INFO:root:Generating samples...


Loaded 1000 documents


  0%|          | 0/1000 [00:00<?, ?it/s]INFO:root:Title: FKBP51 and FKBP52 in signaling and disease, Doc ID: 01bd7816303809e8a5e58f2e39277f410f4202bd
INFO:root:Title: Cell Membrane Coating Nanotechnology, Doc ID: 01bea211a9470f6ab5404d189c59499427e72ee3
INFO:root:Title: Specification and epigenetic programming of the human germ line, Doc ID: 01bf9aaa2322162544b7893de8cad57640c86cf6
INFO:root:Title: From Chemical Topology to Molecular Machines (Nobel Lecture)., Doc ID: 01bff8d7631df9be02d03e097e3044aa4271f934
INFO:root:Title: The Lens of Intrinsic Skill Atoms: A Method for Gameful Design, Doc ID: 01c022a88a709eddb7eeade0899bdd3bd3c0cf00
INFO:root:Title: The Basis of Oncoimmunology, Doc ID: 01c0be9ac999ed6e313711d265c49190521e5d08
INFO:root:Title: Advances in the science and technology of carbon nanotubes and their composites: a review, Doc ID: 01c23f080f47379e915d31cd734448f1b6c7b9cc
INFO:root:Title: Hydrogen sulfide (H2S) releasing agents: chemistry and biological applications., Doc ID

### Training wheels

Use a set of heuristics to augment the query generation samples.

- Generate query variants (remove stop words, drop words)
- Remove poorly performing ones
- Alter format. E.g., remove list numbering which creates unnecessary tokens.


In [15]:
eval_queries = memory.cache(gym.evals.eval_queries)

In [16]:
def parse_query_response(text: str):
    items = [line.strip() for line in text.split("\n") if line.strip()]
    items = [re.sub(r"\d+\.", "", item) for item in items]
    items = [item.strip() for item in items]
    items = list(set(items))
    return items


def get_query_results(query, results):
    for result in results["queries"]:
        if result["query"] == query:
            return result
    return None


def filter_non_additive(queries, results):
    # Filter out queries that don't return any additional true pos
    filtered = []
    true_pos_accum = set()
    for query in queries:
        q_results = get_query_results(query, results)
        if q_results is None:
            continue
        new = set(q_results["true_pos"]) - true_pos_accum
        # print(f"Query: {query}, New: {new}")
        if new:
            filtered.append(query)
            true_pos_accum.update(new)
    return filtered


def remove_stopwords(query):
    stopwords = [
        "and",
        "or",
        "the",
        "a",
        "an",
        "of",
        "in",
        "on",
        "for",
        "to",
        "by",
        "as",
        "with",
        "from",
        "at",
        "is",
    ]
    new = " ".join([w for w in query.split() if w not in stopwords])
    return new


def token_combination_variants(query: str):
    tokens = query.split()
    variants = []
    from itertools import combinations

    for i in range(1, len(tokens) + 1):
        for combo in combinations(tokens, i):
            variants.append(" ".join(combo))
    return variants


def try_eval(doc_id, queries):
    # try:
    return eval_queries(doc_id, queries)
    # except Exception as e:
    #     logging.error(f"Error evaluating queries: {e}")
    #     return {}


def write_jsonl(items, path):
    with open(path, "w") as f:
        writer = jsonlines.Writer(f)
        writer.write_all(items)
        writer.close()
    print(f"Wrote {len(items)} items to {path}")

In [18]:
variant_expansion = True
do_filter_non_additive = True
do_remove_stopwords = True
sort_by_metric = None
max_queries = 20
use_wandb = False

if use_wandb:
    run = wandb.init(
        project="query-gen",
        config={
            "model": base_mistral_model,
            "variant_expansion": variant_expansion,
            "filter_non_additive": do_filter_non_additive,
            "remove_stopwords": do_remove_stopwords,
            "sort_by_metric": sort_by_metric,
            "max_queries": max_queries,
        },
    )


for sample in tqdm(samples):
    doc_id = sample["doc_id"]
    response = sample["response"]
    queries = parse_query_response(response)
    sample["parsed"] = queries
    queries = [q for q in queries if len(q.split()) <= 10]
    if do_remove_stopwords:
        queries = [remove_stopwords(q) for q in queries]
    if variant_expansion:
        variants = []
        for query in queries:
            variants.extend(token_combination_variants(query))
        queries = list(set(variants))
    results = try_eval(doc_id, queries)
    if sort_by_metric:
        queries = sorted(
            queries,
            key=lambda q: get_query_results(q, results)[sort_by_metric],
            reverse=True,
        )
    if do_filter_non_additive:
        queries = filter_non_additive(queries, results)
        new_results = try_eval(doc_id, queries)  # Re-evaluate with filtered queries
        # Assert that recall is not lower, because we only removed queries that didn't add any new true pos
        # print(new_results)
        # old_recall, new_recall = results["recall"], new_results["recall"]
        # print(old_recall, new_recall)
        # assert new_results["recall"] >= results["recall"]
        results = new_results
    # if max_queries:
    #     queries = queries[:max_queries]
    sample["processed_queries"] = queries
    sample["formatted_response"] = "\n".join(queries)
    sample["results"] = results


samples_with_results = [sample for sample in samples if sample.get("results")]
f1s = [sample["results"]["f1"] for sample in samples_with_results]
precisions = [sample["results"]["precision"] for sample in samples_with_results]
recalls = [sample["results"]["recall"] for sample in samples_with_results]
avg_f1 = sum(f1s) / len(f1s)
avg_precision = sum(precisions) / len(precisions)
avg_recall = sum(recalls) / len(recalls)
print(f"Average F1: {avg_f1}")
print(f"Average Precision: {avg_precision}")
print(f"Average Recall: {avg_recall}")
if use_wandb:
    wandb.log(
        {
            "f1": avg_f1,
            "precision": avg_precision,
            "recall": avg_recall,
        }
    )


write_jsonl(samples_with_results, "data/query_generation_samples.jsonl")

  0%|          | 0/1000 [00:00<?, ?it/s]

100%|██████████| 1000/1000 [1:11:07<00:00,  4.27s/it]


Average F1: 0.20095605345753564
Average Precision: 0.1835396681804817
Average Recall: 0.27203198511679
Wrote 996 items to data/query_generation_samples.jsonl
 

         452669184 function calls (393590636 primitive calls) in 4269.691 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
   423459 3025.094    0.007 3025.094    0.007 {method 'fetchall' of 'sqlite3.Cursor' objects}
   423459  525.840    0.001  525.840    0.001 {method 'execute' of 'sqlite3.Cursor' objects}
   422561  293.025    0.001  293.025    0.001 db.py:53(<listcomp>)
   422561  140.828    0.000 4020.217    0.010 search.py:7(keyword_search)
   422561   51.682    0.000 3879.389    0.009 db.py:41(document_fts)
27190625/7394   39.502    0.000  142.513    0.019 pickle.py:535(save)
   422561   23.054    0.000   23.416    0.000 evals.py:10(get_true_pos)
   423459   22.046    0.000   22.046    0.000 {built-in method _sqlite3.connect}
25853431/1798   15.498    0.000  133.364    0.074 numpy_pickle.py:322(save)
   449479   13.476    0.000   13.476    0.000 3403170591.py:9(get_query_results)
 27192423   10.402    0.000   13.279  

In [19]:
# Show highest performing samples
samples_with_results = [sample for sample in samples if sample["results"]]
accs = [sample["results"].get("recall", 0) for sample in samples_with_results]
sorted_samples = sorted(samples_with_results, key=lambda x: x["results"].get("f1", 0))
print("Highest performing samples:")
for sample in sorted_samples[-10:]:
    print(f"Title: {sample['title']}")
    print(f"Doc ID: {sample['doc_id']}")
    print(f"True pos: {sample['results']['n_true_pos']}")
    print(f"Recall: {sample['results']['recall']}")
    print(f"Precision: {sample['results']['precision']}")
    print(f"N queries: {len(sample['processed_queries'])}")
    print(f"Top queries:")
    top_queries = sorted(
        sample["results"]["queries"], key=lambda q: q["recall"], reverse=True
    )
    for query in top_queries[:5]:
        print(
            f"Query: {query['query']}, Recall: {query['recall']}, Precision: {query['precision']}"
        )
    print("\n")
# print("Lowest performing samples:")
# for sample in sorted_samples[:10]:
#     print(f"Title: {sample['title']}")
#     print(f"Doc ID: {sample['doc_id']}")
#     print(f"Response: \n{sample['formatted_response']}")
#     print(f"Results: \n{sample['filtered_results']}")
#     print("\n")

Highest performing samples:
Title: Symposium on the Monetary Transmission Mechanism
Doc ID: 0223f8670e31337559e98101249713fb4312e3fa
True pos: 410
Recall: 0.3992210321285373
Precision: 0.7706766917148369
N queries: 15
Top queries:
Query: policy transmission, Recall: 0.09737098344598472, Precision: 0.9999999999000001
Query: Monetary Transmission Mechanism, Recall: 0.09444985394260517, Precision: 0.969999999903
Query: Monetary Transmission, Recall: 0.09444985394260517, Precision: 0.969999999903
Query: Transmission Mechanism, Recall: 0.09152872443922563, Precision: 0.939999999906
Query: interest rates, Recall: 0.08666017526692639, Precision: 0.8899999999110001


Title: Spondylolysis: a critical review
Doc ID: 0468241a78bff8ce98e45f83cca29e687002914e
True pos: 100
Recall: 0.37037037035665293
Precision: 0.9900990098029605
N queries: 3
Top queries:
Query: Spondylolysis, Recall: 0.3666666666530864, Precision: 0.9899999999010001
Query: Spondylolysis review, Recall: 0.007407407407133059, Precis

## Training

### Prepare data

In [25]:
# Format results for training

for sample in samples:
    sample["formatted_response"] = "\n".join(sample["processed_queries"])

samples_path = "data/query_generation_samples.jsonl"
write_jsonl(samples, samples_path)

# Write 10 out to separate files for inspection
for i, sample in enumerate(sorted_samples[-10:]):
    text = f"{sample['prompt'].strip()}\n{sample['formatted_response']}"
    with open(f"data/query_generation_sample_{i}.md", "w") as f:
        f.write(text)

Wrote 1000 items to data/query_generation_samples.jsonl


In [3]:
from datasets import Dataset

samples_path = "data/query_generation_samples.jsonl"

def generate_training_data():
    with open(samples_path, "r") as f:
        reader = jsonlines.Reader(f)
        for sample in reader:
            yield {
                "input": sample["prompt"],
                "output": sample["formatted_response"],
            }


samples = list(generate_training_data())
print(f"Loaded {len(samples)} samples")

ds = Dataset.from_list(samples)

ds.save_to_disk("data/query_generation_train_data")
ds

Loaded 1000 samples


Saving the dataset (1/1 shards): 100%|██████████| 1000/1000 [00:00<00:00, 313710.10 examples/s]


Dataset({
    features: ['input', 'output'],
    num_rows: 1000
})

In [4]:
split = ds.train_test_split(test_size=0.1)
train_data, test_data = split["train"], split["test"]
print(f"Train size: {len(train_data)}, Test size: {len(test_data)}")

Train size: 900, Test size: 100


### Load base model

In [3]:
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from peft import (
    AutoPeftModelForCausalLM,
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
)
from trl import SFTTrainer
from transformers import TrainingArguments

nf4_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

def load_model():
    return AutoModelForCausalLM.from_pretrained(
        base_mistral_model,
        device_map="auto",
        quantization_config=nf4_config,
        use_cache=False,
    )

def load_tokenizer():
    tokenizer = AutoTokenizer.from_pretrained(
        base_mistral_model,
        padding_side="left",
        add_eos_token=True,
        add_bos_token=True,
    )
    tokenizer.pad_token = tokenizer.eos_token
    return tokenizer



In [4]:
model = load_model()
tokenizer = load_tokenizer()

INFO:accelerate.utils.modeling:We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).
Loading checkpoint shards: 100%|██████████| 3/3 [00:13<00:00,  4.55s/it]


In [7]:
def generate_response(prompt, model):
    encoded_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
    model_inputs = encoded_input.to("cuda")
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=500,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )
    decoded_output = tokenizer.batch_decode(generated_ids)
    output = decoded_output[0].replace(prompt, "")
    output = output.replace(bos_token, "")
    output = output.replace(eos_token, "")
    return output


prompt = "What is the capital of France?"
response = generate_response(prompt, model)

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


KeyboardInterrupt: 

In [5]:
def generate_response(prompt, model, use_chat_template=False):
    if use_chat_template:
        encoded = tokenizer.apply_chat_template(prompt, return_tensors="pt")
    else:
        encoded = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
    encoded = encoded.to("cuda")
    generated = model.generate(
        **encoded,
        max_new_tokens=500,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )
    decoded = tokenizer.batch_decode(generated)
    output = decoded[0].replace(prompt, "")
    # output = output.replace(bos_token, "")
    # output = output.replace(eos_token, "")
    return output

In [6]:
messages = [
    {"role": "user", "content": "What is your favourite condiment?"},
    {"role": "assistant", "content": "Well, I'm quite partial to a good squeeze of fresh lemon juice. It adds just the right amount of zesty flavour to whatever I'm cooking up in the kitchen!"},
    {"role": "user", "content": "Do you have mayonnaise recipes?"}
]

prompt = "What is the capital of France?"
print(generate_response(prompt, model, use_chat_template=False))

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


KeyboardInterrupt: 

### Prepare training

In [7]:
def make_train_prompt(sample):
    return f"{sample['input'].strip()}\n{sample['output']}"

print(make_train_prompt(train_data[0]))

Generate a set of short keyword queries to find papers related to the following title: Dimensional Comparison Theory

Queries:
Theory dimensional comparison
comparison paper
Dimensional comparison theory
Comparison theory
Dimensional comparison


In [8]:
peft_config = LoraConfig(
    lora_alpha=64, lora_dropout=0.05, r=32, bias="none", task_type="CAUSAL_LM"
)

model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, peft_config)

### Train

In [10]:
output_dir = "data/models/query_generation"
checkpoint_dir = "data/models/query_generation/checkpoint-2000"

In [11]:
args = TrainingArguments(
    output_dir=output_dir,
    # overwrite_output_dir=False,
    # resume_from_checkpoint=checkpoint_dir,
    # num_train_epochs=5,
    max_steps=250,
    per_device_train_batch_size=1,
    warmup_steps=0.03,
    logging_steps=50,
    save_strategy="steps",
    save_steps=1000,
    # evaluation_strategy="epoch",
    evaluation_strategy="steps",
    eval_steps=50,  # comment out this line if you want to evaluate at the end of each epoch
    learning_rate=2e-4,
    bf16=True,
    lr_scheduler_type="constant",
    report_to="wandb",
)


trainer = SFTTrainer(
    model=model,
    peft_config=peft_config,
    max_seq_length=512,
    tokenizer=tokenizer,
    packing=True,
    formatting_func=make_train_prompt,
    args=args,
    train_dataset=train_data,
    eval_dataset=test_data,
    dataset_batch_size=500,
)

trainer.train()
# trainer.train(checkpoint_dir)

trainer.save_model(output_dir)

merged_model = model.merge_and_unload()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcyclecycle[0m. Use [1m`wandb login --relogin`[0m to force relogin


You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
 20%|██        | 50/250 [01:49<07:10,  2.15s/it]

{'loss': 1.521, 'learning_rate': 0.0002, 'epoch': 0.06}


                                                
 20%|██        | 50/250 [01:58<07:10,  2.15s/it]

{'eval_loss': 1.4929746389389038, 'eval_runtime': 9.8648, 'eval_samples_per_second': 10.137, 'eval_steps_per_second': 1.318, 'epoch': 0.06}


 40%|████      | 100/250 [03:46<05:22,  2.15s/it]

{'loss': 1.312, 'learning_rate': 0.0002, 'epoch': 0.11}


                                                 
 40%|████      | 100/250 [03:56<05:22,  2.15s/it]

{'eval_loss': 1.3449474573135376, 'eval_runtime': 9.858, 'eval_samples_per_second': 10.144, 'eval_steps_per_second': 1.319, 'epoch': 0.11}


 60%|██████    | 150/250 [05:43<03:34,  2.15s/it]

{'loss': 1.2293, 'learning_rate': 0.0002, 'epoch': 0.17}


                                                 
 60%|██████    | 150/250 [05:53<03:34,  2.15s/it]

{'eval_loss': 1.353052020072937, 'eval_runtime': 9.8586, 'eval_samples_per_second': 10.143, 'eval_steps_per_second': 1.319, 'epoch': 0.17}


 80%|████████  | 200/250 [07:41<01:47,  2.15s/it]

{'loss': 1.2271, 'learning_rate': 0.0002, 'epoch': 0.22}


                                                 
 80%|████████  | 200/250 [07:51<01:47,  2.15s/it]

{'eval_loss': 1.3494257926940918, 'eval_runtime': 9.8578, 'eval_samples_per_second': 10.144, 'eval_steps_per_second': 1.319, 'epoch': 0.22}


100%|██████████| 250/250 [09:38<00:00,  2.15s/it]

{'loss': 1.2125, 'learning_rate': 0.0002, 'epoch': 0.28}


                                                 
100%|██████████| 250/250 [09:48<00:00,  2.35s/it]


{'eval_loss': 1.3381807804107666, 'eval_runtime': 9.8578, 'eval_samples_per_second': 10.144, 'eval_steps_per_second': 1.319, 'epoch': 0.28}
{'train_runtime': 589.7803, 'train_samples_per_second': 0.424, 'train_steps_per_second': 0.424, 'train_loss': 1.300392562866211, 'epoch': 0.28}


In [21]:
def generate_response(prompt, model):
    encoded_input = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
    model_inputs = encoded_input.to("cuda")
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=500,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id,
    )
    decoded_output = tokenizer.batch_decode(generated_ids)
    output = decoded_output[0].replace(prompt, "")
    # output = output.replace(bos_token, "")
    # output = output.replace(eos_token, "")
    return output

In [9]:
prompt = "What is the capital of France?"
# prompt = test_data[1]["input"]
# prompt

print(generate_response(prompt, model))

# model.eval()
# with torch.no_grad():
    # print(tokenizer.decode(model.generate(**model_input, max_new_tokens=512)[0], skip_special_tokens=True))

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


NameError: name 'bos_token' is not defined

### Load and test model

In [11]:
# N.B kill and restart the kernel before running this cell to clear existing model from memory before loading the merged model

from peft import PeftModel

model, tokenizer = load_model(), load_tokenizer()
ft_model = PeftModel.from_pretrained(model, checkpoint_dir)

In [16]:
prompt = test_data[1]["input"]
prompt

'Generate a set of short keyword queries to find papers related to the following title: The shift in stages of the nutrition transition in the developing world differs from past experiences!\n\nQueries:\n\n'

In [13]:
prompt = "What is the capital of France?"
# model_input = tokenizer(prompt, return_tensors="pt").to("cuda")

ft_model.eval()
with torch.no_grad():
    print(tokenizer.decode(ft_model.generate(**model_input, max_new_tokens=512)[0], skip_special_tokens=True))

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


What is the capital of France?


In [None]:
model = accelerator.prepare_model(model)