In [2]:
# %%
# Step 0: Install necessary libraries
#!pip install requests pandas pyyaml tqdm

# %%
import requests
import pandas as pd
import yaml
import random
from google.colab import files
import io
from tqdm.notebook import tqdm
import time
from typing import List, Dict, Any, Set, Tuple

In [6]:
# --- Configuration ---
SAMPLES_PER_CONCEPT = 100
HARD_NEGATIVE_RATIO = 0.8
ORTHOGONAL_NEGATIVE_PRIORITY = 0.7 # 70% of hard negatives will try to be from a matching orthogonal group

# ==============================================================================
# MODULE 1: Scryfall API Client
# ==============================================================================
class ScryfallAPI:
    """A client to fetch card data from the Scryfall API."""
    def __init__(self):
        self.session = requests.Session()
        self.session.headers.update({
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
        })

    def get_cards_for_query(self, query: str) -> List[Dict[str, Any]]:
        """Fetches card data (name and edhrec_rank) for a given Scryfall query."""
        print(f"    [API] Fetching cards for query: '{query}'")
        url = "https://api.scryfall.com/cards/search"
        params = {'q': query}
        card_data_list = []

        page_count = 0
        while url:
            try:
                page_count += 1
                response = self.session.get(url, params=params)
                if response.status_code == 404:
                    break
                response.raise_for_status()
                data = response.json()

                for card_data in data.get('data', []):
                    card_data_list.append({
                        'name': card_data.get('name', 'Unknown'),
                        'edhrec_rank': card_data.get('edhrec_rank')
                    })

                url = data.get('next_page')
                params = None
                if url and page_count == 1:
                    print(f"    [API] Found multiple pages, fetching more...")
                time.sleep(0.1)
            except Exception:
                break
        print(f"    [API] Found {len(card_data_list)} cards for query.")
        return card_data_list

