In [None]:
import os
# Auto-generated setup for portability
if 'google.colab' in str(get_ipython()):
    # Assume data is mounted or downloaded to current dir in Colab
    BASE_DIR = os.getcwd()
else:
    # Local execution
    BASE_DIR = os.getcwd()


In [23]:
# Cell 1: Generate 5000+ Casual Math Expressions
import random

def generate_casual_math_dataset(num_samples=5000):
    """Generate diverse casual math queries"""
    casual_math = []

    # 1. ALGEBRA (1500 samples)
    algebra_templates = [
        "solve {eq}",
        "what is {eq}",
        "find x when {eq}",
        "solve for x: {eq}",
        "whats x in {eq}",
        "{eq} solve this",
        "help me solve {eq}",
        "how to solve {eq}",
    ]

    for _ in range(1500):
        a, b = random.randint(1, 20), random.randint(1, 20)
        equations = [
            f"x + {a} = {b}",
            f"{a}x = {b}",
            f"x - {a} = {b}",
            f"{a}x + {b} = {a+b}",
            f"x^2 = {a}",
            f"2x + {a} = {b}",
            f"{a}x - {b} = 0",
        ]
        eq = random.choice(equations)
        template = random.choice(algebra_templates)
        casual_math.append({"text": template.format(eq=eq), "label": "math"})

    # 2. CALCULUS (1000 samples)
    calculus_templates = [
        "whats the integral of {func}",
        "integrate {func}",
        "find integral of {func}",
        "derivative of {func}",
        "whats d/dx of {func}",
        "differentiate {func}",
        "find derivative {func}",
        "calculate integral {func}",
        "what is the derivative of {func}",
        "integral {func}",
    ]

    functions = [
        "x^2", "x^3", "x^4", "2x", "3x^2", "x^2 + x", "sin(x)", "cos(x)",
        "e^x", "ln(x)", "tan(x)", "x^2 - 1", "2x + 3", "x^3 - x", "sqrt(x)",
        "1/x", "x^5", "4x^3", "sin(2x)", "e^(2x)", "x squared", "x cubed"
    ]

    for _ in range(1000):
        func = random.choice(functions)
        template = random.choice(calculus_templates)
        casual_math.append({"text": template.format(func=func), "label": "math"})

    # 3. GEOMETRY (800 samples)
    geometry_templates = [
        "area of circle radius {r}",
        "whats area of circle with radius {r}",
        "find area circle r={r}",
        "volume of sphere radius {r}",
        "area of rectangle {l} by {w}",
        "perimeter of rectangle {l}x{w}",
        "area of triangle base {b} height {h}",
        "circumference of circle radius {r}",
        "volume of cube side {s}",
        "surface area sphere radius {r}",
        "find volume of sphere r={r}",
    ]

    for _ in range(800):
        r, l, w, b, h, s = [random.randint(1, 20) for _ in range(6)]
        template = random.choice(geometry_templates)
        text = template.format(r=r, l=l, w=w, b=b, h=h, s=s)
        casual_math.append({"text": text, "label": "math"})

    # 4. BASIC ARITHMETIC (700 samples)
    arithmetic_templates = [
        "what is {a} + {b}",
        "calculate {a} * {b}",
        "whats {a} - {b}",
        "{a} divided by {b}",
        "multiply {a} by {b}",
        "{a} plus {b}",
        "{a} times {b}",
        "{a} minus {b}",
        "add {a} and {b}",
        "{a} * {b}",
    ]

    for _ in range(700):
        a, b = random.randint(1, 1000), random.randint(1, 100)
        template = random.choice(arithmetic_templates)
        casual_math.append({"text": template.format(a=a, b=b), "label": "math"})

    # 5. PERCENTAGES & FRACTIONS (500 samples)
    percent_templates = [
        "what is {p}% of {n}",
        "find {p} percent of {n}",
        "calculate {p}% of {n}",
        "{p}% of {n}",
        "whats {a}/{b}",
        "simplify {a}/{b}",
        "{a} divided by {b} as fraction",
        "{p} percent of {n}",
    ]

    for _ in range(500):
        p, n = random.randint(1, 100), random.randint(10, 1000)
        a, b = random.randint(1, 20), random.randint(1, 20)
        template = random.choice(percent_templates)
        casual_math.append({"text": template.format(p=p, n=n, a=a, b=b), "label": "math"})

    # 6. TRIGONOMETRY (300 samples)
    trig_templates = [
        "what is sin({angle})",
        "calculate cos({angle})",
        "find tan({angle})",
        "sin of {angle} degrees",
        "cos({angle})",
        "tangent of {angle}",
        "whats sin({angle})",
    ]

    angles = [0, 30, 45, 60, 90, 180, 270, 360]
    for _ in range(300):
        angle = random.choice(angles)
        template = random.choice(trig_templates)
        casual_math.append({"text": template.format(angle=angle), "label": "math"})

    # 7. EXPONENTS & LOGARITHMS (200 samples)
    exp_templates = [
        "what is {a}^{b}",
        "calculate {a} to the power of {b}",
        "{a} raised to {b}",
        "log of {n}",
        "natural log of {n}",
        "ln({n})",
        "log base {b} of {n}",
        "{a} to the {b}",
    ]

    for _ in range(200):
        a, b, n = random.randint(2, 10), random.randint(2, 5), random.randint(10, 100)
        template = random.choice(exp_templates)
        casual_math.append({"text": template.format(a=a, b=b, n=n), "label": "math"})

    print(f"✓ Generated {len(casual_math)} casual math expressions")
    return casual_math

