# 03. Retrieval Testing

This notebook tests the retrieval accuracy of the YOLO Code Assistant.

In [None]:
# Setup
import sys
from pathlib import Path
import pandas as pd

sys.path.append(str(Path.cwd().parent))

from src.yolo_assistant.config import config
from src.yolo_assistant.storage import MongoDBVectorStore
from src.yolo_assistant.retrieval import CodeSearcher, ResultRanker
from src.yolo_assistant.indexer import CodeEmbedder

## 1. Connect to MongoDB and Check Status

In [None]:
# Initialize components
vector_store = MongoDBVectorStore()
embedder = CodeEmbedder()
searcher = CodeSearcher(vector_store, embedder)
ranker = ResultRanker()

# Connect and check status
try:
    vector_store.connect()
    stats = vector_store.get_statistics()
    
    print("MongoDB Vector Store Statistics:")
    print(f"  Total chunks: {stats['total_chunks']}")
    print(f"  Files indexed: {stats['files_indexed']}")
    print(f"  Index status: {stats['index_status']}")
    print("\nChunks by type:")
    for chunk_type, count in stats['chunks_by_type'].items():
        print(f"    {chunk_type}: {count}")
except Exception as e:
    print(f"Error connecting to MongoDB: {e}")
    print("Make sure you have:")
    print("  1. Set MONGODB_URI in .env")
    print("  2. Run 'python main.py --index' to index the codebase")

## 2. Test Basic Search Queries

In [None]:
# Test queries
test_queries = [
    "How to train a YOLO model?",
    "YOLO model architecture",
    "Data augmentation in YOLO",
    "Export model to ONNX",
    "Calculate mAP metric"
]

# Perform searches
for query in test_queries:
    print(f"\n{'='*60}")
    print(f"Query: {query}")
    print(f"{'='*60}")
    
    # Search
    results = searcher.search(query, limit=3)
    
    # Rank results
    ranked_results = ranker.rank_results(results, query)
    
    # Enhance with summaries
    enhanced_results = ranker.enhance_results_with_summary(ranked_results)
    
    # Display results
    if enhanced_results:
        for i, result in enumerate(enhanced_results, 1):
            print(f"\n[{i}] {result.get('summary', 'No summary')}")
            print(f"    Score: {result.get('search_score', 0):.4f}")
            if result.get('docstring'):
                print(f"    Docstring: {result['docstring'][:100]}...")
    else:
        print("  No results found")

## 3. Test Specific Search Types

In [None]:
# Test function search
print("Function Search Test:")
print("-" * 40)

function_results = searcher.search_by_function_name("train", limit=5)
print(f"Found {len(function_results)} functions with 'train' in name:")
for result in function_results:
    print(f"  - {result['name']} ({result['file_path'].split('/')[-1]}:{result['start_line']})")

# Test class search
print("\nClass Search Test:")
print("-" * 40)

class_results = searcher.search_by_class_name("YOLO", include_methods=True)
print(f"Found {len(class_results)} classes/methods with 'YOLO' in name:")

# Group by class
classes = {}
for result in class_results:
    if result['chunk_type'] == 'class':
        classes[result['name']] = {'methods': []}
    elif result['chunk_type'] == 'method' and result.get('parent_class'):
        if result['parent_class'] not in classes:
            classes[result['parent_class']] = {'methods': []}
        classes[result['parent_class']]['methods'].append(result['name'])

for class_name, info in classes.items():
    print(f"  Class: {class_name}")
    for method in info['methods'][:5]:  # Show first 5 methods
        print(f"    - {method}()")

## 4. Test Retrieval Performance

In [None]:
import time

# Performance test queries
perf_queries = [
    "train model",
    "data loader",
    "loss function",
    "model export",
    "validation metrics"
]

# Measure performance
performance_data = []

for query in perf_queries:
    # Vector search
    start_time = time.time()
    vector_results = searcher.search(query, use_embeddings=True)
    vector_time = time.time() - start_time
    
    # Text search
    start_time = time.time()
    text_results = searcher.search(query, use_embeddings=False)
    text_time = time.time() - start_time
    
    performance_data.append({
        'Query': query,
        'Vector Search Time (ms)': vector_time * 1000,
        'Vector Results': len(vector_results),
        'Text Search Time (ms)': text_time * 1000,
        'Text Results': len(text_results)
    })

# Display results
df = pd.DataFrame(performance_data)
print("Retrieval Performance Comparison:")
print(df.to_string(index=False))

# Summary statistics
print("\nSummary:")
print(f"Average vector search time: {df['Vector Search Time (ms)'].mean():.2f} ms")
print(f"Average text search time: {df['Text Search Time (ms)'].mean():.2f} ms")

## 5. Test Result Ranking

In [None]:
# Test ranking effectiveness
test_query = "How to train a YOLO model with custom dataset?"

# Get search results
raw_results = searcher.search(test_query, limit=10)

# Apply ranking
ranked_results = ranker.rank_results(raw_results, test_query)

# Compare before and after ranking
print(f"Query: {test_query}")
print("\nTop 5 results before ranking:")
for i, result in enumerate(raw_results[:5], 1):
    print(f"{i}. {result['name']} (score: {result.get('search_score', 0):.4f})")

print("\nTop 5 results after ranking:")
for i, result in enumerate(ranked_results[:5], 1):
    combined_score = (
        result.get('search_score', 0) * 0.7 +
        result.get('relevance_score', 0) * 0.3
    )
    print(f"{i}. {result['name']} (combined score: {combined_score:.4f})")
    print(f"   Vector: {result.get('search_score', 0):.4f}, Relevance: {result.get('relevance_score', 0):.4f}")

## 6. Test Context Retrieval

In [None]:
# Test retrieving with context
context_query = "model validation"

# Search with context
results_with_context = searcher.search_with_context(context_query, context_size=2)

# Display results with context
print(f"Query: {context_query}")
print(f"Found {len(results_with_context)} results\n")

for i, result in enumerate(results_with_context[:2], 1):
    print(f"Result {i}: {result['name']}")
    print(f"  File: {result['file_path']}")
    print(f"  Lines: {result['start_line']}-{result['end_line']}")
    
    if 'context_chunks' in result:
        print(f"  Context ({len(result['context_chunks'])} surrounding chunks):")
        for ctx in result['context_chunks']:
            print(f"    - {ctx['chunk_type']} {ctx['name']} (lines {ctx['start_line']}-{ctx['end_line']})")
    print()