For producing the figures you will need the plotting conda environment:

conda env create -f eval.yml

And you will need to first run:

python evaluation_scenarios.py

In [73]:
import os
main_dir = "." # TODO set working directory
data_dir = f"{main_dir}/data"
pred_dir = f'{data_dir}/predictions'
out_dir = f"{main_dir}/out"
stats_dir = f"{out_dir}/stats"
plt_dir = f"{out_dir}/plots"
os.makedirs(plt_dir, exist_ok=True)

In [None]:
# Figure 2: Distribution of JCC prediction scores for Scenario 3 “Predicting hard-to-find junctions” in Real-world use case
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
in_dir = f'{stats_dir}/scenario_3_real_world'
aligner="star"
tool = "jcc"
fig, axs = plt.subplots(1, 3, figsize=(14, 4))  # 3 horizontal subplots
for i, gt_confidence in enumerate(["unfiltered","illumina","cutoff"]):
	file_sj_50 = f'{in_dir}/tool_sjs_50_{aligner}_vs_GT_{aligner}_{gt_confidence}.pkl'
	gt_confidence = gt_confidence.capitalize() if gt_confidence == 'illumina' else gt_confidence
	with open(file_sj_50, 'rb') as f:
		sjs_50 = pd.read_pickle(f)
		for ann, sj in sjs_50[tool].items():
			_, sample_id, run_id = ann
			if sample_id == 'J26675-L1_S1' and run_id == '0':
					ax = axs[i]
					bins = np.histogram_bin_edges(sj['pred'], bins=25)
					ax.hist(sj[sj['label'] == 0]['pred'], bins=bins, alpha=0.6, label='Negative (label=0)')
					ax.hist(sj[sj['label'] == 1]['pred'], bins=bins, alpha=0.6, label='Positive (label=1)')
					ax.set_xlabel(f'{tool.upper()} prediction score', fontsize=16)
					ax.set_ylabel('Number of junctions', fontsize=16)
					ax.set_yscale('log')
					ax.tick_params(labelsize=16)
					ax.set_xlim(-0.02, 1.02)
					if i == 2:
						ax.legend(fontsize=16)
					ax.grid(True)
					ax.title.set_text(f'{aligner.upper()} {gt_confidence}')
					ax.title.set_fontsize(18)
plt.tight_layout()
plt.savefig(f'{plt_dir}/figure_2.tif', dpi=600)
plt.close()

In [None]:
# Supplementary Figure 1: plot upsetplot of all groundtruths 
from upsetplot import UpSet
import pandas as pd
pd.set_option('future.no_silent_downcasting', True)
import matplotlib.pyplot as plt

labels = ["GRCh38.106 reference genome"]
reference_sj = pd.read_csv(f"{data_dir}/annotated.sj",sep='\t', header=None, names=["chr", "start", "end", "strand"], dtype={"chr": str, "start": int, "end": int, "strand": int})
# keep only chr1-22, X, Y
reference_sj = reference_sj[reference_sj["chr"].isin([str(i) for i in range(1,23)] + ["X","Y"])]
reference_sj = reference_sj.drop_duplicates()

sets = [reference_sj]

for aligner in ["hisat","star"]:
	for filter_type in (["unfiltered","illumina","cutoff"] if aligner == "star" else ["unfiltered"]):
		sample_sets = []
		for sample in ["J26675-L1_S1","J26676-L1_S2","J26677-L1_S3","J26678-L1_S4"]:
			sample_sets.append(pd.read_csv(f"{data_dir}/500M/{sample}_{aligner}_{filter_type}.sj", usecols=[0,1,2,3], dtype={"chr": str, "start": int, "end": int, "strand": int}, sep="\t", header=None, names=["chr","start","end","strand"]))
		filter_type_ = filter_type.capitalize() if filter_type == "illumina" else filter_type
		filter_type_ = filter_type_ if filter_type_ == "unfiltered" else filter_type_+" filtered"
		aligner_ = aligner.upper()+'2' if aligner == "hisat" else aligner.upper()
		labels.append(aligner_+" "+filter_type_+" gold standard")
		combined_sj = pd.concat(sample_sets, ignore_index=True).drop_duplicates()
		sets.append(combined_sj)

# Merge all dataframes with an additional 'source' column
merged_df = pd.concat(
	[df.assign(source=name) for name, df in zip(labels, sets)],
	ignore_index=True
)

# get counts for each group
binary_merged = merged_df.pivot_table(index=['chr','start','end','strand'], columns='source', aggfunc='size', fill_value=0).clip(upper=1)
g = binary_merged.groupby(binary_merged.columns.tolist()).size()
g = g.reorder_levels(labels)

