# Theme Extraction & Labeling for Quote Fine-tuning

This notebook implements theme extraction from motivational quotes using Ollama and generates training data pairs for LLM fine-tuning.

**Processing Pipeline:**
1. Load unified quotes dataset (33,697 quotes)
2. Extract themes using Ollama (llama3.2:latest)
3. Generate instruction-output pairs with variations
4. Export as TorchTune-compatible JSON
5. Generate quality metrics and summary

**Target:** 50,000-100,000 training examples

## Setup & Dependencies

In [None]:
import pandas as pd
import json
import requests
import time
import logging
import os
from typing import List, Dict, Optional, Tuple
from tqdm import tqdm
from datetime import datetime
import re

In [None]:
# Configuration
BATCH_SIZE = 500  # Process 500 quotes per batch for memory efficiency
OLLAMA_URL = "http://localhost:11434/api/generate"
OLLAMA_MODEL = "llama3.2:latest"
MAX_RETRIES = 3
TIMEOUT = 30
FALLBACK_THEMES = ["motivation", "inspiration", "wisdom"]

# Paths
DATA_PATH = "../data/processed/unified_quotes_dataset.csv"
OUTPUT_DIR = "../data/training"
CACHE_FILE = os.path.join(OUTPUT_DIR, "theme_extraction_cache.json")

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Configuration loaded:")
print(f"- Batch size: {BATCH_SIZE}")
print(f"- Ollama URL: {OLLAMA_URL}")
print(f"- Model: {OLLAMA_MODEL}")
print(f"- Output directory: {OUTPUT_DIR}")

In [None]:
# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(OUTPUT_DIR, 'theme_extraction.log')),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

print("Logging configured")

## Instruction Template Variations

In [None]:
# Define instruction template variations
INSTRUCTION_TEMPLATES = [
    "Give me a quote about {theme}",
    "Share wisdom on {theme}",
    "What's an inspiring quote about {theme}?",
    "Provide motivation for {theme}",
    "Give me advice about {theme}",
    "Share an inspirational message about {theme}",
    "What would you say to motivate someone about {theme}?",
    "Inspire me with words about {theme}"
]

print(f"Instruction templates defined ({len(INSTRUCTION_TEMPLATES)} variations):")
for i, template in enumerate(INSTRUCTION_TEMPLATES, 1):
    print(f"  {i}. {template}")

## Data Loading & Preparation

In [None]:
# Load the unified quotes dataset
def load_quotes_dataset(file_path: str) -> pd.DataFrame:
    """
    Load the unified quotes dataset.
    """
    try:
        df = pd.read_csv(file_path)
        logger.info(f"Loaded {len(df)} quotes from {file_path}")
        
        # Display dataset info
        print(f"\nDataset Overview:")
        print(f"- Total quotes: {len(df):,}")
        print(f"- Columns: {list(df.columns)}")
        print(f"- Source distribution:")
        print(df['source_dataset'].value_counts().to_string())
        print(f"\nSample quotes:")
        print(df[['quote_text', 'author', 'source_dataset']].head().to_string())
        
        return df
        
    except Exception as e:
        logger.error(f"Error loading dataset: {e}")
        raise

# Load the dataset
quotes_df = load_quotes_dataset(DATA_PATH)
print(f"\nDataset loaded successfully: {len(quotes_df)} quotes")

In [None]:
# Prepare batches for processing
def create_batches(df: pd.DataFrame, batch_size: int) -> List[pd.DataFrame]:
    """
    Split the DataFrame into batches for processing.
    """
    batches = []
    for i in range(0, len(df), batch_size):
        batch = df.iloc[i:i+batch_size].copy()
        batch.reset_index(drop=True, inplace=True)
        batches.append(batch)
    
    logger.info(f"Created {len(batches)} batches of size {batch_size}")
    return batches

# Create batches
quote_batches = create_batches(quotes_df, BATCH_SIZE)
print(f"Created {len(quote_batches)} batches for processing")
print(f"Batch sizes: {[len(batch) for batch in quote_batches[:5]]}{'...' if len(quote_batches) > 5 else ''}")

## Ollama Integration & Theme Extraction

