In [None]:
import sys
import os
import json
from pathlib import Path

# Add src to path
import logging

import datasets
logger = logging.getLogger(__name__)
from src.core.factory import RetrieverFactory
from src.core.base import CodeExample
from src.retrievers.dense.database import CodeExampleDatabase

# Configuration
TRAINING_DATA_PATH = "src/data/database/high-resource/method2test/reformat_test.jsonl"
BENCHMARK_REPO = "Tessera2025/Tessera2025"
OUTPUT_DIR = "src/data/constructed_prompt"
DATABASE_SAVE_PATH = "src/data/database/unixcoder/database_index.pkl"
EMBEDDER_NAME="unixcoder"


In [2]:
data=datasets.load_dataset(BENCHMARK_REPO,trust_remote_code=True)

In [6]:
def _register_builtin_implementations():
    """Register built-in retriever implementations."""
    try:
        from src.retrievers.dense.embedder import UniXcoderEmbedder
        RetrieverFactory.register_embedder(EMBEDDER_NAME, UniXcoderEmbedder)
    except ImportError:
        logger.warning("Could not register UniXcoderEmbedder")
    
    try:
        from src.retrievers.dense.database import CodeExampleDatabase
        RetrieverFactory.register_database("dense_vector", CodeExampleDatabase)
    except ImportError:
        logger.warning("Could not register CodeExampleDatabase")
    
    try:
        from src.retrievers.fewshot_pipeline import FewShotTestGenerationPipeline
        RetrieverFactory.register_pipeline("few_shot", FewShotTestGenerationPipeline)
    except ImportError:
        logger.warning("Could not register FewShotTestGenerationPipeline")


# Register on module import
_register_builtin_implementations()

In [7]:
def load_training_data(file_path: str, max_examples: int = None) -> list:
    """
    Load training data from JSONL file.
    
    Args:
        file_path: Path to JSONL file
        max_examples: Maximum number of examples to load (None = load all)
        
    Returns:
        List of CodeExample objects
    """
    examples = []
    print(f"Loading training data from: {file_path}")
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if max_examples and i >= max_examples:
                    break
                
                try:
                    data = json.loads(line.strip())
                    
                    # Extract focal_method and unit_test
                    # Adjust field names based on your JSONL structure
                    focal_method = data.get('focal_method') 
                    unit_test = data.get('unit_test') 
                    
                    if focal_method and unit_test:
                        example = CodeExample(
                            focal_method=focal_method,
                            unit_test=unit_test,
                            metadata=data.get('metadata', {})
                        )
                        examples.append(example)
                    
                    # Progress indicator
                    if (i + 1) % 1000 == 0:
                        print(f"  Loaded {i + 1} examples...")
                        
                except json.JSONDecodeError:
                    print(f"  Warning: Skipping invalid JSON at line {i + 1}")
                    continue
    
    except FileNotFoundError:
        print(f"  Error: File not found: {file_path}")
        return []
    
    print(f"✓ Loaded {len(examples)} training examples")
    return examples


def load_benchmark(repo_path: str) -> list:
    """
    Load benchmark test cases from JSON file.
    
    Args:
        file_path: Path to benchmark JSON file
        
    Returns:
        List of benchmark dictionaries
    """
    
    try:
        dataset = datasets.load_dataset(repo_path)
        benchmark_rust=dataset["rust"].to_list()
        benchmark_go=dataset["go"].to_list()
        benchmark_julia=dataset["julia"].to_list()

        
        return benchmark_rust, benchmark_go, benchmark_julia
    
    except FileNotFoundError:
        print(f"  Error: File not found: {repo_path}")
        return []
    except json.JSONDecodeError as e:
        print(f"  Error: Invalid JSON: {e}")
        return []

def to_jsonable(obj):
    if isinstance(obj, list):
        return [to_jsonable(o) for o in obj]
    if isinstance(obj, dict):
        return {k: to_jsonable(v) for k, v in obj.items()}
    if hasattr(obj, "to_dict"):
        return obj.to_dict()
    return obj

def save_benchmark(benchmark, output_path: str):
    """
    Save a generated prompt to file.
    
    Args:
        prompt: The prompt text
        output_path: Path to save the prompt
    """
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    jsonable_benchmark = [to_jsonable(d) for d in benchmark]


    with open(output_path, 'w', encoding='utf-8') as f:
        for item in jsonable_benchmark:
            f.write(json.dumps(item) + "\n")


In [8]:
x=RetrieverFactory()
x.list_available_methods()

{'embedders': ['unixcoder'],
 'databases': ['dense_vector'],
 'pipelines': ['few_shot']}

In [9]:
pipeline = RetrieverFactory.create_full_pipeline(
    method=EMBEDDER_NAME,
    db_type="dense_vector",
    pipeline_type="few_shot",
    pipeline_kwargs={
        "top_k": 3,
        "similarity_threshold": 0
    }
)

In [10]:
# Step 2-4: Load or Build database index
print("Steps 2-4: Loading or Building retrieval database...")
print("-" * 80)