upset = UpSet(g, show_counts=True, sort_categories_by='-input', sort_by='-degree')
upset.plot()
plt.subplots_adjust(right=1.25)
out_file = os.path.join(plt_dir, f"supplementary_figure_1.tif")
plt.savefig(out_file, dpi=600, bbox_inches='tight')
plt.close()

In [None]:
from matplotlib_venn import venn3
# plot venn diagram for three sets a,b,c
def prep_for_venn3(a, b, c, colored = True):
	three_sets = {}
	a_c = a.merge(c, how='inner', on=['chr','start','end','strand'])
	a_b_c = a_c.merge(b, how='outer', on=['chr','start','end','strand'], indicator=True)
	three_sets['101'] = len(a_b_c[a_b_c._merge=='left_only'])
	three_sets['111'] = len(a_b_c[a_b_c._merge=='both'])
	b_c = b.merge(c, how='inner', on=['chr','start','end','strand'])
	a_b_c = b_c.merge(a, how='left', on=['chr','start','end','strand'], indicator=True)
	three_sets['011'] = len(a_b_c[a_b_c._merge=='left_only'])
	a_b = a.merge(b, how='inner', on=['chr','start','end','strand'])
	a_b_c = a_b.merge(c, how='left', on=['chr','start','end','strand'], indicator=True)
	three_sets['110'] = len(a_b_c[a_b_c._merge=='left_only'])
	a_nb = a.merge(b, how='left', on=['chr','start','end','strand'], indicator=True)
	a_nb_nc = a_nb[a_nb._merge=='left_only'].drop(columns=['_merge']).merge(c, how='left', on=['chr','start','end','strand'], indicator=True)
	three_sets['100'] = len(a_nb_nc[a_nb_nc._merge=='left_only'])
	na_b = a.merge(b, how='right', on=['chr','start','end','strand'], indicator=True)
	na_b_nc = na_b[na_b._merge=='right_only'].drop(columns=['_merge']).merge(c, how='left', on=['chr','start','end','strand'], indicator=True)
	three_sets['010'] = len(na_b_nc[na_b_nc._merge=='left_only'])
	na_c = a.merge(c, how='right', on=['chr','start','end','strand'], indicator=True)
	na_nb_c = na_c[na_c._merge=='right_only'].drop(columns=['_merge']).merge(b, how='left', on=['chr','start','end','strand'], indicator=True)
	three_sets['001'] = len(na_nb_c[na_nb_c._merge=='left_only'])
	tab10 = plt.cm.tab10.colors 
	colors = [tab10[0], tab10[1], tab10[2]] if colored else [tab10[0], tab10[0], tab10[0]]
	return three_sets, colors

In [None]:
# Supplementary Figure 2
import pandas as pd

# plot overlap STAR 500M with different filtering in VENN diagrams with 3 sets as example sample_id J26675-L1_S1
sample = 'J26675-L1_S1'
aligner = 'star'
sjs = {}
for filter_type in ["unfiltered","illumina","cutoff"]:
	filter_type_ = filter_type.capitalize() if filter_type == "illumina" else filter_type
	sjs[f'{aligner.upper()} {filter_type_}'] = pd.read_csv(f"{data_dir}/500M/{sample}_{aligner}_{filter_type}.sj", usecols=[0,1,2,3], dtype={"chr": str, "start": int, "end": int, "strand": int}, sep="\t", header=None, names=["chr","start","end","strand"])
out_file = f'{plt_dir}/supplementary_figure_2.tif'
three_sets, colors = prep_for_venn3(*sjs.values(), colored=False)
fig, ax = plt.subplots(figsize=(5, 5))
venn3(subsets=three_sets, set_labels = sjs.keys(), set_colors=colors, ax=ax)
plt.savefig(out_file, dpi=600, bbox_inches='tight')
plt.close()

In [None]:
# Supplementary Figure 3
import pandas as pd
		