In [None]:
def test_ollama_connection() -> bool:
    """
    Test connection to Ollama and verify the model is available.
    """
    try:
        # Test basic connection
        response = requests.post(
            OLLAMA_URL,
            json={
                "model": OLLAMA_MODEL,
                "prompt": "Test connection. Respond with 'OK'.",
                "stream": False
            },
            timeout=10
        )
        
        if response.status_code == 200:
            result = response.json()
            logger.info(f"Ollama connection successful. Response: {result.get('response', '')[:50]}")
            return True
        else:
            logger.error(f"Ollama connection failed. Status: {response.status_code}")
            return False
            
    except Exception as e:
        logger.error(f"Error testing Ollama connection: {e}")
        return False

# Test connection
if test_ollama_connection():
    print("✅ Ollama connection successful")
else:
    print("❌ Ollama connection failed - please check if Ollama is running")
    print("Run: ollama serve")

In [None]:
def extract_themes_from_quote(quote_text: str, retries: int = MAX_RETRIES) -> List[str]:
    """
    Extract exactly 2 themes from a quote using Ollama.
    """
    prompt = f"""Without thinking, analyze this motivational quote below and extract exactly 2 core themes. Return only the themes as a comma-separated list and nothing else.

Quote: "{quote_text}"

Return only the themes."""
    
    for attempt in range(retries):
        try:
            response = requests.post(
                OLLAMA_URL,
                json={
                    "model": OLLAMA_MODEL,
                    "prompt": prompt,
                    "stream": False
                },
                timeout=TIMEOUT
            )
            
            if response.status_code == 200:
                result = response.json()
                themes_text = result.get('response', '').strip()
                
                # Parse themes from response
                themes = parse_themes_response(themes_text)
                
                if themes:
                    return themes
                    
            logger.warning(f"Attempt {attempt + 1} failed for quote: {quote_text[:50]}...")
            
        except Exception as e:
            logger.warning(f"Attempt {attempt + 1} error: {e}")
            
        if attempt < retries - 1:
            time.sleep(1)  # Brief delay before retry
    
    # Return fallback themes if all attempts failed
    logger.error(f"All attempts failed for quote: {quote_text[:50]}... Using fallback themes.")
    return FALLBACK_THEMES[:2]  # Return 2 fallback themes


def parse_themes_response(response_text: str) -> List[str]:
    """
    Parse themes from Ollama response text, limiting to exactly 2 themes.
    """
    if not response_text:
        return []
    
    # Clean the response
    response_text = response_text.strip()
    
    # Split by comma and clean each theme
    themes = []
    for theme in response_text.split(','):
        theme = theme.strip().lower()
        # Remove quotes, periods, and extra whitespace
        theme = re.sub(r'^["\']|["\']$', '', theme)
        theme = re.sub(r'[.!?]$', '', theme)
        theme = re.sub(r'\s+', ' ', theme)
        
        # Filter out themes with more than 10 words (conversational responses)
        word_count = len(theme.split())
        
        if theme and len(theme) > 2 and word_count <= 10:  # Valid theme: >2 chars, ≤10 words
            themes.append(theme)
            
        if len(themes) >= 2:  # Limit to exactly 2 themes
            break
    
    return themes

# Test theme extraction with a sample quote
sample_quote = "The only impossible journey is the one you never begin."
sample_themes = extract_themes_from_quote(sample_quote)
print(f"Sample extraction:")
print(f"Quote: {sample_quote}")
print(f"Themes: {sample_themes}")
print(f"Number of themes: {len(sample_themes)}")

## Batch Processing with Caching

In [None]:
def load_cache() -> Dict[str, List[str]]:
    """
    Load theme extraction cache to avoid reprocessing.
    """
    if os.path.exists(CACHE_FILE):
        try:
            with open(CACHE_FILE, 'r') as f:
                cache = json.load(f)
            logger.info(f"Loaded cache with {len(cache)} entries")
            return cache
        except Exception as e:
            logger.warning(f"Error loading cache: {e}")
    
    return {}


def save_cache(cache: Dict[str, List[str]]) -> None:
    """
    Save theme extraction cache.
    """
    try:
        with open(CACHE_FILE, 'w') as f:
            json.dump(cache, f, indent=2)
        logger.info(f"Saved cache with {len(cache)} entries")
    except Exception as e:
        logger.error(f"Error saving cache: {e}")


