# Template Bank Pipeline (Example-Driven Categories)

Use example-driven category classification, then match SQL templates within the winning category.

In [1]:
# Import libraries
import json
from pathlib import Path

import faiss
import numpy as np
from sentence_transformers import SentenceTransformer


In [2]:
# Instantiate model
model = SentenceTransformer('all-MiniLM-L6-v2')

## 1) Load example prompts for category classification

In [3]:
base = Path.cwd()
data_root = base / 'data' / 'prompts'
if not data_root.exists():
    data_root = base.parent / 'data' / 'prompts'

examples_path = data_root / 'prompt_examples.json'
labels_path = data_root / 'prompt_labels.json'

if examples_path.exists():
    prompt_examples = json.loads(examples_path.read_text(encoding='utf-8'))
    print('Loaded prompt examples from', examples_path)
else:
    print('WARNING: prompt_examples.json not found. Using minimal fallback examples.')
    prompt_examples = {
        'frequency': ['How many runs this week?', 'How often is she running?'],
        'volume': ['Total distance last 21 days?', 'How many calories burned?'],
        'intensity': ['What is her average pace?', 'How is her heart rate?'],
        'progression': ['Is she improving over time?', 'Show her pace trend'],
        'performance': ['Show me the last workout', 'What was her fastest time?'],
        'distribution': ['What cardio types does she do?', 'How varied are her workouts?'],
        'recovery': ['How many rest days?', 'Average rest between sessions?'],
    }

if labels_path.exists():
    prompt_labels_map = json.loads(labels_path.read_text(encoding='utf-8'))
    print('Loaded prompt labels from', labels_path)
else:
    prompt_labels_map = {}

# Flatten all example prompts with their categories
category_examples = []
category_labels = []

for category, examples in prompt_examples.items():
    category_examples.extend(examples)
    category_labels.extend([category] * len(examples))

print('Total example prompts:', len(category_examples))


Loaded prompt examples from /Users/evanlee/Desktop/FibuVerse/CardioChatBot/data/prompts/prompt_examples.json
Total example prompts: 240


## 2) Build example embeddings + FAISS index

In [4]:
# Create embeddings for all example prompts (normalize for cosine similarity)
example_embeddings = model.encode(category_examples, normalize_embeddings=True)

# Build FAISS index with cosine similarity (inner product on normalized vectors)
dimension = example_embeddings.shape[1]
index_ip = faiss.IndexFlatIP(dimension)
index_ip.add(example_embeddings)

print('Created FAISS IP index with', len(category_examples), 'example prompts')
print('Embedding dimension:', dimension)
print('Categories available:', list(prompt_examples.keys()))


Created FAISS IP index with 240 example prompts
Embedding dimension: 384
Categories available: ['frequency', 'volume', 'intensity', 'progression', 'performance', 'distribution', 'recovery', 'other/unknown']


## 3) Example-driven category classification

In [5]:
def classify_category_by_examples(user_prompt, k=5, min_confidence=0.6, temperature=0.07, other_category_count=3):
    'Classify a user prompt into cardio data analysis categories using example embeddings.'
    prompt_embedding = model.encode([user_prompt], normalize_embeddings=True)

    # Cosine similarity via dot product on normalized vectors
    sims = np.dot(example_embeddings, prompt_embedding.T).reshape(-1)

    # Top-k indices by similarity
    if k <= 0:
        k = 1
    topk_idx = np.argpartition(-sims, min(k, len(sims)) - 1)[:min(k, len(sims))]
    topk_idx = topk_idx[np.argsort(-sims[topk_idx])]

    nearest_categories = [category_labels[idx] for idx in topk_idx]
    nearest_similarities = sims[topk_idx]

    # Distance-aware class scoring using softmax over similarities
    sims_arr = np.array(nearest_similarities, dtype=float)
    temp = max(float(temperature), 1e-6)
    weights = np.exp((sims_arr - sims_arr.max()) / temp)

    scores = {}
    for i, cat in enumerate(nearest_categories):
        scores[cat] = scores.get(cat, 0.0) + float(weights[i])

    predicted_category = max(scores, key=scores.get)
    total_score = sum(scores.values())
    softmax_confidence = scores[predicted_category] / total_score if total_score else 0.0

    top_similarity = float(sims_arr.max()) if len(sims_arr) else 0.0
    confidence = top_similarity
    reject = min_confidence is not None and confidence < min_confidence

    other_categories = sorted(
        [{'category': k, 'confidence': v / total_score if total_score else 0.0} for k, v in scores.items()],
        key=lambda x: x['confidence'],
        reverse=True
    )[:other_category_count]

    if reject:
        predicted_category = 'other/unknown'

    return {
        'predicted_category': predicted_category,
        'confidence': confidence,
        'softmax_confidence': softmax_confidence,
        'reject': reject,
        'top_similarity': top_similarity,
        'other_categories': other_categories,
    }


