In [1]:
import matplotlib.pyplot as plt
from config import *
import pandas
from scipy.stats import spearmanr, pearsonr
import os
import numpy as np
from matplotlib.ticker import FormatStrFormatter
from matplotlib.colors import LinearSegmentedColormap
import matplotlib as mpl
from scipy.stats import gaussian_kde
import random
from sklearn.metrics import average_precision_score, precision_recall_curve
import json
from sklearn.metrics import auc, roc_auc_score

In [2]:
# constant
DPI=300
species_dict = {
    "hg19": "Human",
    "panTro5": "Chimp",
    "rheMac10": "Rhesus",
    "calJac4": "Marmoset",
    "bosTau9": "Cow",
    "susScr11": "Pig",
    "mm10": "Mouse",
    "rn6": "Rat"
}

method_dict = {
    # color, shape, name
    "spliceai": ["#B89330", "dashed", "o", "SpliceAI", "spliceai.csv", ],
    "pangolin": ["#F6C343", "dashdot", ">", "Pangolin", "pangolin.csv"],
    "hpangolin": ["#F6C343", "dashdot", ">", "Pangolin (Homo sapiens)", "hpangolin.csv"],
    "refsplice": ["#375392", "solid", "D", "DeltaSplice", "our.txt"],
    "hrefsplice": ["#95AAD9", "solid", "D", "DeltaSplice (Homo sapiens)", "hour.txt"],
    "arefsplice": ["orange", "solid", "D", "DeltaSplice (Single)", "hour.txt"],
    "mmsplice": ["#89B0B2", "dotted", "*", "MMSplice", "mmsplice.tsv"]
}

OrderedMethod = ["refsplice", "hrefsplice",
                 "arefsplice", "spliceai", "pangolin",  "mmsplice",]
NumPreserv=2
figsize=[10, 8]
merge_table_list=set()

In [3]:
def load_pred_from_format_bed(path):
    pred = pandas.read_csv(path, sep="\t")
    Yt = pred["Yt"]
    Yp = pred["Yp"]
    species = pred["#species"]
    ret = []
    for a, b, c in zip(species, Yt, Yp):
        if not np.isnan(b):
            ret.append([a, b, c])
    return ret


def scatter(gt, pred, gaussian_kernel=None, textpos=None, xlim=None, ylim=None, xlabel=None, ylabel=None, clim=None):
    plt.figure(figsize=figsize)
    if len(gt)>1000:
        sample_idx = np.array(random.sample(range(len(gt)), 1000))
    else:
        sample_idx=np.arange(len(gt))
    gt = np.array(gt)
    pred = np.array(pred)
    if gaussian_kernel is None:
        gaussian_kernel = gaussian_kde(
            np.stack([gt[sample_idx], pred[sample_idx]]))
        
    z = gaussian_kernel(
        np.stack([gt, pred]))
    if clim is None:
        clim=(min(z), max(z))
    density = plt.scatter(gt, pred,
                          c=z, edgecolor=['none'], s=10)
    plt.clim(*clim)
    if textpos is not None:
        pr = round(pearsonr(gt, pred)[0], 2)
        spr = round(spearmanr(gt, pred)[0], 2)
        plt.text(*textpos, "PCor={}, SPCor={}".format(
        pr, spr), ha="center", va="center")
    if xlim is not None:
        plt.xlim(*xlim)
    if ylim is not None:
        plt.ylim(*ylim)
    if xlabel is not None:
        plt.xlabel(xlabel)
    if ylabel is not None:
        plt.ylabel(ylabel)
    plt.colorbar(density)
    return density, gaussian_kernel, clim


def load_mut_deltasplice(path):
    with open(path, "r") as f:
        content = f.readlines()
    content = [list(map(float, x.strip().split("\t"))) for x in content]
    start_usage,  pred_delta, gt_delta, = [x[0] for x in content], [
        x[1] for x in content], [x[2] for x in content]
    return start_usage, gt_delta, pred_delta


def load_mmsplice(mpath):
    with open(mpath, "r") as f:  # mmsplice results
        MPred = f.readlines()
    retMPred = [float(x.split("\t")[1].replace("\n", ""))
             for x in MPred]  # delta logits if refvalue is Nan else delta prob
    
    return  retMPred


def load_pred(path):
    with open(path, "r") as f:  # spliceai/pangolin results
        Pred = f.readlines()
    Pred = [x.replace("\n", "").split("\t") for x in Pred if x[0] != "#"]
    retPred, retGT = [float(x[3]) for x in Pred], [float(
        x[1]) for x in Pred]  # same order as the input
    retPpsi, retGTpsi = [float(x[2]) for x in Pred], [float(
        x[0]) for x in Pred]  # same order as the input
    return retPpsi, retGT, retPred


def enrich_plot(pred, gt, plot, color, name, NUMPOS, NUMNEG):
    x, y = [0], [0]
    for i in range(len(pred)):
        if gt[i] == 1:
            y.append(y[-1]+1)
        else:
            y.append(y[-1])
        x.append(x[-1]+1)
    x = x[1:]
    y = [b/NUMPOS/(max(a-b, 1)/NUMNEG) for a, b in zip(x, y[1:])]
    plot.plot(x, y, label=name, c=color)
    

def load_json(path):
    with open(path, "r") as f:
        content = json.load(f)
    ref = [(sum(_["label"][0][1])+sum(_["label"][1][1]))/2. for _ in content]
    gt = [(sum(_["mutlabel"][0][1])+sum(_["mutlabel"][1][1]))/2. for _ in content]
    refname = [_["name"] for _ in content]
   
    return ref, gt, refname


def load_multiple_pred(path, threshold):
    with open(path, "r") as f:
        content = f.readlines()
    content = [list(map(float, x.strip().split("\t"))) for x in content]
    end_pred, end_gt = [x[2] for x in content], [x[3] for x in content]
    hg19_pred, hg19_gt = [x[0] for x in content], [x[1] for x in content]
    if threshold > 0:
        idx = [i for i in range(len(hg19_gt)) if abs(
            hg19_gt[i]-end_gt[i]) >= threshold]
        end_pred = [end_pred[_]-hg19_pred[_] for _ in idx]
        end_gt = [end_gt[_] - hg19_gt[_] for _ in idx]
    elif threshold < -2:
        return end_pred, end_gt, hg19_pred, hg19_gt
    return end_pred, end_gt



In [4]:
# plot figure 2
import shutil
def write_csv_bar_fig2(cfunc, funcname, read_prefix, saveprefidx, file_dict):
    TD = []
    TD19 = []
    TD28 = []
    TN = []
    maxvalue = 0

    for idx, m in zip(range(len(OrderedMethod)), OrderedMethod):
        if m in file_dict:
            acceptor=file_dict[m][-1]["acceptor"]
            donor = file_dict[m][-1]["donor"]
        else:
            continue

        _, _, _, name, _ = file_dict[m]
        acc = load_pred_from_format_bed(os.path.join(read_prefix, acceptor))
        don = load_pred_from_format_bed(os.path.join(read_prefix, donor))
        shutil.copy(os.path.join(read_prefix, acceptor), os.path.join(save_prefix, name+"_acceptor.bed"))
        shutil.copy(os.path.join(read_prefix, donor), os.path.join(save_prefix, name+"_donor.bed"))
        Data = acc+don
        D = []
        D19 = []
        D28 = []
        
        TN.append(name)
        for species in species_dict.keys():
            data = [v for v in Data if v[0] == species]
            data19 = [x for x in data if 0.9 >= x[1] > 0.1]
            data28 = [x for x in data if 0.8 >= x[1] > 0.2]
            maxvalue = max(maxvalue, cfunc([x[1] for x in data], [x[2]
                                                                  for x in data])[0])
            r = round(cfunc([x[1] for x in data], [x[2]
                      for x in data])[0], NumPreserv)
            r19 = round(cfunc([x[1] for x in data19], [x[2]
                        for x in data19])[0], NumPreserv)
            r28 = round(cfunc([x[1] for x in data28], [x[2]
                        for x in data28])[0], NumPreserv)
            D.append(r)
            D19.append(r19)
            D28.append(r28)

        Data19 = [x for x in Data if 0.9 >= x[1] > 0.1]
        Data28 = [x for x in Data if 0.8 >= x[1] > 0.2]
        r = round(cfunc([x[1] for x in Data], [x[2]
                  for x in Data])[0], NumPreserv)
        r19 = round(cfunc([x[1] for x in Data19], [x[2]
                    for x in Data19])[0], NumPreserv)
        r28 = round(cfunc([x[1] for x in Data28], [x[2]
                    for x in Data28])[0], NumPreserv)
        D.append(r)
        D19.append(r19)
        D28.append(r28)
        TD.append(D)
        TD19.append(D19)
        TD28.append(D28)

    # write csv file
    with open(os.path.join(saveprefidx, "{}.csv".format(
            funcname)), "w") as f:
        f.writelines("# predicted results on test data; single sequence mode is used in DeltaSplice; the re-trained (on human data) version of pangolin is used; the original version of spliceai is used; for all methods the input length is 35000\n")
        f.writelines("Method,Species,Cor of [0-1],Cor of (0.1-0.9]\n")
        assert len(TD)>0 and len(TD19)>0 and len(TN)>0
        for  td,  td19, tn in zip( TD, TD19, TN):
            for i, m in enumerate(species_dict.keys()):
                f.writelines("{},{},{},{}\n".format(
                    tn, species_dict[m], td[i], td19[i]))
    merge_table_list.add(os.path.join(saveprefidx, "{}.csv".format(
            funcname)))
   


