In [None]:
import pandas as pd
import numpy as np
from openai import OpenAI
import os
from dotenv import load_dotenv
import time
from tqdm import tqdm
import json
from pathlib import Path
import concurrent.futures
import asyncio
import aiohttp
import backoff

# Load environment variables
load_dotenv()
deepseek_api = os.getenv("DEEPSEEK_API_KEY")
client = OpenAI(api_key=deepseek_api, base_url="https://api.deepseek.com")

CACHE_FILE = 'prediction_cache.json'
BATCH_SIZE = 50  # Increased batch size

def load_cache():
    if Path(CACHE_FILE).exists():
        with open(CACHE_FILE, 'r') as f:
            return json.load(f)
    return {}

def save_cache(cache):
    with open(CACHE_FILE, 'w') as f:
        json.dump(cache, f)

def create_prompt(row):
    # ... existing prompt creation code ...
    return f"""Analyze this unit test and determine if it's a good test case.
    
Function being tested:
{row['function_python_code']}

Original function call:
- Call: {row['function_string_call']}
- Parameters: {row['function_call_parameters']}
- Output: {row['function_call_output']}

Unit test:
- Test parameters: {row['unit_test_parameters']}
- Expected output: {row['unit_test_output']}
- Assertion: {row['unit_test_assertion']}

Is this a good test case? Consider:
1. Do the test parameters match the function's requirements?
2. Is the expected output correct for these parameters?
3. Does the assertion correctly validate the function?

Respond with only 1 (good test) or 0 (bad test).
"""

@backoff.on_exception(backoff.expo, 
                     (aiohttp.ClientError, asyncio.TimeoutError),
                     max_tries=5)
async def get_llm_prediction_async(prompt, session):
    try:
        async with session.post(
            "https://api.deepseek.com/v1/chat/completions",
            json={
                "model": "deepseek-chat",
                "messages": [{"role": "user", "content": prompt}],
                "temperature": 0,
                "max_tokens": 1
            },
            headers={"Authorization": f"Bearer {deepseek_api}"}
        ) as response:
            result = await response.json()
            return int(result['choices'][0]['message']['content'].strip())
    except Exception as e:
        print(f"Error getting prediction: {e}")
        return 0

async def process_batch_async(df, cache):
    predictions = []
    uncached_rows = []
    row_indices = []
    
    # Check cache first
    for idx, row in df.iterrows():
        cache_key = f"{row['function_string_call']}_{row['unit_test_parameters']}"
        if cache_key in cache:
            predictions.append(cache[cache_key])
        else:
            uncached_rows.append(row)
            row_indices.append(idx)
    
    if uncached_rows:
        async with aiohttp.ClientSession() as session:
            tasks = []
            for row in uncached_rows:
                prompt = create_prompt(row)
                tasks.append(get_llm_prediction_async(prompt, session))
            
            # Process uncached rows concurrently
            uncached_predictions = await asyncio.gather(*tasks)
            
            # Update cache with new predictions
            for i, pred in enumerate(uncached_predictions):
                cache_key = f"{uncached_rows[i]['function_string_call']}_{uncached_rows[i]['unit_test_parameters']}"
                cache[cache_key] = pred
                predictions.append(pred)
    
    return predictions

async def main_async():
    # Load data and cache
    train_df = pd.read_csv('python-code-unit-test-assertion-quality-prediction/train.csv')
    test_df = pd.read_csv('python-code-unit-test-assertion-quality-prediction/test.csv')
    cache = load_cache()
    
    print("Processing test data...")
    all_predictions = []
    
    # Process in larger batches
    for i in tqdm(range(0, len(test_df), BATCH_SIZE)):
        batch = test_df.iloc[i:i + BATCH_SIZE]
        predictions = await process_batch_async(batch, cache)
        all_predictions.extend(predictions)
        
        # Save cache periodically
        if i % (BATCH_SIZE * 5) == 0:
            save_cache(cache)
    
    # Final cache save
    save_cache(cache)
    
    # Create submission
    submission_df = pd.DataFrame({
        'id': test_df['id'],
        'y_pred_unit_test_parameters_match': all_predictions
    })
    
    submission_df.to_csv('submission.csv', index=False)
    print("Submission saved!")

# For Jupyter notebook execution
await main_async()

Processing test data...


100%|██████████| 20/20 [02:21<00:00,  7.10s/it]

Submission saved!