def process_batch(batch_df: pd.DataFrame, cache: Dict[str, List[str]], batch_num: int) -> Dict[str, List[str]]:
    """
    Process a batch of quotes for theme extraction.
    """
    logger.info(f"Processing batch {batch_num} with {len(batch_df)} quotes")
    
    batch_cache = cache.copy()
    processed_count = 0
    cached_count = 0
    
    for idx, row in tqdm(batch_df.iterrows(), total=len(batch_df), 
                        desc=f"Batch {batch_num}", leave=False):
        quote_text = row['quote_text']
        
        # Check cache first
        if quote_text in batch_cache:
            cached_count += 1
            continue
        
        # Extract themes
        themes = extract_themes_from_quote(quote_text)
        batch_cache[quote_text] = themes
        processed_count += 1
        
        # Brief pause to avoid overwhelming Ollama
        time.sleep(0.1)
    
    logger.info(f"Batch {batch_num} complete: {processed_count} processed, {cached_count} cached")
    return batch_cache

# Load existing cache
theme_cache = load_cache()
print(f"Loaded cache with {len(theme_cache)} existing entries")

In [None]:
# Process all batches
def process_all_batches(batches: List[pd.DataFrame], cache: Dict[str, List[str]]) -> Dict[str, List[str]]:
    """
    Process all batches for theme extraction.
    """
    total_batches = len(batches)
    start_time = time.time()
    
    logger.info(f"Starting processing of {total_batches} batches")
    
    for batch_num, batch_df in enumerate(tqdm(batches, desc="Processing batches"), 1):
        cache = process_batch(batch_df, cache, batch_num)
        
        # Save cache after each batch
        save_cache(cache)
        logger.info(f"Saved cache after batch {batch_num}")
    
    # Final cache save
    save_cache(cache)
    
    elapsed_time = time.time() - start_time
    logger.info(f"All batches processed in {elapsed_time:.2f} seconds")
    logger.info(f"Final cache size: {len(cache)} entries")
    
    return cache

# Process all batches
print(f"\nStarting theme extraction for {len(quote_batches)} batches...")
print(f"This may take 1-2 hours for the full dataset")

theme_cache = process_all_batches(quote_batches, theme_cache)
print(f"\n✅ Theme extraction complete!")
print(f"Total themes extracted: {len(theme_cache)}")

## Training Data Generation

In [None]:
def generate_training_pairs(quote_text: str, themes: List[str]) -> List[Dict[str, str]]:
    """
    Generate training pairs for a quote and its themes.
    Now uses 1-2 random instruction templates per theme for more manageable dataset size.
    """
    pairs = []
    
    for theme in themes:
        # Select 1-2 random instruction templates per theme (instead of 2-3)
        import random
        num_templates = random.choice([1, 2])  # Randomly choose 1 or 2 templates
        selected_templates = random.sample(INSTRUCTION_TEMPLATES, 
                                         min(num_templates, len(INSTRUCTION_TEMPLATES)))
        
        for template in selected_templates:
            instruction = template.format(theme=theme)
            
            pair = {
                "input": instruction,
                "output": quote_text
            }
            pairs.append(pair)
    
    return pairs


def create_training_dataset(quotes_df: pd.DataFrame, theme_cache: Dict[str, List[str]]) -> List[Dict[str, str]]:
    """
    Create the complete training dataset.
    """
    training_pairs = []
    quotes_processed = 0
    quotes_skipped = 0
    
    logger.info("Generating training pairs...")
    
    for idx, row in tqdm(quotes_df.iterrows(), total=len(quotes_df), desc="Creating training pairs"):
        quote_text = row['quote_text']
        
        # Get themes from cache
        if quote_text in theme_cache:
            themes = theme_cache[quote_text]
            
            if themes:  # Only process if themes were extracted
                pairs = generate_training_pairs(quote_text, themes)
                training_pairs.extend(pairs)
                quotes_processed += 1
            else:
                quotes_skipped += 1
        else:
            quotes_skipped += 1
    
    logger.info(f"Training dataset created:")
    logger.info(f"- Quotes processed: {quotes_processed}")
    logger.info(f"- Quotes skipped: {quotes_skipped}")
    logger.info(f"- Training pairs generated: {len(training_pairs)}")
    
    return training_pairs

