In [1]:
#load the variables specific to this dataset 
#on brahma, SNPs have been split by DL CV split 
snp_prefix="/srv/scratch/annashch/gecco/manuscript_loci/splits"
model_prefix="/srv/scratch/annashch/deeplearning/gecco/crossvalid/v4/gecco.classification.SummitWithin200bpCenter"
test_set_prefix="/srv/scratch/annashch/deeplearning/gecco/crossvalid/v4/predictions"
n_folds=10 
num_tasks=5 
all_snps_basename="AnnotationList.labeled.collapsed.txt"
target_layer_idx=-2
tasks=['DNASEC','DNASEV','SW480','HCT116','COLO205']
outf_name="SNP_effect_predictions.txt"
interpretation_prefix="/srv/scratch/annashch/gecco/manuscript_loci/interpretation"
ref_fasta="/mnt/data/annotations/by_release/hg19.GRCh37/hg19.genome.fa"

In [2]:
#load the dependencies
%reload_ext autoreload
%autoreload 2
%matplotlib inline

import warnings
warnings.filterwarnings('ignore')
import pandas as pd 
import pickle 
from abstention.calibration import PlattScaling, IsotonicRegression
from keras.models import load_model, Model 

from kerasAC.metrics import recall, specificity, fpr, fnr, precision, f1
from kerasAC.custom_losses import get_ambig_binary_crossentropy
from kerasAC.generators import * 
from kerasAC.predict import get_model_layer_functor, get_layer_outputs


Using TensorFlow backend.


In [3]:
#dragonn imports
from dragonn.interpret import *
from dragonn.vis import * 

In [4]:
def load_keras_model(fold,model_prefix,target_layer_idx):    
    #load the model
    custom_objects={"recall":recall,
                    "sensitivity":recall,
                    "specificity":specificity,
                    "fpr":fpr,
                    "fnr":fnr,
                    "precision":precision,
                    "f1":f1,
                    "ambig_binary_crossentropy":get_ambig_binary_crossentropy()}
    model=load_model(".".join([model_prefix,str(fold)]),custom_objects=custom_objects)
    print("loaded model")
    #load the model to predict preacts 
    preact_model=Model(inputs=model.input,
                       outputs=model.layers[target_layer_idx].output)
    print("loaded preact model")
    return model,preact_model

In [5]:
def get_snp_generators(fold,snp_prefix,all_snps_basename,ref_fasta):
    snp_file='/'.join([snp_prefix,all_snps_basename+"."+str(fold)])
    snps=pd.read_csv(snp_file,header=0,sep='\t')
    if snps.shape[0]==0: 
        return None,None
    snp_ref_generator=SNPGenerator(snp_file,
                                   batch_size=1,
                                   ref_fasta=ref_fasta,
                                   allele_col="Ref")
    snp_alt_generator=SNPGenerator(snp_file,
                                   batch_size=1,
                                   ref_fasta=ref_fasta,
                                   allele_col="Alt")
    return snps, snp_ref_generator,snp_alt_generator

In [None]:
for fold in range(n_folds): 
    try:
        snps, snp_ref_generator,snp_alt_generator=get_snp_generators(fold,
                                                                     snp_prefix,
                                                                     all_snps_basename,
                                                                     ref_fasta)
        print("Got snp ref and alt generators")
    except:
        print("fold:"+str(fold)+" appears to be empty, skipping")
        continue 
    model_string=".".join([model_prefix,str(fold)])
    #model,preact_model = load_keras_model(fold,model_prefix,target_layer_idx)

    num_snps=len(snp_ref_generator)
    for i in range(num_snps): 
        cur_ref_entry=snp_ref_generator[i]
        cur_alt_entry=snp_alt_generator[i]
        rsid=snps['Rsid'][i]
        ref_allele=snps['Ref'][i]
        alt_allele=snps['Alt'][i]
        ref_interpretations=[] 
        alt_interpretations=[] 
        titles=[]
        for task_idx in range(len(tasks)):
            task=tasks[task_idx]
            ref_interpretations.append(multi_method_interpret(model_string,cur_ref_entry,task_idx,generate_plots=False))
            alt_interpretations.append(multi_method_interpret(model_string,cur_alt_entry,task_idx,generate_plots=False))
            title='_'.join([rsid,ref_allele, alt_allele,task])
            titles.append(title)
            print("done:"+str(title))
        #plotting 
        out_fname_svg='.'.join([rsid,ref_allele,alt_allele,"svg"])
        plot_snp_interpretation(ref_interpretations,
                                alt_interpretations,
                                cur_ref_entry,
                                cur_alt_entry,
                                title=titles,
                                xlim=(400,600),
                                snp_pos=501,
                                out_fname_svg=out_fname_svg)
        print("plotted!")



loaded labels
filtered on chroms_to_use
data.shape:(7, 8)
loaded labels
filtered on chroms_to_use
data.shape:(7, 8)
Got snp ref and alt generators
generator idx:0
generator idx:0
getting 'ism' value
ISM: task:0 sample:0
getting 'input_grad' value
getting 'deeplift' value
getting 'ism' value
ISM: task:0 sample:0
getting 'input_grad' value
getting 'deeplift' value
done:rs599134_C_G_DNASEC
getting 'ism' value
ISM: task:1 sample:0
getting 'input_grad' value
getting 'deeplift' value
getting 'ism' value
ISM: task:1 sample:0
getting 'input_grad' value
getting 'deeplift' value
done:rs599134_C_G_DNASEV
getting 'ism' value
ISM: task:2 sample:0
getting 'input_grad' value
getting 'deeplift' value
getting 'ism' value
ISM: task:2 sample:0
getting 'input_grad' value
getting 'deeplift' value
done:rs599134_C_G_SW480
getting 'ism' value
ISM: task:3 sample:0
getting 'input_grad' value
getting 'deeplift' value
getting 'ism' value
ISM: task:3 sample:0
getting 'input_grad' value
getting 'deeplift' value
don

KeyboardInterrupt: 