def plot_scatter(read_prefix, saveprefidx, file_dict):
    for m in OrderedMethod:
        if m in file_dict:
            acceptor = file_dict[m][-1]["acceptor"]
            donor = file_dict[m][-1]["donor"]
        else:
            continue

        _, _, _, name, _ = file_dict[m]
        acc = load_pred_from_format_bed(os.path.join(read_prefix, acceptor))
        don = load_pred_from_format_bed(os.path.join(read_prefix, donor))
        
        Data = acc+don
        # scatter total data
        gaussian_kernel, clim=None, None
        for jdx, species in enumerate(species_dict.keys()):
            data = [v for v in Data if v[0] == species]
            # scatter species data
            _,_, clim= scatter([_[1] for _ in data], [_[2]
                                           for _ in data], xlim=[0, 1], ylim=[0, 1], xlabel="Experimental SSU", ylabel=name, gaussian_kernel=gaussian_kernel, clim=clim)
            plt.savefig(os.path.join(saveprefidx, "sep_usage_{}_{}.pdf".format(species, name)),
                        dpi=DPI, bbox_inches='tight')
            plt.close()
        
        scatter([_[1] for _ in Data], [_[2]
                                       for _ in Data], xlim=[0, 1], ylim=[0, 1], xlabel="Experimental SSU", ylabel="Total", gaussian_kernel=gaussian_kernel, clim=clim)
        plt.savefig(os.path.join(saveprefidx, "total_usage_{}.pdf".format(name)),
                    dpi=DPI, bbox_inches='tight')
        plt.close()
        
        
save_prefix="figures/figure2"
read_prefix = "experiments/1_evaluate_on_test_and_val/test_results"
if not os.path.exists(save_prefix):
    os.mkdir(save_prefix)
method_dict_eval = {
    # color, shape, name
    "spliceai": ["#B89330", "dashed", "o", "SpliceAI", {"acceptor": "test_acceptor", "donor": "test_donor"}, ],
    "pangolin": ["#F6C343", "dashdot", ">", "Pangolin", {"acceptor": "data_Npy_test_data.json+evalpangolin_human_models_model.ckpt-rep1_acceptor", "donor": "data_Npy_test_data.json+evalpangolin_human_models_model.ckpt-rep1_donor"}],
    "refsplice": ["#375392", "solid", "D", "DeltaSplice", {"acceptor": "data_Npy_test_data.json+evalRefSplice_models_model.ckpt-2_acceptor", "donor": "data_Npy_test_data.json+evalRefSplice_models_model.ckpt-2_donor"}],
    "hrefsplice": ["#95AAD9", "solid", "D", "DeltaSplice (Homo sapiens)", {"acceptor": "data_Npy_test_data.json+evalRefSplice_models_human_model.ckpt-2_acceptor", "donor": "data_Npy_test_data.json+evalRefSplice_models_human_model.ckpt-2_donor"}],
}
write_csv_bar_fig2(pearsonr, "Pearson correlation", read_prefix, save_prefix, method_dict_eval)
write_csv_bar_fig2(spearmanr, "Spearman correlation",
                   read_prefix,save_prefix, method_dict_eval)
plot_scatter(read_prefix, save_prefix, method_dict_eval)

In [5]:
# plot figure 3

def plot_multiple(cfunc, funcname,read_prefix, save_prefix, prefixdict, multiplethreshold):
    Multiplespecies = []
    idx = 0
    TC, TN, MC_delta, MC_raw = [], [], [], []
    file_delta = open(os.path.join(save_prefix, "{}_bar_res_of_deltassu.csv".format(funcname)), "w")
    file_raw = open(os.path.join(
        save_prefix, "{}_bar_res_of_ssu.csv".format(funcname)), "w")
    merge_table_list.add(os.path.join(
        save_prefix, "{}_bar_res_of_ssu.csv".format(funcname)))
    merge_table_list.add(os.path.join(
        save_prefix, "{}_bar_res_of_deltassu.csv".format(funcname)))
    file_delta.writelines("# delta ssu between human and other species predicted by different methods; double sequence mode of deltasplice is used (human as reference); the retrained (on human) version of pangolin is used; the original version of spliceai is used.\n")
    file_delta.writelines("Method,Species,Threshold, Cor, AUROC\n")
    file_raw.writelines("# ssu of other species predicted by different methods; double sequence mode of deltasplice is used (human as reference); the retrained (on human) version of pangolin is used; the original version of spliceai is used.\n")
    file_raw.writelines("Method,Species,Threshold, Cor, AUROC\n")
    gaussian_kernel=None
    clim=None
    for m in OrderedMethod:
        if m not in prefixdict:
            continue
        color, _, _, name, _ = method_dict[m]
        mc_delta, mc_raw = [], []
        TC.append(color)
        TN.append(name)
        mprefix = prefixdict[m]
        
        for species in species_dict.keys():
            multiplepath = os.path.join(read_prefix, mprefix+species+"_pred")
            if os.path.exists(multiplepath):
                mpred, mgt = load_multiple_pred(multiplepath, 1e-5)
                _, _, clim=scatter(mgt, mpred, textpos=[-0.4, 0.7],xlim= [-1, 1], ylim=[-1, 1],xlabel= "Experimental $\Delta$SSU of {}".format(
                    species_dict[species].lower()),ylabel= name, gaussian_kernel=gaussian_kernel, clim=clim)
                mc_delta.append(
                    round(cfunc(*(load_multiple_pred(multiplepath, 1e-5)[:2]))[0], NumPreserv))
                mc_raw.append(
                    round(cfunc(*(load_multiple_pred(multiplepath, -5)[:2]))[0], NumPreserv))
                if idx == 0:
                    Multiplespecies.append(species_dict[species])
                plt.savefig(os.path.join(save_prefix, "{}_{}_{}.pdf".format(m, species, funcname) ),dpi=DPI, bbox_inches='tight')
                plt.close()
        MC_raw.append(mc_raw)
        MC_delta.append(mc_delta)
        idx += 1
    barwidth=0.2
    plt.figure(figsize=figsize)
    f = plt.subplot(1, 1, 1)
    xs = np.array(range(len(MC_delta[0])))
    total_length = barwidth*len(TC)
    Xs = []
    for i in range(len(TC)):
        Xs.append(xs-total_length//2+i*barwidth-barwidth)
    for x,  tc, tn, mc_delta, mc_raw in zip(Xs,  TC, TN, MC_delta, MC_raw):
        plt.bar(x, mc_delta, color=tc, label=tn, width=barwidth, alpha=0.8)
        for s, c_delta, c_raw in zip(Multiplespecies, mc_delta, mc_raw):
            file_delta.writelines("{},{},0,{}, -\n".format(
                tn, s, c_delta))
            file_raw.writelines("{},{},0,{}, -\n".format(
                tn, s, c_raw))
        # print(mc)

    plt.ylabel("{} correlation".format(funcname))
    f.yaxis.set_major_formatter(FormatStrFormatter('%.3f'))
    f.set_xticks(range(len(mc_delta)))
    f.set_xticklabels(Multiplespecies,  ha="center")
    f.set_ylim(0., max(mc_delta)+0.1)
    f.legend(loc='upper right', frameon=False)
    plt.savefig(os.path.join(save_prefix, "{}_multiple.pdf".format(funcname)),
                bbox_inches='tight', dpi=DPI)
    plt.close()

    fig, ax = plt.subplots(2, 3, figsize=[figsize[0], 7])
    sidx = 0
    for species in species_dict.keys():
        f1 = plt.subplot(2, 3, sidx+1)
        idx = 0
        if not os.path.exists(os.path.join("data/Hg19VsOthers/", "{}.json".format(species))):
            print("{} not exist".format(species))
            continue
        for m in OrderedMethod:
            if m not in prefixdict:
                continue
            mprefix = prefixdict[m]
            color, _, _, name, _ = method_dict[m]
            r2 = []
            for midx, t in enumerate(multiplethreshold):
                if os.path.exists(os.path.join(
                        read_prefix, mprefix+species+"_pred")):
                    pred, gt = load_multiple_pred(os.path.join(
                        read_prefix, mprefix+species+"_pred"), t)
                    gt = [int(_ >= 0) for _ in gt]
                    r = roc_auc_score(gt, pred)
                    r2.append(r)
            if len(r2) > 0:
                f1.bar([midx-total_length//2+idx*barwidth-barwidth for midx in range(
                    len(multiplethreshold))], r2, color=color, label=name, width=barwidth, alpha=0.8)
                idx += 1
                for v, r in zip(multiplethreshold, r2):
                    file_raw.writelines("{},{},{},-, {}\n".format(
                        name, species, v, r))
                    file_delta.writelines("{},{},{},-, {}\n".format(
                        name, species, v, r))
        sidx += 1
        # plt.legend()
        f1.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        f1.set_xticks(range(len(multiplethreshold)))
        f1.set_xticklabels(multiplethreshold, ha="center")

        f1.set_xlabel("Thresholds of $\Delta$SSU")
        if sidx % 3 == 1:
            f1.set_ylabel("AUROC")
        f1.set_ylim(0.4, 1.)
        f1.text(1., 0.8, species_dict[species], ha="center", va="center")

    plt.savefig(os.path.join(save_prefix, "{}_usage_multiple.pdf".format(funcname)),
                bbox_inches='tight', dpi=DPI)
    plt.close()
    file_delta.close()
    file_raw.close()


