## Evaluation of AlphaGenome

In [None]:
import pandas as pd
pd.set_option('future.no_silent_downcasting', True)
import os
import numpy as np
from sklearn.metrics import average_precision_score, f1_score, precision_score, recall_score, confusion_matrix
import math
import pickle as pkl
from itertools import combinations
import csv
import logging
logger = logging.getLogger('main')
logger.setLevel(logging.INFO)
file_handler = logging.FileHandler('evaluation_alphagenome.log')
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s', "%m-%d %H:%M")
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

aligners=["star"]
samples=["J26675-L1_S1"]
chromosomes=["1","2","3","4","5","6","7","8","9","10","11","12","13","14","15","16","17","18","19","20","21","22","X","Y"]

# round std up and keep 2 significant digits
# round mean to same precision as std. dev., if more than 3 digits after comma use e notation
def log_round_mean_std(mean, std, tool, aligner, gt_confidence, score):
	nr_nks_std = int(f"{std:.1e}".split('e')[1][1:]) +  1
	std += round(0.1**(nr_nks_std))
	std = round(std, nr_nks_std)
	mean = round(mean, nr_nks_std)
	if std <= 0.009:
		logger.debug(f"{tool} {aligner} {gt_confidence} {score}: {mean} ± {std:.1E}")
	else:
		logger.debug(f"{tool} {aligner} {gt_confidence} {score}: {mean} ± {std}")


# save mean No Skill AUPRC +- 1 std. dev. (over all tools for scenario) in file
def calc_no_skill_auprc(tool_sjs, score_dir, aligner, gt_confidence):
	perc_positives = []
	for tool, (_, sjs) in enumerate(tool_sjs.items()):
		for _, sj in sjs.items():
			perc_positives.append(len(sj[sj['label']==1])/len(sj))
	# Since precision is constant for a No-Skill Classifier, AUPRC is the constant precision multiplied by the range of recall (which is 1)
	no_skill_auprc = np.mean(perc_positives)
	no_skill_std = np.std(perc_positives)
	log_round_mean_std(no_skill_auprc, no_skill_std, 'No Skill', aligner, gt_confidence, 'AUPRC')
	out_file = f'{score_dir}/No_Skill_auprcs.csv'
	if (not os.path.isfile(out_file)):
		with open(out_file, 'w') as file:
			file.write(f'aligner,gt_confidence,tool,mean_auprc,std_auprc\n')
	with open(out_file, 'a') as f:
		f.write(f"{aligner},{gt_confidence},No Skill,{no_skill_auprc},{no_skill_std}\n")


def calc_f1_score_at_threshold(sjs, score_dir, aligner, gt_confidence, threshold=0.5):
	threshold = round(threshold, 1) # round threshold to 1 decimal place
	threshold_ = threshold 
	f1_scores = {}
	for ann, sj in sjs.items():
		tool, _,_ = ann
		# f1 score at given threshold
		if tool == 'deepsplice':
			threshold_ = threshold * 1.7052 # scale DeepSplice scores to [0,1] range
		f1_scores[ann] = f1_score(sj['label'], sj['pred'] >= threshold_)
	mean_f1_score = np.mean(list(f1_scores.values()))
	std_f1_score = np.std(list(f1_scores.values()))
	# if out_mean_file does not exist, create it and write header
	out_mean_file = f'{score_dir}/mean_std_f1_scores.csv'
	if (not os.path.isfile(out_mean_file)):
		with open(out_mean_file, 'w') as file:
			file.write(f'aligner,gt_confidence,tool,mean_f1_score,std_f1_score,threshold\n')
	with open(out_mean_file, 'a') as file:
		file.write(f'{aligner},{gt_confidence},{tool},{mean_f1_score},{std_f1_score},{threshold}\n')
	log_round_mean_std(mean_f1_score, std_f1_score, tool, aligner, gt_confidence, f'F1 Score at {threshold}')