# plot overlap 50M with 500M with reference genome gtf in VENN diagrams with 3 sets as example subsample 0 of sample_id J26675-L1_S1
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
for ax, aligner in zip(axes,["star","hisat"]):
	sample = 'J26675-L1_S1'
	i=0
	filter_type = 'unfiltered'
	title =''
	full = pd.read_csv(f"{data_dir}/500M/{sample}_{aligner}_{filter_type}.sj", usecols=[0,1,2,3], dtype={"chr": str, "start": int, "end": int, "strand": int}, sep="\t", header=None, names=["chr","start","end","strand"])
	subsampled = pd.read_csv(f"{data_dir}/50M/{aligner}/{sample}_50M_{i}.sj", usecols=[0,1,2,3], dtype={"chr": str, "start": int, "end": int, "strand": int}, sep="\t", header=None, names=["chr","start","end","strand"])
	reference_sj = pd.read_csv(f"{data_dir}/annotated.sj",sep='\t', header=None, names=["chr", "start", "end", "strand"], dtype={"chr": str, "start": int, "end": int, "strand": int})
	aligner_ = aligner.upper()+'2' if aligner == "hisat" else aligner.upper()
	three_sets, colors = prep_for_venn3(full, subsampled, reference_sj)
	venn = venn3(subsets=three_sets, set_labels = [f'{aligner_} 500M', f'{aligner_} 50M', 'reference genome'], set_colors=colors, ax=ax)
	# make set labels colored
	for i, text in enumerate(venn.set_labels):
		text.set_color(plt.cm.tab10.colors [i])
out_file = f'{plt_dir}/supplementary_figure_3.tif'
plt.tight_layout()
plt.savefig(out_file, dpi=600, bbox_inches='tight')
plt.close()

In [None]:
# Supplementary Figure 4
# JCC scenario_2_real_world positives plot distribution score in 50m data vs not in 50m data
import pandas as pd
import matplotlib.pyplot as plt
pd.set_option('future.no_silent_downcasting', True)
tab10 = plt.cm.tab10.colors 
fig, axes = plt.subplots(1, 3, figsize=(14, 4))  # 3 horizontal subplots
for i, gt_confidence in enumerate(["unfiltered","illumina","cutoff"]):
	ax = axes[i]
	aligner ="star"
	file_sj_50 = f'{stats_dir}/scenario_2_real_world/tool_sjs_50_{aligner}_vs_GT_{aligner}_{gt_confidence}.pkl'
	with open(file_sj_50, 'rb') as f:
		tool_sjs_50 = pd.read_pickle(f)
	tool ='jcc'
	for ann_50, sj_50 in tool_sjs_50[tool].items():
		_, sample_id, run_id = ann_50
		run_id = int(run_id)
		if run_id == 0 and sample_id == "J26675-L1_S1":
			positives = sj_50[sj_50.label == 1] # positives
			# read in sj in 50M input
			input_data = pd.read_csv(f"{data_dir}/50M/{aligner}/{sample_id}_50M_{run_id}.sj", usecols=[0,1,2,3], dtype={"chr": str, "start": int, "end": int, "strand": int}, sep="\t", header=None, names=["chr","start","end","strand"])
			# add column in_input 1 or 0
			positives = positives.merge(input_data, how='left', on=['chr','start','end','strand'], indicator=True)
			positives['_merge'] = positives['_merge'].astype('object').replace({'both': 1, 'left_only': 0, 'right_only': 0}).astype(int)
			positives.rename(columns={'_merge': 'in_input'},inplace=True, errors='raise')
			# plot distribution of jcc score of sj that are in groundtruth - that are in in 50m data vs not in 50m data
			ax.hist(positives[positives.in_input == 1]['pred'], bins=50, label='in 50M input data', alpha=0.5, color=tab10[1])
			ax.hist(positives[positives.in_input == 0]['pred'], bins=50, label='not in 50M input data', alpha=0.5, color=tab10[0])
			ax.set_xlabel('JCC prediction score', fontsize=16)
			ax.set_ylabel('Number of junctions', fontsize=16)
			ax.set_yscale('log')
			ax.tick_params(labelsize=16)
			ax.set_title(f'{aligner.upper()} {gt_confidence.capitalize() if gt_confidence == "illumina" else gt_confidence}', fontsize=16)
			if gt_confidence == 'unfiltered':
				ax.legend(fontsize=16) #only show legend for the last plot
plt.tight_layout()
plt.savefig(f"{plt_dir}/supplementary_figure_4.tif", dpi=600, bbox_inches='tight')
plt.close()