# Generate
print("Generating casual math dataset...")
casual_math_data = generate_casual_math_dataset(5000)

print("\n📝 Sample math queries:")
for i in random.sample(casual_math_data, 5):
    print(f"  - {i['text']}")


In [25]:
import random

def generate_diverse_other_data(num_samples=3000):
    """Generate greetings, chitchat, and diverse general queries"""
    other_data = []

    # 1. GREETINGS & CHITCHAT (600)
    greetings_base = [
        "hi", "hello", "hey", "hi there", "hey there", "hello there",
        "good morning", "good afternoon", "good evening", "good night",
        "whats up", "sup", "wassup", "how are you", "how are you doing",
        "hows it going", "how you doing", "nice to meet you",
        "pleased to meet you", "goodbye", "bye", "see you later",
        "talk to you later", "thanks", "thank you", "thanks a lot",
        "appreciate it", "yes", "no", "okay", "ok", "sure", "alright",
        "cool", "nice", "great", "awesome", "perfect",
        "help", "help me", "can you help", "i need help", "help please",
    ]

    for greeting in greetings_base:
        other_data.append({"text": greeting, "label": "other"})
        other_data.append({"text": greeting.capitalize(), "label": "other"})
        other_data.append({"text": greeting + "!", "label": "other"})
        other_data.append({"text": greeting + "?", "label": "other"})

    # 2. COUNTRY/GEOGRAPHY (600)
    geo_templates = [
        "capital of {country}",
        "what is the capital of {country}",
        "whats the capital of {country}",
        "population of {country}",
        "where is {country}",
        "largest city in {country}",
        "official language of {country}",
        "{country} capital",
    ]

    countries = [
        "france", "usa", "china", "india", "brazil", "russia", "japan",
        "germany", "uk", "canada", "australia", "italy", "spain", "mexico",
        "south korea", "egypt", "turkey", "argentina", "poland", "netherlands"
    ]

    for _ in range(600):
        template = random.choice(geo_templates)
        country = random.choice(countries)
        other_data.append({"text": template.format(country=country), "label": "other"})

    # 3. HISTORY (500)
    history_templates = [
        "when did {event}",
        "what year did {event}",
        "what caused {event}",
        "who {action}",
    ]

    events = [
        "world war 2 end", "world war 1 start", "the titanic sink",
        "dinosaurs go extinct", "the roman empire fall", "america get founded",
        "the cold war begin", "berlin wall fall", "moon landing happen"
    ]

    actions = [
        "invented the telephone", "discovered america", "painted the mona lisa",
        "wrote hamlet", "invented the lightbulb", "discovered penicillin",
        "founded microsoft", "invented the airplane", "discovered gravity"
    ]

    for _ in range(500):
        if random.random() > 0.5:
            # Use templates that have {event} placeholder
            event_templates = [t for t in history_templates if '{event}' in t]
            if event_templates:  # Check if there are any templates with {event}
                query = random.choice(event_templates).format(event=random.choice(events))
        else:
            # Use templates that have {action} placeholder
            action_templates = [t for t in history_templates if '{action}' in t]
            if action_templates:  # Check if there are any templates with {action}
                query = random.choice(action_templates).format(action=random.choice(actions))
        other_data.append({"text": query, "label": "other"})

    # 4. SCIENCE & NATURE (600)
    science_queries = [
        "what is photosynthesis", "how does gravity work", "what is dna",
        "explain evolution", "what is a black hole", "largest planet in solar system",
        "smallest planet in solar system", "what is climate change",
        "how do vaccines work", "what is quantum physics", "what is photosynthesis",
        "how does the sun work", "what are atoms", "what is electricity",
        "how do magnets work", "what causes earthquakes", "what is lightning",
        "why is the sky blue", "what causes rain", "what is thunder"
    ]

    for _ in range(600):
        other_data.append({"text": random.choice(science_queries), "label": "other"})

    # 5. PEOPLE & CULTURE (400)
    people_templates = [
        "who is {person}",
        "who was {person}",
        "tell me about {person}",
        "what did {person} do",
    ]

    people = [
        "einstein", "newton", "shakespeare", "mozart", "picasso",
        "gandhi", "tesla", "darwin", "galileo", "aristotle",
        "cleopatra", "napoleon", "beethoven", "da vinci", "plato"
    ]

    for _ in range(400):
        template = random.choice(people_templates)
        person = random.choice(people)
        other_data.append({"text": template.format(person=person), "label": "other"})

    # 6. GENERAL KNOWLEDGE (300)
    general = [
        "what is the internet", "what is democracy", "what is capitalism",
        "what is the meaning of life", "what is artificial intelligence",
        "what is blockchain", "what is cryptocurrency", "what is the universe",
        "what is time", "what is consciousness", "what is love", "what is art"
    ]

    for _ in range(300):
        other_data.append({"text": random.choice(general), "label": "other"})

    print(f"✓ Generated {len(other_data)} diverse 'other' queries")
    return other_data

