In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from scipy import stats
import os, json
import seaborn as sns

def load_and_process_data(benchmark_path: str, results_path: str):
    with open(benchmark_path) as f:
        benchmark = json.load(f)
    with open(results_path) as f:
        results = json.load(f)
        
    df_benchmark = pd.DataFrame.from_dict(benchmark, orient='index')
    
    processed_results = {}
    for pdb_id, data in results.items():
        try:
            processed_results[pdb_id] = {
                'ba_val': data['ba_val'],
                #'kd': data['kd'],
                'CC': data['contacts']['CC'],
                'CP': data['contacts']['CP'],
                'AC': data['contacts']['AC'],
                'PP': data['contacts']['PP'],
                'AP': data['contacts']['AP'],
                'AA': data['contacts']['AA'],
                'nis_p': data['nis']['polar'],
                'nis_a': data['nis']['aliphatic'],
                'nis_c': data['nis']['charged'],
                'execution_time': data['execution_time']["seconds"]
            }
        except KeyError as e:
            print(f"Warning: Missing data for {pdb_id}: {e}")
            continue
            
    df_results = pd.DataFrame.from_dict(processed_results, orient='index')
    
    return df_benchmark, df_results

def calculate_correlations(df_benchmark: pd.DataFrame, df_results: pd.DataFrame):
    common_pdbs = sorted(set(df_benchmark.index) & set(df_results.index))
    print(f"Common PDB entries: {len(common_pdbs)}")
    
    metrics = {
        'ba_val': 'Binding Affinity',
        'CC': 'Charged-Charged contacts',
        'CP': 'Charged-Polar contacts',
        'AC': 'Aliphatic-Charged contacts',
        'PP': 'Polar-Polar contacts',
        'AP': 'Aliphatic-Polar contacts',
        'AA': 'Aliphatic-Aliphatic contacts',
        'nis_p': 'NIS Polar',
        'nis_a': 'NIS Aliphatic',
        'nis_c': 'NIS Charged'
    }
    
    correlations = []
    for metric in metrics:
        if metric in df_benchmark.columns and metric in df_results.columns:
            bench_vals = df_benchmark.loc[common_pdbs, metric]
            result_vals = df_results.loc[common_pdbs, metric]
            pearson = stats.pearsonr(bench_vals, result_vals)
            rmse = np.sqrt(np.mean((bench_vals - result_vals) ** 2))
            correlations.append({
                'Metric': metrics[metric],
                'Pearson r': pearson[0],
                'p-value': pearson[1],
                'RMSE': rmse
            })
    
    return pd.DataFrame(correlations)

def plot_correlations(df_benchmark: pd.DataFrame, df_results: pd.DataFrame, output_dir: str):
    os.makedirs(output_dir, exist_ok=True)
    common_pdbs = sorted(set(df_benchmark.index) & set(df_results.index))
    
    metrics = {
        'ba_val': 'Binding Affinity',
        'CC': 'Charged-Charged contacts',
        'CP': 'Charged-Polar contacts',
        'AC': 'Aliphatic-Charged contacts',
        'PP': 'Polar-Polar contacts',
        'AP': 'Aliphatic-Polar contacts',
        'AA': 'Aliphatic-Aliphatic contacts',
        'nis_p': 'NIS Polar',
        'nis_a': 'NIS Aliphatic',
        'nis_c': 'NIS Charged'
    }
    
    fig, axes = plt.subplots(2, 5, figsize=(15, 10))
    axes = axes.ravel()
    
    for i, (metric, title) in enumerate(metrics.items()):
        if i < len(axes):
            bench_vals = df_benchmark.loc[common_pdbs, metric]
            result_vals = df_results.loc[common_pdbs, metric]
            
            pearson = stats.pearsonr(bench_vals, result_vals)[0]
            
            ax = axes[i]
            ax.scatter(bench_vals, result_vals, alpha=0.6)
            ax.plot([min(bench_vals), max(bench_vals)], 
                   [min(bench_vals), max(bench_vals)], 'r--')
            
            ax.set_xlabel('Prodigy ORG')
            ax.set_ylabel('Prodigy JAX')
            ax.set_title(f'{title}\nr = {pearson:.3f}')
    
    plt.tight_layout()
    plt.savefig(f'{output_dir}/correlations.png', dpi=300, bbox_inches='tight')
    plt.show()

def add_sequence_lengths(df: pd.DataFrame, pdb_folder: str):
    lengths = {}
    for pdb_id in df.index:
        try:
            pdb_file = os.path.join(pdb_folder, f"{pdb_id}.pdb")
            if os.path.exists(pdb_file):
                with open(pdb_file, 'r') as f:
                    lines = f.readlines()
                    # Count ATOM lines for chain A and B
                    chain_a = sum(1 for line in lines if line.startswith('ATOM') and line[21] == 'A')
                    chain_b = sum(1 for line in lines if line.startswith('ATOM') and line[21] == 'B')
                    # Divide by typical number of atoms per residue (usually around 8-10)
                    lengths[pdb_id] = {
                        'chain_a_length': chain_a // 8,
                        'chain_b_length': chain_b // 8
                    }
        except Exception as e:
            print(f"Error processing {pdb_id}: {e}")
            
    # Add to DataFrame
    length_df = pd.DataFrame.from_dict(lengths, orient='index')
    return pd.concat([df, length_df], axis=1)