## 4) Test classification accuracy

In [25]:
labeled_queries = [
    ('How many runs did Sarah do this week?', 'frequency'),
    ('How often is she running?', 'frequency'),
    ('Total distance last 21 days?', 'volume'),
    ('How many calories burned this month?', 'volume'),
    ('What is her average pace?', 'intensity'),
    ('How fast was her pace?', 'intensity'),
    ('Is she improving over time?', 'progression'),
    ('Show her pace trend for the last month', 'progression'),
    ('Show me the last workout', 'performance'),
    ('What was her fastest time?', 'performance'),
    ('What cardio types does she do?', 'distribution'),
    ('How varied are her workouts?', 'distribution'),
    ('How many rest days did she take?', 'recovery'),
    ('Average rest between sessions?', 'recovery'),
]

correct = 0
for query, expected in labeled_queries:
    result = classify_category_by_examples(query)
    pred = result['predicted_category']
    conf = result['confidence']
    is_correct = pred == expected
    correct += int(is_correct)
    status = 'OK' if is_correct else 'MISS'
    print(f'[{status}] {query} -> {pred} (exp: {expected}, conf: {conf:.3f})')

accuracy = correct / len(labeled_queries)
print(f"\nCategory accuracy: {accuracy:.3f} ({correct}/{len(labeled_queries)})")


[MISS] How many runs did Sarah do this week? -> volume (exp: frequency, conf: 0.880)
[MISS] How often is she running? -> volume (exp: frequency, conf: 0.621)
[OK] Total distance last 21 days? -> volume (exp: volume, conf: 0.941)
[OK] How many calories burned this month? -> volume (exp: volume, conf: 0.902)
[OK] What is her average pace? -> intensity (exp: intensity, conf: 0.699)
[OK] How fast was her pace? -> intensity (exp: intensity, conf: 0.673)
[MISS] Is she improving over time? -> other/unknown (exp: progression, conf: 0.462)
[OK] Show her pace trend for the last month -> progression (exp: progression, conf: 0.670)
[OK] Show me the last workout -> performance (exp: performance, conf: 0.955)
[OK] What was her fastest time? -> performance (exp: performance, conf: 0.678)
[OK] What cardio types does she do? -> distribution (exp: distribution, conf: 0.722)
[MISS] How varied are her workouts? -> frequency (exp: distribution, conf: 0.701)
[OK] How many rest days did she take? -> recovery

## 5) Build template embeddings