In [None]:
# Generate training dataset from current cache (first batch processed)
print("📚 Generating training dataset from current cache...")

# Load current cache
theme_cache = load_cache()
print(f"📊 Cache loaded: {len(theme_cache)} quotes processed ({len(theme_cache)/len(quotes_df)*100:.1f}% of dataset)")

# Create training dataset using the current theme cache
# Note: This will only process quotes that have themes in the cache
training_dataset = create_training_dataset(quotes_df, theme_cache)

print(f"✅ Training dataset created with {len(training_dataset)} pairs")
print(f"Average pairs per processed quote: {len(training_dataset)/len(theme_cache):.1f}")

# Show sample training pairs to verify the optimized format
print("\n🔍 Sample training pairs (showing optimized 2-theme system):")
for i, pair in enumerate(training_dataset[:8], 1):
    print(f"\n{i}. Input: {pair['input']}")
    print(f"   Output: {pair['output'][:80]}{'...' if len(pair['output']) > 80 else ''}")

In [None]:
# Generate training dataset from current cache
print("📚 Generating training dataset from current cache...")

# Create training dataset using the current theme cache
training_dataset = create_training_dataset(quotes_df, theme_cache)

print(f"✅ Training dataset created with {len(training_dataset)} pairs")

# Show sample training pairs
print("\nSample training pairs:")
for i, pair in enumerate(training_dataset[:10], 1):
    print(f"\n{i}. Input: {pair['input']}")
    print(f"   Output: {pair['output']}")

## Export & Validation

In [None]:
def validate_training_data(training_data: List[Dict[str, str]]) -> Dict[str, any]:
    """
    Validate the training dataset and return quality metrics.
    """
    if not training_data:
        return {"valid": False, "error": "Empty dataset"}
    
    # Validation checks
    valid_pairs = 0
    input_lengths = []
    output_lengths = []
    duplicate_pairs = set()
    
    for pair in training_data:
        # Check required keys
        if "input" not in pair or "output" not in pair:
            continue
            
        # Check for non-empty values
        if not pair["input"] or not pair["output"]:
            continue
            
        valid_pairs += 1
        input_lengths.append(len(pair["input"]))
        output_lengths.append(len(pair["output"]))
        
        # Check for duplicates
        pair_signature = f"{pair['input']}|{pair['output']}"
        if pair_signature in duplicate_pairs:
            continue
        duplicate_pairs.add(pair_signature)
    
    metrics = {
        "valid": True,
        "total_pairs": len(training_data),
        "valid_pairs": valid_pairs,
        "unique_pairs": len(duplicate_pairs),
        "avg_input_length": sum(input_lengths) / len(input_lengths) if input_lengths else 0,
        "avg_output_length": sum(output_lengths) / len(output_lengths) if output_lengths else 0,
        "min_input_length": min(input_lengths) if input_lengths else 0,
        "max_input_length": max(input_lengths) if input_lengths else 0,
        "min_output_length": min(output_lengths) if output_lengths else 0,
        "max_output_length": max(output_lengths) if output_lengths else 0
    }
    
    return metrics


def export_training_data(training_data: List[Dict[str, str]], output_path: str) -> bool:
    """
    Export training data as TorchTune-compatible JSON.
    """
    try:
        # Create backup
        backup_path = output_path.replace('.json', '_backup.json')
        
        with open(output_path, 'w') as f:
            json.dump(training_data, f, indent=2, ensure_ascii=False)
        
        with open(backup_path, 'w') as f:
            json.dump(training_data, f, indent=2, ensure_ascii=False)
        
        logger.info(f"Training data exported to {output_path}")
        logger.info(f"Backup created at {backup_path}")
        
        return True
        
    except Exception as e:
        logger.error(f"Error exporting training data: {e}")
        return False

# Validate training data
print("\nValidating training dataset...")
validation_results = validate_training_data(training_dataset)

