# Loading card data

In [3]:
import json, jsonschema
import time
import tqdm
import numpy as np
import openai
import asyncio
from cards import CardDataset

In [5]:
cards = CardDataset.from_file("oracle-cards-20250405210637.json")

INFO 2025-05-01 23:59:01 [cards:22] Double-faced cards not yet supported, omitting.
INFO 2025-05-01 23:59:01 [cards:132] Imported 30918 of 34458 cards from oracle-cards-20250405210637.json


# [Optional][WIP] Use an LLM to expand the card descriptions

In [None]:
def prepare_postprocessing_prompts(card_descriptions):
    prompts = []
    for text in card_descriptions:
        prompt = f"""
## ROLE
You are an expert Magic: the Gathering rules analyst.

## TASK
Generate compact, retrieval-oriented annotations for the card below.  
Return ONLY the JSON object described in *Output schema* (inside a ```json block).  
Do **not** repeat the card's rules text or name.

## INPUT
{text}
""" + """
## OUTPUT SCHEMA
```json
{
  "mechanics": ["<up to 7 MTG keywords or shorthand different from the card description, e.g. \"ETB\", \"dies trigger\", \"lifegain\" >"],
  "roles": ["<card roles: ramp, removal, finisher, toolbox, etc.>"],
  "strategies": ["<decks or archetypes it fits: aristocrats, blink, etc.>"],
  "synergies": ["<key tribes, card types, or mechanics it combines with>"],
  "power_band":"<one of: weak, fair, strong, broken>",
}
```
"""
        prompts.append(prompt.strip())
    return prompts

postprocessing_prompts = prepare_postprocessing_prompts(formatted_cards)

## Using Qwen locally with vLLM

In [8]:
from vllm import LLM, SamplingParams

sampling_params = SamplingParams(temperature=0.3, top_p=0.95, max_tokens=512)
llm = LLM(model="Qwen/Qwen2.5-7B-Instruct", quantization="fp8")
outputs = llm.chat(messages=[{"role": "user", "content": prompt} for prompt in postprocessing_prompts], sampling_params=sampling_params)

  from .autonotebook import tqdm as notebook_tqdm


INFO 04-23 15:51:07 [__init__.py:239] Automatically detected platform cuda.


2025-04-23 15:51:07,749	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


INFO 04-23 15:51:11 [config.py:689] This model supports multiple tasks: {'score', 'embed', 'generate', 'classify', 'reward'}. Defaulting to 'generate'.
INFO 04-23 15:51:11 [config.py:1901] Chunked prefill is enabled with max_num_batched_tokens=8192.
INFO 04-23 15:51:12 [core.py:61] Initializing a V1 LLM engine (v0.8.4) with config: model='Qwen/Qwen2.5-7B-Instruct', speculative_config=None, tokenizer='Qwen/Qwen2.5-7B-Instruct', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=32768, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=fp8, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoi

Loading safetensors checkpoint shards:   0% Completed | 0/4 [00:00<?, ?it/s]
Loading safetensors checkpoint shards:  25% Completed | 1/4 [00:00<00:01,  2.85it/s]
Loading safetensors checkpoint shards:  50% Completed | 2/4 [00:00<00:00,  2.78it/s]
Loading safetensors checkpoint shards:  75% Completed | 3/4 [00:01<00:00,  2.66it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:01<00:00,  2.63it/s]
Loading safetensors checkpoint shards: 100% Completed | 4/4 [00:01<00:00,  2.67it/s]