In [6]:
# SQL Templates per category
SQL_TEMPLATES = {
    "volume": {
        "total_distance": "SELECT SUM(distance) FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days')",
        "total_duration": "SELECT SUM(duration) FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days')",
        "total_calories": "SELECT SUM(calories_burned) FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days')",
        "avg_distance_per_session": "SELECT AVG(distance) FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days')",
    },

    "frequency": {
        "total_sessions": "SELECT COUNT(*) FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days')",
        "sessions_per_week": "SELECT COUNT(*)/({days}/7.0) FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days')",
        "rest_days": "SELECT {days} - COUNT(DISTINCT DATE(cardio_date)) FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days')",
    },

    "intensity": {
        "avg_heart_rate": "SELECT AVG(avg_heart_rate) FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days') AND avg_heart_rate IS NOT NULL",
        "max_heart_rate": "SELECT MAX(max_heart_rate) FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days') AND max_heart_rate IS NOT NULL",
        "avg_pace": "SELECT AVG(avg_pace) FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days') AND avg_pace IS NOT NULL",
        "avg_speed": "SELECT AVG(avg_speed) FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days') AND avg_speed IS NOT NULL",
    },

    "progression": {
        "pace_trend": "SELECT cardio_date, avg_pace FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days') AND avg_pace IS NOT NULL ORDER BY cardio_date",
        "distance_trend": "SELECT cardio_date, distance FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days') ORDER BY cardio_date",
        "fastest_pace": "SELECT MIN(avg_pace) as fastest_pace, cardio_date FROM Cardio WHERE client_id = {client_id} AND avg_pace IS NOT NULL",
    },

    "performance": {
        "last_workout": "SELECT * FROM Cardio WHERE client_id = {client_id} ORDER BY cardio_date DESC LIMIT 1",
        "elevation_gain": "SELECT SUM(elevation_gain) FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days')",
        "best_distance": "SELECT MAX(distance) as longest_run, cardio_date FROM Cardio WHERE client_id = {client_id}",
    },

    "distribution": {
        "cardio_type_breakdown": "SELECT cardio_type, COUNT(*) as count FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days') GROUP BY cardio_type",
        "workout_variety": "SELECT COUNT(DISTINCT cardio_type) as variety FROM Cardio WHERE client_id = {client_id} AND cardio_date >= DATE('now', '-{days} days')",
    },

    "recovery": {
        "avg_rest_between": "SELECT AVG(julianday(cardio_date) - LAG(julianday(cardio_date)) OVER (ORDER BY cardio_date)) as avg_rest_days FROM Cardio WHERE client_id = {client_id}",
        "longest_streak": "SELECT MAX(streak) FROM (SELECT COUNT(*) as streak FROM Cardio WHERE client_id = {client_id} GROUP BY strftime('%Y%W', cardio_date))",
    }
}


In [7]:
def build_template_embeddings(sql_templates):
    template_embeddings = {}
    for category, templates in sql_templates.items():
        template_embeddings[category] = {}
        for name in templates.keys():
            text = name.replace('_', ' ')
            emb = model.encode([text], normalize_embeddings=True)[0]
            template_embeddings[category][name] = emb
    return template_embeddings

TEMPLATE_EMBEDDINGS = build_template_embeddings(SQL_TEMPLATES)
print('Built template embeddings for', sum(len(v) for v in TEMPLATE_EMBEDDINGS.values()), 'templates')


Built template embeddings for 21 templates


## 6) End-to-end: category + template selection

In [9]:
def query_to_sql(query, client_id=123, days=21):
    # Step 1: classify category using examples
    cat_result = classify_category_by_examples(query)
    best_cat = cat_result['predicted_category']

    if best_cat == 'other/unknown':
        return {
            'category': best_cat,
            'category_confidence': cat_result['confidence'],
            'template': None,
            'template_confidence': None,
            'sql': None,
            'reject': cat_result['reject'],
            'top_similarity': cat_result['top_similarity'],
        }

    # Step 2: match template within category
    q_emb = model.encode([query], normalize_embeddings=True)[0]
    temp_scores = {
        name: float(np.dot(q_emb, emb))
        for name, emb in TEMPLATE_EMBEDDINGS[best_cat].items()
    }
    best_temp = max(temp_scores, key=temp_scores.get)

    # Step 3: format SQL
    sql = SQL_TEMPLATES[best_cat][best_temp].format(client_id=client_id, days=days)

    return {
        'category': best_cat,
        'category_confidence': cat_result['confidence'],
        'template': best_temp,
        'template_confidence': temp_scores[best_temp],
        'sql': sql,
        'reject': cat_result['reject'],
        'top_similarity': cat_result['top_similarity'],
    }


In [10]:
# Interactive query
user_query = input("Ask a question: ")
result = query_to_sql(user_query)

if result['category'] == 'other/unknown' or result['category_confidence'] < 0.6:
    print('Other/unknown question')
    print('Routing prompt to LLMs')
else:
    print('Predicted category: {} ({:.3f})'.format(result['category'], result['category_confidence']))
    print('Template: {} ({:.3f})'.format(result['template'], result['template_confidence']))
    print('SQL: {}'.format(result['sql']))


Other/unknown question
Routing prompt to LLMs
