In [1]:
import sys
import os
import json
from pathlib import Path

# Add src to path
import logging


logger = logging.getLogger(__name__)
from src.core.factory import RetrieverFactory
from src.core.base import CodeExample
from src.retrievers.dense.database import CodeExampleDatabase

# Configuration
TRAINING_DATA_PATH = "data/database/high-resource/method2test/reformat_test.jsonl"
BENCHMARK_PATH = "data/benchmark/example_benchmark.json"
OUTPUT_DIR = "data/constructed_prompt"
DATABASE_SAVE_PATH = "data/database_index.pkl"


Could not register UniXcoderEmbedder
Could not register CodeExampleDatabase
Could not register FewShotTestGenerationPipeline


ModuleNotFoundError: No module named 'torch'

In [2]:
def _register_builtin_implementations():
    """Register built-in retriever implementations."""
    try:
        from src.retrievers.dense.embedder import UniXcoderEmbedder
        RetrieverFactory.register_embedder("unixcoder", UniXcoderEmbedder)
    except ImportError:
        logger.warning("Could not register UniXcoderEmbedder")
    
    try:
        from src.retrievers.dense.database import CodeExampleDatabase
        RetrieverFactory.register_database("dense_vector", CodeExampleDatabase)
    except ImportError:
        logger.warning("Could not register CodeExampleDatabase")
    
    try:
        from src.retrievers.fewshot_pipeline import FewShotTestGenerationPipeline
        RetrieverFactory.register_pipeline("few_shot", FewShotTestGenerationPipeline)
    except ImportError:
        logger.warning("Could not register FewShotTestGenerationPipeline")


# Register on module import
_register_builtin_implementations()

In [3]:
def load_training_data(file_path: str, max_examples: int = None) -> list:
    """
    Load training data from JSONL file.
    
    Args:
        file_path: Path to JSONL file
        max_examples: Maximum number of examples to load (None = load all)
        
    Returns:
        List of CodeExample objects
    """
    examples = []
    print(f"Loading training data from: {file_path}")
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if max_examples and i >= max_examples:
                    break
                
                try:
                    data = json.loads(line.strip())
                    
                    # Extract focal_method and unit_test
                    # Adjust field names based on your JSONL structure
                    focal_method = data.get('focal_method') or data.get('method') or data.get('code')
                    unit_test = data.get('unit_test') or data.get('test') or data.get('test_code')
                    
                    if focal_method and unit_test:
                        example = CodeExample(
                            focal_method=focal_method,
                            unit_test=unit_test,
                            metadata=data.get('metadata', {})
                        )
                        examples.append(example)
                    
                    # Progress indicator
                    if (i + 1) % 1000 == 0:
                        print(f"  Loaded {i + 1} examples...")
                        
                except json.JSONDecodeError:
                    print(f"  Warning: Skipping invalid JSON at line {i + 1}")
                    continue
    
    except FileNotFoundError:
        print(f"  Error: File not found: {file_path}")
        return []
    
    print(f"✓ Loaded {len(examples)} training examples")
    return examples


def load_benchmark(file_path: str) -> list:
    """
    Load benchmark test cases from JSON file.
    
    Args:
        file_path: Path to benchmark JSON file
        
    Returns:
        List of benchmark dictionaries
    """
    print(f"\nLoading benchmark from: {file_path}")
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            benchmark = json.load(f)
        
        print(f"✓ Loaded {len(benchmark)} benchmark cases")
        return benchmark
    
    except FileNotFoundError:
        print(f"  Error: File not found: {file_path}")
        return []
    except json.JSONDecodeError as e:
        print(f"  Error: Invalid JSON: {e}")
        return []


def save_prompt(prompt: str, output_path: str):
    """
    Save a generated prompt to file.
    
    Args:
        prompt: The prompt text
        output_path: Path to save the prompt
    """
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    with open(output_path, 'w', encoding='utf-8') as f:
        f.write(prompt)


In [4]:
def load_training_data(file_path: str, max_examples: int = None) -> list:
    """
    Load training data from JSONL file.
    
    Args:
        file_path: Path to JSONL file
        max_examples: Maximum number of examples to load (None = load all)
        
    Returns:
        List of CodeExample objects
    """
    examples = []
    print(f"Loading training data from: {file_path}")
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                if max_examples and i >= max_examples:
                    break
                
                try:
                    data = json.loads(line.strip())
                    
                    # Extract focal_method and unit_test
                    # Adjust field names based on your JSONL structure
                    focal_method = data.get('focal_method') 
                    unit_test = data.get('unit_test') 
                    
                    if focal_method and unit_test:
                        example = CodeExample(
                            focal_method=focal_method,
                            unit_test=unit_test,
                            metadata=data.get('metadata', {})
                        )
                        examples.append(example)
                    
                    # Progress indicator
                    if (i + 1) % 1000 == 0:
                        print(f"  Loaded {i + 1} examples...")
                        
                except json.JSONDecodeError:
                    print(f"  Warning: Skipping invalid JSON at line {i + 1}")
                    continue
    
    except FileNotFoundError:
        print(f"  Error: File not found: {file_path}")
        return []
    
    print(f"✓ Loaded {len(examples)} training examples")
    return examples