In [None]:
# Supplementary Figure 5
# Distribution of prediction scores of JCC on gold standard data STAR unfiltered, STAR Illumina filtered, STAR cutoff filtered, HISAT2 unfiltered. 
# The score distribution of splice junctions that were not found in the filtered down 50M reads data (shown in blue) is compared to the score distribution of splice junctions that were not annotated in the reference genome (shown in orange). 
# In panel a. the prediction score distribution is visualized for Scenario 2: “Predicting junctions that could be detected with higher sequencing depth” Real-world setting, in panel b. 
# for Scenario 2: “Predicting junctions that could be detected with higher sequencing depth” Hypothetical setting. 
# This is shown on sample J26675-L1_S1 subsample 0 as an example, but is similar for the other three samples and their subsamples.
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
tool = 'jcc'
sample = 'J26675-L1_S1'
run = 0
aligner = 'star'
tab10 = plt.cm.tab10.colors
fig, axes = plt.subplots(3, 4, figsize=(19, 11))  # 4x3 subplots
# read in sj in 50M input
sj_50 = pd.read_csv(f"{data_dir}/50M/star/{sample}_50M_{run}.sj", usecols=[0,1,2,3], dtype={"chr": str, "start": int, "end": int, "strand": int}, sep="\t", header=None, names=["chr","start","end","strand"])
# read in reference genome sj
reference_sj = pd.read_csv(f"{data_dir}/annotated.sj",sep='\t', header=None, names=["chr", "start", "end", "strand"], dtype={"chr": str, "start": int, "end": int, "strand": int})
for i, filter_type in enumerate(["unfiltered","illumina","cutoff"]):
	filter_type_ = filter_type.capitalize() if filter_type == "illumina" else filter_type 
	file_sj = f'{stats_dir}/scenario_2_real_world/tool_sjs_50_{aligner}_vs_GT_{aligner}_{filter_type}.pkl'
	with open(file_sj, 'rb') as f:
		tool_sjs = pd.read_pickle(f)
	for ann, sj in tool_sjs[tool].items():
		_, sample_id, run_id = ann
		run_id = int(run_id)
		if run_id == run and sample_id == sample:
			# plot score distribution of sj that are in groundtruth - that are not in 50m data vs not in reference genome
			sj_not_in_50 = sj.merge(sj_50, how='left', on=['chr','start','end','strand'], indicator=True)
			sj_not_in_50 = sj_not_in_50[sj_not_in_50['_merge'] == 'left_only'].drop(columns=['_merge'])
			sj_not_in_reference = sj.merge(reference_sj, how='left', on=['chr','start','end','strand'], indicator=True)
			sj_not_in_reference = sj_not_in_reference[sj_not_in_reference['_merge'] == 'left_only'].drop(columns=['_merge'])
			bins = np.linspace(0, 1, 50)
			ax = axes[0, i]
			ax.hist(sj_not_in_reference['pred'], bins=bins, label='not in reference genome', alpha=0.5, color=tab10[2])
			ax.hist(sj_not_in_50['pred'], bins=bins, label='not in 50M input data', alpha=0.5, color=tab10[0])
			ax.set_xlabel('Prediction score', fontsize=16)
			ax.set_ylabel('Number of junctions', fontsize=16)
			ax.set_yscale('log')
			ax.tick_params(labelsize=16)
			ax.set_title(f'{aligner.upper()} {filter_type_}', fontsize=18)
			if i == 0:
				ax.text(-0.25, 1.2, 'a.', transform=ax.transAxes, fontsize=22, va='top', ha='right', weight='bold')
				ax.text(-0.25, 0.5, tool.upper(), transform=ax.transAxes, fontsize=20, va='center', ha='right', rotation=90)
axes[0,3].axis('off')
for i, (aligner, filter_type) in enumerate([("star", "unfiltered"), ("star", "illumina"), ("star", "cutoff"), ("hisat", "unfiltered")]):
	filter_type_ = filter_type.capitalize() if filter_type == "illumina" else filter_type 
	file_sj = f'{stats_dir}/scenario_2_hypothetical/tool_sjs_50_{aligner}_vs_GT_{aligner}_{filter_type}.pkl'
	with open(file_sj, 'rb') as f:
		tool_sjs = pd.read_pickle(f)
	for i_y, tool in enumerate(['deepsplice','spliceai']):
		i_y += 1
		for ann, sj in tool_sjs[tool].items():
			_, sample_id, run_id = ann
			if sample_id == sample:
				ax = axes[i_y, i]
				# plot score distribution of sj that are in groundtruth - that are not in 50m data vs not in reference genome
				sj_not_in_50 = sj.merge(sj_50, how='left', on=['chr','start','end','strand'], indicator=True)
				sj_not_in_50 = sj_not_in_50[sj_not_in_50['_merge'] == 'left_only'].drop(columns=['_merge'])
				sj_not_in_reference = sj.merge(reference_sj, how='left', on=['chr','start','end','strand'], indicator=True)
				sj_not_in_reference = sj_not_in_reference[sj_not_in_reference['_merge'] == 'left_only'].drop(columns=['_merge'])
				bins = np.linspace(0, 1 if tool=='spliceai' else 1.7, 50)
				ax.hist(sj_not_in_reference['pred'], bins=bins, label='not in reference genome', alpha=0.5, color=tab10[2])
				ax.hist(sj_not_in_50['pred'], bins=bins, label='not in 50M input data', alpha=0.5, color=tab10[0])
				ax.set_xlabel('Prediction score', fontsize=16)
				ax.set_ylabel('Number of junctions', fontsize=16)
				ax.set_yscale('log')
				ax.tick_params(labelsize=16)
				ax.set_title(f'{aligner.upper()} {filter_type_}', fontsize=18)
				if (filter_type == 'unfiltered') & (tool == 'spliceai') & (aligner == 'hisat'):
					ax.legend(fontsize=16)
				if i == 0:
					tool_ = 'DeepSplice' if tool == 'deepsplice' else 'SpliceAI'
					if i_y == 1: ax.text(-0.25, 1.2, 'b.', transform=ax.transAxes, fontsize=22, va='top', ha='right', weight='bold')
					ax.text(-0.25, 0.5, tool_, transform=ax.transAxes, fontsize=20, va='center', ha='right', rotation=90)
