Skip to content
Permalink
master
Switch branches/tags
Go to file
 
 
Cannot retrieve contributors at this time
"""
Snakefile for the splicing model
"""
workdir:
"../../"
ROOT = "../.."
DATA = "data/"
import matplotlib
matplotlib.use('Agg')
import sys
from kipoi.cli.env import get_env_name
from glob import glob
# Folder structure for data:
# - raw: raw/tfbinding/eval/<task>/files...
# - processed:
# - processed/tfbinding/eval/preds/
# - processed/tfbinding/eval/metrics/
# - processed/tfbinding/eval/plots/
# --------------------------------------------
# Config
from m_kipoi.utils import get_env_executable
from m_kipoi.exp.tfbinding.config import (DATA, TF_C_pairs, TFS, CELL_TYPES,
NUM_FASTA_FILE, CHR_FASTA_FILE,
HOLDOUT_CHR, TF2CT, SINGLE_TASK_MODELS,
get_dl_kwargs)
# in the new API, we could also use a single name for the model
ENVIRONMENTS = {
"DeepSEA": "DeepSEA/predict",
"pwm_HOCOMOCO": "pwm_HOCOMOCO/human/CEBPB",
"DeepBind": "DeepBind/Homo_sapiens/TF/D00317.009_ChIP-seq_CEBPB",
"FactorNet": "FactorNet/JUND/meta_Unique35_DGF",
"lsgkm-SVM": "lsgkm-SVM/Tfbs/Cebpb/Helas3/Sydh_Iggrab",
}
# --------------------------------------------
rule all:
input:
# Raw data
expand(DATA + "raw/tfbinding/eval/tf-DREAM/" +
"chr8_wide_bin101_flank0_stride101.{tf}.{ctype}.intervals_file.tsv.gz",
zip, tf=TFS, ctype=CELL_TYPES),
# FactorNet DNase
expand(DATA + "raw/tfbinding/eval/DNASE/FactorNet/{ctype}.1x.bw", ctype=CELL_TYPES),
# Environments
[get_env_executable(env) for env in ENVIRONMENTS],
# predictions
[DATA + "processed/tfbinding/eval/preds/{tf}/{model}.h5".format(tf=tf, model=model)
for tf, value in SINGLE_TASK_MODELS.items()
for model, model_name in value.items()
if model_name is not None],
# runtime evaluation
[DATA + "processed/tfbinding/eval/runtimes/{tf}/{model}.{batch_size}.json".format(tf='JUND', batch_size=256, model=model)
for model, model_name in SINGLE_TASK_MODELS['JUND'].items()
if model_name is not None],
# metrics
"data/processed/tfbinding/eval/metrics/all_models.chr8.csv",
# plots
expand("data/processed/tfbinding/eval/plots/all_models.chr8.{fmt}",
fmt=['pdf', 'png'])
rule fasta_numchr2chr:
"""Create a chr<N> based fasta file
"""
input:
f = NUM_FASTA_FILE
output:
f = CHR_FASTA_FILE
shell:
"""
cat {input.f} | sed 's/^>/>chr/g' > {output.f}
"""
# --------------------------------------------
# Get all the evaluation data
EVAL_FILES = ['chr8_wide_bin101_flank0_stride101.CEBPB.HeLa-S3.intervals_file.tsv.gz',
'chr8_wide_bin101_flank0_stride101.HepG2.JUND.intervals_file.tsv.gz',
'chr8_wide_bin101_flank0_stride101.MAFK.K562.intervals_file.tsv.gz',
'chr8_wide_bin101_flank0_stride101.NANOG.H1-hESC.intervals_file.tsv.gz']
synapse_DNASE = {
"HeLa-S3": "syn8073618",
"K562": "syn8073483",
"HepG2": "syn8073517",
"H1-hESC": "syn8073583",
}
rule unzip_tsv:
input:
f = DATA + "raw/tfbinding/eval/tf-DREAM/chr8_wide_bin101_flank0_stride101.{tf}.{ctype}.intervals_file.tsv.gz"
output:
f = DATA + "raw/tfbinding/eval/tf-DREAM/chr8_wide_bin101_flank0_stride101.{tf}.{ctype}.intervals_file.tsv"
shell:
"zcat {input.f} > '{output.f}"
rule get_DNase_data_factornet:
"""Get the DNase datasets for factornet"""
output:
dnase = DATA + "raw/tfbinding/eval/DNASE/FactorNet/{ctype}.1x.bw"
params:
syn = lambda w: synapse_DNASE[w.ctype]
shell:
"synapse get {params.syn} --downloadLocation $(dirname {output.dnase})"
# --------------------------------------------
# Setup the environment
rule create_conda_env:
"""Create one conda environment per model"""
output:
env = get_env_executable("{env}")
params:
env_name = lambda w: ENVIRONMENTS[w.env]
shell:
"kipoi env create {params.env_name} -e {wildcards.env}" # Add --gpu to leverage the GPU
# --------------------------------------------
# Evaluate the models
rule predict:
"""Run model prediction"""
input:
kipoi = get_env_executable("{model}"),
intervals = lambda w: get_dl_kwargs(w.tf)['intervals_file'],
fasta = lambda w: get_dl_kwargs(w.tf)['fasta_file'],
dnase_file = lambda w: get_dl_kwargs(w.tf)['dnase_file'],
output:
preds = DATA + "processed/tfbinding/eval/preds/{tf}/{model}.h5"
params:
model = lambda w: SINGLE_TASK_MODELS[w.tf][w.model],
dl_kwargs = lambda w: get_dl_kwargs(w.tf),
batch_size = lambda w: 1024 if w.model.startswith("lsgkm-SVM") else 256
threads: 16
shell:
"""
echo batch size: {params.batch_size}
echo model: {wildcards.model}
{input.kipoi} predict {params.model} \
--dataloader_args="{params.dl_kwargs}" \
--batch_size={params.batch_size} \
-n 16 \
-o {output.preds}
"""
rule measure_execution_time:
"""Measure model prediction time
"""
input:
env_bin = get_env_executable("{model}"),
intervals = lambda w: get_dl_kwargs(w.tf)['intervals_file'],
fasta = lambda w: get_dl_kwargs(w.tf)['fasta_file'],
dnase_file = lambda w: get_dl_kwargs(w.tf)['dnase_file'],
output:
json_fname = DATA + "processed/tfbinding/eval/runtimes/{tf}/{model}.{batch_size}.json"
params:
model = lambda w: SINGLE_TASK_MODELS[w.tf][w.model],
dl_kwargs = lambda w: get_dl_kwargs(w.tf)
threads: 16
shell:
"""
source activate {wildcards.model}
python src/tf-binding/execution_time.py \
--model={params.model} \
--model_group={wildcards.model} \
--dl_kwargs="{params.dl_kwargs}" \
--batch_size={wildcards.batch_size} \
--num_workers={threads} \
--num_runs=10 \
--tf={wildcards.tf} \
--output={output.json_fname}
"""
rule evaluate_models:
"""Gather model predictions and compute auPRC"""
input:
preds = [DATA + "processed/tfbinding/eval/preds/{tf}/{model}.h5".format(tf=tf, model=model)
for tf, value in SINGLE_TASK_MODELS.items()
for model, model_name in value.items()
if model_name is not None]
output:
csv = "data/processed/tfbinding/eval/metrics/all_models.chr8.csv"
run:
import pandas as pd
from joblib import Parallel, delayed
from m_kipoi.exp.tfbinding.eval import eval_model
from tqdm import tqdm
from m_kipoi.metrics import classification_metrics
MODELS = ['pwm_HOCOMOCO', 'DeepBind', 'lsgkm-SVM', 'DeepSEA', 'FactorNet']
df = pd.DataFrame(Parallel(n_jobs=32)(delayed(eval_model)(tf, model, classification_metrics,
filter_dnase=filter_dnase)
for model in tqdm(MODELS)
for tf in tqdm(TFS)
for filter_dnase in [True, False]))
# Make a nice column description
df['dataset'] = "Chromosome wide (chr8))"
df['dataset'][df.filter_dnase == True] = "Only accessible regions (chr8))"
df.to_csv(output.csv)
rule plot:
"""Create the plot"""
input:
csv = "data/processed/tfbinding/eval/metrics/all_models.chr8.csv"
output:
files_all = expand("data/processed/tfbinding/eval/plots/all_models.chr8.{fmt}",
fmt=['pdf', 'png']),
files_dnase = expand("data/processed/tfbinding/eval/plots/all_models.chr8.DNASE-peaks.{fmt}",
fmt=['pdf', 'png'])
run:
import matplotlib
matplotlib.use('Agg')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
import pandas as pd
import seaborn as sns
sns.set_context("talk")
df = pd.read_csv(input.csv)
pallete = ['#8dd3c7', '#fdb462', '#bebada', '#fb8072', '#80b1d3']
df = df.rename(columns={"model": "Model"})
# filter_dnase=False
g = sns.factorplot(x="tf", y="auPR", hue="Model", data=df[~df.filter_dnase],
size=5, kind="bar", palette=pallete)
g.set_xlabels("Transcription Factor")
g.set_ylabels("auPRC")
for fname in output.files_all:
g.savefig(fname)
# filter_dnase=True
g = sns.factorplot(x="tf", y="auPR", hue="Model", data=df[df.filter_dnase],
size=5, kind="bar", palette=pallete)
g.set_xlabels("Transcription Factor")
g.set_ylabels("auPRC")
for fname in output.files_dnase:
g.savefig(fname)