# Generate
print("Generating diverse 'other' dataset...")
diverse_other_data = generate_diverse_other_data(3000)

print("\n📝 Sample 'other' queries:")
for i in random.sample(diverse_other_data, 10):
    print(f"  - {i['text']}")

In [26]:
# Cell 3: Download datasets
from datasets import load_dataset

print("Downloading datasets...")
print("-" * 50)

# 1. GSM8K for word problems
print("📊 Downloading GSM8K (Math word problems)...")
gsm8k = load_dataset("openai/gsm8k", "main", trust_remote_code=True)
print(f"✓ GSM8K loaded: {len(gsm8k['train'])} examples")

# 2. Python Code Instructions
print("\n💻 Downloading Python Code Instructions...")
code_dataset = load_dataset("iamtarun/python_code_instructions_18k_alpaca", trust_remote_code=True)
print(f"✓ Code dataset loaded: {len(code_dataset['train'])} examples")

# 3. SQuAD for general questions
print("\n🌐 Downloading SQuAD...")
squad = load_dataset("rajpurkar/squad", trust_remote_code=True)
print(f"✓ SQuAD loaded: {len(squad['train'])} examples")

print("\n" + "=" * 50)
print("✅ All datasets downloaded!")


In [27]:
# Cell 4: Create balanced dataset with all improvements
from datasets import Dataset, concatenate_datasets

