In [1]:
import picrispr
from picrispr import TrainResult
from load_data import normaliseCF
from encoding import oneHotSingleNuclTargetMismatchType, oneHotSingleNuclTargetMismatch, oneHotSingleNucl, FeatureEncoding
from models import mySequential, vecToMatEncoder, vecToMatEncoding
import pickle
import torch
import tensorflow as tf
from tensorflow import keras
from sklearn.metrics import precision_recall_curve, roc_curve, auc, roc_auc_score, confusion_matrix
from scipy.stats import spearmanr, pearsonr
import numpy as np
import pandas as pd
import xgboost as xgb
import matplotlib
matplotlib.use('TkAgg') # inline graphics
import matplotlib.pyplot as plt
import os

os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

In [2]:
home = "models"
csv_path = "offtarget_260520_nuc.csv.zip"
do_load = False # load a CSVDataset from disk (True) or generate it anew from csv (False)
compare_deepcrispr = True # DeepCRISPR-style testing scenario (True) or 80/20 crisprSQL testing scenario (False)

use_gpu = True

In [None]:
filenames = ["xgboost_interface_type_epi", "torch_eng", "torch_engnuc", "torch_interface_type_nuc", "tf_eng", "tf_engnuc"]
modelnames =["XGB_S3E2", "CNN_S2E0", "CNN_S4E0", "CNN_S5E2", "RNN_S2E3", "RNN_S4E3"]

# define models
seq_features = "target_sequence, grna_target_sequence, "
energy_features = "energy_4*(energy_3/energy_2), energy_2-energy_4*(energy_3/energy_2), energy_1-energy_4*(energy_3/energy_2), "
epigen_features = "epigen_ctcf, epigen_dnase, epigen_rrbs, epigen_h3k4me3, epigen_drip, "
nuc_features = "StrongWeakBDM, WSScore, NuPoP_Occup_147_human, "
numBpWise = 3
regression = True

seq_energy_feat         = seq_features + energy_features + "experiment_id"
seq_energy_epi_feat     = seq_features + energy_features + epigen_features + "experiment_id"
seq_energy_nuc_feat     = seq_features + nuc_features + energy_features + "experiment_id"
seq_energy_epi_nuc_feat = seq_features + nuc_features + energy_features + epigen_features + "experiment_id"

args = [("", home, "xgboost", True,  regression, "epi",    True),  # XGB_S3E2
        ("", home, "torch",   False, regression, "eng",    False), # CNN_S2E0
        ("", home, "torch",   False, regression, "engnuc", False), # CNN_S4E0
        ("", home, "torch",   True,  regression, "nuc",    True),  # CNN_S5E2
        ("", home, "tf",      False, regression, "eng",    False), # RNN_SSE3
        ("", home, "tf",      False, regression, "engnuc", False)  # RNN_S4E3
        ]

kwargs = [{'dbFields': seq_energy_epi_feat,    'numBpWise': 0},
          {'dbFields': seq_energy_feat,        'numBpWise': 0},
          {'dbFields': seq_energy_nuc_feat,    'numBpWise': numBpWise},
          {'dbFields': seq_energy_epi_nuc_feat,'numBpWise': numBpWise},
          {'dbFields': seq_energy_feat,        'numBpWise': 0},
          {'dbFields': seq_energy_nuc_feat,    'numBpWise': numBpWise}
          ]

load data matrix from csv

In [None]:
%matplotlib inline

dataset = picrispr.CSVDataset(csv_path)