def write_table(read_prefix,save_prefix, prefixdict):
    for species in species_dict.keys():
        if not os.path.exists(os.path.join("data/Hg19VsOthers/", "{}.json".format(species))):
            continue
        with open(os.path.join("data/Hg19VsOthers/", "{}.json".format(species)), "r") as f:
            evaldata = json.load(f)
        writelines = ["" for _ in evaldata]
        header = "{}_chrom,{}_coordinate,{}_SSU,hg19_chrom,hg19_coordinate,hg19_SSU".format(
            species, species, species)

        for idx, line in enumerate(evaldata):
            writelines[idx] = writelines[idx]+"{},{},{},{},{},{}".format(line["{}_chrom".format(species)], line["{}_idx".format(species)],
                                                                         sum(line["{}_label".format(
                                                                             species)]), line["hg19_chrom"],
                                                                         line["hg19_idx"], sum(line["hg19_label"]))
        for m in OrderedMethod:
            if m not in prefixdict:
                print("{} not exist".format(m))
                continue
            mprefix = prefixdict[m]
            multiplepath = os.path.join(read_prefix, mprefix+species+"_pred")
            if os.path.exists(multiplepath):
                header = header + \
                    ",{}_{}, hg19_{}".format(
                        species, method_dict[m][3], method_dict[m][3])
                end_pred, end_gt, hg19_pred, hg19_gt = load_multiple_pred(
                    multiplepath, -3)
                
                assert len(end_pred) == len(evaldata)
                for idx, ep, hp in zip(range(len(end_pred)), end_pred, hg19_pred):
                    writelines[idx] = writelines[idx]+",{},{}".format(ep, hp)
        with open(os.path.join(save_prefix, "{}vsHg19.csv".format(species)), "w") as f:
            f.writelines(header+"\n")
            for line in writelines:
                f.writelines(line+"\n")

prefixdict = {
        "spliceai": "spliceai_",
        "pangolin": "pangolin_human_",
        "refsplice": "",
        "arefsplice": "single_"
    }
multiplethreshold = [0.1,  0.2, 0.3, 0.5]
save_prefix = "figures/figure3/"
read_prefix = "experiments/0_eval_on_multiple_species/test_results"
if not os.path.exists( save_prefix):
    os.mkdir(save_prefix)
plot_multiple(pearsonr, "Pearson", read_prefix,save_prefix,
              prefixdict, multiplethreshold)
plot_multiple(spearmanr, "Spearman", read_prefix,save_prefix,
              prefixdict, multiplethreshold)
write_table(read_prefix,"figures/figure3/", prefixdict)

hg19 not exist
calJac4 not exist
hg19 not exist
calJac4 not exist
hrefsplice not exist
mmsplice not exist
hrefsplice not exist
mmsplice not exist
hrefsplice not exist
mmsplice not exist
hrefsplice not exist
mmsplice not exist
hrefsplice not exist
mmsplice not exist
hrefsplice not exist
mmsplice not exist


In [6]:
# plot figure 4
def plot_vexseq(save_prefix, read_prefix, filename_dict):
    ref, gt, allnames=load_json("data/vexseq/data.json")
    header = "name,experimental psi before mut, experimental deltapsi"
    writelines = ["{},{},{}".format(_, a, b) for _, a,b in zip(allnames, ref, gt)]
    gaussian_kernel=None
    clim=None
    for idx, method in enumerate(OrderedMethod):
        if method in filename_dict:
            color, _, _, name, File = filename_dict[method]
        else:
            continue

        if method == "refsplice":
            start_usage, gt_delta, pred_delta = load_mut_deltasplice(
                os.path.join(read_prefix, File))
        elif method == "mmsplice":
            pred_delta = load_mmsplice(os.path.join(
                read_prefix, File))
            start_usage=None
        else:
            start_usage, gt_delta, pred_delta = load_pred(
                os.path.join(read_prefix, File))
        if start_usage is not None:
            header = header+",ssu before mut predicted by {},deltassu predicted by {}".format(name, name)
            for i, p, r in zip(range(len(pred_delta)),pred_delta, start_usage):
                writelines[i] = writelines[i]+",{},{}".format(r, p)
        else:
            header = header+",deltassu predicted by {}".format(name, name)
            for i, p in enumerate(pred_delta):
                writelines[i] = writelines[i]+",{}".format(p)
        
        _, _, clim= scatter(gt_delta, pred_delta, gaussian_kernel=gaussian_kernel, textpos=[
                0, 0.6], xlim=[-0.9, 0.75], ylim=[-0.9, 0.75], xlabel="Experimental $\Delta$SSU", ylabel="Predicted $\Delta$SSU", clim=clim)
        
        plt.savefig(os.path.join(save_prefix, "vexseq_{}.pdf".format(
            name)),  dpi=DPI, bbox_inches='tight')
        plt.close()
    
    with open(os.path.join(save_prefix, "VexSeq.csv"), "w") as f:
        f.writelines("# predicted delta ssu; double sequence mode is used in DeltaSplice with experimental data as reference; the original versions of panglin spliceai mmsplice are used; for deltasplice pangolin spliceai the input length is 35000; for mmsplice delta ssu is computed with experimental start ssu.\n")
        f.writelines(header+"\n")
        for line in writelines:
            f.writelines(line+"\n")
    merge_table_list.add(os.path.join(save_prefix, "VexSeq.csv"))
    


