For producing the figures you will need the conda environment:

conda env create -f eval.yml

And you will need to first run:

python evaluation_scenarios.py

In [None]:
import numpy as np
import pandas as pd
pd.set_option('future.no_silent_downcasting', True)
import matplotlib.pyplot as plt
from upsetplot import UpSet
from matplotlib_venn import venn3, venn2
from sklearn.metrics import average_precision_score, precision_recall_curve
from scipy.stats import wilcoxon
from itertools import combinations
import pickle
import os
from pyfaidx import Fasta
import re
main_dir = "." # TODO set working directory
data_dir = f"{main_dir}/data"
pred_dir = f'{data_dir}/predictions'
out_dir = f"{main_dir}/out"
stats_dir = f"{out_dir}/stats"
include_ag_results = True
stats_dir_ag = f"{out_dir}/stats_alphagenome"
plt_dir = f"{out_dir}/plots"
os.makedirs(plt_dir, exist_ok=True)

In [None]:
# Figure 2: Distribution of JCC prediction scores for Scenario 3 “Predicting hard-to-find junctions” in Real-world use case
in_dir = f'{stats_dir}/scenario_3_real_world'
aligner="star"
tool = "jcc"
fig, axs = plt.subplots(1, 3, figsize=(12, 3.5))  # 3 horizontal subplots
for i, gt_confidence in enumerate(["unfiltered","illumina","cutoff"]):
	file_sj_50 = f'{in_dir}/tool_sjs_50_{aligner}_vs_GT_{aligner}_{gt_confidence}.pkl'
	gt_confidence = gt_confidence.capitalize() if gt_confidence == 'illumina' else gt_confidence
	with open(file_sj_50, 'rb') as f:
		sjs_50 = pd.read_pickle(f)
		for ann, sj in sjs_50[tool].items():
			_, sample_id, run_id = ann
			if sample_id == 'J26675-L1_S1' and run_id == '0':
					ax = axs[i]
					bins = np.histogram_bin_edges(sj['pred'], bins=25)
					ax.hist(sj[sj['label'] == 0]['pred'], bins=bins, alpha=0.6, label='Negative (label=0)')
					ax.hist(sj[sj['label'] == 1]['pred'], bins=bins, alpha=0.6, label='Positive (label=1)')
					ax.set_xlabel(f'{tool.upper()} prediction score', fontsize=14)
					ax.set_ylabel('Number of junctions', fontsize=14)
					ax.set_yscale('log')
					ax.tick_params(labelsize=14)
					ax.set_xlim(-0.02, 1.02)
					if i == 2:
						ax.legend(fontsize=14)
					ax.grid(True)
					ax.title.set_text(f'{aligner.upper()} {gt_confidence}')
					ax.title.set_fontsize(16)
plt.tight_layout()
plt.savefig(f'{plt_dir}/figure_2s.tif', dpi=600, transparent=True, pil_kwargs={"compression": "tiff_lzw"})
plt.close()

In [None]:
# Supplementary Figure S1: plot upsetplot of all groundtruths 

labels = ["GRCh38.106 reference genome"]
reference_sj = pd.read_csv(f"{data_dir}/annotated.sj",sep='\t', header=None, names=["chr", "start", "end", "strand"], dtype={"chr": str, "start": int, "end": int, "strand": int})
# keep only chr1-22, X, Y
reference_sj = reference_sj[reference_sj["chr"].isin([str(i) for i in range(1,23)] + ["X","Y"])]
reference_sj = reference_sj.drop_duplicates()

sets = [reference_sj]

for aligner in ["hisat","star"]:
	for filter_type in (["unfiltered","illumina","cutoff"] if aligner == "star" else ["unfiltered"]):
		sample_sets = []
		for sample in ["J26675-L1_S1","J26676-L1_S2","J26677-L1_S3","J26678-L1_S4"]:
			sample_sets.append(pd.read_csv(f"{data_dir}/500M/{sample}_{aligner}_{filter_type}.sj", usecols=[0,1,2,3], dtype={"chr": str, "start": int, "end": int, "strand": int}, sep="\t", header=None, names=["chr","start","end","strand"]))
		filter_type_ = filter_type.capitalize() if filter_type == "illumina" else filter_type
		filter_type_ = filter_type_ if filter_type_ == "unfiltered" else filter_type_+" filtered"
		aligner_ = aligner.upper()+'2' if aligner == "hisat" else aligner.upper()
		labels.append(aligner_+" "+filter_type_+" gold standard")
		combined_sj = pd.concat(sample_sets, ignore_index=True).drop_duplicates()
		sets.append(combined_sj)

# Merge all dataframes with an additional 'source' column
merged_df = pd.concat(
	[df.assign(source=name) for name, df in zip(labels, sets)],
	ignore_index=True
)

# get counts for each group
binary_merged = merged_df.pivot_table(index=['chr','start','end','strand'], columns='source', aggfunc='size', fill_value=0).clip(upper=1)
g = binary_merged.groupby(binary_merged.columns.tolist()).size()
g = g.reorder_levels(labels)

upset = UpSet(g, show_counts=True, sort_categories_by='-input', sort_by='-degree')
upset.plot()
plt.subplots_adjust(right=1.25)
plt.savefig(f'{plt_dir}/supplementary_figure_1.tif', dpi=600, bbox_inches='tight')
plt.close()

