In [None]:
# clone repo and checkout locally

# load python files from directory

# split into documents

# embed and store in vectorstore
# metadata: name (e.g. MyClass.my_function) 

In [1]:
from datasets import load_dataset

DATASET_ID = "princeton-nlp/SWE-bench_Lite"
SPLIT = "test"
INSTANCE_ID = "sympy__sympy-20442"
RUN_ID = "v0"

dataset = load_dataset(DATASET_ID, split=SPLIT)
instance_details = [r for r in dataset if r["instance_id"] == INSTANCE_ID][0]

In [2]:
import git
import os
from typing import List, Tuple, Dict

def clone_and_checkout(instance_details: dict) -> str:
    local_path = os.path.join("../.repos/", instance_details["repo"].split("/")[-1])
    repo_url = f"https://github.com/{instance_details['repo']}.git"
    commit_sha = instance_details["base_commit"]
    # Clone the repository if it doesn't exist
    if not os.path.exists(local_path):
        print(f"Cloning repository from {repo_url}...")
        repo = git.Repo.clone_from(repo_url, local_path)
        print("Repository cloned successfully!")
    else:
        print(f"Repository already exists at {local_path}")
        repo = git.Repo(local_path)

    # Fetch all remote branches
    print("Fetching all remote branches...")
    repo.remotes.origin.fetch()

    # Checkout the specific commit
    repo.git.checkout(commit_sha)
    print(f"Successfully checked out commit {commit_sha}")
    return local_path

repo_path = clone_and_checkout(instance_details)

Repository already exists at ../.repos/sympy
Fetching all remote branches...
Successfully checked out commit 1abbc0ac3e552cb184317194e5d5c5b9dd8fb640


In [3]:
from tree_sitter import Language, Parser
import tree_sitter_python as tspython

import os
from pydantic import BaseModel
from typing import Optional, List
import re
import hashlib

class CodeBlock(BaseModel):
    name: str  # Full path (e.g., "module.class.function")
    type: str  # "class" or "function"
    code: str  # Complete code block
    docstring: Optional[str]
    file_path: str
    start_line: int
    end_line: int
    parent: Optional[str]  # Parent class/module name
    category: str  # tests / src
    id: str  # hash of the code

def setup_parser():
    """Initialize the tree-sitter parser for Python."""
    PY_LANGUAGE = Language(tspython.language())
    parser = Parser(PY_LANGUAGE)
    return parser

def extract_docstring(node, source_code):
    """Extract docstring from a class or function node."""
    for child in node.children:
        if child.type == 'block':
            for block_child in child.children:
                if block_child.type == 'expression_statement':
                    for expr_child in block_child.children:
                        if expr_child.type == 'string':
                            return source_code[expr_child.start_byte:expr_child.end_byte].strip('\"\'')
    return None

def get_node_source(node, source_code):
    """Get the source code for a node."""
    return source_code[node.start_byte:node.end_byte]

def process_file(file_path: str, parser: Parser) -> List[CodeBlock]:
    """Process a single Python file and extract all code blocks."""
    with open(file_path, 'r', encoding='utf-8') as f:
        source_code = f.read()

    tree = parser.parse(bytes(source_code, "utf8"))
    blocks = []
    
    def process_node(node, parent_name=None):

        code = get_node_source(node, source_code)
        id = hashlib.md5(code.encode('utf-8')).hexdigest()
        category = "tests" if "test" in file_path else "src"

        if node.type == 'class_definition':
            name_node = next(child for child in node.children if child.type == 'identifier')
            class_name = source_code[name_node.start_byte:name_node.end_byte]
            
            full_name = f"{parent_name}.{class_name}" if parent_name else class_name
            
            blocks.append(CodeBlock(
                name=full_name,
                type="class",
                code=code,
                docstring=extract_docstring(node, source_code),
                file_path=file_path,
                start_line=node.start_point[0] + 1,
                end_line=node.end_point[0] + 1,
                parent=parent_name,
                category=category,
                id=id,
            ))
            
            # Process methods within the class
            for child in node.children:
                if child.type == 'block':
                    for block_child in child.children:
                        process_node(block_child, full_name)
                        
        elif node.type == 'function_definition':
            name_node = next(child for child in node.children if child.type == 'identifier')
            func_name = source_code[name_node.start_byte:name_node.end_byte]
            
            full_name = f"{parent_name}.{func_name}" if parent_name else func_name
            
            blocks.append(CodeBlock(
                name=full_name,
                type="function",
                code=code,
                docstring=extract_docstring(node, source_code),
                file_path=file_path,
                start_line=node.start_point[0] + 1,
                end_line=node.end_point[0] + 1,
                parent=parent_name,
                category=category,
                id=id,
            ))

    # Start processing from the root
    for node in tree.root_node.children:
        process_node(node)
        
    return blocks

def analyze_directory(directory_path: str) -> List[CodeBlock]:
    """
    Analyze all Python files in a directory and extract code blocks.
    
    Args:
        directory_path: Path to the directory to analyze
        
    Returns:
        List of CodeBlock objects containing the extracted information
    """
    parser = setup_parser()
    all_blocks = []
    
    for root, _, files in os.walk(directory_path):
        for file in files:
            if file.endswith('.py'):
                file_path = os.path.join(root, file)
                try:
                    blocks = process_file(file_path, parser)
                    all_blocks.extend(blocks)
                except Exception as e:
                    print(f"Error processing {file_path}: {str(e)}")
                    
    return all_blocks


code_blocks = analyze_directory(repo_path)

In [11]:
import json