def plot_fas(save_prefix, read_prefix, filename_dict):   
    ref, gt, allnames=load_json("data/FAS/data.json")
    header = "name,experimental psi before mut, experimental deltapsi"
    writelines = ["{},{},{}".format(_, a, b) for _, a,b in zip(allnames, ref, gt)]

    gaussian_kernel = None
    clim=None
    for idx, method in enumerate(OrderedMethod):
        if method in filename_dict:
            color, _, _, name, File = filename_dict[method]
        else:
            continue
        if method == "refsplice":
            start_usage, gt_delta, pred_delta = load_mut_deltasplice(
                os.path.join(read_prefix, File))
        elif method == "mmsplice":
            pred_delta = load_mmsplice(os.path.join(
                read_prefix, File))
            start_usage=None
        else:
            start_usage, gt_delta, pred_delta = load_pred(
                os.path.join(read_prefix, File))
        assert len(gt_delta)==len(pred_delta)
        if start_usage is not None:
            header = header+",ssu before mut predicted by {},deltassu predicted by {}".format(name, name)
            for i, p, r in zip(range(len(pred_delta)),pred_delta, start_usage):
                writelines[i] = writelines[i]+",{},{}".format(r, p)
        else:
            header = header+",deltassu predicted by {}".format(name, name)
            for i, p in enumerate(pred_delta):
                writelines[i] = writelines[i]+",{}".format(p)
        assert len(gt_delta)==len(pred_delta)
        
        _,_, clim= scatter(gt_delta, pred_delta, gaussian_kernel=gaussian_kernel, textpos= [
                0, -0.6], xlabel="Experimental $\Delta$SSU", ylabel="Predicted $\Delta$SSU", xlim=(-1, 1), ylim=(-1, 1), clim=clim)
        plt.savefig(os.path.join(save_prefix, "fas{}.pdf".format(
            name)),  dpi=DPI, bbox_inches='tight')
        plt.close()
        
    with open(os.path.join(save_prefix, "Fas.csv"), "w") as f:
        f.writelines("# predicted delta ssu; double sequence mode is used in DeltaSplice with experimental data as reference; the original versions of panglin & spliceai & mmsplice are used; for deltasplice & pangolin & spliceai the input length is 35000; for mmsplice delta ssu is computed with experimental start ssu.\n")
        f.writelines(header+"\n")
        for line in writelines:
            f.writelines(line+"\n")
    merge_table_list.add(os.path.join(save_prefix, "Fas.csv"))
            


save_prefix = "figures/figure4"
read_prefix = "experiments/2_eval_mut/test_results"
method_dict_vexseq = {
    # color, shape, name
    "spliceai": ["#B89330", "dashed", "o", "SpliceAI", "vexseq_spliceai.txt", ],
    "pangolin": ["#F6C343", "dashdot", ">", "Pangolin", "vexseq_pangolin.txt"],
    "refsplice": ["#375392", "solid", "D", "DeltaSplice", "Mutation_data_vexseq_data.json+RefSplice_models_model.ckpt-2.txt"],
    "mmsplice": ["#89B0B2", "dotted", "*", "MMSplice", "vexseq_mmsplice.txt"]
}
method_dict_fas = {
    # color, shape, name
    "spliceai": ["#B89330", "dashed", "o", "SpliceAI", "fas_spliceai.txt", ],
    "pangolin": ["#F6C343", "dashdot", ">", "Pangolin", "fas_pangolin.txt"],
    "refsplice": ["#375392", "solid", "D", "DeltaSplice", "Mutation_data_FAS_data.json+RefSplice_models_model.ckpt-2.txt"],
    "mmsplice": ["#89B0B2", "dotted", "*", "MMSplice", "fas_mmsplice.txt"]
}
if not os.path.exists(save_prefix):
    os.mkdir(save_prefix)
plot_vexseq(save_prefix, "experiments/2_eval_mut/test_results", method_dict_vexseq)
plot_fas(save_prefix, "experiments/3_eval_fas/test_results", method_dict_fas)

In [7]:
# plot figure 5
def load_name_file():
    name_path="data/mfass_name.csv"
    name_data=pandas.read_csv(name_path)
    ret={}
    for Id, ref, alt, Chr, pos in zip(name_data["id"], name_data["ref_allele"], name_data["alt_allele"], name_data["chr"], name_data["snp_position"]):
        key="{}_{}_{}_{}".format(Chr, pos, ref.upper(), alt.upper())
        assert key not in ret
        ret[key]=Id
        
    return ret

def plot_mfass(read_prefix, filename_dict, saveprefix):
    if not os.path.exists(saveprefix):
        os.mkdir(saveprefix)
    distancethreshold = 20
    linewidth = 2
    external_ids=load_name_file()

    with open( "data/MFASS.bed", "r") as f:
        content = f.readlines()
    distance = []

    
    for line in content:
        line = line.split("\t")
        name, start, end = line[0], int(line[4]), int(line[5])
        pos = int(name.split("_")[1])
        distance.append(min(abs(pos-start), abs(end-pos)))

    plt.figure(figsize=[20,5])
    f1 = plt.subplot(1, 3, 1)
    f2 = plt.subplot(1, 3, 2)
    f3 = plt.subplot(1, 3, 3)
    ref, gt, allnames=load_json("data/mfass/data.json")
    header = "name,external_ID, experimental psi before mut, experimental deltapsi"
    writelines = ["{},{},{},{}".format(_,external_ids["_".join(_.split("_")[1:])], a, b) for _, a,b in zip(allnames, ref, gt)]

    global_gt=None
    for method in OrderedMethod:
        if method in filename_dict:
            color, linestype, _, name, File = filename_dict[method]
        else:
            continue
        if method == "refsplice":
            start_usage, gt_delta, pred_delta = load_mut_deltasplice (
                os.path.join(read_prefix, File))
        elif method == "mmsplice":
            pred_delta = load_mmsplice(os.path.join(
                read_prefix, File))
            start_usage=None
        else:
            start_usage, gt_delta, pred_delta = load_pred(
                os.path.join(read_prefix, File))
        if start_usage is not None:
            header = header+",ssu before mut predicted by {},deltassu predicted by {}".format(name, name)
            for i, p, r in zip(range(len(pred_delta)),pred_delta, start_usage):
                writelines[i] = writelines[i]+",{},{}".format(r, p)
        else:
            header = header+",deltassu predicted by {}".format(name, name)
            for i, p in enumerate(pred_delta):
                writelines[i] = writelines[i]+",{}".format(p)
        pred_delta = [-_ for _ in pred_delta]
        if global_gt is None:
            gt_delta = [int(_ <= -0.5) for _ in gt_delta]
            global_gt=gt_delta
        else:
            gt_delta=global_gt
        assert len(pred_delta)==len(gt_delta)
        fullscore = str(
            round(average_precision_score(gt_delta, pred_delta), 3))
        precision, recall, thresholds = precision_recall_curve(
            gt_delta, pred_delta)
        f1.plot(precision, recall, label="{} = {}".format(
            name, fullscore), linewidth=linewidth, c=color, linestyle=linestype)
        f1.legend()

        assert len(gt_delta) == len(distance)
        cgt, fgt = [a for a, b in zip(gt_delta, distance) if b <= distancethreshold], [
            a for a, b in zip(gt_delta, distance) if b > distancethreshold]
        cpred, fpred = [a for a, b in zip(
            pred_delta, distance) if b <= distancethreshold], [a for a, b in zip(
                pred_delta, distance) if b > distancethreshold]

        cscore, fscore = str(round(average_precision_score(cgt, cpred), 3)), str(
            round(average_precision_score(fgt, fpred), 3))

        precision, recall, thresholds = precision_recall_curve(
            cgt, cpred)
        f2.plot(precision, recall, label="{} = {}".format(
            name, cscore), linewidth=linewidth, c=color, linestyle=linestype)
        f2.legend()

        precision, recall, thresholds = precision_recall_curve(
            fgt, fpred)
        f3.plot(precision, recall, label="{} = {}".format(
            name, fscore), linewidth=linewidth, c=color, linestyle=linestype)
        f3.legend()

    f1.set_xlabel("Precision")
    f1.set_ylabel("Recall")
    f2.set_xlabel("Precision")
    f2.set_ylabel("Recall")
    f3.set_xlabel("Precision")
    f3.set_ylabel("Recall")
    f1.set_title("All SNPs")
    f2.set_title("≤{} bp from splice sites".format(distancethreshold))
    f3.set_title(">{} bp from splice sites".format(distancethreshold))

    plt.savefig(os.path.join(saveprefix, "MFASS.pdf"), bbox_inches="tight")
    plt.close()
    with open(os.path.join(saveprefix, "MFASS.csv"), "w") as f:
        f.writelines("# predicted delta ssu; double sequence mode is used in DeltaSplice with experimental data as reference; the original versions of panglin & spliceai & mmsplice are used; for deltasplice & pangolin & spliceai the input length is 35000; for mmsplice delta ssu is computed with experimental start ssu.\n")
        f.writelines(header+"\n")
        for line in writelines:
            f.writelines(line+"\n")
    merge_table_list.add(os.path.join(saveprefix, "MFASS.csv"))
            

