In [1]:
import plotly.graph_objects as go
import plotly.io as pio
import plotly.express as px
from plotly.subplots import make_subplots

In [15]:
from os import path
pio.orca.config.use_xvfb=True

def snap(fig, filename, figuredir = '../figures') :
    fig.update_layout(width=700, height=300,
                      margin = dict(l=1, r=1, t=20, b=20))
    try :
        fig.write_image(path.join(figuredir, filename))
    except:
        pio.orca.shutdown_server()
        fig.write_image(path.join(figuredir, filename))
    return fig

In [2]:
import pandas as pd
import numpy as np
import pickle
from glob import glob
from os.path import basename

from tqdm.notebook import tqdm

from sklearn.model_selection import cross_validate, LeaveOneGroupOut
from sklearn.metrics import precision_recall_curve, auc

import os
import json

os.environ["CUDA_VISIBLE_DEVICES"]="2"

from tensorflow import keras

import xarray as xa

from os import path

In [66]:
%run ../scripts/utils_keras.py

# Loading model parameters

In [4]:
with open('../models/models.json', 'r') as f:
    model_params = json.load(f)

# Loading model scores

In [10]:
model_path = f'../models/Apis_mellifera/exome/'

scores = dict()
scores_std = dict()

for file in glob(f'{model_path}/r*_onehot_convo_*_scores.pkl') :
    if file.find('negative')>0 or file.find('methylome')>0:
        continue

    radius, intyp, topo, i= file[len(model_path):-len('_scores.pkl')].split('_')

    radius = int(radius[1:])
    modelno = int(i)
    modelname = '_'.join([intyp, topo, i])

    if modelname not in scores :
            scores[modelname] = dict()
            scores_std[modelname] = dict()

    with open(file, 'rb') as f :
        s = np.array(pickle.load(f))
        scores[modelname][radius] = s.mean()
        scores_std[modelname][radius] = s.std()
        

scores = pd.DataFrame(model_params).set_index('name').join(pd.DataFrame(scores).T)

scores_top10 = scores.sort_values(500, ascending=False).head(10)
scores_top10_std = pd.DataFrame(scores_std).T.round(2).reindex(scores_top10.index)
scores_top10.round(2).join(scores_top10_std, rsuffix='_std')

Unnamed: 0_level_0,conv_nfilters,conv_ksize,conv_psize,dense_units,dropout,300,500,300_std,500_std
name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
onehot_convo_19,128,16,32,128,0.0,0.92,0.92,0.02,0.02
onehot_convo_6,128,8,16,128,0.25,0.92,0.92,0.02,0.02
onehot_convo_23,128,16,32,128,0.5,0.92,0.92,0.02,0.02
onehot_convo_18,256,8,32,32,0.0,0.91,0.92,0.02,0.02
onehot_convo_12,64,8,32,32,0.0,0.9,0.92,0.02,0.02
onehot_convo_11,256,16,16,128,0.5,0.91,0.92,0.02,0.02
onehot_convo_1,256,32,32,64,0.5,0.91,0.92,0.02,0.02
onehot_convo_2,256,32,32,64,0.25,0.91,0.92,0.02,0.02
onehot_convo_17,256,8,8,128,0.25,0.91,0.92,0.03,0.02
onehot_convo_22,64,32,16,64,0.0,0.92,0.92,0.02,0.02


# Loading data

In [None]:
X_exome,y_exome,seqid_exome = get_data(f'../data/Apis_mellifera/exome/r500_onehot.cdf')

# Testing top 10 models on genome

In [None]:
X_genome,y_genome,seqid_genome=get_data(f'../data/Apis_mellifera/genome/r500_onehot.cdf')

scores_genome = dict()
r=500

for model_name in scores_top10.index : 
    model_path = f'../models/Apis_mellifera/exon/r{r}_{model_name}',
    scores_genome[model_name] = test_cv(model_path, (X_genome,y_genome,seqid_genome))

Apis mellifera
genome/r500_onehot.cdf: size: 3057492, pos: 70592, positive_ratio: 0.02, dummy_mean: 0.51, dummy_std: 0.00
Bombus terrestris


In [None]:
X_genome_300,y_genome_300,seqid_genome_300=get_data(f'../data/Apis_mellifera/genome/r500_onehot.cdf', radius=300)

scores_genome_300 = dict()
r=300

