In [1]:
# ==============================================================================
# Supervised Fine-Tuning Notebook
# ==============================================================================
# This notebook is designed to be run step-by-step in an environment like
# Google Colab or a local Jupyter notebook.

# %%
# CELL 1: Install Dependencies
# Run this cell first to make sure all necessary libraries are installed.
# ==============================================================================
!pip install sentence-transformers pandas pyarrow torch orjson tqdm scikit-learn

In [10]:
# %%
# CELL 2: Imports and Configuration
# This cell imports all required modules and sets up the main configuration
# variables for the training run.
# ==============================================================================
import pandas as pd
import polars as pl
import orjson
from sentence_transformers import SentenceTransformer, InputExample, losses
from sentence_transformers.evaluation import TripletEvaluator
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import os
import shutil
import zipfile
import torch
from sklearn.model_selection import train_test_split



Configuration set. Ready for the next step.


In [25]:
# --- Configuration ---
NUM_EPOCHS = 1
BATCH_SIZE = 16
LEARNING_RATE = 2e-5
OUTPUT_PATH = './mtg_supervised'
BASE_MODEL_ZIP_PATH = './gte-mtg-base.zip'
BASE_MODEL_EXTRACT_PATH = './gte-mtg-base'
TRIPLET_CSV_PATH = './generated_training_triplets.csv'
CARD_DATA_PARQUET_PATH = './mtg_data.parquet'

print("Configuration set. Ready for the next step.")

Configuration set. Ready for the next step.


In [5]:
# %%
# CELL 3: Unzip Base Model
# This cell checks for your zipped base model and unzips it if necessary.
# In Colab, you would upload 'gte-mtg-base.zip' before running this.
# ==============================================================================
if os.path.exists(BASE_MODEL_ZIP_PATH) and not os.path.exists(BASE_MODEL_EXTRACT_PATH):
    print(f"Found '{BASE_MODEL_ZIP_PATH}'. Unzipping...")
    with zipfile.ZipFile(BASE_MODEL_ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall('./')
    print(f"Successfully unzipped model to '{BASE_MODEL_EXTRACT_PATH}'")
elif os.path.exists(BASE_MODEL_EXTRACT_PATH):
    print(f"Base model directory '{BASE_MODEL_EXTRACT_PATH}' already exists. Skipping unzip.")
else:
    print(f"ERROR: Base model not found at '{BASE_MODEL_EXTRACT_PATH}' or as a zip file.")
    # In a notebook, we'd stop here.
    # exit()

Base model directory './gte-mtg-base' already exists. Skipping unzip.


In [7]:
# %%
# CELL 4: Load Data Files
# This cell loads your training triplets and the full card data from local files.
# In Colab, you would upload these files before running this cell.
# ==============================================================================
print("\nLoading data files from local paths...")
try:
    df_triplets = pd.read_csv(TRIPLET_CSV_PATH)
    df_full_data = pl.read_parquet(CARD_DATA_PARQUET_PATH)
    print("All data loaded successfully.")
    print(f"Loaded {len(df_triplets)} triplets.")
    print(f"Loaded {len(df_full_data)} cards.")
except FileNotFoundError as e:
    print(f"Error: Could not find a required data file. {e}")
    print("Please ensure 'generated_training_triplets.csv' and 'mtg_data.parquet' are in the same directory.")
except Exception as e:
    print(f"Error reading data files: {e}")


Loading data files from local paths...
All data loaded successfully.
Loaded 100 triplets.
Loaded 32722 cards.


In [8]:
# %%
# CELL 5: Prepare Data Structures
# This cell builds the lookup map that converts card names to their full JSON representation.
# ==============================================================================
print("\nBuilding card name to JSON lookup map...")
card_name_to_json = {}
for row in df_full_data.iter_rows(named=True):
    name = row.get("name")
    if name:
        row_dict = {k: v for k, v in row.items() if v is not None and k != "scryfallId"}
        row_str = orjson.dumps(row_dict, option=orjson.OPT_INDENT_2).decode("utf-8")
        card_name_to_json[name] = row_str
print(f"Created lookup map with {len(card_name_to_json)} entries.")


Building card name to JSON lookup map...
Created lookup map with 31879 entries.


In [26]:
# %%
# CELL 6: Construct Training & Evaluation Examples
# This cell loads the base model, constructs the rich training examples using the
# lookup map, and splits them into training and development sets.
# ==============================================================================
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"\nUsing device: {device}")
print(f"Loading base model from '{BASE_MODEL_EXTRACT_PATH}'...")
model = SentenceTransformer(BASE_MODEL_EXTRACT_PATH, device=device)