# ==============================================================================
# MODULE 2: Triplet Generation Logic
# ==============================================================================
class TripletGenerator:
    """Generates training triplets based on a concept curriculum."""

    @staticmethod
    def get_tiered_pools(card_list: List[Dict]) -> Tuple[List[str], List[str], List[str]]:
        """Partitions a list of cards into three tiers based on EDHREC rank."""
        if not card_list: return [], [], []
        df = pd.DataFrame(card_list).sort_values(by='edhrec_rank', ascending=True, na_position='last')
        total_cards = len(df)
        tier1_count = min(max(1, int(total_cards * 0.05)), 10)
        tier1 = df.head(tier1_count)['name'].tolist()
        remaining_df = df.iloc[tier1_count:]
        tier2_count = int(len(remaining_df) * 0.25)
        tier2 = remaining_df.head(tier2_count)['name'].tolist()
        tier3 = remaining_df.iloc[tier2_count:]['name'].tolist()
        return tier1, tier2, tier3

    def generate_triplets_for_concept(self, concept: Dict, tag_to_card_data_map: Dict, card_to_tags_map: Dict) -> List[Dict]:
        """Generates a list of triplets for a single concept using balanced quota-based sampling."""
        concept_name = concept['concept_name']

        positive_cards_data = tag_to_card_data_map.get(concept['positive_tag'], [])
        if not positive_cards_data:
            print(f"  [Generator] Skipping '{concept_name}' - no positive cards found.")
            return []

        tier1, tier2, tier3 = self.get_tiered_pools(positive_cards_data)
        print(f"  [Generator] Tiered positive pools: {len(tier1)} (Icons), {len(tier2)} (Staples), {len(tier3)} (Niche)")

        positive_card_names = {c['name'] for c in positive_cards_data}
        strategy = concept.get('hard_negative_strategy', {})

        parent_negatives = set()
        if 'parent_tag' in strategy:
            parent_cards = {c['name'] for c in tag_to_card_data_map.get(strategy['parent_tag'], [])}
            parent_negatives.update(parent_cards - positive_card_names)

        orthogonal_negatives = set()
        for ortho_tag in strategy.get('orthogonal_tags', []):
            ortho_cards = {c['name'] for c in tag_to_card_data_map.get(ortho_tag, [])}
            orthogonal_negatives.update(ortho_cards - positive_card_names)

        easy_negatives = set()
        for easy_tag in concept.get('easy_negative_tags', []):
            easy_negatives.update(c['name'] for c in tag_to_card_data_map.get(easy_tag, []))

        hard_neg_list = list(parent_negatives.union(orthogonal_negatives))
        easy_neg_list = list(easy_negatives)

        print(f"  [Generator] Found {len(hard_neg_list)} unique hard negatives and {len(easy_neg_list)} unique easy negatives.")

        if not hard_neg_list and not easy_neg_list:
            print(f"  [Generator] Skipping '{concept_name}' - no negative examples could be found.")
            return []

        concept_triplets = []

        # Balanced Quota Sampling Logic
        tier_pools = {'tier1': tier1, 'tier2': tier2, 'tier3': tier3}

        # **UPDATED**: Using your suggested 45/35/20 distribution.
        tier_quotas = {
            'tier1': int(SAMPLES_PER_CONCEPT * 0.45),
            'tier2': int(SAMPLES_PER_CONCEPT * 0.35),
            'tier3': int(SAMPLES_PER_CONCEPT * 0.20)
        }

        all_positive_samples = []
        for tier_name, quota in tier_quotas.items():
            pool = tier_pools[tier_name]
            if not pool: continue

            # Create a balanced list by cycling through the pool
            # e.g., if quota is 10 and pool is [A, B], create [A,B,A,B,A,B,A,B,A,B]
            tier_samples = (pool * (quota // len(pool) + 1))[:quota]
            all_positive_samples.extend(tier_samples)

        random.shuffle(all_positive_samples)

        for positive in all_positive_samples:
            anchor = random.choice(concept['queries'])

            negative = None
            if random.random() < HARD_NEGATIVE_RATIO and hard_neg_list:
                if random.random() < ORTHOGONAL_NEGATIVE_PRIORITY:
                    positive_card_tags = card_to_tags_map.get(positive, [])
                    matching_ortho_tags = [tag for tag in strategy.get('orthogonal_tags', []) if tag in positive_card_tags]
                    if matching_ortho_tags:
                        chosen_ortho_tag = random.choice(matching_ortho_tags)
                        ortho_group_negatives = {c['name'] for c in tag_to_card_data_map.get(chosen_ortho_tag, [])}
                        ortho_group_negatives = list(ortho_group_negatives - positive_card_names)
                        if ortho_group_negatives:
                            negative = random.choice(ortho_group_negatives)

                if not negative: # Fallback to general hard negatives
                    negative = random.choice(hard_neg_list)

            if not negative and easy_neg_list: # Fallback to easy negatives
                negative = random.choice(easy_neg_list)

            if not negative and hard_neg_list: # Final fallback
                negative = random.choice(hard_neg_list)

            if positive and negative:
                concept_triplets.append({"anchor": anchor, "positive": positive, "negative": negative})

        print(f"  [Generator] Generated {len(concept_triplets)} triplets for '{concept_name}'.")
        return concept_triplets

# ==============================================================================
# MAIN EXECUTION SCRIPT
# ==============================================================================
def main():
    """Main function to run the data generation pipeline."""
    # 1. Load configuration
    print("Please upload your 'concept_curriculum.yml' file.")
    uploaded = files.upload()
    config_filename = next(iter(uploaded))
    try:
        concepts = yaml.safe_load(io.BytesIO(uploaded[config_filename]))
        print(f"Successfully loaded {len(concepts)} concepts from the curriculum.")
    except Exception as e:
        print(f"Error reading YAML file: {e}")
        return

    # 2. Pre-fetch all necessary card data
    scryfall_client = ScryfallAPI()
    all_tags_to_fetch = set()
    for concept in concepts:
        all_tags_to_fetch.add(concept['positive_tag'])
        strategy = concept.get('hard_negative_strategy', {})
        if 'parent_tag' in strategy:
            all_tags_to_fetch.add(strategy['parent_tag'])
        all_tags_to_fetch.update(strategy.get('orthogonal_tags', []))
        all_tags_to_fetch.update(concept.get('easy_negative_tags', []))

    print(f"\nFound {len(all_tags_to_fetch)} unique tags to fetch from Scryfall...")

    tag_to_card_data_map = {}
    for tag in tqdm(list(all_tags_to_fetch), desc="Fetching All Tag Data"):
        tag_to_card_data_map[tag] = scryfall_client.get_cards_for_query(tag)

    card_to_tags_map = {}
    for tag, cards in tag_to_card_data_map.items():
        for card in cards:
            card_name = card['name']
            if card_name not in card_to_tags_map:
                card_to_tags_map[card_name] = []
            card_to_tags_map[card_name].append(tag)

    # 3. Process each concept and generate triplets
    triplet_generator = TripletGenerator()
    all_triplets = []

    print("\nStarting training data generation...")
    for concept in tqdm(concepts, desc="Processing Concepts"):
        print(f"\n--- Processing concept: {concept['concept_name']} ---")

        triplets = triplet_generator.generate_triplets_for_concept(concept, tag_to_card_data_map, card_to_tags_map)
        all_triplets.extend(triplets)

    # 4. Save final results
    print("\nData generation complete. Creating final CSV file...")
    if not all_triplets:
        print("No triplets were generated. Exiting.")
        return

    df_final = pd.DataFrame(all_triplets)
    df_final = df_final.sample(frac=1).reset_index(drop=True)

    output_filename = "generated_training_triplets.csv"
    df_final.to_csv(output_filename, index=False)
    print(f"Successfully generated {len(df_final)} triplets and saved to '{output_filename}'")

    print("\nTo download the file, run the following code in a new cell:")
    print("from google.colab import files")
    print(f"files.download('{output_filename}')")

if __name__ == "__main__":
    main()



Please upload your 'concept_curriculum.yml' file.


Saving query_to_card_concepts_counters.yaml to query_to_card_concepts_counters (3).yaml
Successfully loaded 1 concepts from the curriculum.

Found 12 unique tags to fetch from Scryfall...


Fetching All Tag Data:   0%|          | 0/12 [00:00<?, ?it/s]

    [API] Fetching cards for query: 'otag:cycle-pcy-pitchspell'
    [API] Found 5 cards for query.
    [API] Fetching cards for query: 'otag:cycle-bok-shoal'
    [API] Found 5 cards for query.
    [API] Fetching cards for query: 'otag:cycle-mh3-flare'
    [API] Found 5 cards for query.
    [API] Fetching cards for query: 'otag:phyrexian-mana-cost'
    [API] Found 41 cards for query.
    [API] Fetching cards for query: 'otag:cycle-fut-pact'
    [API] Found 5 cards for query.
    [API] Fetching cards for query: 'otag:counterspell'
    [API] Found multiple pages, fetching more...
    [API] Found 534 cards for query.
    [API] Fetching cards for query: 'otag:board-wipe'
    [API] Found multiple pages, fetching more...
    [API] Found 884 cards for query.
    [API] Fetching cards for query: 'otag:mana-rock'
    [API] Found multiple pages, fetching more...
    [API] Found 361 cards for query.
    [API] Fetching cards for query: 'otag:cycle-c20-free-spell'
    [API] Found 5 cards for query.
 

Processing Concepts:   0%|          | 0/1 [00:00<?, ?it/s]


--- Processing concept: Free Counterspells ---
  [Generator] Tiered positive pools: 1 (Icons), 3 (Staples), 9 (Niche)
  [Generator] Found 588 unique hard negatives and 1241 unique easy negatives.
  [Generator] Generated 100 triplets for 'Free Counterspells'.

Data generation complete. Creating final CSV file...
Successfully generated 100 triplets and saved to 'generated_training_triplets.csv'

To download the file, run the following code in a new cell:
from google.colab import files
files.download('generated_training_triplets.csv')