def load_benchmark(file_path: str) -> list:
    """
    Load benchmark test cases from JSON file.
    
    Args:
        file_path: Path to benchmark JSON file
        
    Returns:
        List of benchmark dictionaries
    """
    print(f"\nLoading benchmark from: {file_path}")
    
    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            benchmark = json.load(f)
        
        print(f"✓ Loaded {len(benchmark)} benchmark cases")
        return benchmark
    
    except FileNotFoundError:
        print(f"  Error: File not found: {file_path}")
        return []
    except json.JSONDecodeError as e:
        print(f"  Error: Invalid JSON: {e}")
        return []


def save_prompt(prompts: list[str], output_path: str):
    """
    Save a generated prompt to file.
    
    Args:
        prompt: The prompt text
        output_path: Path to save the prompt
    """
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    with open(output_path, 'w', encoding='utf-8') as f:
        for prompt in prompts:
            f.write(json.dumps(prompt) + "\n")


In [6]:
x=RetrieverFactory()
x.list_available_methods()

{'embedders': ['unixcoder'],
 'databases': ['dense_vector'],
 'pipelines': ['few_shot']}

In [7]:
pipeline = RetrieverFactory.create_full_pipeline(
    method="unixcoder",
    db_type="dense_vector",
    pipeline_type="few_shot",
    pipeline_kwargs={
        "top_k": 5,
        "similarity_threshold": 0.5
    }
)

  return torch._C._cuda_getDeviceCount() > 0


In [None]:
training_examples = load_training_data(
    TRAINING_DATA_PATH,
    max_examples=1000  # Change to None to load all
)

if not training_examples:
    print("⚠ No training data loaded. Please check the file path and format.")

print()

# Step 3: Build database index
print("Step 3: Building retrieval database...")
print("-" * 80)

print(f"Adding {len(training_examples)} examples to database...")
pipeline.database.add_examples_bulk(training_examples)

print("Building index (this may take a few minutes)...")
pipeline.database.build_index()

print(f"✓ Database built with {pipeline.database.size} examples")
print()

# Step 4: Save database index
print("Step 4: Saving database index...")
print("-" * 80)

os.makedirs(os.path.dirname(DATABASE_SAVE_PATH), exist_ok=True)
pipeline.database.save_index(DATABASE_SAVE_PATH)

print(f"✓ Database saved to: {DATABASE_SAVE_PATH}")
print()

# Step 5: Load benchmark
print("Step 5: Loading benchmark...")
print("-" * 80)

benchmark_cases = load_benchmark(BENCHMARK_PATH)

if not benchmark_cases:
    print("⚠ No benchmark cases loaded. Please check the file path.")



Loading training data from: data/database/high-resource/method2test/reformat_test.jsonl
  Loaded 1000 examples...
✓ Loaded 1000 training examples

Step 3: Building retrieval database...
--------------------------------------------------------------------------------
Adding 1000 examples to database...
Building index (this may take a few minutes)...


In [None]:
# Step 6: Generate prompts for each benchmark case
print("Step 6: Generating prompts for benchmark cases...")
print("-" * 80)

os.makedirs(OUTPUT_DIR, exist_ok=True)
prompts = []
for i, case in enumerate(benchmark_cases, 1):
    case_id = case.get('id', f'case_{i}')
    focal_method = case.get('focal_method')
    
    if not focal_method:
        print(f"  ⚠ Skipping case {case_id}: no focal_method found")
        continue
    
    print(f"  [{i}/{len(benchmark_cases)}] Generating prompt for: {case_id}")
    
    # Generate prompt using pipeline
    prompt = pipeline.run(focal_method)
    
    prompts.append(prompt)
    
    # Save prompt to file
# output_path = os.path.join(OUTPUT_DIR, "test_output.jsonl")
# save_prompt(prompts, output_path)
    
# print(f"      ✓ Saved to: {output_path}")

Step 6: Generating prompts for benchmark cases...
--------------------------------------------------------------------------------
  [1/5] Generating prompt for: test_001
  [2/5] Generating prompt for: test_002
  [3/5] Generating prompt for: test_003
  [4/5] Generating prompt for: test_004
  [5/5] Generating prompt for: test_005


TypeError: Object of type ndarray is not JSON serializable

In [None]:
prompts[0]

{'focal_method': 'def calculate_sum(numbers: list) -> int:\n    """Calculate the sum of a list of numbers."""\n    return sum(numbers)',
 'pipeline_stages': {'query_processing': {'focal_method': 'def calculate_sum(numbers: list) -> int:\n    """Calculate the sum of a list of numbers."""\n    return sum(numbers)',
   'embedding': array([-9.64255512e-01, -1.52129984e+00,  4.54460651e-01, -2.63846546e-01,
          -3.74762535e-01, -1.08061695e+00,  2.92344183e-01, -6.54486239e-01,
           2.01330805e+00, -2.69911826e-01,  8.35599303e-01,  1.89650965e+00,
          -7.20902026e-01, -1.16389886e-01, -3.97118837e-01,  8.41765106e-01,
           7.74191201e-01,  2.69913226e-01,  4.83628720e-01,  4.68551904e-01,
          -1.09527183e+00,  1.88282800e+00, -2.23696008e-01, -1.47829485e+00,
          -1.99194002e+00,  7.59038806e-01, -1.63609460e-01,  1.78241789e+00,
          -1.39221298e-02, -1.60750747e-01, -3.63634109e+00,  1.19903493e+00,
           1.20163485e-01, -1.75594175e+00, -8.6