# save mean AUPRC (area under the precision recall curve) scores +- 1 std. dev. and number positives and number negatives +- 1 std. dev. in file 
def calc_auprc(sjs, score_dir, aligner, gt_confidence):
	scores = {}
	nr_positives = []
	nr_negatives = []
	for ann, sj in sjs.items():
		tool, _,_ = ann
		score = average_precision_score(sj['label'], sj['pred'])
		scores[ann] = score
		nr_positives.append(len(sj[sj["label"]==1]))
		nr_negatives.append(len(sj[sj["label"]==0]))
	mean_auprc = np.mean(list(scores.values()))
	auprc_std = np.std(list(scores.values()))
	# if out_file does not exist, create it and write header
	out_file = f'{score_dir}/{tool}_{aligner}_{gt_confidence}_auprcs.csv'
	with open(out_file, 'w') as file:
		file.write(f'tool,sample_id,run_nr,auprc\n')
	with open(out_file, 'a') as file:
		for ann, score in scores.items():
			file.write(','.join(ann)+','+str(score)+'\n')
	# if out_mean_file does not exist, create it and write header
	out_mean_file = f'{score_dir}/mean_std_auprcs.csv'
	if (not os.path.isfile(out_mean_file)):
		with open(out_mean_file, 'w') as file:
			file.write(f'aligner,gt_confidence,tool,mean_auprc,std_auprc\n')
	with open(out_mean_file, 'a') as file:
		file.write(f'{aligner},{gt_confidence},{tool},{mean_auprc},{auprc_std}\n')
	log_round_mean_std(mean_auprc, auprc_std, tool, aligner, gt_confidence, 'AUPRC')
	# if out_nr_file does not exist, create it and write header
	out_nr_file = f'{score_dir}/mean_nr_positives_negatives.csv'
	if (not os.path.isfile(out_nr_file)):
		with open(out_nr_file, 'w') as file:
			file.write(f'aligner,gt_confidence,tool,mean_nr_positives,std_nr_positives,mean_nr_negatives,std_nr_negatives\n')
	with open(out_nr_file, 'a') as file:
		file.write(f'{aligner},{gt_confidence},{tool},{math.floor(np.mean(nr_positives))},{math.floor(np.std(nr_positives))},{math.floor(np.mean(nr_negatives))},{math.floor(np.std(nr_negatives))}\n')