print("Creating balanced dataset...")
print("-" * 50)

# 1. MATH: 50% GSM8K word problems + 50% casual expressions = 10,000
print("🔢 Creating balanced math dataset...")
gsm8k_data = []
for i, example in enumerate(gsm8k["train"]):
    if i >= 5000:
        break
    gsm8k_data.append({"text": example["question"], "label": "math"})

balanced_math_data = gsm8k_data + casual_math_data
print(f"  - GSM8K word problems: {len(gsm8k_data)}")
print(f"  - Casual expressions: {len(casual_math_data)}")
print(f"  ✓ Total math: {len(balanced_math_data)} samples")

# 2. CODE: Python instructions = 10,000
print("\n💻 Preparing code dataset...")
code_data = []
for i, example in enumerate(code_dataset['train']):
    if i >= 10000:
        break
    text = example.get('instruction') or example.get('prompt') or example.get('input', '')
    if text:  # Only add if text exists
        code_data.append({"text": text, "label": "code"})
print(f"  ✓ Code: {len(code_data)} samples")

# 3. OTHER: SQuAD (7000) + Generated diverse (3000) = 10,000
print("\n🌐 Preparing other/general dataset...")
other_data = []

# Add SQuAD questions (7000)
for i, example in enumerate(squad['train']):
    if i >= 7000:
        break
    other_data.append({"text": example['question'], "label": "other"})

# Add generated diverse queries (3000 - greetings, geography, etc.)
other_data.extend(diverse_other_data)

print(f"  - SQuAD questions: 7,000")
print(f"  - Greetings/geography/science/diverse: {len(diverse_other_data)}")
print(f"  ✓ Total other: {len(other_data)} samples")

# 4. Combine and shuffle
print("\n🔀 Combining and shuffling...")
math_dataset = Dataset.from_list(balanced_math_data)
code_dataset_final = Dataset.from_list(code_data)
other_dataset_final = Dataset.from_list(other_data)

final_dataset = concatenate_datasets([math_dataset, code_dataset_final, other_dataset_final])
final_dataset = final_dataset.shuffle(seed=42)

print("\n" + "=" * 50)
print("📊 FINAL DATASET STATISTICS")
print("=" * 50)
print(f"Total samples: {len(final_dataset)}")
print(f"  - Math: {len(balanced_math_data)} (50% word problems + 50% casual)")
print(f"  - Code: {len(code_data)}")
print(f"  - Other: {len(other_data)} (70% SQuAD + 30% greetings/diverse)")

# Save
final_dataset.save_to_disk("./moe_router_dataset_v2")
print(f"\n✓ Saved to: ./moe_router_dataset_v2")


In [28]:
# Cell 5: Load DistilBERT and tokenize
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from datasets import load_from_disk

# Load balanced dataset
dataset = load_from_disk("./moe_router_dataset_v2")
split_dataset = dataset.train_test_split(test_size=0.1, seed=42)

print(f"Training: {len(split_dataset['train'])} samples")
print(f"Validation: {len(split_dataset['test'])} samples")

# Label mapping
label2id = {"math": 0, "code": 1, "other": 2}
id2label = {0: "math", 1: "code", 2: "other"}

# Load pre-trained DistilBERT
model_name = "distilbert-base-uncased"
print(f"\nLoading {model_name}...")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=3,
    id2label=id2label,
    label2id=label2id
)

print(f"✓ Model loaded: {model.num_parameters():,} parameters")

# Tokenize
def preprocess(examples):
    tokenized = tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)
    tokenized["label"] = label2id[examples["label"]]
    return tokenized

print("\nTokenizing dataset...")
tokenized_dataset = split_dataset.map(preprocess, remove_columns=["text"])
print("✓ Tokenization complete!")


In [29]:
# Cell 6: Setup training
from transformers import TrainingArguments, Trainer
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
import numpy as np
import os