for model_name in scores_top10.index : 
    model_path = f'../models/Apis_mellifera/exon/r{r}_{model_name}',
    scores_genome[model_name] = test_cv(model_path, (X_genome,y_genome,seqid_genome))

Apis mellifera
genome/r500_onehot.cdf: size: 3057492, pos: 70592, positive_ratio: 0.02, dummy_mean: 0.51, dummy_std: 0.00
Bombus terrestris


In [320]:
scores_top10 = scores_top10.join(pd.Series(scores_genome_300, name='300_genome'))

In [321]:
scores_top10 = scores_top10.join(pd.Series(scores_genome, name='500_genome'))

In [5]:
scores_top10.sort_values('500_genome', ascending=False, inplace=True)
scores_top10.round(2)

NameError: name 'scores_top10' is not defined

In [316]:
scores_top10.to_latex('top10_hyperparameters.tex', 
                      index=False, 
                      float_format="%.2f", 
                      caption='Top 5 hyperparameters and resulting scores')

In [197]:
#with open('../models/models_top5.json', 'w') as f: 
    #json.dump([params for params in model_params if params['name'] in scores_top10.head(5).index],f, indent=1)

# Testing the models trained on whole genome 

In [331]:
for r in [300,500] :

    for model_name in ['onehot_convo_19','onehot_convo_22'] :
        
        try :
            
            with open(f'../models/Apis_mellifera/genome/r{r}_{model_name}_scores.pkl', 'rb') as f :
                                s = np.array(pickle.load(f))
                                score_genome = s.mean()
                        
            score_exome = test_cv(r, model_name, 
                        f'../models/Apis_mellifera/genome', 
                        f'../data/Apis_mellifera/exon/r500_onehot.cdf')
            
            print( r, model_name, round(score_genome,2), round(score_exome, 2))

        except:
            pass

398244 samples.
358084 samples after filtering.
56102.0 positive samples (15.7%).