method_dict_mfass= {
    # color, shape, name
    "spliceai": ["#B89330", "dashed", "o", "SpliceAI", "mfass_spliceai.txt", ],
    "pangolin": ["#F6C343", "dashdot", ">", "Pangolin", "mfass_pangolin.txt"],
    "refsplice": ["#375392", "solid", "D", "DeltaSplice", "Mutation_data_mfass_data.json+RefSplice_models_model.ckpt-2.txt"],
    "mmsplice": ["#89B0B2", "dotted", "*", "MMSplice", "mfass_mmsplice.txt"]
}
plot_mfass("experiments/2_eval_mut/test_results", method_dict_mfass, "figures/figure5/")

  name_data=pandas.read_csv(name_path)


In [9]:
# plot figure 6, note that logits is used for mmsplice in this figure, as no reference values are avaiable

def load_dataidx(posfile, posjson, negfile, negjson):
    totalnames = []
    names = []
    allnames = []
    with open(posjson, "r") as f:
        content = json.load(f)
        totalnames.extend(["_".join(x["name"].split("_")[1:])
                          for x in content])
        names = ["_".join(x["name"].split("_")[1:]).split(
            "||")[1].replace("chr", "") for x in content]
        allnames.extend([x["name"] for x in content])

    with open(negjson, "r") as f:
        content = json.load(f)
        totalnames.extend(["_".join(x["name"].split("_")[1:])
                          for x in content])
        allnames.extend([x["name"] for x in content])
    totalnamedict = {a: b for b, a in enumerate(totalnames)}
    assert len(totalnamedict) == len(
        totalnames), "replicates exist in name list"
    retidx = []
    retref = []
    extronic_intronic = []

    with open(posjson, "r") as f:
        content = json.load(f)
        NUMPOS = len(content)
        for x in content:
            start, end = x["label"][0][0], x["label"][1][0]
            pos = x["mutpos"]
            if (pos-start)*(pos-end) < 0:
                extronic_intronic.append(1)  # extronic
            else:
                extronic_intronic.append(0)
            retidx.append(totalnamedict["_".join(x["name"].split("_")[1:])])
            retref.append(sum(x["label"][0][1])*0.5+sum(x["label"][1][1])*0.5)
    with open(negjson, "r") as f:
        content = json.load(f)
        NUMNEG = len(content)
        for x in content:
            start, end = x["label"][0][0], x["label"][1][0]
            pos = x["mutpos"]
            if (pos-start)*(pos-end) < 0:
                extronic_intronic.append(1)  # extronic
            else:
                extronic_intronic.append(0)
            retidx.append(totalnamedict["_".join(x["name"].split("_")[1:])])
            retref.append(sum(x["label"][0][1])*0.5+sum(x["label"][1][1])*0.5)

    return retidx, NUMPOS, NUMNEG, retref, names, allnames, extronic_intronic



def plot_sqtl(read_prefix, filename_dict, saveprefix,  explaination, filter_nan=False, endname=None):
    if not os.path.exists(saveprefix):
        os.mkdir(saveprefix)
    figsize = [16, 5]
    posfile = os.path.join("data/data/Brain_tissues.v8.sqtl_signifpairs_map_cass_intron.dist.200bp_dataset.txt")
    
    posjson = os.path.join("data/sQTLs/data.json")
    negfile = os.path.join(
        "data/k1g_phase3_within_hg38.cass.exon.ext200_snps2cass_byEST_nonredundant_dataset.txt")
    negjson = os.path.join("data/sQTLs_neg/data.json")
    
    _, NUMPOS, NUMNEG, refvalues, posnames, allnames, extronic = load_dataidx(
        posfile, posjson, negfile, negjson)
    
    GT = [1 for _ in range(NUMPOS)]+[0 for _ in range(NUMNEG)]

    NUMSAMPLES = NUMPOS//2

    plt.figure(figsize=figsize)
    plot = plt.subplot(2, 1, 1)
    NUMSAMPLES = NUMPOS//2
    assert len(allnames)==len(refvalues)
    writelines = ["{},{}".format(_, r) for _,r in zip(allnames, refvalues)]
    header = "name,experimental ssu"
    for m in OrderedMethod:
        if m in filename_dict:
            color, _, _, name, f_pos, f_neg = filename_dict[m]
        else:
            continue
        if m == "refsplice":
            start_usage, _, pred = load_mut_deltasplice(os.path.join(read_prefix, f_pos))
            start_usage_neg, _, pred_neg = load_mut_deltasplice(os.path.join(read_prefix, f_neg))
            start_usage=start_usage+start_usage_neg
        elif m == "mmsplice":
            pred = load_mmsplice(os.path.join(read_prefix, f_pos))
            pred_neg = load_mmsplice(os.path.join(read_prefix, f_neg))
            start_usage=None
        else:
            start_usage, _, pred = load_pred(os.path.join(read_prefix, f_pos))
            start_usage_neg, _, pred_neg = load_pred(os.path.join(read_prefix, f_neg))
            start_usage=start_usage+start_usage_neg
        pred = pred+pred_neg

        assert len(pred) == len(GT)

        #pred = [abs(_) for _ in pred]
        if start_usage is not None:
            header = header+",ssu before mut predicted by {},deltassu predicted by {}".format(name, name)
            for i, p, r in zip(range(len(pred)),pred, start_usage):
                writelines[i] = writelines[i]+",{},{}".format(r, p)
        else:
            header = header+",deltassu predicted by {}".format(name, name)
            for i, p in enumerate(pred):
                writelines[i] = writelines[i]+",{}".format(p)
        sortedidx = sorted(range(len(pred)), key=lambda x: -
                           abs(pred[x]))[:NUMSAMPLES]

        enrich_plot([pred[_] for _ in sortedidx], [GT[_]
                    for _ in sortedidx], plot, color, name, NUMPOS, NUMNEG)
        print("{}:auprc {}".format(
            name, average_precision_score(GT[:len(pred)], pred)))
    plot.legend()
    plot.set_xlabel("Topk")
    plot.set_ylabel("Enrichment of sQTLs")
    
    plt.savefig(os.path.join(saveprefix, "sQTL_enrichment.pdf"),
                bbox_inches="tight")
    plt.close()
    with open(os.path.join(saveprefix, "sQTL_{}.csv".format(endname)), "w") as f:
        f.writelines(explaination)
        f.writelines(header+"\n")
        for line, r in zip(writelines, refvalues):
            if not (filter_nan and np.isnan(r)):
                f.writelines(line+"\n")
            else:
                print(r)
      
    merge_table_list.add(os.path.join(saveprefix, "sQTL_{}.csv".format(endname)))
            