In [None]:
# plot venn diagram for three sets a,b,c
def prep_for_venn3(a, b, c, colored = True):
	three_sets = {}
	a_c = a.merge(c, how='inner', on=['chr','start','end','strand'])
	a_b_c = a_c.merge(b, how='outer', on=['chr','start','end','strand'], indicator=True)
	three_sets['101'] = len(a_b_c[a_b_c._merge=='left_only'])
	three_sets['111'] = len(a_b_c[a_b_c._merge=='both'])
	b_c = b.merge(c, how='inner', on=['chr','start','end','strand'])
	a_b_c = b_c.merge(a, how='left', on=['chr','start','end','strand'], indicator=True)
	three_sets['011'] = len(a_b_c[a_b_c._merge=='left_only'])
	a_b = a.merge(b, how='inner', on=['chr','start','end','strand'])
	a_b_c = a_b.merge(c, how='left', on=['chr','start','end','strand'], indicator=True)
	three_sets['110'] = len(a_b_c[a_b_c._merge=='left_only'])
	a_nb = a.merge(b, how='left', on=['chr','start','end','strand'], indicator=True)
	a_nb_nc = a_nb[a_nb._merge=='left_only'].drop(columns=['_merge']).merge(c, how='left', on=['chr','start','end','strand'], indicator=True)
	three_sets['100'] = len(a_nb_nc[a_nb_nc._merge=='left_only'])
	na_b = a.merge(b, how='right', on=['chr','start','end','strand'], indicator=True)
	na_b_nc = na_b[na_b._merge=='right_only'].drop(columns=['_merge']).merge(c, how='left', on=['chr','start','end','strand'], indicator=True)
	three_sets['010'] = len(na_b_nc[na_b_nc._merge=='left_only'])
	na_c = a.merge(c, how='right', on=['chr','start','end','strand'], indicator=True)
	na_nb_c = na_c[na_c._merge=='right_only'].drop(columns=['_merge']).merge(b, how='left', on=['chr','start','end','strand'], indicator=True)
	three_sets['001'] = len(na_nb_c[na_nb_c._merge=='left_only'])
	tab10 = plt.cm.tab10.colors 
	colors = [tab10[0], tab10[1], tab10[2]] if colored else [tab10[0], tab10[0], tab10[0]]
	return three_sets, colors

In [None]:
# Supplementary Figure S2
# plot overlap STAR 500M with different filtering in VENN diagrams with 3 sets as example sample_id J26675-L1_S1
sample = 'J26675-L1_S1'
aligner = 'star'
sjs = {}
for filter_type in ["unfiltered","illumina","cutoff"]:
	filter_type_ = filter_type.capitalize() if filter_type == "illumina" else filter_type
	sjs[f'{aligner.upper()} {filter_type_}'] = pd.read_csv(f"{data_dir}/500M/{sample}_{aligner}_{filter_type}.sj", usecols=[0,1,2,3], dtype={"chr": str, "start": int, "end": int, "strand": int}, sep="\t", header=None, names=["chr","start","end","strand"])
three_sets, colors = prep_for_venn3(*sjs.values(), colored=False)
fig, ax = plt.subplots(figsize=(5, 5))
venn3(subsets=three_sets, set_labels = sjs.keys(), set_colors=colors, ax=ax)
plt.savefig(f'{plt_dir}/supplementary_figure_2.tif', dpi=600, bbox_inches='tight')
plt.close()

In [None]:
# Supplementary Figure S3		
# plot overlap 50M with 500M with reference genome gtf in Venn diagrams with 3 sets as example subsample 0 of sample_id J26675-L1_S1
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for ax, aligner in zip(axes,["star","hisat"]):
	sample = 'J26675-L1_S1'
	i=0
	filter_type = 'unfiltered'
	title =''
	full = pd.read_csv(f"{data_dir}/500M/{sample}_{aligner}_{filter_type}.sj", usecols=[0,1,2,3], dtype={"chr": str, "start": int, "end": int, "strand": int}, sep="\t", header=None, names=["chr","start","end","strand"])
	subsampled = pd.read_csv(f"{data_dir}/50M/{aligner}/{sample}_50M_{i}.sj", usecols=[0,1,2,3], dtype={"chr": str, "start": int, "end": int, "strand": int}, sep="\t", header=None, names=["chr","start","end","strand"])
	reference_sj = pd.read_csv(f"{data_dir}/annotated.sj",sep='\t', header=None, names=["chr", "start", "end", "strand"], dtype={"chr": str, "start": int, "end": int, "strand": int})
	aligner_ = aligner.upper()+'2' if aligner == "hisat" else aligner.upper()
	three_sets, colors = prep_for_venn3(full, subsampled, reference_sj)
	venn = venn3(subsets=three_sets, set_labels = [f'{aligner_} 500M', f'{aligner_} 50M', 'reference genome'], set_colors=colors, ax=ax)
	# make set labels colored
	for i, text in enumerate(venn.set_labels):
		text.set_color(plt.cm.tab10.colors [i])
plt.tight_layout()
plt.savefig(f'{plt_dir}/supplementary_figure_3.tif', dpi=600, bbox_inches='tight')
plt.close()