# Disable tokenizer warning
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    accuracy = accuracy_score(labels, predictions)
    f1_macro = f1_score(labels, predictions, average='macro')
    f1_weighted = f1_score(labels, predictions, average='weighted')

    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, predictions, average=None, labels=[0, 1, 2]
    )

    return {
        'accuracy': accuracy,
        'f1_macro': f1_macro,
        'f1_weighted': f1_weighted,
        'f1_math': f1[0],
        'f1_code': f1[1],
        'f1_other': f1[2],
    }

training_args = TrainingArguments(
    output_dir="Classifier/checkpoints",

    # Training
    num_train_epochs=3,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=128,
    gradient_accumulation_steps=1,

    # Optimization
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_ratio=0.1,

    # Memory
    fp16=True,
    dataloader_num_workers=4,

    # Evaluation & Saving
    eval_strategy="steps",
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",

    # Logging
    logging_dir="Classifier/logs",
    logging_steps=50,
    report_to="none",

    seed=42,
)

print("✓ Training configuration ready!")


In [30]:
# Cell 7: Train
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    compute_metrics=compute_metrics,
)

# Check for checkpoints
checkpoint_dir = "Classifier/checkpoints"
checkpoints = []
if os.path.exists(checkpoint_dir):
    checkpoints = [
        os.path.join(checkpoint_dir, d)
        for d in os.listdir(checkpoint_dir)
        if d.startswith("checkpoint-")
    ]
    checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[-1]))

if checkpoints:
    latest_checkpoint = checkpoints[-1]
    print(f"✓ Resuming from: {latest_checkpoint}")
    resume_from = latest_checkpoint
else:
    print("✓ Starting fresh training")
    resume_from = None

print("\n" + "=" * 50)
print("STARTING TRAINING")
print("=" * 50)

# Train
trainer.train(resume_from_checkpoint=resume_from)

print("\n✅ Training complete!")


In [31]:
# Cell 8: Evaluate and save
import json

# Final evaluation
print("Running final evaluation...")
eval_results = trainer.evaluate()

print("\n" + "=" * 50)
print("FINAL EVALUATION RESULTS")
print("=" * 50)
for key, value in eval_results.items():
    if isinstance(value, float):
        print(f"{key}: {value:.4f}")

# Save
save_path = "Classifier"
print(f"\nSaving model to {save_path}...")

trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)

# Save config
config_dict = {
    "model_type": "distilbert_classifier",
    "base_model": "distilbert-base-uncased",
    "num_labels": 3,
    "label2id": label2id,
    "id2label": id2label,
    "training_results": eval_results,
    "dataset_composition": {
        "math": "50% GSM8K word problems + 50% casual expressions",
        "code": "Python instruction dataset",
        "other": "70% SQuAD questions + 30% greetings/geography/science/diverse"
    }
}

with open(f"{save_path}/router_config.json", "w") as f:
    json.dump(config_dict, f, indent=2)

print(f"✓ Model saved to: {save_path}")
print(f"✓ Config saved to: {save_path}/router_config.json")
print("\n✅ All done!")


In [32]:
# Cell 9: Comprehensive testing with 40+ queries
from transformers import pipeline
import random

# Load router
router = pipeline(
    "text-classification",
    model="Classifier",
    device=0
)