method_dict_sqtl = {
    # color, shape, name
    "spliceai": ["#B89330", "dashed", "o", "SpliceAI", "sqtl_spliceai.txt", "neg_spliceai.txt", ],
    "pangolin": ["#F6C343", "dashdot", ">", "Pangolin", "sqtl_pangolin.txt", "neg_pangolin.txt"],
    "refsplice": ["#375392", "solid", "D", "DeltaSplice", "Mutation_data_sQTLs_data.json+RefSplice_models_model.ckpt-2.txt", "Mutation_data_sQTLs_neg_data.json+RefSplice_models_model.ckpt-2.txt"],
    "mmsplice": ["#89B0B2", "dotted", "*", "MMSplice", "sqtl_mmsplice.txt", "neg_mmsplice.txt"]
}
method_dict_sqtl_single = {
    # color, shape, name
    "spliceai": ["#B89330", "dashed", "o", "SpliceAI", "sqtl_spliceai.txt", "neg_spliceai.txt", ],
    "pangolin": ["#F6C343", "dashdot", ">", "Pangolin", "sqtl_pangolin.txt", "neg_pangolin.txt"],
    "refsplice": ["#375392", "solid", "D", "DeltaSplice", "Single_Mutation_data_sQTLs_data.json+RefSplice_models_model.ckpt-2.txt", "Single_Mutation_data_sQTLs_neg_data.json+RefSplice_models_model.ckpt-2.txt"],
    "mmsplice": ["#89B0B2", "dotted", "*", "MMSplice", "sqtl_mmsplice.txt", "neg_mmsplice.txt"]
}
plot_sqtl("experiments/4_eval_sqtls/test_results",
           method_dict_sqtl, "figures/figure6/", explaination="#predicted delta ssu; double sequence mode of deltasplice is used with experimental values from rnaseq is used; the input length for deltasplice & spliceai & pangolin is 35000; the original versions of spliceai & pangolin & mmsplice are used; for mmsplice delta logits is used.\n", filter_nan=True, endname="with_ref")
plot_sqtl("experiments/4_eval_sqtls/test_results",
           method_dict_sqtl_single, "figures/figure6/", explaination="#predicted delta ssu; single sequence mode of deltasplice is used; the input length for deltasplice & spliceai & pangolin is 35000; the original versions of spliceai & pangolin & mmsplice are used; for mmsplice delta logits is used.\n", filter_nan=False, endname="without_ref")

DeltaSplice:auprc 0.02220733477512832
SpliceAI:auprc 0.020378956325007688
Pangolin:auprc 0.021849410018677816
MMSplice:auprc 0.02730000419220341
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan

In [8]:
# plot figure 7, note that logits is used for mmsplice in this figure, as no reference values are avaiable
def write_autism(read_prefix, filename_dict, saveprefix, explaination, filter_nan=False, endname=None):
    exome_json_path, genome_json_path = "data/autism_exome/data.json", "data/autism_genome/data.json"
    exome_external_path, genome_external_path= "data/autism_exome/external_ids.txt", "data/autism_genome/external_ids.txt"
    if not os.path.exists(saveprefix):
        os.mkdir(saveprefix)
    for jsonfile, idx, dataname, external_path in zip([exome_json_path, genome_json_path], [0, 1], ["exome_", "genome_"], [exome_external_path, genome_external_path]):
        with open(external_path, "r") as f:
            content=f.readlines()
        external_dict={}
        for k in content:            
            key=k.replace("\n", "")
            assert key not in external_dict
            external_dict[key]=key
        
        ref, gt, refname = load_json(jsonfile)
    
        header = "name,external_ID, experimental ssu"
        
        ###
        assert len(ref)==len(refname)
        writelines = [_+","+external_dict["||".join(_.split("||")[:2])]+","+str(r) for _,r in zip(refname, ref)]
        for method in OrderedMethod:
            if method in filename_dict:
                color, _, _, name, f1, f2 = filename_dict[method]
                if idx==0:
                    f=f1
                else:
                    f=f2
            else:
                continue
            if method=="refsplice":
                start_usage, _, pred=load_mut_deltasplice(os.path.join(read_prefix, f))
            elif method=="mmsplice":
                pred = load_mmsplice(os.path.join(read_prefix, f))
                start_usage=None
            else:
                start_usage, _, pred = load_pred(os.path.join(read_prefix, f))
            assert len(pred) == len(refname)
            if start_usage is not None:
                header = header+",ssu before mut predicted by {},deltassu predicted by {}".format(name, name)
                for i, p, r in zip(range(len(pred)),pred, start_usage):
                    writelines[i] = writelines[i]+",{},{}".format(r, p)
            else:
                header = header+",deltassu predicted by {}".format(name, name)
                for i, p in enumerate(pred):
                    writelines[i] = writelines[i]+",{}".format(p)
            assert len(writelines) == len(pred)
           
       
        with open(os.path.join(saveprefix, "{}_{}.csv".format(dataname, endname)), "w") as f:
            f.writelines(explaination)
            f.writelines(header+"\n")
            for line, r in zip(writelines, ref):
                if not (filter_nan and np.isnan(r)):
                    f.writelines(line+"\n")
                else:
                    print(r)
        merge_table_list.add(os.path.join(saveprefix, "{}_{}.csv".format(dataname, endname)))

method_dict_autism = {
    # color, shape, name
    "spliceai": ["#B89330", "dashed", "o", "SpliceAI", "exome_spliceai.txt", "genome_spliceai.txt", ],
    "pangolin": ["#F6C343", "dashdot", ">", "Pangolin", "exome_pangolin.txt", "genome_pangolin.txt"],
    "refsplice": ["#375392", "solid", "D", "DeltaSplice", "Mutation_data_autism_exome_data.json+RefSplice_models_model.ckpt-2.txt", "Mutation_data_autism_genome_data.json+RefSplice_models_model.ckpt-2.txt"],
    "mmsplice": ["#89B0B2", "dotted", "*", "MMSplice", "exome_mmsplice.txt", "genome_mmsplice.txt"]
}
method_dict_autism_withoutref = {
    # color, shape, name
    "spliceai": ["#B89330", "dashed", "o", "SpliceAI", "exome_spliceai.txt", "genome_spliceai.txt", ],
    "pangolin": ["#F6C343", "dashdot", ">", "Pangolin", "exome_pangolin.txt", "genome_pangolin.txt"],
    "refsplice": ["#375392", "solid", "D", "DeltaSplice", "Single_Mutation_data_autism_exome_data.json+RefSplice_models_model.ckpt-2.txt", "Single_Mutation_data_autism_genome_data.json+RefSplice_models_model.ckpt-2.txt"],
    "mmsplice": ["#89B0B2", "dotted", "*", "MMSplice", "exome_mmsplice.txt", "genome_mmsplice.txt"]
}

write_autism("experiments/5_eval_autism/test_results",
          method_dict_autism, "figures/figure6/", explaination="#predicted delta ssu; double sequence mode of deltasplice is used with experimental values from rnaseq is used; the input length for deltasplice & spliceai & pangolin is 35000; the original versions of spliceai & pangolin & mmsplice are used; for mmsplice delta logits is used.\n", filter_nan=True, endname="with_ref")
write_autism("experiments/5_eval_autism/test_results",
          method_dict_autism_withoutref, "figures/figure6/", explaination="#predicted delta ssu; single sequence mode of deltasplice is used; the input length for deltasplice & spliceai & pangolin is 35000; the original versions of spliceai & pangolin & mmsplice are used; for mmsplice delta logits is used.\n", filter_nan=False, endname="without_ref")

nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan
nan