print("\nConstructing training examples with full JSON...")
all_examples = []
for index, row in tqdm(df_triplets.iterrows(), total=df_triplets.shape[0], desc="Building Examples"):
    try:
        anchor_name = row['anchor']
        positive_name = row['positive']
        negative_name = row['negative']
        anchor_text = anchor_name if anchor_name not in card_name_to_json else card_name_to_json[anchor_name]
        positive_text = card_name_to_json[positive_name]
        negative_text = card_name_to_json[negative_name]
        all_examples.append(InputExample(texts=[anchor_text, positive_text, negative_name]))
    except KeyError as e:
        print(f"Warning: Skipping row {index} due to card not found in lookup map: {e}")

print(f"Successfully created {len(all_examples)} total examples.")

# Split data into training and evaluation sets
train_examples, dev_examples = train_test_split(all_examples, test_size=0.1, random_state=42)
print(f"Split data into {len(train_examples)} training examples and {len(dev_examples)} evaluation examples.")








Using device: cuda:0
Loading base model from './gte-mtg-base'...

Constructing training examples with full JSON...


Building Examples:   0%|          | 0/100 [00:00<?, ?it/s]

Successfully created 100 total examples.
Split data into 90 training examples and 10 evaluation examples.


In [27]:
# %%
# CELL 7: Run the Supervised Fine-Tuning Process
# This is the main training cell. It will fine-tune the model and save the
# best version to the specified output path.
# ==============================================================================
train_loss = losses.TripletLoss(model=model)
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)

# Create an evaluator to get metrics during training on an unseen portion of the data.
# This is the best way to confirm the model is learning general rules, not just memorizing.
evaluator = TripletEvaluator.from_input_examples(dev_examples, name='mtg-dev')

# Calculate logging steps to get a few logs per epoch
# This prevents the evaluation from running too often and slowing down training.
logging_steps = max(1, len(train_dataloader) // 4)
print(f"Logging & Evaluating every {logging_steps} steps.")

# Set a small, fixed number of warmup steps.
warmup_steps = int(len(train_dataloader) * 0.1) # 10% of steps in the first epoch

print("\nStarting supervised fine-tuning...")
# The .fit() method handles the entire training loop.
# It will periodically run the evaluator and print the accuracy.
# Watch for the 'Mtg-dev Cosine Accuracy' to increase over time.
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=NUM_EPOCHS,
    evaluator=evaluator,
    evaluation_steps=logging_steps,
    output_path=OUTPUT_PATH,
    save_best_model=True, # Save the model checkpoint with the best evaluation score
    optimizer_params={'lr': LEARNING_RATE},
    warmup_steps=warmup_steps
)

print(f"\nFine-tuning complete. The best performing model was saved to: {OUTPUT_PATH}")




Logging & Evaluating every 1 steps.

Starting supervised fine-tuning...


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss,Validation Loss,Mtg-dev Cosine Accuracy
1,No log,No log,0.3
2,No log,No log,0.3
3,No log,No log,0.8
4,No log,No log,0.9
5,No log,No log,1.0
6,No log,No log,1.0



Fine-tuning complete. The best performing model was saved to: ./mtg_supervised


In [None]:
# %%
# CELL 8: Package and Download the Final Model
# This cell zips your final, trained model for easy download and distribution.
# ==============================================================================
print("\nZipping the final model for distribution...")
if os.path.exists(f"{OUTPUT_PATH}.zip"):
    os.remove(f"{OUTPUT_PATH}.zip")
shutil.make_archive(OUTPUT_PATH, 'zip', OUTPUT_PATH)
print(f"Successfully created '{OUTPUT_PATH}.zip'")

# In Colab, you would uncomment the following lines to download the file
# from google.colab import files
# files.download(f"{OUTPUT_PATH}.zip")