In [None]:
# Supplementary Figure S4: Precision-Recall curves of the tools over the different evaluation scenarios for STAR unfiltered, STAR Illumina filtered, STAR cutoff filtered and HISAT unfiltered gold standards.
TOOL_COLORS = {'deepsplice':0, 'spliceai':1, 'jcc':2, 'baseline':3, 'alphagenome':4}
tool2tool = {'alphagenome':'AlphaGenome','spliceai':'SpliceAI','deepsplice':'DeepSplice','jcc':'JCC'}
fig, axes = plt.subplots(5, 4, figsize=(4*4.5, 5*4), squeeze=False)
for row_idx, scenario in enumerate(['scenario_1a','scenario_1b','scenario_2_hypothetical','scenario_2_real_world','scenario_3_hypothetical']): 
	for col_idx, (aligner, gt_confidence) in enumerate([('star','unfiltered'),('star','illumina'),('star','cutoff'),('hisat','unfiltered')]):
		ax = axes[row_idx][col_idx]
		if (('real_world' in scenario) and (aligner == 'hisat')):
			ax.axis('off')
		else:
			sj_file = f'{stats_dir}/{scenario}/tool_sjs_50_{aligner}_vs_GT_{aligner}_{gt_confidence}.pkl'
			sj_file_ag = f'{stats_dir_ag}/{scenario}/tool_sjs_50_{aligner}_vs_GT_{aligner}_{gt_confidence}.pkl'
			if (scenario == 'scenario_3_hypothetical') and include_ag_results and os.path.exists(sj_file_ag):
				sj_file = sj_file_ag
			with open(sj_file,'rb') as f:
				tool_sjs_50 = pickle.load(f)
			for tool, sjs in tool_sjs_50.items():
				y_true = np.concatenate([sj["label"] for sj in sjs.values()])
				y_scores = np.concatenate([sj["pred"] for sj in sjs.values()])
				precision, recall, thresholds = precision_recall_curve(y_true, y_scores)
				ax.plot(recall, precision, label=tool2tool[tool], color=plt.get_cmap('tab10')(TOOL_COLORS[tool]))
			if row_idx == 0:
				ax.set_title(f"{aligner.upper()}{'2' if aligner=='hisat' else ''} {gt_confidence}",fontweight='bold', fontsize=12)
			ax.set_xlabel('Recall')
			ax.set_ylim(0, 1.05)
			ax.set_ylabel('Precision')s
			if col_idx == 0:
				scenario_name = scenario.capitalize().replace('_', ' ')
				ax.text(-0.25, 0.5, scenario_name,transform=ax.transAxes,fontsize=12,fontweight='bold',va='center',rotation=90)
			ax.legend()
plt.savefig(f"{plt_dir}/supplementary_figure_4.tif", dpi=600, bbox_inches='tight')
plt.close()

In [None]:
# Supplementary Figure S5: Boxplots of the distribution of the tools' AUPRC scores over the different evaluation scenarios with Wilcoxon test.
# Color mapping as specified
TOOL_COLORS = {'deepsplice':0, 'spliceai':1, 'jcc':2, 'baseline':3, 'alphagenome':4}