plt.tight_layout()
plt.savefig(f"{plt_dir}/supplementary_figure_5.tif", dpi=600, bbox_inches='tight')
plt.close()

In [None]:
# Supplementary Figure 6: Difference max and avg
import pandas as pd
import matplotlib.pyplot as plt
for tool in ['spliceai']:
	maxi = pd.read_csv(f'{pred_dir}/annotated/{tool}_pred_max.csv',low_memory=False)
	avgi = pd.read_csv(f'{pred_dir}/annotated/{tool}_pred_avg.csv',low_memory=False)
	maxi = maxi.merge(avgi, on=['chr','start','end','strand'],suffixes=('_max','_avg'))
	maxi['pred_diff'] = maxi['pred_max'] - maxi['pred_avg']
	print(f'for reference genome {len(maxi[maxi.pred_diff > 0.01])}/{len(maxi)} predictions differ more than 0.01 between max and avg for {tool}')
for tool in ['spliceai','jcc']:
	maxi = pd.read_csv(f'{pred_dir}/500M/J26675-L1_S1/star/{tool}_pred_max.csv',low_memory=False)
	avgi = pd.read_csv(f'{pred_dir}/500M/J26675-L1_S1/star/{tool}_pred_avg.csv',low_memory=False)
	maxi = maxi.merge(avgi, on=['chr','start','end','strand'],suffixes=('_max','_avg'))
	maxi['pred_diff'] = maxi['pred_max'] - maxi['pred_avg']
	print(f'for 500M STAR J26675-L1_S1 {len(maxi[maxi.pred_diff > 0.01])}/{len(maxi)} predictions differ more than 0.01 between max and avg for {tool}')
# for SpliceAI whether to take the max or the avg is not highly important, the difference is not that big, only for 1 case it is more than 0.01
# for JCC the difference is more pronounced
# plot distribution of the differences
plt.figure(figsize=(7, 5))
plt.hist(maxi['pred_diff'], bins=100, alpha=0.5)
plt.xlabel('Difference between maximum and mean combined JCC score', fontsize=16)
plt.ylabel('Number of junctions', fontsize=16)
plt.yscale('log')
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.savefig(f"{plt_dir}/supplementary_figure_6.tif", dpi=600, bbox_inches='tight')
plt.close()

In [74]:
# Supplementary Figure 7: Distribution of the prediction scores of the three different tools DeepSplice, SpliceAI, and JCC, after running on the 500M reads RNA-seq data aligned with STAR. This is shown for sample J26675-L1_S1 as an example, but is similar for the other three samples.
import pandas as pd
import matplotlib.pyplot as plt
ds_pred = pd.read_csv(f'{pred_dir}/500M/J26675-L1_S1/star/deepsplice_pred.csv',low_memory=False)
sa_pred = pd.read_csv(f'{pred_dir}/500M/J26675-L1_S1/star/spliceai_pred_avg.csv',low_memory=False)
jcc_pred = pd.read_csv(f'{pred_dir}/500M/J26675-L1_S1/star/jcc_pred_avg.csv',low_memory=False)
plt.figure(figsize=(6, 5))
plt.hist(ds_pred['pred'], bins=50, label='DeepSplice', alpha=0.5)
plt.hist(sa_pred['pred'], bins=50, label='SpliceAI', alpha=0.5)
plt.hist(jcc_pred['pred'], bins=50, label='JCC', alpha=0.5)
plt.xlabel('Prediction score', fontsize=16)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.ylabel('Number of junctions', fontsize=16)
plt.legend(fontsize=14)
plt.savefig(f"{plt_dir}/supplementary_figure_7.tif", dpi=600, bbox_inches='tight')
plt.close()