def prepare_code_blocks(code_blocks: List[CodeBlock]) -> Tuple[List[str], List[str], List[Dict]]:
    # texts, ids, metadatas
    texts = ["\n".join([c.name, c.type, c.code]) for c in code_blocks]
    ids = [hashlib.md5(json.dumps(c.model_dump()).encode("utf-8")).hexdigest() for c in code_blocks]
    metadatas = [{
        "file_path": c.file_path,
        "start_line": c.start_line,
        "end_line": c.end_line,
        "category": c.category,
        "type": c.type,
        "name": c.name,
    } for c in code_blocks]
    return texts, ids, metadatas


texts, ids, metadatas = prepare_code_blocks(code_blocks)

In [None]:
import boto3
import json
from typing import List, Optional
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm


MAX_CHARACTERS_PER_DOC = 2048

def embed_cohere(texts: List[str], model_id: str = "cohere.embed-english-v3", input_type: str = "search_query") -> List[List[float]]:
    runtime_client = boto3.client('bedrock-runtime', region_name='us-east-1')
        
    # `truncate` parameter does not seem to do anything
    payload = {
        "texts": [text[:MAX_CHARACTERS_PER_DOC] for text in texts],
        "input_type": input_type,
        "truncate": "END"
    }
    
    response = runtime_client.invoke_model(
        body=json.dumps(payload),
        modelId=model_id,
    )
    
    return json.loads(response['body'].read().decode())["embeddings"]


def process_batch(batch_with_index, model_id: str) -> tuple[int, List[List[float]]]:
    """
    Process a single batch of texts and return embeddings with the batch index.
    
    Args:
        batch_with_index (tuple): Tuple of (batch_index, texts)
        endpoint_name (str): Name of the SageMaker endpoint
    
    Returns:
        tuple: (batch_index, embeddings)
    """
    batch_index, batch_texts = batch_with_index
    try:
        batch_embeddings = embed_cohere(batch_texts, model_id, "search_document") 
        return batch_index, batch_embeddings
        
    except Exception as e:
        print(f"Error processing batch {batch_index}: {str(e)}")
        return batch_index, None

def get_embeddings_parallel(
    texts: List[str],
    model_id: str = "cohere.embed-english-v3",
    batch_size: int = 96,
    max_workers: int = 16,
    show_progress: bool = True
) -> Optional[List[List[float]]]:
    """
    Get embeddings for a list of texts using parallel processing.
    
    Args:
        texts (List[str]): List of strings to get embeddings for
        endpoint_name (str): Name of the SageMaker endpoint
        batch_size (int): Number of texts to process in each batch
        max_workers (int): Maximum number of parallel threads
        show_progress (bool): Whether to show progress bar
    
    Returns:
        Optional[List[List[float]]]: List of embeddings in the same order as input texts
    """
    # Create batches with their indices
    batches = [
        (batch_idx, texts[i:i + batch_size]) 
        for batch_idx, i in enumerate(range(0, len(texts), batch_size))
    ]
    
    # Initialize results storage
    results = {}
    all_embeddings = []
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all batches to the executor
        future_to_batch = {
            executor.submit(process_batch, batch, model_id): batch[0] 
            for batch in batches
        }
        
        # Process completed futures with optional progress bar
        futures_iterator = as_completed(future_to_batch)
        if show_progress:
            futures_iterator = tqdm(
                futures_iterator, 
                total=len(batches),
                desc="Processing batches"
            )
        
        # Collect results
        for future in futures_iterator:
            batch_index, batch_embeddings = future.result()
            if batch_embeddings is None:
                return None
            results[batch_index] = batch_embeddings

    # Combine results in correct order
    for i in range(0, len(batches)):
        all_embeddings.extend(results[i])
    
    return all_embeddings
        
    
embeddings = get_embeddings_parallel(texts)

In [7]:
from langchain_chroma import Chroma
from langchain_aws import BedrockEmbeddings

db_path = f"../.vectors/{instance_details['instance_id']}"


db = Chroma(
    collection_name=instance_details["instance_id"],
    persist_directory=db_path,
)

In [19]:
db._collection.count()

32476

In [17]:
os.path.exists(db_path)

True

In [12]:
db._collection.add(
    ids=ids,
    embeddings=embeddings,
    metadatas=metadatas,
)

In [16]:
query = "separability matrix"

query_vec = embed_cohere([query])

db._collection.query(query_vec)["metadatas"]

[[{'category': 'src',
   'end_line': 952,
   'file_path': '../.repos/sympy/sympy/matrices/sparse.py',
   'name': 'MutableSparseMatrix.copyin_matrix',
   'start_line': 928,
   'type': 'function'},
  {'category': 'tests',
   'end_line': 181,
   'file_path': '../.repos/sympy/sympy/printing/tests/test_mathematica.py',
   'name': 'test_matrices',
   'start_line': 150,
   'type': 'function'},
  {'category': 'src',
   'end_line': 130,
   'file_path': '../.repos/sympy/sympy/matrices/expressions/blockmatrix.py',
   'name': 'BlockMatrix.__new__',
   'start_line': 79,
   'type': 'function'},
  {'category': 'src',
   'end_line': 188,
   'file_path': '../.repos/sympy/sympy/matrices/immutable.py',
   'name': 'ImmutableSparseMatrix',
   'start_line': 123,
   'type': 'class'},
  {'category': 'src',
   'end_line': 300,
   'file_path': '../.repos/sympy/sympy/matrices/sparse.py',
   'name': 'SparseMatrix.__getitem__',
   'start_line': 258,
   'type': 'function'},
  {'category': 'src',
   'end_line': 116,