HBox(children=(FloatProgress(value=0.0, description='onehot_convo_19', max=16.0, style=ProgressStyle(descripti…


300 onehot_convo_19 0.83 0.91
398244 samples.
358084 samples after filtering.
56102.0 positive samples (15.7%).


HBox(children=(FloatProgress(value=0.0, description='onehot_convo_22', max=16.0, style=ProgressStyle(descripti…


300 onehot_convo_22 0.82 0.91
398244 samples.
358084 samples after filtering.
56102.0 positive samples (15.7%).


HBox(children=(FloatProgress(value=0.0, description='onehot_convo_19', max=16.0, style=ProgressStyle(descripti…


500 onehot_convo_19 0.85 0.93
398244 samples.
358084 samples after filtering.
56102.0 positive samples (15.7%).


HBox(children=(FloatProgress(value=0.0, description='onehot_convo_22', max=16.0, style=ProgressStyle(descripti…


500 onehot_convo_22 0.85 0.92


# Testing the dynamic methylome

In [18]:
model_name = 'onehot_convo_19' 
model_path = f'../models/Apis_mellifera/genome/r500_{model_name}'

X,y,seqid = get_data('../data/Apis_mellifera/exome/r500_onehot.cdf', mode='negative')

exome/r500_onehot.cdf: size: 341730, pos: 35634, baseline: 0.10


In [19]:
scores = test_cv(model_path, (X,y,seqid))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

mean: 0.60 std:0.06


In [20]:
X,y,seqid = get_data('../data/Apis_mellifera/genome/r500_onehot.cdf', mode='negative')

genome/r500_onehot.cdf: size: 3047948, pos: 53436, baseline: 0.02


In [21]:
scores = test_cv(model_path, (X,y,seqid))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

mean: 0.45 std:0.07


This is a model specifically trained for the negative task on the exons.

In [None]:
X,y,seqid = get_data('../data/Apis_mellifera/exome/r500_onehot.cdf', mode='negative')

In [57]:
model_name = 'onehot_convo_19' 
model_path = f'../models/Apis_mellifera/exome/r500_{model_name}_negative'

scores = test_cv(model_path, (X,y,seqid))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

mean: 0.61 std:0.06


## Precision-recall curves for exome and genome

We compare these to the logistic case

In [35]:
def test_pr(model_path, data, npoints=1000, batch_size=1024 ):
    
    X,y,seqid=data
    
    interp_threshold=np.linspace(0, 1, npoints)
    
    precisions, recalls, scores = [],[],[]

    for path in tqdm(list(glob(f'{model_path}_NC*.h5')), leave=False) :

        test_seqid = path[len(f'{model_path}_'):-3 ]

        model = keras.models.load_model(path)

        test = seqid == test_seqid
        X_test = X[test]
        y_test = y[test]

        y_pred = model.predict(X_test, batch_size=batch_size)

        precision, recall, threshold = precision_recall_curve(y_test, y_pred)

        interp_precision = np.interp(interp_threshold, threshold, precision[:-1])
        interp_recall = np.interp(interp_threshold, threshold, recall[:-1])

        recalls.append(interp_recall)
        precisions.append(interp_precision)
        
        scores.append(auc(recall, precision))

    precision_mean = np.array(precisions).mean(axis=0)
    precision_std  = np.array(precisions).std(axis=0)
    recall_mean  = np.array(recalls).mean(axis=0)
    recall_std   = np.array(recalls).std(axis=0)
    
    scores = np.array(scores)
    print(f'mean: {scores.mean():.2f} std:{scores.std():.2f}')
    
    return precision_mean, recall_mean, precision_std, recall_std

In [8]:
model_name = 'onehot_convo_19' 
model_path = f'../models/Apis_mellifera/genome/r500_{model_name}'

In [36]:
# For exome

X,y,seqid = get_data('../data/Apis_mellifera/exome/r500_onehot.cdf')

exome/r500_onehot.cdf: size: 358084, pos: 56102, baseline: 0.16


In [37]:
precision_exome, recall_exome, precision_std_exome, recall_std_exome = test_pr(model_path, (X,y,seqid))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

mean: 0.93 std:0.01


In [38]:
# For genome

X,y,seqid = get_data('../data/Apis_mellifera/genome/r500_onehot.cdf')

genome/r500_onehot.cdf: size: 3057492, pos: 70592, baseline: 0.02


In [39]:
precision_genome, recall_genome, precision_std_genome, recall_std_genome = test_pr(model_path, (X,y,seqid))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

mean: 0.86 std:0.04


In [27]:
markers = { s:dict(color=c) for s, c in zip (['convo', 'logit', 'baseline'], px.colors.qualitative.Set1) }

trace1 = go.Scatter(x=recall_exome, y=precision_exome, 
                   error_x=dict(array=recall_std_exome, type='data'), 
                   error_y=dict(array=precision_std_exome, type='data'),
                   marker=markers['convo'],
                   name ='convolutional')

trace2= go.Scatter(x=recall_genome, y=precision_genome, 
                   error_x=dict(array=recall_std_genome, type='data'), 
                   error_y=dict(array=precision_std_genome, type='data'),
                   marker=markers['convo'],
                   showlegend=False,
                   name = 'convolutional')

convo_pr=dict(exome=trace1, genome=trace2)

with open('convo_pr.pkl', 'wb') as f :
    pickle.dump(dict(exome=trace1, genome=trace2), f)

In [34]:
fig = make_subplots(rows = 1, cols=2, shared_yaxes=True, subplot_titles=['Test on exome', 'Test on genome'])

with open('logistic_pr.pkl', 'rb') as f :
    logistic_pr = pickle.load(f)

logistic_pr['exome'].update(name='logistic', marker=markers['logit'], showlegend =False)
logistic_pr['genome'].update(name='logistic', marker=markers['logit'])

fig.add_traces([logistic_pr['exome'], convo_pr['exome'], 
                logistic_pr['genome'], convo_pr['genome'] ], 
               rows=(1,1,1,1), cols = (1,1,2,2))

fig.update_traces(opacity=0.7)

fig.add_trace(go.Scatter(name='baseline', x=(0,1), y=(0.16, 0.16), marker = markers['baseline'], mode='lines', line_dash ='dash'), row=1, col=1 )
fig.add_trace(go.Scatter(showlegend=False, x=(0,1), y=(0.02, 0.02), marker = markers['baseline'], mode='lines', line_dash ='dash'), row=1, col=2 )

fig.update_xaxes(range=(0, 1), title ='Recall')
fig.update_yaxes(range=(0, 1))
fig.update_yaxes(title='Precision', row=1, col=1)


snap(fig, 'precision-recall.pdf')

# Precision-recall curves on the negative task

In [102]:
model_name = 'onehot_convo_19' 
model_path = f'../models/Apis_mellifera/genome/r500_{model_name}'

In [104]:
X,y,seqid = get_data('../data/Apis_mellifera/genome/r500_onehot.cdf', mode='negative')

genome/r500_onehot.cdf: size: 3047948, pos: 53436, baseline: 0.02


In [105]:
precision, recall, precision_std, recall_std = test_pr(model_path, (X,y,seqid))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

mean: 0.46 std:0.07


In [106]:
trace_neg  = go.Scatter(x=recall, y=precision, 
                   error_x=dict(array=recall_std, type='data'), 
                   error_y=dict(array=precision_std, type='data'),
                   #marker=markers['convo'],
                   name ='dynamic')

In [107]:
X,y,seqid = get_data('../data/Apis_mellifera/genome/r500_onehot.cdf', mode='normal')

genome/r500_onehot.cdf: size: 3057492, pos: 70592, baseline: 0.02


In [108]:
precision, recall, precision_std, recall_std = test_pr(model_path, (X,y,seqid))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

mean: 0.86 std:0.04


In [109]:
trace_pos  = go.Scatter(x=recall, y=precision, 
                   error_x=dict(array=recall_std, type='data'), 
                   error_y=dict(array=precision_std, type='data'),
                   #marker=markers['convo'],
                   name ='robust')

In [111]:
fig=go.Figure([trace_neg, trace_pos])

fig.add_trace(go.Scatter(name='baseline', x=(0,1), y=(0.02, 0.02), mode='lines', line_dash ='dash'))
#fig.add_trace(go.Scatter(name='baseline', x=(0,1), y=(0.16, 0.16), mode='lines', line_dash ='dash'))


fig.update_xaxes(range=(0, 1), title ='Recall')
fig.update_yaxes(range=(0, 1), title ='Precision')

In [67]:
def predict_cv(model_path, data, batch_size=128, verbose=False ) :
    X,y,seqid = data
    
    y_value, y_pred = [],[]
    
    for path in tqdm(list(glob(f'{model_path}_NC*.h5')), leave=False) :
        
        test_seqid = path[len(f'{model_path}_'):-3 ]
               
        model = keras.models.load_model(path)
        
        test = seqid == test_seqid
        X_test = X[test]
        
        y_value.append(y[test])
    
        y_pred.append( model.predict(X_test, batch_size=batch_size, verbose=verbose))
        
    return y_value, y_pred      

In [69]:
X,y, seqid = get_data('../data/Apis_mellifera/exome/r500_onehot.cdf', mode='continuous')

exome/r500_onehot.cdf: size: 92144, pos: 65368.2421875, baseline: 0.71


In [70]:
y_val, y_pred = predict_cv('../models/Apis_mellifera/exome/r500_onehot_convo_19', (X,y,seqid))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

In [74]:
y_val = np.concatenate(y_val)
y_pred = np.concatenate(y_pred)


In [81]:
y_pred = y_pred.flatten()

In [91]:
go.Figure(go.Histogram2d(x=y_val, y=y_pred, histnorm='probability', zmin=0.0, zmax=0.075))

# Precision-recall curves on methylome

In [58]:
X,y,seqid = get_data('../data/Apis_mellifera/methylome/r500_onehot.cdf', mode='methylome', threshold=0.75)

methylome/r500_onehot.cdf: size: 139614, pos: 71886, baseline: 0.51


In [59]:
model_name = 'onehot_convo_19' 
model_path = f'../models/Apis_mellifera/methylome/r500_{model_name}'

precision_met, recall_met, precision_std_met, recall_std_met = test_pr(model_path, (X,y,seqid))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

mean: 0.82 std:0.02


In [60]:
trace  = go.Scatter(x=recall_met, y=precision_met, 
                   error_x=dict(array=recall_std_met, type='data'), 
                   error_y=dict(array=precision_std_met, type='data'),
                   #marker=markers['convo'],
                   name ='convolutional')

fig=go.Figure(trace)

fig.add_trace(go.Scatter(name='baseline', x=(0,1), y=(0.51, 0.51), mode='lines', line_dash ='dash'))

fig.update_xaxes(range=(0, 1), title ='Recall')
fig.update_yaxes(range=(0, 1), title ='Precision')



In [61]:
model_name = 'onehot_convo_19' 
model_path = f'../models/Apis_mellifera/genome/r500_{model_name}'

precision_met, recall_met, precision_std_met, recall_std_met = test_pr(model_path, (X,y,seqid))

HBox(children=(FloatProgress(value=0.0, max=16.0), HTML(value='')))

mean: 0.80 std:0.02