In [None]:
# compare results of different reference
ref_withlabel_gt, gt, allname_with_label=load_json("data/autism_exome/data_withlabel.json")
ref, gt, allname_without_label=load_json("data/autism_exome/data.json")
ref_withlabel,_,  pred_res_withlabel=load_mut_deltasplice("experiments/5_eval_autism/test_results/Mutation_data_autism_exome_data_withlabel.json+RefSplice_models_model.ckpt-2.txt")
ref_withoutlabel,_,  pred_res_withoutlabel=load_mut_deltasplice("experiments/5_eval_autism/test_results/Mutation_data_autism_exome_data.json+RefSplice_models_model.ckpt-2.txt")
DWithlabel={}
D={}
for a,b, c,d in zip(allname_with_label, ref_withlabel, pred_res_withlabel, ref_withlabel_gt):
    a="_".join(a.split("_")[1:]).replace("chr", "")
   
    DWithlabel[a]=[b,c,d]

for a,b ,c in zip(allname_without_label, ref_withoutlabel, pred_res_withoutlabel):
    a=a.split("||")[1]
    D[a]=[b,c]
print(D.keys())
prw,prwo=[],[]
for k in DWithlabel:
    if k in D:
        prw.append(DWithlabel[k][0])
        prwo.append(D[k][0])
plt.figure()
plt.scatter(prw, prwo)
plt.show()
    


In [10]:
# plot figure 8 
def load_info(info):
    data = pandas.read_csv(info, sep="\t")
    concentrationdict = {}
    assaydict = {}
    study_dict = {}
    alldict = {"all": [[] for _ in range(4)]}

    index = 0
    Cpsi = []
    Dpsi = []
    allnames = []
    name2savedict=[{},{},{},{}]
    for dpsi, assay, concentration, asogroup, asotype, ctrlpsi,  aso1, aso2, aso3, aso4, study in zip(data["dPSI"], data["Assay"], data["Concentration"], data["ASO.group"], data["ASO2exon.type"], data["ctrl.psi"], data['ext0_mask.seq.name'], data['ext1_mask.seq.name'], data['ext2_mask.seq.name'], data['ext3_mask.seq.name'], data["Study"]):
        # assay.replace(".endogenous", "").upper())
        cname = "{}".format(concentration)
        aname = "{}".format(assay)
        sname = "{}_{}_{}".format(study, concentration, assay)
        if cname not in concentrationdict:
            concentrationdict[cname] = [[] for _ in range(4)]
        if aname not in assaydict:
            assaydict[aname] = [[] for _ in range(4)]
        if sname not in study_dict:
            study_dict[sname] = [[] for _ in range(4)]

        for i, aso in enumerate([aso1, aso2, aso3, aso4]):
            if isinstance(aso, str):
                concentrationdict[cname][i].append(index)
                assaydict[aname][i].append(index)
                study_dict[sname][i].append(index)
                alldict["all"][i].append(index)
                index += 1
                Cpsi.append(ctrlpsi/100.)
                Dpsi.append(dpsi/100.)
                allnames.append(aso)
                
                name2savedict[i][len(allnames)-1]=len(name2savedict[i])
                

    assaydict = {key: assaydict[key]
                 for key in assaydict if len(assaydict[key][0]) > 2}
    concentrationdict = {key: concentrationdict[key] for key in concentrationdict if len(
        concentrationdict[key][0]) > 2}
    study_dict = {key: study_dict[key] for key in study_dict if len(
        study_dict[key][0]) > 2}

    return Cpsi, Dpsi, concentrationdict, assaydict, study_dict,  alldict,  allnames, name2savedict