# Display name mapping for plotting
TOOL_PLOT_NAME = {
	'alphagenome': 'AlphaGenome',
	'spliceai': 'SpliceAI',
	'deepsplice': 'DeepSplice',
	'jcc': 'JCC',
	'baseline': 'No-Skill'
}
gridspec_kw = {'height_ratios': [1, 1, 1, 1, 1.2]} if include_ag_results else {'height_ratios': [1, 1, 1, 1, 1]}
nr_cols = 4
nr_rows = 5
fig, axes = plt.subplots(nr_rows, nr_cols, figsize=(4*nr_cols, nr_rows*3.5), squeeze=False, gridspec_kw=gridspec_kw)
tab10 = plt.get_cmap('tab10')
# for hypothetial scenarios only 4 samples and wilcoxon less meaningful
for row_idx, scenario in enumerate(['scenario_1a','scenario_1b','scenario_2_hypothetical','scenario_2_real_world','scenario_3_hypothetical']): 
	for col_idx, (aligner, gt_confidence) in enumerate([('star','unfiltered'),('star','illumina'),('star','cutoff'),('hisat','unfiltered')]):
		ax = axes[row_idx][col_idx]
		if (('real_world' in scenario) and (aligner == 'hisat')):
			ax.axis('off')
		else:
			sj_file = f'{stats_dir}/{scenario}/tool_sjs_50_{aligner}_vs_GT_{aligner}_{gt_confidence}.pkl'
			sj_file_ag = f'{stats_dir_ag}/{scenario}/tool_sjs_50_{aligner}_vs_GT_{aligner}_{gt_confidence}.pkl'
			if (scenario == 'scenario_3_hypothetical') and include_ag_results and os.path.exists(sj_file_ag):
				sj_file = sj_file_ag
			with open(sj_file,'rb') as f:
				tool_sjs = pickle.load(f)

			tool_names = [key for key in TOOL_PLOT_NAME if key in tool_sjs]
			plot_names = [TOOL_PLOT_NAME[n] for n in tool_names]
			n_tools = len(tool_names)
			all_auprcs = []
			# Get AUPRCs for all tools
			for tool in tool_names:
				sjs = tool_sjs[tool]
				auprcs = []
				for (_, sample, idx), s in sjs.items():
					auprcs.append(average_precision_score(s['label'], s['pred']))
				all_auprcs.append(np.array(auprcs))
			# Baseline column
			baseline_auprcs = []
			first_sjs = next(iter(tool_sjs.values()))
			for (_, sample, idx), s in first_sjs.items():
				baseline_auprcs.append(average_precision_score(
					s['label'],
					pd.Series([1] * len(s['pred']), index=s['pred'].index)
				))
			all_auprcs.append(np.array(baseline_auprcs))
			plot_names.append(TOOL_PLOT_NAME['baseline'])

			# Boxplot
			if row_idx == 0:
				ax.set_title(f"{aligner.upper()}{'2' if aligner=='hisat' else ''} {gt_confidence}", fontweight='bold', fontsize=12)
			ax.set_ylim(0, 1.199)
			if row_idx == nr_rows-1 and include_ag_results:
				ax.set_ylim(0, 1.35)
			bp = ax.boxplot(all_auprcs, 
							tick_labels=plot_names, 
							patch_artist=True)
			for i, (box, median) in enumerate(zip(bp['boxes'], bp['medians'])):
				toolname = tool_names[i] if i<n_tools else 'baseline'
				color = tab10(TOOL_COLORS[toolname])
				box.set_facecolor((color[0], color[1], color[2], 0.25))
				median.set_color(color[:3])

			# Find the max value for annotation placement
			max_auprc = max([np.max(au) for au in all_auprcs])

			# Draw significance for all pairs of tools (excluding baseline vs baseline)
			bracket_gaps = 0.05 #max(0.005, 0.05 * max_auprc)
			heights = []
			# Compare each tool to baseline
			for i in range(n_tools):
				x1, x2 = i+1, len(all_auprcs) # 1-indexed for plot
				y = max_auprc + bracket_gaps * (len(heights)+1)
				heights.append(y)
				if len(all_auprcs[-1]) == len(all_auprcs[i]):
					stat, p_value = wilcoxon(all_auprcs[-1], all_auprcs[i], alternative='two-sided')
				else:
					stat, p_value = wilcoxon(all_auprcs[-1][:1], all_auprcs[i], alternative='two-sided')
					if len(all_auprcs[-1][:1]) != len(all_auprcs[i]): print(f'Error: Number of samples is not the same between {tool_names[i]} and baseline.')

				significant = "N.S." if ((p_value > 0.05) or np.isnan(p_value)) else ("*" if p_value > 0.01 else ("**" if p_value > 0.001 else "***"))
				ax.plot([x1, x2], [y, y], lw=1, c='k')
				ax.text((x1 + x2) * 0.5, y, f'{significant} (p='+(f'{p_value:.2e})' if p_value < 0.01 else f'{p_value:.2f})'), ha='center', va='bottom', color='k', fontsize=8)

			# Compare all pairs of tools excluding baseline
			tool_indices = range(n_tools)
			for idx, (i, j) in enumerate(combinations(tool_indices, 2)):
				x1, x2 = i+1, j+1
				y = max_auprc + bracket_gaps * (len(heights)+1)
				heights.append(y)
				if len(all_auprcs[j]) == len(all_auprcs[i]):
					stat, p_value = wilcoxon(all_auprcs[j], all_auprcs[i], alternative='two-sided')
				else:
					stat, p_value = wilcoxon(all_auprcs[j][:1], all_auprcs[i], alternative='two-sided')
					if len(all_auprcs[j][:1]) != len(all_auprcs[i]): print(f'Error: Number of samples is not the same between the two compared tools {tool_names[i]} and {tool_names[j]}.')
				significant = "N.S." if ((p_value > 0.05) or np.isnan(p_value)) else ("*" if p_value > 0.01 else ("**" if p_value > 0.001 else "***"))
				ax.plot([x1, x2], [y, y], lw=1, c='k')
				ax.text((x1 + x2) * 0.5, y, f'{significant} (p='+(f'{p_value:.2e})' if p_value < 0.01 else f'{p_value:.2f})'), ha='center', va='bottom', color='k', fontsize=8)
			ax.set_ylabel("AUPRC")
			if col_idx == 0:
				scenario_name = scenario.capitalize().replace('_', ' ')
				ax.text(-0.25, 0.5, scenario_name,transform=ax.transAxes,fontsize=12,fontweight='bold',va='center',rotation=90)
plt.tight_layout()
plt.savefig(f'{plt_dir}/supplementary_figure_5.tif', dpi=600, bbox_inches='tight')
plt.close()