if validation_results["valid"]:
    print("✅ Validation passed!")
    print(f"\nValidation Results:")
    for key, value in validation_results.items():
        if key != "valid":
            print(f"- {key.replace('_', ' ').title()}: {value:,.2f}" if isinstance(value, float) else f"- {key.replace('_', ' ').title()}: {value:,}")
else:
    print(f"❌ Validation failed: {validation_results.get('error', 'Unknown error')}")

In [None]:
# Export training data
output_file = os.path.join(OUTPUT_DIR, "theme_labeled_dataset.json")

print(f"\nExporting training dataset...")
if export_training_data(training_dataset, output_file):
    print(f"✅ Export successful!")
    print(f"Training data saved to: {output_file}")
    
    # Show file size
    file_size = os.path.getsize(output_file) / (1024 * 1024)  # MB
    print(f"File size: {file_size:.1f} MB")
else:
    print("❌ Export failed")

## Quality Metrics & Summary Report

In [None]:
def analyze_theme_distribution(theme_cache: Dict[str, List[str]]) -> Dict[str, int]:
    """
    Analyze the distribution of extracted themes.
    """
    theme_counts = {}
    
    for themes in theme_cache.values():
        for theme in themes:
            theme_counts[theme] = theme_counts.get(theme, 0) + 1
    
    # Sort by frequency
    sorted_themes = dict(sorted(theme_counts.items(), key=lambda x: x[1], reverse=True))
    
    return sorted_themes


def generate_summary_report(quotes_df: pd.DataFrame, theme_cache: Dict[str, List[str]], 
                           training_data: List[Dict[str, str]], validation_results: Dict[str, any]) -> Dict[str, any]:
    """
    Generate comprehensive summary report.
    """
    theme_distribution = analyze_theme_distribution(theme_cache)
    
    # Calculate success rates
    total_quotes = len(quotes_df)
    processed_quotes = len(theme_cache)
    success_rate = (processed_quotes / total_quotes) * 100 if total_quotes > 0 else 0
    
    # Theme statistics
    themes_per_quote = [len(themes) for themes in theme_cache.values() if themes]
    avg_themes_per_quote = sum(themes_per_quote) / len(themes_per_quote) if themes_per_quote else 0
    
    # Instruction template usage
    template_usage = {}
    for pair in training_data:
        input_text = pair["input"]
        for i, template in enumerate(INSTRUCTION_TEMPLATES):
            # Extract the theme part to check template usage
            template_pattern = template.replace("{theme}", ".*")
            if re.match(template_pattern, input_text):
                template_usage[f"Template {i+1}"] = template_usage.get(f"Template {i+1}", 0) + 1
                break
    
    report = {
        "processing_summary": {
            "total_quotes": total_quotes,
            "quotes_processed": processed_quotes,
            "success_rate_percent": round(success_rate, 2),
            "processing_timestamp": datetime.now().isoformat()
        },
        "theme_extraction": {
            "unique_themes_extracted": len(theme_distribution),
            "avg_themes_per_quote": round(avg_themes_per_quote, 2),
            "top_10_themes": dict(list(theme_distribution.items())[:10]),
            "theme_distribution_stats": {
                "most_common": max(theme_distribution.values()) if theme_distribution else 0,
                "least_common": min(theme_distribution.values()) if theme_distribution else 0,
                "median_frequency": sorted(theme_distribution.values())[len(theme_distribution)//2] if theme_distribution else 0
            }
        },
        "training_data": {
            "total_training_pairs": len(training_data),
            "pairs_per_quote_avg": round(len(training_data) / total_quotes, 2) if total_quotes > 0 else 0,
            "validation_results": validation_results,
            "template_usage": template_usage
        },
        "file_outputs": {
            "training_dataset": output_file,
            "cache_file": CACHE_FILE,
            "log_file": os.path.join(OUTPUT_DIR, 'theme_extraction.log')
        }
    }
    
    return report

# Generate summary report
print("\nGenerating summary report...")
summary_report = generate_summary_report(quotes_df, theme_cache, training_dataset, validation_results)

# Save summary report
summary_file = os.path.join(OUTPUT_DIR, "theme_extraction_summary.json")
with open(summary_file, 'w') as f:
    json.dump(summary_report, f, indent=2, ensure_ascii=False)

print(f"✅ Summary report saved to: {summary_file}")