# Check if database already exists
if os.path.exists(DATABASE_SAVE_PATH):
    print(f"✓ Found existing database at: {DATABASE_SAVE_PATH}")
    print("Loading database index...")
    
    try:
        pipeline.database.load_index(DATABASE_SAVE_PATH)
        print(f"✓ Database loaded successfully with {pipeline.database.size} examples")
    except Exception as e:
        print(f"⚠ Error loading database: {e}")
        print("Building new database instead...")
        
        # # Load training data
        # training_examples = load_training_data(
        #     TRAINING_DATA_PATH,
        #     max_examples=None  # Change to None to load all
        # )
        
        # if not training_examples:
        #     print("⚠ No training data loaded. Please check the file path and format.")
        # else:
        #     # Build database
        #     print(f"Adding {len(training_examples)} examples to database...")
        #     pipeline.database.add_examples_bulk(training_examples)
            
        #     print("Building index (this may take a few minutes)...")
        #     pipeline.database.build_index()
            
        #     print(f"✓ Database built with {pipeline.database.size} examples")
            
        #     # Save database
        #     os.makedirs(os.path.dirname(DATABASE_SAVE_PATH), exist_ok=True)
        #     pipeline.database.save_index(DATABASE_SAVE_PATH)
        #     print(f"✓ Database saved to: {DATABASE_SAVE_PATH}")
else:
    print(f"✗ No existing database found at: {DATABASE_SAVE_PATH}")
    print("Building new database...")
    
    # Load training data
    training_examples = load_training_data(
        TRAINING_DATA_PATH,
        max_examples=None  # Change to None to load all
    )
    
    if not training_examples:
        print("⚠ No training data loaded. Please check the file path and format.")
    else:
        # Build database
        print(f"Adding {len(training_examples)} examples to database...")
        pipeline.database.add_examples_bulk(training_examples)
        
        print("Building index (this may take a few minutes)...")
        pipeline.database.build_index()
        
        print(f"✓ Database built with {pipeline.database.size} examples")
        
        # Save database
        print("Saving database index...")
        os.makedirs(os.path.dirname(DATABASE_SAVE_PATH), exist_ok=True)
        pipeline.database.save_index(DATABASE_SAVE_PATH)
        print(f"✓ Database saved to: {DATABASE_SAVE_PATH}")

print()

# Step 5: Load benchmark
print("Step 5: Loading benchmark...")
print("-" * 80)


Steps 2-4: Loading or Building retrieval database...
--------------------------------------------------------------------------------
✗ No existing database found at: data/database_index.pkl
Building new database...
Loading training data from: src/data/database/high-resource/method2test/reformat_test.jsonl
  Loaded 1000 examples...
  Loaded 2000 examples...
  Loaded 3000 examples...
  Loaded 4000 examples...
  Loaded 5000 examples...
  Loaded 6000 examples...
  Loaded 7000 examples...
  Loaded 8000 examples...
  Loaded 9000 examples...
  Loaded 10000 examples...
  Loaded 11000 examples...
  Loaded 12000 examples...
  Loaded 13000 examples...
  Loaded 14000 examples...
  Loaded 15000 examples...
  Loaded 16000 examples...
  Loaded 17000 examples...
  Loaded 18000 examples...
  Loaded 19000 examples...
  Loaded 20000 examples...
  Loaded 21000 examples...
  Loaded 22000 examples...
  Loaded 23000 examples...
  Loaded 24000 examples...
  Loaded 25000 examples...
  Loaded 26000 examples...

100%|██████████| 7889/7889 [32:43<00:00,  4.02it/s]



✓ Database built with 63107 examples
Saving database index...
✓ Database saved to: data/database_index.pkl

Step 5: Loading benchmark...
--------------------------------------------------------------------------------
✓ Database saved to: data/database_index.pkl

Step 5: Loading benchmark...
--------------------------------------------------------------------------------


In [11]:
benchmark_rust,benchmark_go,benchmark_julia = load_benchmark(BENCHMARK_REPO)
    
if not (benchmark_go and benchmark_rust and benchmark_julia):
    print("⚠ No benchmark cases loaded. Please check the file path.")


In [12]:
# Step 6: Generate prompts for each benchmark case
print("Step 6: Generating prompts for benchmark cases...")
print("-" * 80)

os.makedirs(OUTPUT_DIR, exist_ok=True)
prompts = []
for lang in ['rust','go','julia']:
    if lang=='rust':
        benchmark_cases=benchmark_rust
    elif lang=='go':
        benchmark_cases=benchmark_go
    else:
        benchmark_cases=benchmark_julia
    output_path = os.path.join(OUTPUT_DIR, EMBEDDER_NAME, lang,"data_with_fewshot.jsonl")
    results=[]
    for i, case in enumerate(benchmark_cases, 1):
        case_id = case.get('id', f'case_{i}')
        focal_method = case.get('focal_code')
        
        if not focal_method:
            print(f"  ⚠ Skipping case {case_id}: no focal_method found")
            continue
            
        # Generate prompt using pipeline
        prompt = pipeline.run(focal_method)

        result={"id":case_id,"retrieved_context":prompt}
        results.append(result)
    # Save prompt to file
    save_benchmark(results, output_path)

# print(f"      ✓ Saved to: {output_path}")

Step 6: Generating prompts for benchmark cases...
--------------------------------------------------------------------------------