In [None]:
# Supplementary Figure S6
# JCC scenario_2_real_world positives plot distribution score in 50m data vs not in 50m data
tab10 = plt.cm.tab10.colors 
fig, axes = plt.subplots(1, 3, figsize=(14, 4))  # 3 horizontal subplots
for i, gt_confidence in enumerate(["unfiltered","illumina","cutoff"]):
	ax = axes[i]
	aligner ="star"
	file_sj_50 = f'{stats_dir}/scenario_2_real_world/tool_sjs_50_{aligner}_vs_GT_{aligner}_{gt_confidence}.pkl'
	with open(file_sj_50, 'rb') as f:
		tool_sjs_50 = pd.read_pickle(f)
	tool = 'jcc'
	for ann_50, sj_50 in tool_sjs_50[tool].items():
		_, sample_id, run_id = ann_50
		run_id = int(run_id)
		if run_id == 0 and sample_id == "J26675-L1_S1":
			positives = sj_50[sj_50.label == 1] # positives
			# read in sj in 50M input
			input_data = pd.read_csv(f"{data_dir}/50M/{aligner}/{sample_id}_50M_{run_id}.sj", usecols=[0,1,2,3], dtype={"chr": str, "start": int, "end": int, "strand": int}, sep="\t", header=None, names=["chr","start","end","strand"])
			# add column in_input 1 or 0
			positives = positives.merge(input_data, how='left', on=['chr','start','end','strand'], indicator=True)
			positives['_merge'] = positives['_merge'].astype('object').replace({'both': 1, 'left_only': 0, 'right_only': 0}).astype(int)
			positives.rename(columns={'_merge': 'in_input'},inplace=True, errors='raise')
			# plot distribution of jcc score of sj that are in groundtruth - that are in in 50m data vs not in 50m data
			ax.hist(positives[positives.in_input == 1]['pred'], bins=50, label='in 50M input data', alpha=0.5, color=tab10[1])
			ax.hist(positives[positives.in_input == 0]['pred'], bins=50, label='not in 50M input data', alpha=0.5, color=tab10[0])
			ax.set_xlabel('JCC prediction score', fontsize=16)
			ax.set_ylabel('Number of junctions', fontsize=16)
			ax.set_yscale('log')
			ax.tick_params(labelsize=16)
			ax.set_title(f'{aligner.upper()} {gt_confidence.capitalize() if gt_confidence == "illumina" else gt_confidence}', fontsize=16)
			if gt_confidence == 'unfiltered':
				ax.legend(fontsize=16) #only show legend for the last plot
plt.tight_layout()
plt.savefig(f"{plt_dir}/supplementary_figure_6.tif", dpi=600, bbox_inches='tight')
plt.close()

In [None]:
#Supplementary Figure S7: splice motifs, canonical motif (GT/AG) strongly increases confidence. done running
# Plot donor + acceptor motifs
def plot_bar_motifs(data, file_dir):
	donor_counts = data['donor_motif'].value_counts()
	acceptor_counts = data['acceptor_motif'].value_counts()
	fig, axes = plt.subplots(1, 2, figsize=(9, 4), squeeze=False)
	ax = axes.flatten()[0]
	ax.bar(donor_counts.index, donor_counts.values)
	ax.set_xlabel('Motif')
	ax.set_ylabel('Count')
	ax.set_title('Donor Motifs')
	ax.tick_params(axis='x', rotation=45)
	ax = axes.flatten()[1]
	ax.bar(acceptor_counts.index, acceptor_counts.values)
	ax.set_xlabel('Motif')
	ax.set_title('Acceptor Motifs')
	ax.tick_params(axis='x', rotation=45)
	plt.savefig(file_dir, dpi=600, bbox_inches='tight')
	plt.close()


# Function to get reverse complement
def reverse_complement(seq):
	complement = str.maketrans('ATCGatcg', 'TAGCtagc')
	return seq.translate(complement)[::-1]


# Extract motifs at junctions
def extract_junction_motifs(df, genome):
	donor_seqs = []
	acceptor_seqs = []
	for _, row in df.iterrows():
		chrom_seq = genome[row['chr']]
		start = row['start']
		end = row['end']
		strand = row['strand']

		donor_seq = chrom_seq[start-1:start+1]
		acceptor_seq = chrom_seq[end-2:end]

		if strand == '-':
			acceptor_seq_ = reverse_complement(donor_seq)
			donor_seq = reverse_complement(acceptor_seq)
			acceptor_seq = acceptor_seq_
		donor_seqs.append(donor_seq)
		acceptor_seqs.append(acceptor_seq)

	return pd.DataFrame({
        'donor_motif': donor_seqs,
        'acceptor_motif': acceptor_seqs
    }, index=df.index)


def extract_gene_id(attribute_str):
	pattern = r'gene_id "([^"]+)"'
	match = re.search(pattern, attribute_str)
	if match:
		gene_id = match.group(1)
		gene = gene_id.split('.')[0]
		return gene
	else:
		return None


if not os.path.exists(f'{data_dir}/Homo_sapiens.GRCh38.dna.primary_assembly.fa'):
	print('Downloading Homo_sapiens.GRCh38.dna.primary_assembly.fa ...')
	!wget https://ftp.ensembl.org/pub/release-104/fasta/homo_sapiens/dna/Homo_sapiens.GRCh38.dna.primary_assembly.fa.gz
	!gunzip Homo_sapiens.GRCh38.dna.primary_assembly.fa.gz
genome = Fasta(f'{data_dir}/Homo_sapiens.GRCh38.dna.primary_assembly.fa', as_raw=True)