# save effectsize in file 
def write_effectsize_ci(tool_sjs, stats_dir, aligner, gt_confidence, n_bootstrap=1000):
    TOOL_PLOT_NAME = {
		'alphagenome' :'AlphaGenome',
        'spliceai': 'SpliceAI',
        'deepsplice': 'DeepSplice',
        'jcc': 'JCC',
        'baseline': 'No-Skill'
    }
    # Tools to use (excluding baseline)
    tool_names = [key for key in TOOL_PLOT_NAME if key in tool_sjs]

    # Precompute bootstrap indices for every sample/run
    bootstrap_indices = {}
    lens = {}
    for (_, sample, run_id), s in tool_sjs[tool_names[1]].items():
        n = s.shape[0]
        bootstrap_indices[(sample, run_id)] = np.random.randint(0, n, size=(n_bootstrap, n))
        lens[(sample, run_id)] = n

    # Precompute per-tool per-(sample,run) AUPRCs for each bootstrap
    per_tool_auprcs = {t: {} for t in tool_names}
    for t in tool_names:
        for (_, sample, run_id), s in tool_sjs[t].items():
            indices = bootstrap_indices[(sample, run_id)]
            labels = s['label'].values
            preds = s['pred'].values
            boot_auprcs = np.array([
                average_precision_score(labels[idx], preds[idx]) for idx in indices
            ])
            per_tool_auprcs[t][(sample, run_id)] = boot_auprcs

    # Precompute baseline auprcs:
    per_baseline_auprcs = {}
    for (_, sample, run_id), s in tool_sjs[tool_names[1]].items():
        indices = bootstrap_indices[(sample, run_id)]
        labels = s['label'].values
        n = len(labels)
        # Baseline: predict all 1's
        preds_baseline = np.ones_like(labels)
        boot_auprcs = np.array([
            average_precision_score(labels[idx], preds_baseline[idx]) for idx in indices
        ])
        per_baseline_auprcs[(sample, run_id)] = boot_auprcs

    results_rows = []

    # Compare each tool to baseline
    for t in tool_names:
        t_name = TOOL_PLOT_NAME[t]
        for (sample, run_id) in per_tool_auprcs[t]:
            tool_auprcs = per_tool_auprcs[t][(sample, run_id)]
            baseline_auprcs = per_baseline_auprcs[(sample, run_id)]
            bootstrapped_deltas = tool_auprcs - baseline_auprcs
            mean_delta = np.mean(bootstrapped_deltas)
            ci_lower, ci_upper = np.percentile(bootstrapped_deltas, [2.5, 97.5])
            row = [aligner, gt_confidence, sample, run_id, t_name, TOOL_PLOT_NAME['baseline'], mean_delta, ci_lower, ci_upper]
            results_rows.append(row)

    # Compare all pairs of tools (excluding baseline)
    for t1, t2 in combinations(tool_names, 2):
        n1 = TOOL_PLOT_NAME[t1]
        n2 = TOOL_PLOT_NAME[t2]
        for (sample, run_id) in per_tool_auprcs[t1]:
            auprc1 = per_tool_auprcs[t1][(sample, run_id)]
            auprc2 = per_tool_auprcs[t2][(sample, run_id)]
            bootstrapped_deltas = auprc1 - auprc2
            mean_delta = np.mean(bootstrapped_deltas)
            ci_lower, ci_upper = np.percentile(bootstrapped_deltas, [2.5, 97.5])
            row = [aligner, gt_confidence, sample, run_id, n1, n2, mean_delta, ci_lower, ci_upper]
            results_rows.append(row)

    # Write to CSV
    csv_file_path = f'{stats_dir}/effect_sizes_new_bootstrap.csv'
    write_header = not os.path.exists(csv_file_path) or os.stat(csv_file_path).st_size == 0
    with open(csv_file_path, mode='a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        if write_header:
            headers = ['aligner', 'gt_confidence', 'sample', 'run_id', 'tool1', 'tool2', 'mean_delta', 'ci_lower', 'ci_upper']
            writer.writerow(headers)
        writer.writerows(results_rows)


def write_operating_points(y_true, y_scores, stats_dir, tool, aligner, gt_confidence):
	# At fixed thresholds
	thresholds = np.array([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
	summary_file = f"{stats_dir}/operating_points.csv"
	if (not os.path.isfile(summary_file)):
		with open(summary_file, 'w') as file:
			file.write(f'tool,aligner,gt_confidence,threshold,precision,recall,TP,FP,FN,TN\n')
	with open(summary_file, 'a') as f:
		for threshold in thresholds:
			pred_at_thresh = (y_scores >= threshold).astype(int)
			precision_fixed = precision_score(y_true, pred_at_thresh, zero_division=0)
			recall_fixed = recall_score(y_true, pred_at_thresh, zero_division=0)
			tn, fp, fn, tp = confusion_matrix(y_true, pred_at_thresh).ravel()
			f.write(f"{tool},{aligner},{gt_confidence},{threshold},{precision_fixed},{recall_fixed},{tp},{fp},{fn},{tn}\n")


# calculate performance for Scenario 2: "Predicting junctions that could be detected with higher sequencing depth" / Scenario 3: "Predicting hard-to-find junctions"
# for Real-world / Hypothetical setting
def scenario_2_3(aligner, gt_confidence, stats_dir, scenario, hard_to_find, file_sj_50):
	os.makedirs(stats_dir, exist_ok=True)
	if os.path.isfile(file_sj_50):
		with open(file_sj_50, 'rb') as f:
			tool_sjs_50 = pkl.load(f)
		logger.debug(f'Loaded {file_sj_50} from disk.')
	else:
		logger.error(f'No file {file_sj_50}')
		exit()
	for tool, sjs in tool_sjs_50.items():	
		all_labels = np.concatenate([sj["label"] for sj in sjs.values()]) # append all samples
		all_preds = np.concatenate([sj["pred"] for sj in sjs.values()]) # append all samples
		write_operating_points(all_labels, all_preds, stats_dir, tool, aligner, gt_confidence)
		calc_auprc(sjs, stats_dir, aligner, gt_confidence)
		for threshold in np.arange(0, 1.1, 0.1):
			calc_f1_score_at_threshold(sjs, stats_dir, aligner, gt_confidence, threshold=threshold)
	calc_no_skill_auprc(tool_sjs_50, stats_dir, aligner, gt_confidence)
	write_effectsize_ci(tool_sjs_50, stats_dir, aligner, gt_confidence)


main_dir="." #TODO adjust path
out_dir = f'{main_dir}/out'
data_dir = f'{main_dir}/data'
pred_dir = f'{data_dir}/predictions'
stats_dir = f'{out_dir}/stats_alphagenome'
os.makedirs(stats_dir, exist_ok = True)

for hard_to_find in [True]:
	for scenario in ['hypothetical']:
		for aligner, gt_confidence in [('star','unfiltered'),('star','illumina'),('star','cutoff')]:
			logger.info(f'--------- RUN Scenario {3 if hard_to_find else 2} {scenario} vs GT {aligner} {gt_confidence} ---------')
			stats_dir_scenario = f'{stats_dir}/scenario_{3 if hard_to_find else 2}_{scenario}'
			os.makedirs(stats_dir_scenario, exist_ok = True)
			file_sj_50 = f'{stats_dir_scenario}/tool_sjs_50_{aligner}_vs_GT_{aligner}_{gt_confidence}.pkl'
			scenario_2_3(aligner, gt_confidence, stats_dir, scenario, hard_to_find, file_sj_50)