def plot_aso_concentration(info, groupdict, endfix, cscatter, cbar, preidx, dname, read_prefix, filename_dict, save_prefix):
    ctrl_psi, global_gt, concentrationdict, assaydict, studydict, alldict, allnames,  name2savedict= info
    barwidth = 0.2
    header = "name,ctrl_psi"
    assert len(allnames)==len(ctrl_psi)
    writelines = [_+","+str(r) for _,r in zip(allnames, ctrl_psi)]

    for idx, key in enumerate(groupdict.keys()):
        barbias = -(barwidth*4)/2+barwidth
        midx = 0
        for method in OrderedMethod:
            if method in filename_dict:
                color, _, _, name, f = filename_dict[method]
            else:
                continue
            f=f[dname]
            Xs, Ys = [], []
            if method=="refsplice":
                start_usage, _, pred = load_mut_deltasplice(os.path.join(read_prefix, f))
            elif method=="mmsplice":
                pred = load_mmsplice(os.path.join(read_prefix, f))
                start_usage=None
            else:
                start_usage, _, pred = load_pred(os.path.join(read_prefix, f))
            
            if endfix == "assay" and idx == 0:
                if start_usage is not None:
                    header = header+",ssu before mut predicted by {},deltassu predicted by {}".format(name, name)
                    for i, p, r in zip(range(len(pred)),pred, start_usage):
                        writelines[i] = writelines[i]+",{},{}".format(r, p)
                else:
                    header = header+",deltassu predicted by {}".format(name, name)
                    for i, p in enumerate(pred):
                        writelines[i] = writelines[i]+",{}".format(p)
                assert len(pred) == len(writelines)
                
            assert len(pred) == len(global_gt), f"{len(pred)} {len(global_gt)}, {method}"
            for gidx, group in enumerate(groupdict[key]):
                group_gt = [global_gt[_] for _ in group]
                group_pred = [pred[_] for _ in group]
                if gidx == 0:
                    cscatter[idx+preidx][midx].scatter(group_pred, group_gt)
                    cscatter[idx+preidx][midx].set_ylabel(
                        "Experimental $\Delta$ PSI of {}".format(dname))
                    cscatter[idx+preidx][midx].set_xlabel(
                        "Predicted $\Delta$ PSI by {}".format(name))
                    cscatter[idx+preidx][midx].text(0.,
                                                    0., "{}={}".format(endfix, key))
                sr = spearmanr(group_gt, group_pred)[0]
                pr = pearsonr(group_gt, group_pred)[0]
                Xs.append(gidx+barbias)
                Ys.append(pr)
            barbias += barwidth
            cbar[(idx+preidx)//2][(idx+preidx) % 2].bar(Xs, Ys, color=color, label=name,
                                                        width=barwidth)
            midx += 1

        cbar[(idx+preidx)//2][(idx+preidx) %
                              2].set_ylabel("Pearson correlation")
        cbar[(idx+preidx)//2][(idx+preidx) % 2].legend()
        cbar[(idx+preidx)//2][(idx+preidx) % 2].set_xticks([0, 1, 2, 3])
        cbar[(idx+preidx)//2][(idx+preidx) % 2].set_xticklabels(
            ["Without extending", "Extend 1 bp", "Extend 2 bp", "Extend 3 bp"], rotation=8)
        cbar[(idx+preidx)//2][(idx+preidx) % 2].text(0.5,
                                                     0.5, "{}:{}={}".format(dname, endfix, key))

    if endfix == "assay":
        for i, D in enumerate(name2savedict):
            revD={v:k for k,v in D.items()}
            merge_table_list.add(os.path.join(save_prefix, dname+"_extend_{}.csv".format(i)))
            with open(os.path.join(save_prefix, dname+"_extend_{}.csv".format(i)), "w") as f:
                f.writelines("#predicted delta ssu; double sequence mode of deltasplice is used with experimental values as reference; the input length for deltasplice & spliceai & pangolin is 35000; the original versions of spliceai & pangolin & mmsplice are used; for mmsplice delta ssu is computed with experimental ssu before mutation.\n")
                f.writelines(header+"\n")
                for i in range(len(revD)):
                    f.writelines(writelines[revD[i]]+"\n")
            

def plot_aso(read_prefix, filename_dict, saveprefix):
    if not os.path.exists(save_prefix):
        os.mkdir(save_prefix)
    info_ikbkap=load_info("data/ASO/IKBKAP_exon20_measured.ASO_combined_info.txt")
    info_smn2=load_info("data/ASO/SMN2_exon7_measured.ASO_combined_info.txt")
    ctrl_psi_ikbkap, global_gt_ikbkap, concentrationdict_ikbkap, assaydict_ikbkap, studydict_ikbkap, alldict_ikbkap, allnames_ikbkap, _ = info_ikbkap
    ctrl_psi_smn2, global_gt_smn2, concentrationdict_smn2, assaydict_smn2, studydict_smn2, alldict_smn2, allnames_smn2, _ = info_smn2

    Ncolumn = 4
    for adict, bdict, endfix in zip([concentrationdict_ikbkap, assaydict_ikbkap, studydict_ikbkap, alldict_ikbkap], [concentrationdict_smn2, assaydict_smn2, studydict_smn2, alldict_smn2], ["concentration", "assay", "study", "all"]):
        _, cscatter = plt.subplots(len(
            adict)+len(bdict), Ncolumn, figsize=[24, 4*(len(adict)+len(bdict))])
        _, cbar = plt.subplots((len(adict)+len(bdict)+1)//2, 2,
                               figsize=[24, 4*(len(adict)+len(bdict)+1)//2])
        # print(cscatter, cbar)
        if len(adict)+len(bdict) <= 2:
            cbar = [cbar]
        plot_aso_concentration(info_ikbkap, adict, endfix,
                               cscatter, cbar, 0, "IKBKAP", read_prefix, filename_dict, saveprefix)
        plot_aso_concentration(info_smn2, bdict, endfix,
                               cscatter, cbar, len(adict), "SMN2", read_prefix, filename_dict, saveprefix)
        plt.savefig(os.path.join(saveprefix, "{}_bar.pdf".format(
            endfix)), bbox_inches='tight', dpi=DPI)
        plt.close()
        plt.savefig(os.path.join(saveprefix, "{}_scatter.pdf".format(
            endfix)), bbox_inches='tight', dpi=DPI)
        plt.close()


method_dict_aso = {
    # color, shape, name
    "spliceai": ["#B89330", "dashed", "o", "SpliceAI", {"IKBKAP": "IKBKAP_spliceai_withlabel.txt", "SMN2": "SMN2_spliceai_withlabel.txt"}],
    "pangolin": ["#F6C343", "dashdot", ">", "Pangolin", {"IKBKAP": "IKBKAP_pangolin_withlabel.txt", "SMN2": "SMN2_pangolin_withlabel.txt"}],
    "refsplice": ["#375392", "solid", "D", "DeltaSplice", {"IKBKAP":"Mutation_data_IKBKAP_data_withlabel.json+RefSplice_models_model.ckpt-2.txt", "SMN2":"Mutation_data_SMN2_data_withlabel.json+RefSplice_models_model.ckpt-2.txt"}],
    "mmsplice": ["#89B0B2", "dotted", "*", "MMSplice", {"IKBKAP": "IKBKAP_mmsplice_withlabel.txt", "SMN2": "SMN2_mmsplice_withlabel.txt"}]
}
plot_aso("experiments/7_eval_aso/test_results",
            method_dict_aso, "figures/figure8/")



In [11]:
# summary all aso predictions 

def write_aso_all(read_prefix, filename_dict, saveprefix):
    IKBKAP_json_path, SMN2_json_path = "data/IKBKAP/data.json", "data/SMN2/data.json"
    if not os.path.exists(saveprefix):
        os.mkdir(saveprefix)
    for jsonfile, idx, dataname in zip([IKBKAP_json_path, SMN2_json_path], [0, 1], ["INBKAP", "SMN2"]):
        ref, gt, refname= load_json(jsonfile)

        header = "name"
        ###
        writelines = [_ for _ in refname]
        for method in OrderedMethod:
            if method in filename_dict:
                color, _, _, name, f1, f2 = filename_dict[method]
                if idx == 0:
                    f = f1
                else:
                    f = f2
            else:
                continue
            if method == "refsplice":
                start_usage, _, pred = load_mut_deltasplice(
                    os.path.join(read_prefix, f))
            elif method == "mmsplice":
                pred = load_mmsplice(os.path.join(read_prefix, f))
                start_usage=None
            else:
                start_usage, _, pred = load_pred(os.path.join(read_prefix, f))
            assert len(pred) == len(refname)
            if start_usage is not None:
                assert len(start_usage) == len(writelines)
                header = header+",ssu before mut predicted by {},deltassu predicted by {}".format(name, name)
                for i, p, r in zip(range(len(pred)),pred, start_usage):
                    writelines[i] = writelines[i]+",{},{}".format(r, p)
            else:
                header = header+",deltassu predicted by {}".format(name, name)
                for i, p in enumerate(pred):
                    writelines[i] = writelines[i]+",{}".format(p)
            assert len(writelines) == len(pred)
           
        
   
        with open(os.path.join(saveprefix, "{}_all_prediction.csv".format(dataname)), "w") as f:
            f.writelines("#predicted delta ssu; double sequence mode of deltasplice is used with experimental data as reference; the input length for deltasplice & spliceai & pangolin is 35000; the original versions of spliceai & pangolin & mmsplice are used; for mmsplice delta ssu is computed with reference values.\n")
            f.writelines(header+"\n")
            for line in writelines:
                f.writelines(line+"\n")
        merge_table_list.add(os.path.join(
            saveprefix, "{}_all_prediction.csv".format(dataname)))


method_dict_aso = {
    # color, shape, name
    "spliceai": ["#B89330", "dashed", "o", "SpliceAI",  "IKBKAP_spliceai.txt", "SMN2_spliceai.txt"],
    "pangolin": ["#F6C343", "dashdot", ">", "Pangolin","IKBKAP_pangolin.txt", "SMN2_pangolin.txt"],
    "refsplice": ["#375392", "solid", "D", "DeltaSplice", "Mutation_data_IKBKAP_data.json+RefSplice_models_model.ckpt-2.txt",  "Mutation_data_SMN2_data.json+RefSplice_models_model.ckpt-2.txt"],
    "mmsplice": ["#89B0B2", "dotted", "*", "MMSplice", "IKBKAP_mmsplice.txt", "SMN2_mmsplice.txt"]
}
write_aso_all("experiments/7_eval_aso/test_results",
         method_dict_aso, "figures/figure8/")

In [12]:
# merge all table to one
merge_file_name="./experiments/10_merged_tables.xlsx"
writer=pandas.ExcelWriter(merge_file_name)
for idx, file in enumerate(merge_table_list):
    d=pandas.read_csv(file, dtype=str)
    print(file, len(d))
    d.to_excel(writer, file.split("/")[-1].replace(" ", "").replace(".csv", ""))

writer.close()
    

figures/figure8/SMN2_extend_0.csv 246
figures/figure3/Spearman_bar_res_of_ssu.csv 121
figures/figure8/SMN2_extend_3.csv 215
figures/figure8/IKBKAP_extend_2.csv 114
figures/figure8/IKBKAP_extend_0.csv 114
figures/figure6/genome__without_ref.csv 83923
figures/figure6/sQTL_with_ref.csv 408854
figures/figure4/Fas.csv 3072
figures/figure8/SMN2_all_prediction.csv 83862
figures/figure2/Pearson correlation.csv 33
figures/figure6/sQTL_without_ref.csv 589989
figures/figure3/Spearman_bar_res_of_deltassu.csv 121
figures/figure2/Spearman correlation.csv 33
figures/figure3/Pearson_bar_res_of_ssu.csv 121
figures/figure8/SMN2_extend_1.csv 246
figures/figure6/exome__with_ref.csv 15384
figures/figure8/IKBKAP_extend_3.csv 104
figures/figure6/exome__without_ref.csv 20550
figures/figure8/IKBKAP_extend_1.csv 114
figures/figure3/Pearson_bar_res_of_deltassu.csv 121
figures/figure8/SMN2_extend_2.csv 246
figures/figure4/VexSeq.csv 1961
figures/figure8/INBKAP_all_prediction.csv 83862
figures/figure5/MFASS.csv 27