In [5]:
# Cell 1: Imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Cell 2: Fast Data Loading (No SMILES validation)
def load_data_fast(ground_truth_csv, ts_results_csv, k=10, maximize=True):
    """Fast loading without SMILES validation"""
    
    print("Loading ground truth...")
    gt = pd.read_csv(ground_truth_csv)
    print(f"Ground truth: {len(gt):,} molecules")
    
    print("Loading TS results...")
    ts = pd.read_csv(ts_results_csv)
    print(f"TS results: {len(ts):,} molecules")
    
    # Simple cleanup - just remove NaNs and "FAIL"
    gt = gt.dropna(subset=['score']).copy()
    gt = gt[gt['SMILES'] != 'FAIL']
    print(f"Clean ground truth: {len(gt):,} molecules")
    
    ts = ts.dropna(subset=['score']).copy() 
    ts = ts[ts['SMILES'] != 'FAIL']
    print(f"Clean TS results: {len(ts):,} molecules")
    
    # Get top k from ground truth
    top_k_gt = gt.nlargest(k, 'score') if maximize else gt.nsmallest(k, 'score')
    top_k_smiles = set(top_k_gt['SMILES'])
    top_k_avg = top_k_gt['score'].mean()
    
    print(f"\\nTop {k} ground truth average: {top_k_avg:.4f}")
    print(f"Score range: {top_k_gt['score'].min():.4f} to {top_k_gt['score'].max():.4f}")
    
    return {
        'ts_results': ts,
        'top_k_smiles': top_k_smiles, 
        'top_k_avg': top_k_avg,
        'k': k,
        'maximize': maximize
    }

# Cell 3: Load Data
ground_truth_file = "runs/scored_enumerated_TS_example_molecules.csv"
ts_results_file = "runs/results_example.csv"

data = load_data_fast(
    ground_truth_csv=ground_truth_file,
    ts_results_csv=ts_results_file,
    k=10,
    maximize=True
)
# Cell 4: Fast Analysis Using Actual Batch Numbers
def analyze_fast(data):
    """Fast batch analysis using actual batch numbers from CSV"""
    
    ts = data['ts_results']
    top_k_smiles = data['top_k_smiles']
    k = data['k']
    maximize = data['maximize']
    
    # Group by actual batch numbers
    unique_batches = sorted(ts['batch'].unique())
    batch_nums = []
    recovery = []
    avg_scores = []
    found_smiles = set()
    
    for batch_num in unique_batches:
        # Get all data up to and including this batch
        current_data = ts[ts['batch'] <= batch_num].copy()
        
        # Add new SMILES from this batch to found set
        batch_smiles = set(ts[ts['batch'] == batch_num]['SMILES'])
        found_smiles.update(batch_smiles)
        
        # Count recoveries
        recovery.append(len(found_smiles & top_k_smiles))
        
        # Current top k average from all data so far
        if maximize:
            current_top_k = current_data.nlargest(k, 'score')
        else:
            current_top_k = current_data.nsmallest(k, 'score')
        avg_scores.append(current_top_k['score'].mean())
        
        batch_nums.append(batch_num)
    
    return batch_nums, recovery, avg_scores

# Update the function call
batch_nums, recovery, avg_scores = analyze_fast(data)

# Cell 5: Plot Recovery
plt.figure(figsize=(10, 5))
plt.plot(batch_nums, recovery, 'b-', linewidth=2, marker='o', markersize=3)
plt.axhline(y=data['k'], color='red', linestyle='--', label=f'Perfect ({data["k"]})')
plt.xlabel('Batch Number')
plt.ylabel(f'Top {data["k"]} Recovered')
plt.title(f'Recovery Over Time')
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

# Cell 6: Plot Quality
plt.figure(figsize=(10, 5))
plt.plot(batch_nums, avg_scores, 'g-', linewidth=2, marker='s', markersize=3, label='TS Average')
plt.axhline(y=data['top_k_avg'], color='black', linestyle='--', linewidth=2, label='Ground Truth')
plt.xlabel('Batch Number')
plt.ylabel('Average Score')
plt.title(f'Quality Over Time')
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

# Cell 7: Results
print(f"Final recovery: {recovery[-1]}/{data['k']} ({100*recovery[-1]/data['k']:.1f}%)")
print(f"Final TS average: {avg_scores[-1]:.4f}")
print(f"Ground truth average: {data['top_k_avg']:.4f}")
print(f"Efficiency: {avg_scores[-1]/data['top_k_avg']:.3f}")
print(f"Batches analyzed: {len(batch_nums)} (batch 0 = warmup, batch 1+ = search)")

Loading ground truth...
Ground truth: 3,590,000 molecules
Loading TS results...


ParserError: Error tokenizing data. C error: Expected 3 fields in line 27698, saw 4