INFO 04-23 15:51:15 [loader.py:458] Loading weights took 1.54 seconds
INFO 04-23 15:51:15 [gpu_model_runner.py:1291] Model loading took 8.1372 GiB and 1.889585 seconds
INFO 04-23 15:51:19 [backends.py:416] Using cache directory: /home/benchislett/.cache/vllm/torch_compile_cache/fdfad17ec6/rank_0_0 for vLLM's torch.compile
INFO 04-23 15:51:19 [backends.py:426] Dynamo bytecode transform time: 3.94 s
INFO 04-23 15:51:21 [backends.py:132] Cache the graph of shape None for later use
INFO 04-23 15:51:32 [backends.py:144] Compiling a graph for general shape takes 12.07 s
INFO 04-23 15:51:43 [monitor.py:33] torch.compile takes 16.00 s in total
INFO 04-23 15:51:44 [kv_cache_utils.py:634] GPU KV cache size: 50,512 tokens
INFO 04-23 15:51:44 [kv_cache_utils.py:637] Maximum concurrency for 32,768 tokens per request: 1.54x
INFO 04-23 15:52:21 [gpu_model_runner.py:1626] Graph capturing finished in 37 secs, took 0.49 GiB
INFO 04-23 15:52:21 [core.py:163] init engine (profile, create kv cache, warmup 

Processed prompts: 100%|██████████| 1/1 [00:02<00:00,  2.56s/it, est. speed input: 4342.67 toks/s, output: 27.37 toks/s]


In [16]:
outputs = llm.chat(messages=[[{"role": "user", "content": prompt}] for prompt in postprocessing_prompts], sampling_params=sampling_params)

Processed prompts: 100%|██████████| 37/37 [00:02<00:00, 13.17it/s, est. speed input: 4260.25 toks/s, output: 1113.38 toks/s]


In [36]:
def parse_output(model_output):
    """Parse the model output into a JSON object."""
    try:
        model_output = model_output.strip("```").strip("json").strip()
        return json.loads(model_output)
    except json.JSONDecodeError as e:
        print(f"Failed to parse JSON: {e}")
        return None

print(formatted_cards[2])
print('\n\n\n')
parse_output(outputs[2].outputs[0].text)

The following is a card from the game Magic: The Gathering.

Name: Daze
Mana cost: {1}{U}
Converted mana cost: 2.0
Type Line: Instant
Oracle text: You may return an Island you control to its owner's hand rather than pay this spell's mana cost.
Counter target spell unless its controller pays {1}.
Power: None
Toughness: None
Loyalty: None






{'mechanics': ['Instant', 'Counter', 'Return'],
 'roles': ['Removal', 'Ramp'],
 'strategies': ['Control', 'Blue Weenie'],
 'synergies': ['Counterspells', 'Islands', 'Blue Spells'],
 'power_band': 'medium',
 'why_pick': 'Flexible removal and mana ramp in a single card.'}

In [45]:
print(postprocessing_prompts[2])

## ROLE
You are an expert Magic: the Gathering rules analyst.

## TASK
Generate compact, retrieval-oriented annotations for the card below.  
Return ONLY the JSON object described in *Output schema* (inside a ```json block).  
Do **not** repeat the card's rules text or name.

## INPUT
The following is a card from the game Magic: The Gathering.

Name: Daze
Mana cost: {1}{U}
Converted mana cost: 2.0
Type Line: Instant
Oracle text: You may return an Island you control to its owner's hand rather than pay this spell's mana cost.
Counter target spell unless its controller pays {1}.
Power: None
Toughness: None
Loyalty: None

## OUTPUT SCHEMA
```json
{
  "mechanics": ["<up to 7 MTG keywords or shorthand, e.g. "ETB", "dies trigger", "lifegain" >"],
  "roles":    ["<card roles: ramp, removal, finisher, toolbox, etc.>"],
  "strategies":["<decks or archetypes it fits: aristocrats, blink, etc.>"],
  "synergies":["<key tribes, card types, or mechanics it combines with>"],
  "power_band":"<one of: lo

# Embed the card descriptions

In [13]:
len(formatted_cards)

31095

In [19]:
embedding_client = openai.Client(base_url="http://localhost:30000/v1", api_key="None")

In [None]:
all_embeddings = []

batch_size = 100
for i in tqdm.tqdm(range(0, len(formatted_cards), batch_size)):
    batch = formatted_cards[i:i + batch_size]
    response = embedding_client.embeddings.create(
        model="Alibaba-NLP/gte-Qwen2-7B-instruct",
        input=batch,
        user="user"
    )
    all_embeddings.extend(map(lambda s: s.embedding, response.data))

len(all_embeddings), len(all_embeddings[0])

100%|██████████| 311/311 [23:46<00:00,  4.59s/it]


(31095, 3584)

In [40]:
all_embeddings_np = np.array(all_embeddings)
all_embeddings_np.shape, all_embeddings_np.dtype

((31095, 3584), dtype('float64'))

In [41]:
np.save("formatted_cards_embeddings.npy", all_embeddings_np)

In [12]:
all_embeddings = np.load("formatted_cards_embeddings.npy")

In [13]:
import textwrap
import chromadb
import pandas as pd
from IPython.display import Markdown
from chromadb import Documents, EmbeddingFunction, Embeddings

In [14]:
chroma_client = chromadb.Client()
db = chroma_client.create_collection(
    "MTGCardsDatabase",
)

In [17]:
assert len(all_embeddings) == len(formatted_cards)
batch_size = 100
for batch in tqdm.tqdm(range(0, len(formatted_cards), batch_size)):
    batch_documents = formatted_cards[batch:batch + batch_size]
    batch_embeddings = all_embeddings[batch:batch + batch_size]
    batch_ids = sample_cards[batch:batch + batch_size]
    db.add(
        documents=batch_documents,
        embeddings=batch_embeddings,
        ids=batch_ids,
    )
# db.add(
#     documents=formatted_cards,
#     embeddings=all_embeddings,
#     ids=sample_cards,
# )

  0%|          | 0/311 [00:00<?, ?it/s]

100%|██████████| 311/311 [00:25<00:00, 12.12it/s]


In [40]:
test_query = "colours Black, Green elf draw a card"
query_embedding = response = embedding_client.embeddings.create(
    model="Alibaba-NLP/gte-Qwen2-7B-instruct",
    input=test_query,
    user="user"
)

In [41]:
query_embedding = np.array(query_embedding.data[0].embedding)
query_embedding

array([-0.01178741,  0.01060486,  0.01307678, ..., -0.02017212,
       -0.01441193, -0.00382042], shape=(3584,))

In [50]:
res = db.query(
    query_embeddings=[query_embedding],
    n_results=20,
)

In [51]:
res["documents"]

[['The following is a card from the game Magic: The Gathering.\n\nName: Young Necromancer\nMana cost: {4}{B}\nConverted mana cost: 5.0\nType Line: <type> Creature — Human Warlock </type>\nColors: <colors> Black </colors>\nOracle text: <oracle_text> When this creature enters, you may exile two cards from your graveyard. When you do, return target creature card from your graveyard to the battlefield. </oracle_text>\nPower: 2\nToughness: 3',
  'The following is a card from the game Magic: The Gathering.\n\nName: Skemfar Shadowsage\nMana cost: {3}{B}\nConverted mana cost: 4.0\nType Line: <type> Creature — Elf Cleric </type>\nColors: <colors> Black </colors>\nOracle text: <oracle_text> When this creature enters, choose one —\n• Each opponent loses X life, where X is the greatest number of creatures you control that have a creature type in common.\n• You gain X life, where X is the greatest number of creatures you control that have a creature type in common. </oracle_text>\nPower: 2\nToughne

In [52]:
# format the documents with the query for re-ranking
documents_string = ""
for i, doc in enumerate(res["documents"][0]):
    documents_string += f"[{i + 1}]: {doc}\n"

rerank_prompt = f"""
You are a language model responsible for re-ranking search findings in an application for finding Magic: The Gathering cards.
Your purpose is to rank documents based on their relevance to the user's query. Consider the query details, the content of the card descriptions, and the context of the game.
The following is a user query: <query> {test_query} </query>. 

Please output only a JSON object with the schema described below. Do not include any other text or explanations. Ensure that the output is valid JSON, and that all documents are included in the ranking exactly once.

Here is an example of the output format assuming there are 10 input documents and they are already in the correct order:
```json
{"{"}
  "ranking": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
{"}"}
```

Here are the documents:\n\n""" + documents_string

llm_client = openai.Client(base_url="http://localhost:30001/v1", api_key="None")
rerank_response = llm_client.chat.completions.create(
    model="google/gemma-3-27b-it",
    messages=[{"role": "user", "content": rerank_prompt}],
)


In [53]:
print(rerank_response)

ChatCompletion(id='8c579c06394545f48325f1bb4a17e7f6', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Okay, I understand. You want me to re-rank Magic: The Gathering card results based on a user\'s query. I will act as the re-ranking engine, focusing on relevance to the query, card content, and game context.\n\nHowever, you\'ve provided a very strange and fragmented set of "documents" and a "query." It looks like a lot of text has been chopped up and concatenated. I\'m going to assume you want me to still demonstrate how I would work, so I\'ll treat each of the sections delimited by "text" and "Colors" as a potential card or document. I will also try to infer the user query based on the overall content.\n\n**Inferred User Query:** "Find me cards that interact with or are related to \'life,\' \'sacrifice,\' \'control,\' \'draw,\' \'the,\' \'a,\' \'and,\' \'or,\' \'the,\' \'to,\' \'with,\' and \'this,\' as these words appear repeatedly

In [54]:
def extract_rerank_response(rerank_response):
    """Extract the rerank response from the model output."""
    try:
        rerank_response = rerank_response.choices[0].message.content
        start_index = rerank_response.index("```json") + len("```json") + 1
        # end index is at the ``` after the JSON
        end_index = start_index + rerank_response[start_index:].index("```")
        data = json.loads(rerank_response[start_index:end_index])
        ranking = data["ranking"]
        if not isinstance(ranking, list):
            raise ValueError("Ranking is not a list")
        # if len(ranking) != len(res["documents"][0]):
        #     raise ValueError("Ranking length does not match number of documents")
        # if not len(set(ranking)) == len(ranking):
        #     raise ValueError("Ranking contains duplicates")
        return [i - 1 for i in ranking]  # convert to 0-indexed
    except ValueError as e:
        print(f"Failed to extract JSON: {e}")
        return None

print(extract_rerank_response(rerank_response))

Failed to extract JSON: substring not found
None


In [55]:
[res["documents"][0][i] for i in extract_rerank_response(rerank_response)]

Failed to extract JSON: substring not found


TypeError: 'NoneType' object is not iterable

# [WIP] ColBERT Embeddings

In [9]:
from ragatouille import RAGPretrainedModel
RAG = RAGPretrainedModel.from_pretrained("colbert-ir/colbertv2.0")

  from .autonotebook import tqdm as notebook_tqdm
  self.scaler = torch.cuda.amp.GradScaler()


In [14]:
documents = formatted_cards
document_ids = sample_cards
index_path = RAG.index(
    index_name="mtg_cards",
    collection=documents,
    document_ids=document_ids,
    split_documents=False
)

This is a behaviour change from RAGatouille 0.8.0 onwards.
This works fine for most users and smallish datasets, but can be considerably slower than FAISS and could cause worse results in some situations.
If you're confident with FAISS working on your machine, pass use_faiss=True to revert to the FAISS-using behaviour.
--------------------


[Apr 26, 21:56:09] #> Creating directory .ragatouille/colbert/indexes/mtg_cards 


[Apr 26, 21:56:10] [0] 		 #> Encoding 30907 passages..


  self.scaler = torch.cuda.amp.GradScaler()
  return torch.cuda.amp.autocast() if self.activated else NullContextManager()


[Apr 26, 21:56:33] [0] 		 avg_doclen_est = 70.08810424804688 	 len(local_sample) = 30,907
[Apr 26, 21:56:33] [0] 		 Creating 16,384 partitions.
[Apr 26, 21:56:33] [0] 		 *Estimated* 2,179,389 embeddings.
[Apr 26, 21:56:33] [0] 		 #> Saving the indexing plan to .ragatouille/colbert/indexes/mtg_cards/plan.json ..
used 20 iterations (0.904s) to cluster 2116213 items into 16384 clusters
[Apr 26, 21:56:34] Loading decompress_residuals_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


[Apr 26, 21:57:04] Loading packbits_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)...


If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].


[0.02, 0.019, 0.021, 0.019, 0.018, 0.02, 0.02, 0.018, 0.019, 0.018, 0.018, 0.02, 0.019, 0.02, 0.019, 0.019, 0.017, 0.019, 0.019, 0.019, 0.019, 0.018, 0.019, 0.02, 0.019, 0.018, 0.018, 0.02, 0.02, 0.019, 0.019, 0.02, 0.019, 0.018, 0.019, 0.019, 0.021, 0.018, 0.02, 0.022, 0.02, 0.019, 0.02, 0.019, 0.019, 0.019, 0.018, 0.021, 0.022, 0.017, 0.019, 0.018, 0.018, 0.02, 0.019, 0.019, 0.019, 0.018, 0.022, 0.018, 0.02, 0.018, 0.019, 0.022, 0.02, 0.022, 0.02, 0.02, 0.018, 0.019, 0.02, 0.02, 0.018, 0.018, 0.018, 0.019, 0.019, 0.019, 0.018, 0.021, 0.02, 0.02, 0.019, 0.019, 0.018, 0.02, 0.019, 0.019, 0.018, 0.021, 0.02, 0.02, 0.019, 0.02, 0.019, 0.019, 0.02, 0.019, 0.021, 0.018, 0.018, 0.018, 0.018, 0.019, 0.02, 0.018, 0.019, 0.019, 0.021, 0.019, 0.021, 0.019, 0.019, 0.018, 0.019, 0.019, 0.02, 0.019, 0.018, 0.019, 0.019, 0.019, 0.02, 0.02, 0.02, 0.02, 0.019, 0.019]


0it [00:00, ?it/s]

[Apr 26, 21:57:35] [0] 		 #> Encoding 25000 passages..


  return torch.cuda.amp.autocast() if self.activated else NullContextManager()
1it [00:20, 20.57s/it]

[Apr 26, 21:57:56] [0] 		 #> Encoding 6095 passages..


2it [00:25, 12.54s/it]
100%|██████████| 2/2 [00:00<00:00, 608.71it/s]

[Apr 26, 21:58:01] #> Optimizing IVF to store map from centroids to list of pids..
[Apr 26, 21:58:01] #> Building the emb2pid mapping..
[Apr 26, 21:58:01] len(emb2pid) = 2179473



100%|██████████| 16384/16384 [00:00<00:00, 257999.57it/s]

[Apr 26, 21:58:01] #> Saved optimized IVF to .ragatouille/colbert/indexes/mtg_cards/ivf.pid.pt





Done indexing!


In [18]:
query = "golgari elves that draw cards"
RAG.search(query, k=400)

  return torch.cuda.amp.autocast() if self.activated else NullContextManager()


[{'content': 'The following is a card from the game Magic: The Gathering.\n\nGolgari Grave-Troll.\nCreature — Troll Skeleton.\nMana cost: {4}{G}\nConverted mana cost: 5.0\nColors: Green.\nOracle text: This creature enters with a +1/+1 counter on it for each creature card in your graveyard.\n{1}, Remove a +1/+1 counter from this creature: Regenerate this creature.\nDredge 6 (If you would draw a card, you may mill six cards instead. If you do, return this card from your graveyard to your hand.)\nPower: 0\nToughness: 0\nKeywords: Dredge, Mill',
  'score': 21.0625,
  'rank': 1,
  'document_id': 'Golgari Grave-Troll',
  'passage_id': 12639},
 {'content': 'The following is a card from the game Magic: The Gathering.\n\nGolgari Cluestone.\nArtifact.\nMana cost: {3}\nConverted mana cost: 3.0\nColors: Colorless.\nOracle text: {T}: Add {B} or {G}.\n{B}{G}, {T}, Sacrifice this artifact: Draw a card.',
  'score': 20.578125,
  'rank': 2,
  'document_id': 'Golgari Cluestone',
  'passage_id': 28282},


# Create a synthetic dataset for fine-tuning ColBERTv2

In [10]:
def prepare_synthetic_querygen_prompt(formatted_card: str) -> str:
    prompt = """You are an expert AI assisting in creating a high-quality, diverse synthetic dataset to train Information Retrieval systems for Magic: The Gathering cards.
Your task is to analyse the card description below and generate a set of rich, high-quality potential queries for which the given card would rank very highly.

The output queries should be diverse, covering various aspects of the card, including its mechanics, roles, strategies, and synergies. The queries should be in natural language and should not be too similar to each other.
They should feature a variety of keywords and phrases that a user might use when searching for cards like this one, including specific jargons or slang used in the Magic: The Gathering community. Avoid naming the card directly or repeating its exact text.
You should submit about 5-10 brief queries. Include at least two queries that are short collections of keywords or phrases, and at least two queries that are full sentences.
Your output should be a JSON object with the same schema as the following example:
<schema>
{
    "hypothetical_queries": ["<query1>", "<query2>", "<query3>", "<query4>", "<query5>", "<query6>"]
}
</schema>""".strip() + f"""
<input>
{formatted_card}
</input>
"""
    return prompt.strip()
    
def postprocess_synthetic_querygen_response(model_output: str) -> list[str]:
    try:
        model_output = model_output.strip().strip("```").strip().strip("json").strip()
        data = json.loads(model_output)
        schema = { "type": "object", "properties": { "hypothetical_queries": { "type": "array", "items": { "type": "string" } } }, "required": ["hypothetical_queries"] }
        jsonschema.validate(data, schema)
        return data["hypothetical_queries"]
    except json.JSONDecodeError as e:
        # print(f"Failed to parse JSON: {e}")
        return None
    except jsonschema.ValidationError as e:
        # print(f"Validation error: {e}")
        return None

synthetic_querygen_prompts = [(card_name, prepare_synthetic_querygen_prompt(card)) for card_name, card in cards.formatted_cards.items()]

# TODO: Evaluate the model against the base model on a test subset of the card/query pairs

In [18]:
len(cards.card_data)

30497

In [None]:
import os


OPENROUTER_API_KEY = os.environ.get("OPENROUTER_API_KEY")

In [14]:
large_llm_client = openai.AsyncClient(base_url="https://openrouter.ai/api/v1", api_key=OPENROUTER_API_KEY)

In [19]:
initial_sample_prompt_idxs = [i for i in range(len(synthetic_querygen_prompts))]
initial_sample_prompt_pool = [synthetic_querygen_prompts[i] for i in initial_sample_prompt_idxs]
print(len(initial_sample_prompt_pool))

30497


In [36]:
# Load the previous batch from the output file
queries_data = {}
with open("synthetic_queries.jsonl", "r") as f:
    for line in f:
        data: dict = json.loads(line)
        if not data.get("queries", []) or data.get("id") is None:
            continue
        queries_data[data["card_name"]] = data["queries"]
print(f"Loaded {len(queries_data)} queries from file")

Loaded 24876 queries from file


In [24]:
queries_data

{'Nissa, Worldsoul Speaker': ['Green landfall cards that generate energy',
  'Elf Druid commander with energy synergy',
  'What cards let me cast permanents for energy instead of mana?',
  'Best energy counter generators in MTG',
  'Landfall, energy, green, elf',
  'How to build around a green energy commander?',
  'MTG cards that reduce permanent spell costs with energy',
  'High synergy green landfall commanders',
  '3/3 green creature landfall energy',
  'Pay energy instead of mana for spells'],
 'Static Orb': ['artifact that limits untapping permanents',
  'stax piece for untap denial',
  'how to lock opponents out of untapping',
  'best cards for untap restriction in EDH',
  'static orb effect without naming it',
  'untap hate, artifact, stax',
  'what artifact stops multiple untaps per turn?',
  'competitive stax pieces under 3 mana',
  'colorless lock pieces for vintage',
  'permanent untap restriction MTG'],
 'Sensory Deprivation': ['blue aura that weakens creatures',
  'cheap 

In [37]:
batch_size = 49
all_responses = []

In [38]:
len(all_responses)

0

In [39]:
for i in tqdm.tqdm(range(0, len(initial_sample_prompt_pool), batch_size)):
    batch = initial_sample_prompt_pool[i:i + batch_size]
    batch_names, batch = zip(*batch)
    # Check if the batch is already in the output file

    valid_names = [name for name in batch_names if name not in queries_data]
    valid_positions = [j for j, name in enumerate(batch_names) if name not in queries_data]
    if not valid_positions:
        continue
    else:
        batch = [batch[j] for j in valid_positions]

    start_time = time.time()
    
    batch_futures = [large_llm_client.chat.completions.create(
        model="deepseek/deepseek-chat-v3-0324:cost",
        messages=[{"role": "user", "content": prompt}],
        max_tokens=512,
        n=1,
        stream=False,
        temperature=0.3,
    ) for prompt in batch]
    batch_responses = await asyncio.gather(*batch_futures)
    decoded_responses = [postprocess_synthetic_querygen_response(response.choices[0].message.content) for response in batch_responses]
    all_responses.extend(batch_responses)
    with open("synthetic_queries.jsonl", "a") as f:
        for j, response in enumerate(decoded_responses):
            query_data = {
                "card_name": valid_names[j],
                "queries": response
            }
            f.write(json.dumps(query_data) + "\n")
            queries_data[valid_names[j]] = query_data
    
    end_time = time.time()
    elapsed_time = end_time - start_time
    # if less than 1 second, wait for 1 second
    if elapsed_time < 1:
        time.sleep(1 - elapsed_time + 0.1)

  0%|          | 0/623 [00:00<?, ?it/s]INFO 2025-05-02 00:26:47 [httpx:1740] HTTP Request: POST https://openrouter.ai/api/v1/chat/completions "HTTP/1.1 200 OK"
INFO 2025-05-02 00:26:47 [httpx:1740] HTTP Request: POST https://openrouter.ai/api/v1/chat/completions "HTTP/1.1 200 OK"
INFO 2025-05-02 00:26:47 [httpx:1740] HTTP Request: POST https://openrouter.ai/api/v1/chat/completions "HTTP/1.1 200 OK"
INFO 2025-05-02 00:26:47 [httpx:1740] HTTP Request: POST https://openrouter.ai/api/v1/chat/completions "HTTP/1.1 200 OK"
INFO 2025-05-02 00:26:47 [httpx:1740] HTTP Request: POST https://openrouter.ai/api/v1/chat/completions "HTTP/1.1 200 OK"
INFO 2025-05-02 00:26:47 [httpx:1740] HTTP Request: POST https://openrouter.ai/api/v1/chat/completions "HTTP/1.1 200 OK"
INFO 2025-05-02 00:26:47 [httpx:1740] HTTP Request: POST https://openrouter.ai/api/v1/chat/completions "HTTP/1.1 200 OK"
INFO 2025-05-02 00:26:47 [httpx:1740] HTTP Request: POST https://openrouter.ai/api/v1/chat/completions "HTTP/1.1 2

In [71]:
batch_responses = await asyncio.gather(*batch_futures)

In [None]:
for i in range(50):
    idx = initial_sample_prompt_idxs[i]
    # postprocess_synthetic_querygen_response(batch_responses[i].choices[0].message.content)
    print(f"Card  {i}: {sample_cards[idx]}")
    print(f"Description {i}: {formatted_cards[idx]}")
    print(f"Response {i}: {postprocess_synthetic_querygen_response(batch_responses[i].choices[0].message.content)}")

In [29]:
sum([batch_responses[i].usage.completion_tokens + batch_responses[i].usage.prompt_tokens for i in range(len(batch_responses))]) / len(batch_responses)

480.0

## Fine-tuning

In [42]:
from ragatouille import RAGTrainer
trainer = RAGTrainer(model_name="MTGColBERTv_0_0", pretrained_model_name="colbert-ir/colbertv2.0", language_code="en")

  from .autonotebook import tqdm as notebook_tqdm


In [18]:
train_pairs = []
test_pairs = []
train_pairs_idxs = []
test_pairs_idxs = []
idxs_for_train = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
idxs_for_test = []
for idx, queries_for_card in queries_data.items():
    doc = formatted_cards[idx]
    num_queries = len(queries_for_card)
    for i, query in enumerate(queries_for_card):
        if i in idxs_for_train:
            train_pairs.append((query, doc))
            train_pairs_idxs.append((query, idx))
        elif i in idxs_for_test:
            test_pairs.append((query, doc))
            test_pairs_idxs.append((query, idx))
        else:
            print(f"Invalid index {i} for query {query} for card {idx}")

## Mine Negatives

In [22]:
all_cards_by_name

{'Nissa, Worldsoul Speaker': {'name': 'Nissa, Worldsoul Speaker',
  'mana_cost': '{3}{G}',
  'cmc': 4.0,
  'type': 'Legendary Creature — Elf Druid',
  'text': 'Landfall — Whenever a land you control enters, you get {E}{E} (two energy counters).\nYou may pay eight {E} rather than pay the mana cost for permanent spells you cast.',
  'power': '3',
  'toughness': '3',
  'loyalty': None,
  'colors': ['G'],
  'keywords': ['Landfall']},
 'Static Orb': {'name': 'Static Orb',
  'mana_cost': '{3}',
  'cmc': 3.0,
  'type': 'Artifact',
  'text': "As long as this artifact is untapped, players can't untap more than two permanents during their untap steps.",
  'power': None,
  'toughness': None,
  'loyalty': None,
  'colors': [],
  'keywords': []},
 'Sensory Deprivation': {'name': 'Sensory Deprivation',
  'mana_cost': '{U}',
  'cmc': 1.0,
  'type': 'Enchantment — Aura',
  'text': 'Enchant creature\nEnchanted creature gets -3/-0.',
  'power': None,
  'toughness': None,
  'loyalty': None,
  'colors': [

In [24]:
formatted_cards[0], sample_cards[0]

('Nissa, Worldsoul Speaker.\nLegendary Creature — Elf Druid.\nMana cost: {3}{G}\nConverted mana cost: 4.0\nColors: Green.\nOracle text: Landfall — Whenever a land you control enters, you get {E}{E} (two energy counters).\nYou may pay eight {E} rather than pay the mana cost for permanent spells you cast.\nPower: 3\nToughness: 3\nKeywords: Landfall',
 'Nissa, Worldsoul Speaker')

In [30]:
# Easy negatives: rejection-sample random cards from the database as long as at least one of the key attributes is different
easy_negatives = {}
num_negatives_per_card = 1000
for idx in tqdm.tqdm(range(len(formatted_cards))):
    doc_data = all_cards_by_name[sample_cards[idx]]
    # sample 2 * num_negatives_per_card cards from the database
    sample_idxs = np.random.choice(range(len(sample_cards)), size=int(num_negatives_per_card * 1.5), replace=False)
    sample_card_names = [sample_cards[i] for i in sample_idxs]
    sample_cards_data = [all_cards_by_name[name] for name in sample_card_names]
    # filter out cards that match the current card on all key attributes
    filtered_card_names = [card["name"] for card in sample_cards_data if not (
        set(card["colors"]) == set(doc_data["colors"]) and
        card["type"] == doc_data["type"] and
        card["cmc"] == doc_data["cmc"] and
        set(card["keywords"]) == set(doc_data["keywords"])
    )]
    # sample num_negatives_per_card cards from the filtered cards, with replacement if needed
    if len(filtered_card_names) > num_negatives_per_card:
        filtered_card_names = filtered_card_names[:num_negatives_per_card]
    else:
        filtered_card_names = np.random.choice(filtered_card_names, size=num_negatives_per_card, replace=True)
    easy_negatives[doc_data["name"]] = filtered_card_names

100%|██████████| 31095/31095 [00:59<00:00, 522.32it/s]


In [33]:
# save easy negatives to file
with open("easy_negatives.jsonl", "w") as f:
    for card_name, negatives in easy_negatives.items():
        data = {
            "card_name": card_name,
            "negatives": negatives
        }
        f.write(json.dumps(data) + "\n")

In [19]:
# load the embeddings from file
with open("easy_negatives.jsonl", "r") as f:
    easy_negatives = {}
    for line in f:
        data = json.loads(line)
        easy_negatives[data["card_name"]] = data["negatives"]

In [2]:
# Medium negatives: embed the cards using a dense vector embedding model and find the nearest neighbors in the embedding space
from sklearn.neighbors import NearestNeighbors
import numpy as np
num_neighbors = 100
embeddings = np.load("formatted_cards_embeddings.npy")
nbrs = NearestNeighbors(n_neighbors=num_neighbors, algorithm='auto', n_jobs=-1).fit(embeddings)
# find the nearest neighbors for each card
indices = nbrs.kneighbors(embeddings, return_distance=False)

In [24]:
medium_negatives = {}
for idx, card_name in enumerate(sample_cards):
    # get the indices of the nearest neighbors
    neighbor_indices = indices[idx]
    # get the names of the nearest neighbors
    neighbor_names = [sample_cards[i] for i in neighbor_indices]
    # filter out the current card from the neighbors
    neighbor_names = [name for name in neighbor_names if name != card_name]
    medium_negatives[card_name] = neighbor_names

In [27]:
# save the medium negatives to file
with open("medium_negatives.jsonl", "w") as f:
    for card_name, negatives in medium_negatives.items():
        data = {
            "card_name": card_name,
            "negatives": negatives
        }
        f.write(json.dumps(data) + "\n")

In [29]:
# load the medium negatives from file
with open("medium_negatives.jsonl", "r") as f:
    medium_negatives = {}
    for line in f:
        data = json.loads(line)
        medium_negatives[data["card_name"]] = data["negatives"]

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

idxs = queries_data.keys()
num_queries = [len(queries_data[idx]) for idx in idxs]
plt.figure(figsize=(10, 6))
sns.histplot(num_queries)
plt.xlabel("Number of queries per card")
plt.ylabel("Number of cards")
plt.title("Distribution of number of queries per card")
plt.show()


## Training

In [37]:
train_triples = []
for query, doc_idx in tqdm.tqdm(train_pairs_idxs):
    # sample 2 easy negatives and 8 medium negatives
    easy_negatives_sample = np.random.choice(easy_negatives[sample_cards[idx]], size=2, replace=False)
    medium_negatives_sample = np.random.choice(medium_negatives[sample_cards[idx]], size=8, replace=False)
    # add the negatives to the triples
    for negative in [*easy_negatives_sample, *medium_negatives_sample]:
        train_triples.append((query, formatted_cards[doc_idx], formatted_cards[sample_cards.index(negative)]))


100%|██████████| 242695/242695 [04:31<00:00, 894.38it/s]


In [None]:
len(train_triples), len(train_triples) / 32

In [40]:
# save the triples with easy and medium negatives to file
with open("train_triples_easy_med_10x.jsonl", "w") as f:
    for triple in train_triples:
        data = {
            "query": triple[0],
            "positive": triple[1],
            "negative": triple[2],
        }
        f.write(json.dumps(data) + "\n")

In [None]:
# load the triples from file
with open("train_triples_easy_med.jsonl", "r") as f:
    train_triples = []
    for line in f:
        data = json.loads(line)
        train_triples.append((data["query"], data["positive"], data["negative"]))

In [43]:
trainer.prepare_training_data(raw_data=train_triples, data_out_path="./train_data_v0_0/", all_documents=formatted_cards, mine_hard_negatives=False)

'./train_data_v0_0/'

In [44]:
trainer.train(batch_size=32,
              nbits=4, # How many bits will the trained model use when compressing indexes
              maxsteps=100_000, # Maximum steps hard stop
              use_ib_negatives=True, # Use in-batch negative to calculate loss
              dim=128, # How many dimensions per embedding. 128 is the default and works well.
              learning_rate=5e-6, # Learning rate, small values ([3e-6,3e-5] work best if the base model is BERT-like, 5e-6 is often the sweet spot)
              doc_maxlen=256, # Maximum document length. Because of how ColBERT works, smaller chunks (128-256) work very well.
              use_relu=False, # Disable ReLU -- doesn't improve performance
              warmup_steps="auto", # Defaults to 10%
             )

#> Starting...
{
    "query_token_id": "[unused0]",
    "doc_token_id": "[unused1]",
    "query_token": "[Q]",
    "doc_token": "[D]",
    "ncells": null,
    "centroid_score_threshold": null,
    "ndocs": null,
    "load_index_with_mmap": false,
    "index_path": null,
    "index_bsize": 64,
    "nbits": 4,
    "kmeans_niters": 20,
    "resume": false,
    "pool_factor": 1,
    "clustering_mode": "hierarchical",
    "protected_tokens": 0,
    "similarity": "cosine",
    "bsize": 32,
    "accumsteps": 1,
    "lr": 5e-6,
    "maxsteps": 100000,
    "save_every": 14726,
    "warmup": 14726,
    "warmup_bert": null,
    "relu": false,
    "nway": 2,
    "use_ib_negatives": true,
    "reranker": false,
    "distillation_alpha": 1.0,
    "ignore_scores": false,
    "model_name": "MTGColBERTv_0_0",
    "query_maxlen": 32,
    "attend_to_mask_tokens": false,
    "interaction": "colbert",
    "dim": 128,
    "doc_maxlen": 256,
    "mask_punctuation": true,
    "checkpoint": "colbert-ir\/colber

Process Process-1:
Traceback (most recent call last):
  File "/opt/homebrew/Cellar/python@3.12/3.12.10/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/opt/homebrew/Cellar/python@3.12/3.12.10/Frameworks/Python.framework/Versions/3.12/lib/python3.12/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/Users/benchislett/Repos/Hexanomicon/.venv/lib/python3.12/site-packages/colbert/infra/launcher.py", line 134, in setup_new_process
    return_val = callee(config, *args)
                 ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/benchislett/Repos/Hexanomicon/.venv/lib/python3.12/site-packages/colbert/training/training.py", line 55, in train
    colbert = torch.nn.parallel.DistributedDataParallel(colbert, device_ids=[config.rank],
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/benchislett/Repos/Hexanomicon/.

KeyboardInterrupt: 

In [47]:
len([i for i in queries_data.keys() if i in cards.card_data.keys()])

30497

In [48]:
synthetic_queries = {name: data for name, data in queries_data.items() if name in cards.card_data.keys()}
# save the synthetic queries to file
with open("synthetic_queries.json", "w") as f:
    f.write(json.dumps(synthetic_queries, indent=2))

In [2]:
import json
with open("synthetic_queries.json", "r") as f:
    synthetic_queries = json.load(f)

In [15]:
new_queries = {}
for name, data in synthetic_queries.items():
    if type(data) == dict:
        if data.get("queries") is None:
            print(f"Card {name} has no queries")
            continue
        new_queries[name] = data["queries"]
    else:
        assert type(data) == list
        new_queries[name] = data


Card Acquisitions Expert has no queries
Card Thieving Sprite has no queries
Card Fiery Gambit has no queries


In [14]:
min(len(x) for x in new_queries.values()), max(len(x) for x in new_queries.values()), sum(len(x) for x in new_queries.values()) / len(new_queries)

(6, 10, 9.72312586082508)

In [13]:
len(new_queries)

30494

In [16]:
for i in new_queries:
    assert new_queries[i]
    assert len(new_queries[i]) > 5

In [17]:
with open("synthetic_queries.json", "w") as f:
    f.write(json.dumps(new_queries, indent=2))