for aligner in ['star']:
	perc_consensus_donor = 0
	perc_consensus_acceptor = 0
	all_nr_sj = 0
	all_htf_dfs = pd.read_pickle(f"{stats_dir}/scenario_3_hypothetical/tool_sjs_50_{aligner}_vs_GT_{aligner}_unfiltered.pkl")['spliceai']
	for (_, sample, _),htf_df in all_htf_dfs.items():
		if sample == 'J26675-L1_S1':
			len_bef = len(htf_df)
			htf_df['start'] = htf_df.start+2
			htf_df = htf_df[['start', 'end', 'strand', 'chr']]
			htf_df.loc[:,['donor_motif','acceptor_motif']] = extract_junction_motifs(htf_df, genome)
			nr_consensus_donor = len(htf_df[htf_df['donor_motif'].isin(['GT','CT'])])
			nr_consensus_acceptor = len(htf_df[htf_df['acceptor_motif'].isin(['AG','AC'])])
			nr_sj = len(htf_df)
			all_nr_sj += nr_sj
			perc_consensus_acceptor += nr_consensus_acceptor
			perc_consensus_donor += nr_consensus_donor
			plot_bar_motifs(htf_df, f'{plt_dir}/supplementary_figure_7.tif')
			print(f'{aligner.upper()} {sample}: {nr_consensus_donor}/{nr_sj} ({np.round((nr_consensus_donor/nr_sj)*100,2)}%) consensus donor motif, {nr_consensus_acceptor}/{nr_sj} ({np.round((nr_consensus_acceptor/nr_sj)*100,2)}%) consensus acceptor motif')

STAR J26675-L1_S1: 69359/69726 (99.47%) consensus donor, 69356/69726 (99.47%) consensus acceptor


In [None]:
# Supplementary Figure S8
# Distribution of prediction scores of JCC on gold standard data STAR unfiltered, STAR Illumina filtered, STAR cutoff filtered, HISAT2 unfiltered. 
# The score distribution of splice junctions that were not found in the filtered down 50M reads data (shown in blue) is compared to the score distribution of splice junctions that were not annotated in the reference genome (shown in orange). 
# In panel a. the prediction score distribution is visualized for Scenario 2: “Predicting junctions that could be detected with higher sequencing depth” Real-world setting, in panel b. 
# for Scenario 2: “Predicting junctions that could be detected with higher sequencing depth” Hypothetical setting. 
# This is shown on sample J26675-L1_S1 subsample 0 as an example, but is similar for the other three samples and their subsamples.
tool = 'jcc'
sample = 'J26675-L1_S1'
run = 0
aligner = 'star'
tab10 = plt.cm.tab10.colors
fig, axes = plt.subplots(3, 4, figsize=(19, 11))  # 4x3 subplots
# read in sj in 50M input
sj_50 = pd.read_csv(f"{data_dir}/50M/star/{sample}_50M_{run}.sj", usecols=[0,1,2,3], dtype={"chr": str, "start": int, "end": int, "strand": int}, sep="\t", header=None, names=["chr","start","end","strand"])
# read in reference genome sj
reference_sj = pd.read_csv(f"{data_dir}/annotated.sj",sep='\t', header=None, names=["chr", "start", "end", "strand"], dtype={"chr": str, "start": int, "end": int, "strand": int})
for i, filter_type in enumerate(["unfiltered","illumina","cutoff"]):
	filter_type_ = filter_type.capitalize() if filter_type == "illumina" else filter_type 
	file_sj = f'{stats_dir}/scenario_2_real_world/tool_sjs_50_{aligner}_vs_GT_{aligner}_{filter_type}.pkl'
	with open(file_sj, 'rb') as f:
		tool_sjs = pd.read_pickle(f)
	for ann, sj in tool_sjs[tool].items():
		_, sample_id, run_id = ann
		run_id = int(run_id)
		if run_id == run and sample_id == sample:
			# plot score distribution of sj that are in groundtruth - that are not in 50m data vs not in reference genome
			sj_not_in_50 = sj.merge(sj_50, how='left', on=['chr','start','end','strand'], indicator=True)
			sj_not_in_50 = sj_not_in_50[sj_not_in_50['_merge'] == 'left_only'].drop(columns=['_merge'])
			sj_not_in_reference = sj.merge(reference_sj, how='left', on=['chr','start','end','strand'], indicator=True)
			sj_not_in_reference = sj_not_in_reference[sj_not_in_reference['_merge'] == 'left_only'].drop(columns=['_merge'])
			bins = np.linspace(0, 1, 50)
			ax = axes[0, i]
			ax.hist(sj_not_in_reference['pred'], bins=bins, label='not in reference genome', alpha=0.5, color=tab10[2])
			ax.hist(sj_not_in_50['pred'], bins=bins, label='not in 50M input data', alpha=0.5, color=tab10[0])
			ax.set_xlabel('Prediction score', fontsize=16)
			ax.set_ylabel('Number of junctions', fontsize=16)
			ax.set_yscale('log')
			ax.tick_params(labelsize=16)
			ax.set_title(f'{aligner.upper()}{'2' if aligner=='hisat' else ''} {filter_type_}', fontsize=18)
			if i == 0:
				ax.text(-0.25, 1.2, 'a.', transform=ax.transAxes, fontsize=22, va='top', ha='right', weight='bold')
				ax.text(-0.25, 0.5, tool.upper(), transform=ax.transAxes, fontsize=20, va='center', ha='right', rotation=90)