# choose which model to encode for
for modelnum in range(6):
    config = {"dbFields": kwargs[modelnum]['dbFields'],
              "numBpWise": kwargs[modelnum]['numBpWise'],
              "mode": args[modelnum][2],
              "mismatchType": args[modelnum][3],
              "interfaceMode": args[modelnum][6],
              "chooseSpecies": ["hg19", "hg38"],
              "regression": args[modelnum][4],
             }

    filenameAppendix  = "_interface" if config["interfaceMode"] else ""
    filenameAppendix += "_type"      if config["mismatchType"] and config["interfaceMode"]  else ""
    filenameAppendix += "_"+args[modelnum][5]
    
    print(config["mode"], filenameAppendix, "__________________________________")

    if config["mismatchType"]:
        oneHotFct = oneHotSingleNuclTargetMismatchType
        featurenames = ['A_match', 'T_match', 'C_match', 'G_match', 
                        'A_mismT', 'T_mismC', 'C_mismG', 'G_mismA', 
                        'A_mismC', 'T_mismG', 'C_mismA', 'G_mismT',
                        'A_mismG', 'T_mismA', 'C_mismT', 'G_mismC']
    elif config["interfaceMode"]: 
        oneHotFct = oneHotSingleNuclTargetMismatch
        featurenames = ['A', 'A_mism', 'T', 'T_mism', 'C', 'C_mism', 'G', 'G_mism']
    else:
        oneHotFct = oneHotSingleNucl
        featurenames = ['A',           'T',           'C',           'G']

    featurenames.extend(" ".join(config["dbFields"].split()).split(', ')[2:]) # append whatever database fields apart from guide and target sequence are used
    
    
    # Load from csv
    if do_load: 
        dm = dm.load('.', filenameAppendix)

    else:
        dm = dataset.getDataMatrix(config["dbFields"], oneHotFct, normaliseCF, chooseSpecies=config["chooseSpecies"], 
                                   featurenames=featurenames, mode=config["mode"], numBpWise=config["numBpWise"], test_size=0.2)

        dm.mode = config["mode"]
        dm.interfaceMode = config["interfaceMode"]
        dm.mismatchType = config["mismatchType"]
        dm.regression = config["regression"]

        dm.save('.', filenameAppendix)
        
    filenameAppendix += "_class" if not config["regression"] else ""
    
    # load model
    result = picrispr.TrainResult.load(home, config["mode"], filenameAppendix, device="gpu:0" if use_gpu else "cpu")
    
    # predict on ext set
    isHPC = True
    bs = int(7e4) if isHPC or dm.mode != "torch" else 35000

    # prepare dataset
    if compare_deepcrispr:
        deepcrispr = [11,4,1,12,3,10,2]
        dm.setExtExperimentIds(deepcrispr, removeFromTrainSet=False, onlyTestToExt=True)

        # choose arbitrary ratio of measured:augmented data points to match publication
        dm.weight_ext_augmented *= 250
        dm.balanceClasses = False
    
    # don't use experiment_id column
    dm.toDense()
    dm = dm.dropColumn(-1)
    
    trainDataset, validDataset, extDataset = dm.prepareDataset(cutoff_class=-4, addGaussian=False)
    dm.prepareDataloaders(trainDataset, validDataset, extDataset, bs, balanceClasses=dm.balanceClasses if hasattr(dm, 'balanceClasses') else False, ignoreExtSet=True)
    
    x_ext, y_ext, _ = next(iter(dm.extLoader))

    if (dm.mode == "torch"):
        model = result.model
        if not use_gpu: model.device = "cpu"
        x_ext = vecToMatEncoding(x_ext, seqDim=dm.encoding.seqDim, single=dm.interfaceMode, 
                                numBpWise=dm.numBpWise, setOR=dm.mismatchType and not dm.interfaceMode)
        ORencoding = dm.mismatchType and not dm.interfaceMode
        siamese = (dm.mode == "torch" and not config["interfaceMode"] and not ORencoding)

        model = model.to(model.device)

        if siamese:
            x_ext = list(x_ext)
            for i in range(len(x_ext)):
                x_ext[i] = x_ext[i].to(model.device)
            ypred = model(*x_ext).flatten()
        else:
            x_ext = x_ext.to(model.device)
            ypred = model(x_ext).flatten()
        y_ext, preds = y_ext.detach().cpu().data, ypred.detach().cpu().data.numpy()

    elif (dm.mode == "xgboost"):
        if result.mismatchType and not result.interfaceMode: x_ext = vecToOrEncoding(x_ext, result.encoding.seqDim) # OR encoding
        matrix_ext = xgb.DMatrix(x_ext.cpu().detach().numpy())
        y_ext = y_ext.detach().cpu().data

        preds = result.model.predict(matrix_ext).flatten()

    elif (dm.mode == "tf"):
        preds = result.model.predict(x_ext.numpy())


    plt.scatter(np.nan_to_num(preds), y_ext, marker='.', alpha=0.4)
    plt.xlabel("preds")
    plt.ylabel("y_ext")
    plt.show()
    
    preds = np.nan_to_num(preds)
    preds = preds[:, 0] if len(preds.shape) > 1 else preds
    
    if regression:
        print("Spearman r =", spearmanr(preds, y_ext.numpy())[0])
        print("Pearson r =", pearsonr(preds, y_ext.numpy())[0])
        
    else:
        cutoff = 0.05
        print(confusion_matrix(y_ext, preds < cutoff))

        fpr, tpr, _ = roc_curve(y_ext.numpy(), preds)
        plt.plot(fpr, tpr)
        plt.show()

        print("AUC ROC:", roc_auc_score(y_ext.numpy(), preds))

        precision, recall, _ = precision_recall_curve(y_ext, preds)
        plt.plot(recall, precision)
        plt.show()

        print("AUC PRC:", auc(recall, precision))

In [None]:
! python pycrispr.py 