# Comprehensive test queries
comprehensive_tests = [
    # MATH (10)
    {"query": "whats the integral of x squared", "expected": "math"},
    {"query": "solve 2x + 5 = 15", "expected": "math"},
    {"query": "derivative of sin(x)", "expected": "math"},
    {"query": "area of circle radius 7", "expected": "math"},
    {"query": "what is 25% of 200", "expected": "math"},
    {"query": "calculate 456 + 789", "expected": "math"},
    {"query": "find x when x^2 = 16", "expected": "math"},
    {"query": "whats 15 times 8", "expected": "math"},
    {"query": "volume of sphere radius 5", "expected": "math"},
    {"query": "simplify 3x + 2x", "expected": "math"},

    # CODE (8)
    {"query": "how to reverse string in python", "expected": "code"},
    {"query": "debug null pointer exception", "expected": "code"},
    {"query": "write function to sort array", "expected": "code"},
    {"query": "make code to find prime numbers", "expected": "code"},
    {"query": "how do i loop through a list in python", "expected": "code"},
    {"query": "fix my javascript error", "expected": "code"},
    {"query": "create function that checks palindrome", "expected": "code"},
    {"query": "python code for binary search", "expected": "code"},

    # OTHER - General Knowledge (7)
    {"query": "who made the telephone", "expected": "other"},
    {"query": "capital of france", "expected": "other"},
    {"query": "when did world war 2 end", "expected": "other"},
    {"query": "who is the president of usa", "expected": "other"},
    {"query": "what is photosynthesis", "expected": "other"},
    {"query": "largest planet in solar system", "expected": "other"},
    {"query": "who wrote hamlet", "expected": "other"},

    # OTHER - Greetings & Chitchat (8)
    {"query": "hi how are you", "expected": "other"},
    {"query": "hello", "expected": "other"},
    {"query": "whats up", "expected": "other"},
    {"query": "good morning", "expected": "other"},
    {"query": "how are you doing", "expected": "other"},
    {"query": "hey there", "expected": "other"},
    {"query": "thanks", "expected": "other"},
    {"query": "bye", "expected": "other"},

    # OTHER - Edge Cases (7)
    {"query": "help", "expected": "other"},
    {"query": "what is python", "expected": "other"},
    {"query": "tell me about einstein", "expected": "other"},
    {"query": "i need help", "expected": "other"},
    {"query": "ok", "expected": "other"},
    {"query": "yes", "expected": "other"},
    {"query": "no", "expected": "other"},
]

random.shuffle(comprehensive_tests)

print("=" * 80)
print("COMPREHENSIVE ROUTER TESTING - 40 DIVERSE QUERIES")
print("=" * 80)

correct = 0
total = len(comprehensive_tests)
results_by_category = {"math": {"correct": 0, "total": 0},
                       "code": {"correct": 0, "total": 0},
                       "other": {"correct": 0, "total": 0}}

print("\n")
for i, test in enumerate(comprehensive_tests, 1):
    query = test["query"]
    expected = test["expected"]

    result = router(query)[0]
    predicted = result['label'].lower()
    confidence = result['score']

    is_correct = (predicted == expected)
    correct += is_correct

    results_by_category[expected]["total"] += 1
    if is_correct:
        results_by_category[expected]["correct"] += 1

    status = "✅" if is_correct else "❌"

    print(f"{i:2d}. {status} \"{query[:45]}{'...' if len(query) > 45 else ''}\"")
    print(f"     Expected: {expected.upper():5s} | Predicted: {predicted.upper():5s} ({confidence:.0%})")

    if not is_correct:
        print(f"     ⚠️  MISMATCH!")

    if i % 5 == 0:
        print()

print("\n" + "=" * 80)
print("OVERALL RESULTS")
print("=" * 80)
print(f"Total Accuracy: {correct}/{total} correct ({correct/total*100:.1f}%)")
print()

print("By Category:")
for category, stats in results_by_category.items():
    if stats["total"] > 0:
        accuracy = (stats["correct"] / stats["total"]) * 100
        print(f"  {category.upper():6s}: {stats['correct']:2d}/{stats['total']:2d} correct ({accuracy:.0f}%)")

print("\n" + "=" * 80)

if correct / total >= 0.95:
    print("🎉 EXCELLENT! Router handles all query types perfectly!")
elif correct / total >= 0.90:
    print("👍 VERY GOOD! Router works great with minimal issues.")
elif correct / total >= 0.85:
    print("✓ GOOD! Router works well overall.")
else:
    print("⚠️  Needs improvement on some categories.")

print("=" * 80)