axes[0,3].axis('off')
for i, (aligner, filter_type) in enumerate([("star", "unfiltered"), ("star", "illumina"), ("star", "cutoff"), ("hisat", "unfiltered")]):
	filter_type_ = filter_type.capitalize() if filter_type == "illumina" else filter_type 
	file_sj = f'{stats_dir}/scenario_2_hypothetical/tool_sjs_50_{aligner}_vs_GT_{aligner}_{filter_type}.pkl'
	with open(file_sj, 'rb') as f:
		tool_sjs = pd.read_pickle(f)
	for i_y, tool in enumerate(['deepsplice','spliceai']):
		i_y += 1
		for ann, sj in tool_sjs[tool].items():
			_, sample_id, run_id = ann
			if sample_id == sample:
				ax = axes[i_y, i]
				# plot score distribution of sj that are in groundtruth - that are not in 50m data vs not in reference genome
				sj_not_in_50 = sj.merge(sj_50, how='left', on=['chr','start','end','strand'], indicator=True)
				sj_not_in_50 = sj_not_in_50[sj_not_in_50['_merge'] == 'left_only'].drop(columns=['_merge'])
				sj_not_in_reference = sj.merge(reference_sj, how='left', on=['chr','start','end','strand'], indicator=True)
				sj_not_in_reference = sj_not_in_reference[sj_not_in_reference['_merge'] == 'left_only'].drop(columns=['_merge'])
				bins = np.linspace(0, 1 if tool=='spliceai' else 1.7, 50)
				ax.hist(sj_not_in_reference['pred'], bins=bins, label='not in reference genome', alpha=0.5, color=tab10[2])
				ax.hist(sj_not_in_50['pred'], bins=bins, label='not in 50M input data', alpha=0.5, color=tab10[0])
				ax.set_xlabel('Prediction score', fontsize=16)
				ax.set_ylabel('Number of junctions', fontsize=16)
				ax.set_yscale('log')
				ax.tick_params(labelsize=16)
				ax.set_title(f'{aligner.upper()}{'2' if aligner=='hisat' else ''} {filter_type_}', fontsize=18)
				if (filter_type == 'unfiltered') & (tool == 'spliceai') & (aligner == 'hisat'):
					ax.legend(fontsize=16)
				if i == 0:
					tool_ = 'DeepSplice' if tool == 'deepsplice' else 'SpliceAI'
					if i_y == 1: ax.text(-0.25, 1.2, 'b.', transform=ax.transAxes, fontsize=22, va='top', ha='right', weight='bold')
					ax.text(-0.25, 0.5, tool_, transform=ax.transAxes, fontsize=20, va='center', ha='right', rotation=90)
plt.tight_layout()
plt.savefig(f"{plt_dir}/supplementary_figure_8.tif", dpi=600, bbox_inches='tight')
plt.close()

In [None]:
# Supplementary Figure S9: Overlap of splice junctions aligned with HISAT vs STAR for all samples.
fig, axes = plt.subplots(2, 2, figsize=(2*5, 2*4.5), squeeze=False)
for i, sample in enumerate(['J26675-L1_S1','J26676-L1_S2','J26677-L1_S3','J26678-L1_S4']):
	ax = axes.flatten()[i]
	df_hisat = pd.read_csv(f'{data_dir}/500M/{sample}_hisat_unfiltered.sj',usecols=[0,1,2,3],names=['chr','start','end','strand'],delimiter='\t',dtype={'chr':str})
	df_star = pd.read_csv(f'{data_dir}/500M/{sample}_star_unfiltered.sj',usecols=[0,1,2,3],names=['chr','start','end','strand'],delimiter='\t',dtype={'chr':str})
	df_both = df_hisat.merge(df_star, on=['start', 'end', 'strand', 'chr'],how='inner')
	print(f'{round((len(df_both)/len(df_hisat))*100,2)}% of splice junctions from {sample} are found in both aligners 500M data.')

	# Sets of junctions
	set_hisat = set(zip(df_hisat['chr'], df_hisat['start'], df_hisat['end'], df_hisat['strand']))
	set_star = set(zip(df_star['chr'], df_star['start'], df_star['end'], df_star['strand']))
	only_hisat = len(set_hisat - set_star)
	only_star = len(set_star - set_hisat)
	both = len(set_hisat & set_star)
	ax.set_title(f'Sample {sample}', fontweight='bold')
	venn2(subsets = (len(set_hisat - set_star), len(set_star - set_hisat), len(set_hisat & set_star)),
		set_labels = ('HISAT2 500M unfiltered', 'STAR 500M unfiltered'),ax=ax)
plt.savefig(f'{plt_dir}/supplementary_figure_9.tif', dpi=600, bbox_inches='tight')
plt.close()

97.41% of splice junctions from J26675-L1_S1 are found in both aligners 500M data.
96.79% of splice junctions from J26676-L1_S2 are found in both aligners 500M data.
97.16% of splice junctions from J26677-L1_S3 are found in both aligners 500M data.
97.7% of splice junctions from J26678-L1_S4 are found in both aligners 500M data.