def compare_sasa_results(gpu_dir: str, cpu_dir: str):
    gpu_path = Path(gpu_dir)
    cpu_path = Path(cpu_dir)
    all_comparisons = []
    all_sasa_values = []
    
    for protein_dir in gpu_path.glob("*"):
        print("")
        if not protein_dir.is_dir():
            continue
            
        protein_name = protein_dir.name
        gpu_csv = list(protein_dir.rglob("*.csv"))
        cpu_csv = list((cpu_path / protein_name).rglob("*.csv"))
        
        if not gpu_csv or not cpu_csv:
            continue
        
        gpu_data = pd.read_csv(gpu_csv[0])
        gpu_data.resid = gpu_data.resid.astype(int)
        gpu_data = gpu_data.sort_values(['chain', "resname", 'resid', 'atom'])
        cpu_data = pd.read_csv(cpu_csv[0])
        cpu_data.resid = cpu_data.resid.astype(int)
        cpu_data = cpu_data.sort_values(['chain', "resname", 'resid', 'atom'])

        if len(gpu_data) != len(cpu_data):
            print(f"Length mismatch in {protein_name}: GPU={len(gpu_data)}, CPU={len(cpu_data)}")
            continue
        
        comparison = pd.DataFrame({
          'sasa_cpu': cpu_data['sasa'].values,
          'sasa_gpu': gpu_data['sasa'].values,
          'diff': abs(cpu_data['sasa'].values - gpu_data['sasa'].values),
          'protein': protein_name,
          'chain_gpu': gpu_data['chain'].values,
          'resname_gpu': gpu_data['resname'].values, 
          'resid_gpu': gpu_data['resid'].values,
          'atom_gpu': gpu_data['atom'].values,
          'chain_cpu': cpu_data['chain'].values,
          'resname_cpu': cpu_data['resname'].values,
          'resid_cpu': cpu_data['resid'].values, 
          'atom_cpu': cpu_data['atom'].values
        })
        
        all_sasa_values.append(comparison)
        
        rmse = np.sqrt(np.mean(comparison['diff']**2))
        correlation = stats.pearsonr(comparison['sasa_cpu'], comparison['sasa_gpu'])[0]
        
        all_comparisons.append({
            'protein': protein_name,
            'rmse': rmse,
            'correlation': correlation,
            'mean_diff': comparison['diff'].mean(),
            'max_diff': comparison['diff'].max(),
            'num_atoms': len(comparison),
            'num_nonzero': len(comparison[comparison['sasa_gpu'] > 0])
        })
    
    summary_df = pd.DataFrame(all_comparisons)
    all_sasa_df = pd.concat(all_sasa_values)
    
    # Add high_rmse column to all_sasa_df
    high_rmse_proteins = set(summary_df[summary_df['rmse'] > 2]['protein'])
    all_sasa_df['high_rmse'] = all_sasa_df['protein'].isin(high_rmse_proteins)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # RMSE scatter plot
    ax1.scatter(summary_df['num_atoms'], summary_df['rmse'], alpha=0.6)
    for i, txt in enumerate(summary_df['protein']):
        if summary_df['rmse'].iloc[i] > 2:
            ax1.annotate(txt, (summary_df['num_atoms'].iloc[i], summary_df['rmse'].iloc[i]))
    ax1.set_xlabel('Number of Atoms')
    ax1.set_ylabel('RMSE (Å²)')
    ax1.set_title('RMSE GPU and CPU vs Structure Size')
    
    # SASA values comparison with color coding
    normal_points = all_sasa_df[~all_sasa_df['high_rmse']]
    high_rmse_points = all_sasa_df[all_sasa_df['high_rmse']]
    
    ax2.scatter(normal_points['sasa_cpu'], normal_points['sasa_gpu'], alpha=0.1, color='blue')
    
    max_val = max(all_sasa_df['sasa_cpu'].max(), all_sasa_df['sasa_gpu'].max())
    ax2.plot([0, max_val], [0, max_val], 'k--')
    ax2.set_xlabel('CPU SASA (Å²)')
    ax2.set_ylabel('GPU SASA (Å²)')
    ax2.set_title('CPU vs GPU SASA Values')
    ax2.legend()
    
    plt.tight_layout()
    plt.savefig('sasa_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    return summary_df, all_sasa_df

In [None]:
benchmark = "./benchmark_af/dataset.json"
results = "./benchmark_af/20250126_140901_gpu/combined_results.json"
output = "output_comp"
dataset = "./benchmark_af/PRODIGYdataset/" # make sure you have the dataset
os.makedirs(output, exist_ok=True)
df_benchmark, df_results = load_and_process_data(benchmark, results)
correlations = calculate_correlations(df_benchmark, df_results)
print("\nCorrelation Analysis:")
print(correlations.to_string(index=False))
correlations.to_csv(f'{output}/correlations.csv', index=False)
plot_correlations(df_benchmark, df_results, output)

# Save processed DataFrames
df_benchmark.to_csv(f'{output}/benchmark_processed.csv')
df_results.to_csv(f'{output}/results_processed.csv')

df = add_sequence_lengths(df_results, dataset)
df['total_length'] = df['chain_a_length'] + df['chain_b_length']

plt.figure(figsize=(10, 6))
plt.scatter(df['total_length'], df['execution_time'])

# Add trend line
z = np.polyfit(df['total_length'], df['execution_time'], 1)
p = np.poly1d(z)
plt.plot(df['total_length'], p(df['total_length']), "r--", alpha=0.8)

# Calculate correlation
corr = df['total_length'].corr(df['execution_time'])

plt.xlabel('Total Sequence Length (residues)')
plt.ylabel('Execution Time (s)')
plt.title(f'Execution Time vs Sequence Length\nCorrelation: {corr:.3f}')
plt.grid(True, alpha=0.3)
summary_df, all_sasa_df = compare_sasa_results("./benchmark_af/20250126_140901_gpu", "./benchmark_af/20250127_160612_cpu")


In [None]:
from bio_lib.custom_prodigy import predict_binding_affinity

predict_binding_affinity("/Users/alessio/Documents/Repos/bio_lib/3bzd.pdb", save_results=True)