In [1]:
import io
import os
import shutil
import tempfile
import logging
import zipfile
from typing import List, Tuple
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing

import kaggle_evaluation.konwinski_prize_inference_server
import numpy as np

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

instance_count = None

def get_number_of_instances(num_instances: int) -> None:
    global instance_count
    instance_count = num_instances
    logger.info(f"Total number of instances: {num_instances}")

def predict(problem_statement: str, repo_archive: io.BytesIO) -> str:
    try:
        logger.info("Starting prediction for new issue")
        
        # Unpack the repo archive
        with tempfile.TemporaryDirectory() as temp_dir:
            repo_path = os.path.join(temp_dir, 'repo')
            with open(os.path.join(temp_dir, 'repo_archive.tar'), 'wb') as f:
                f.write(repo_archive.read())
            shutil.unpack_archive(os.path.join(temp_dir, 'repo_archive.tar'), extract_dir=repo_path)

            # Analyze the problem statement and generate a patch
            patch = generate_patch(repo_path, problem_statement)
            logger.info("Patch generated successfully")

        return patch
    except Exception as e:
        logger.error(f"Error in predict function: {str(e)}")
        return ""  # Return empty string in case of error to skip the issue

def generate_patch(repo_path: str, problem_statement: str) -> str:
    # Analyze the problem statement
    issue_type = analyze_issue(problem_statement)
    
    # Find relevant files
    relevant_files = find_relevant_files(repo_path, issue_type)
    
    # Generate patch using multi-threading
    with ThreadPoolExecutor(max_workers=multiprocessing.cpu_count()) as executor:
        future_to_file = {executor.submit(process_file, file, issue_type, problem_statement): file for file in relevant_files}
        patches = []
        for future in as_completed(future_to_file):
            file = future_to_file[future]
            try:
                patch = future.result()
                if patch:
                    patches.append(patch)
            except Exception as e:
                logger.error(f"Error processing file {file}: {str(e)}")
    
    return ''.join(patches)

def process_file(file: str, issue_type: str, problem_statement: str) -> str:
    try:
        with open(file, 'r') as f:
            content = f.read()
        
        new_content = modify_content(content, issue_type, problem_statement)
        
        if new_content != content:
            return create_diff(file, content, new_content)
        return ""
    except Exception as e:
        logger.error(f"Error processing file {file}: {str(e)}")
        return ""

def analyze_issue(problem_statement: str) -> str:
    # Simple keyword-based analysis
    keywords = {
        'bug': ['bug', 'error', 'fix', 'issue', 'problem', 'crash'],
        'feature': ['feature', 'add', 'implement', 'new', 'enhance'],
        'performance': ['performance', 'slow', 'speed', 'optimize', 'efficient'],
    }
    
    problem_statement = problem_statement.lower()
    scores = {category: sum(keyword in problem_statement for keyword in words) 
              for category, words in keywords.items()}
    
    if max(scores.values()) == 0:
        return 'other'
    return max(scores, key=scores.get)

def find_relevant_files(repo_path: str, issue_type: str) -> List[str]:
    relevant_files = []
    for root, _, files in os.walk(repo_path):
        for file in files:
            if file.endswith('.py'):
                file_path = os.path.join(root, file)
                try:
                    with open(file_path, 'r') as f:
                        content = f.read()
                    if issue_type in content.lower():
                        relevant_files.append(file_path)
                except Exception as e:
                    logger.error(f"Error reading file {file_path}: {str(e)}")
    return relevant_files[:5]  # Limit to top 5 relevant files

def modify_content(content: str, issue_type: str, problem_statement: str) -> str:
    if issue_type == 'bug':
        # Add error handling
        lines = content.split('\n')
        new_lines = []
        for line in lines:
            if line.strip().startswith('def '):
                new_lines.append(line)
                new_lines.append('    try:')
                new_lines.append('        ' + '\n        '.join(lines[lines.index(line)+1:]))
                new_lines.append('    except Exception as e:')
                new_lines.append('        logger.error(f"Error: {str(e)}")')
                break
            new_lines.append(line)
        return '\n'.join(new_lines)
    elif issue_type == 'feature':
        # Add TODO comment
        return content + f"\n\n# TODO: Implement new feature - {problem_statement}\n"
    elif issue_type == 'performance':
        # Add simple caching
        return f"from functools import lru_cache\n\n@lru_cache(maxsize=None)\n{content}"
    else:
        # Add a comment
        return f"# Addressing issue: {problem_statement}\n{content}"

def create_diff(file_path: str, old_content: str, new_content: str) -> str:
    import difflib
    diff = difflib.unified_diff(
        old_content.splitlines(keepends=True),
        new_content.splitlines(keepends=True),
        fromfile=file_path,
        tofile=file_path
    )
    return ''.join(diff)

def unpack_data():
    data_path = '/kaggle/input/konwinski-prize/data.a_zip'
    extract_path = '/kaggle/working/data'
    
    try:
        with zipfile.ZipFile(data_path, 'r') as zip_ref:
            zip_ref.extractall(extract_path)
        logger.info(f"Data extracted to {extract_path}")
    except Exception as e:
        logger.error(f"Error unpacking data: {str(e)}")

inference_server = kaggle_evaluation.konwinski_prize_inference_server.KPrizeInferenceServer(
    get_number_of_instances,   
    predict
)

if __name__ == "__main__":
    unpack_data()
    if os.getenv('KAGGLE_IS_COMPETITION_RERUN'):
        inference_server.serve()
    else:
        inference_server.run_local_gateway(
            data_paths=(
                '/kaggle/input/konwinski-prize/',
                '/kaggle/working/data/',
            )
        )



Forcing micromamba reinstallation to mitigate issues with older images.