In [None]:
# Supplementary Figure S10: Calibration curves of the tools over the different evaluation scenarios for STAR unfiltered, STAR Illumina filtered, STAR cutoff filtered and HISAT unfiltered gold standards.
TOOL_COLORS = {'deepsplice':0, 'spliceai':1, 'jcc':2, 'baseline':3, 'alphagenome':4}
tool2tool = {'alphagenome':'AlphaGenome','spliceai':'SpliceAI','deepsplice':'DeepSplice','jcc':'JCC'}
fig, axes = plt.subplots(5, 4, figsize=(4*4, 5*3.5), squeeze=False)
for row_idx, scenario in enumerate(['scenario_1a','scenario_1b','scenario_2_hypothetical','scenario_2_real_world','scenario_3_hypothetical']): 
	for col_idx, (aligner, gt_confidence) in enumerate([('star','unfiltered'),('star','illumina'),('star','cutoff'),('hisat','unfiltered')]):
		ax = axes[row_idx][col_idx]
		if (('real_world' in scenario) and (aligner == 'hisat')):
			ax.axis('off')
		else:
			sj_file = f'{stats_dir}/{scenario}/tool_sjs_50_{aligner}_vs_GT_{aligner}_{gt_confidence}.pkl'
			sj_file_ag = f'{stats_dir_ag}/{scenario}/tool_sjs_50_{aligner}_vs_GT_{aligner}_{gt_confidence}.pkl'
			if (scenario == 'scenario_3_hypothetical') and include_ag_results and os.path.exists(sj_file_ag):
				sj_file = sj_file_ag
			with open(sj_file,'rb') as f:
				tool_sjs_50 = pickle.load(f)
			
			n_bins=10
			bin_preds = []
			bin_trues = []
			for tool, sjs in tool_sjs_50.items():
			# Gather all predictions/labels for scenario
				y_true = np.concatenate([sj["label"] for sj in sjs.values()])
				y_pred = np.concatenate([sj["pred"] for sj in sjs.values()])

				bins = np.linspace(0, 1, n_bins+1)
				binids = np.digitize(y_pred, bins) - 1 # binids from 0 to n_bins-1
				
				bin_true = []
				bin_pred = []
				bin_count = []
				for i in range(n_bins):
					mask = (binids == i)
					if np.any(mask):
						bin_true.append(np.mean(y_true[mask]))
						bin_pred.append(np.mean(y_pred[mask]))
						bin_count.append(np.sum(mask))
					else:
						bin_true.append(np.nan)
						bin_pred.append((bins[i] + bins[i+1]) / 2)
						bin_count.append(0)
				bin_trues.append(bin_true)
				bin_preds.append(bin_pred)

			mask = (~np.isnan(bin_true))
			for (tool, bin_pred, bin_true) in zip(tool_sjs_50.keys(),bin_preds,bin_trues):
				ax.plot(bin_pred, bin_true, "s-", label=f"{tool2tool[tool]} Calibration", zorder=2, color=plt.get_cmap('tab10')(TOOL_COLORS[tool]))
				ax.fill_between(bin_pred, 0, bin_true, alpha=0.1, color=plt.get_cmap('tab10')(TOOL_COLORS[tool]))
			ax.plot([0,1],[0,1], "--k", label="Perfect Calibration", zorder=1)
			ax.set_xlabel("Mean predicted probability")
			ax.set_ylabel("Empirical frequency (Fraction of positives)")
			ax.set_xlim(0,1)
			ax.set_ylim(0,1)
			if row_idx == 0:
				ax.set_title(f"{aligner.upper()}{'2' if aligner=='hisat' else ''} {gt_confidence}",fontweight='bold', fontsize=12)
			if col_idx == 0:
				scenario_name = scenario.capitalize().replace('_', ' ')
				ax.text(-0.25, 0.5, scenario_name, transform=ax.transAxes, fontsize=12, fontweight='bold', va='center', rotation=90)
			ax.legend()
			ax.grid(True, zorder=0)
plt.tight_layout()
plt.savefig(f"{plt_dir}/supplementary_figure_10.tif", dpi=600, bbox_inches='tight')
plt.close()

In [None]:
# Supplementary Figure S11: Distribution of the prediction scores of the three different tools DeepSplice, SpliceAI, and JCC, after running on the 500M reads RNA-seq data aligned with STAR. This is shown for sample J26675-L1_S1 as an example, but is similar for the other three samples.
ds_pred = pd.read_csv(f'{pred_dir}/500M/J26675-L1_S1/star/deepsplice_pred.csv',low_memory=False)
sa_pred = pd.read_csv(f'{pred_dir}/500M/J26675-L1_S1/star/spliceai_pred.csv',low_memory=False)
jcc_pred = pd.read_csv(f'{pred_dir}/500M/J26675-L1_S1/star/jcc_pred.csv',low_memory=False)
plt.figure(figsize=(6, 5))
plt.hist(ds_pred['pred'], bins=50, label='DeepSplice', alpha=0.5)
plt.hist(sa_pred['pred'], bins=50, label='SpliceAI', alpha=0.5)
plt.hist(jcc_pred['pred'], bins=50, label='JCC', alpha=0.5)
plt.xlabel('Prediction score', fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.ylabel('Number of junctions', fontsize=16)
plt.legend(fontsize=14)
plt.savefig(f"{plt_dir}/supplementary_figure_11.tif", dpi=600, bbox_inches='tight')
plt.close()