Skip to content
Permalink
Branch: master
Find file Copy path
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
249 lines (218 sloc) 8.82 KB
"""
Snakefile for the splicing model
"""
workdir:
"../../"
ROOT = "../.."
DATA = "data/"
import matplotlib
matplotlib.use('Agg')
import sys
from kipoi.cli.env import get_env_name
from glob import glob
# Folder structure for data:
# - raw: raw/tfbinding/eval/<task>/files...
# - processed:
# - processed/tfbinding/eval/preds/
# - processed/tfbinding/eval/metrics/
# - processed/tfbinding/eval/plots/
# --------------------------------------------
# Config
from m_kipoi.utils import get_env_executable
from m_kipoi.exp.tfbinding.config import (DATA, TF_C_pairs, TFS, CELL_TYPES,
NUM_FASTA_FILE, CHR_FASTA_FILE,
HOLDOUT_CHR, TF2CT, SINGLE_TASK_MODELS,
get_dl_kwargs)
# in the new API, we could also use a single name for the model
ENVIRONMENTS = {
"DeepSEA": "DeepSEA/predict",
"pwm_HOCOMOCO": "pwm_HOCOMOCO/human/CEBPB",
"DeepBind": "DeepBind/Homo_sapiens/TF/D00317.009_ChIP-seq_CEBPB",
"FactorNet": "FactorNet/JUND/meta_Unique35_DGF",
"lsgkm-SVM": "lsgkm-SVM/Tfbs/Cebpb/Helas3/Sydh_Iggrab",
}
# --------------------------------------------
rule all:
input:
# Raw data
expand(DATA + "raw/tfbinding/eval/tf-DREAM/" +
"chr8_wide_bin101_flank0_stride101.{tf}.{ctype}.intervals_file.tsv.gz",
zip, tf=TFS, ctype=CELL_TYPES),
# FactorNet DNase
expand(DATA + "raw/tfbinding/eval/DNASE/FactorNet/{ctype}.1x.bw", ctype=CELL_TYPES),
# Environments
[get_env_executable(env) for env in ENVIRONMENTS],
# predictions
[DATA + "processed/tfbinding/eval/preds/{tf}/{model}.h5".format(tf=tf, model=model)
for tf, value in SINGLE_TASK_MODELS.items()
for model, model_name in value.items()
if model_name is not None],
# runtime evaluation
[DATA + "processed/tfbinding/eval/runtimes/{tf}/{model}.{batch_size}.json".format(tf='JUND', batch_size=256, model=model)
for model, model_name in SINGLE_TASK_MODELS['JUND'].items()
if model_name is not None],
# metrics
"data/processed/tfbinding/eval/metrics/all_models.chr8.csv",
# plots
expand("data/processed/tfbinding/eval/plots/all_models.chr8.{fmt}",
fmt=['pdf', 'png'])
rule fasta_numchr2chr:
"""Create a chr<N> based fasta file
"""
input:
f = NUM_FASTA_FILE
output:
f = CHR_FASTA_FILE
shell:
"""
cat {input.f} | sed 's/^>/>chr/g' > {output.f}
"""
# --------------------------------------------
# Get all the evaluation data
EVAL_FILES = ['chr8_wide_bin101_flank0_stride101.CEBPB.HeLa-S3.intervals_file.tsv.gz',
'chr8_wide_bin101_flank0_stride101.HepG2.JUND.intervals_file.tsv.gz',
'chr8_wide_bin101_flank0_stride101.MAFK.K562.intervals_file.tsv.gz',
'chr8_wide_bin101_flank0_stride101.NANOG.H1-hESC.intervals_file.tsv.gz']
synapse_DNASE = {
"HeLa-S3": "syn8073618",
"K562": "syn8073483",
"HepG2": "syn8073517",
"H1-hESC": "syn8073583",
}
rule unzip_tsv:
input:
f = DATA + "raw/tfbinding/eval/tf-DREAM/chr8_wide_bin101_flank0_stride101.{tf}.{ctype}.intervals_file.tsv.gz"
output:
f = DATA + "raw/tfbinding/eval/tf-DREAM/chr8_wide_bin101_flank0_stride101.{tf}.{ctype}.intervals_file.tsv"
shell:
"zcat {input.f} > '{output.f}"
rule get_DNase_data_factornet:
"""Get the DNase datasets for factornet"""
output:
dnase = DATA + "raw/tfbinding/eval/DNASE/FactorNet/{ctype}.1x.bw"
params:
syn = lambda w: synapse_DNASE[w.ctype]
shell:
"synapse get {params.syn} --downloadLocation $(dirname {output.dnase})"
# --------------------------------------------
# Setup the environment
rule create_conda_env:
"""Create one conda environment per model"""
output:
env = get_env_executable("{env}")
params:
env_name = lambda w: ENVIRONMENTS[w.env]
shell:
"kipoi env create {params.env_name} -e {wildcards.env}" # Add --gpu to leverage the GPU
# --------------------------------------------
# Evaluate the models
rule predict:
"""Run model prediction"""
input:
kipoi = get_env_executable("{model}"),
intervals = lambda w: get_dl_kwargs(w.tf)['intervals_file'],
fasta = lambda w: get_dl_kwargs(w.tf)['fasta_file'],
dnase_file = lambda w: get_dl_kwargs(w.tf)['dnase_file'],
output:
preds = DATA + "processed/tfbinding/eval/preds/{tf}/{model}.h5"
params:
model = lambda w: SINGLE_TASK_MODELS[w.tf][w.model],
dl_kwargs = lambda w: get_dl_kwargs(w.tf),
batch_size = lambda w: 1024 if w.model.startswith("lsgkm-SVM") else 256
threads: 16
shell:
"""
echo batch size: {params.batch_size}
echo model: {wildcards.model}
{input.kipoi} predict {params.model} \
--dataloader_args="{params.dl_kwargs}" \
--batch_size={params.batch_size} \
-n 16 \
-o {output.preds}
"""
rule measure_execution_time:
"""Measure model prediction time
"""
input:
env_bin = get_env_executable("{model}"),
intervals = lambda w: get_dl_kwargs(w.tf)['intervals_file'],
fasta = lambda w: get_dl_kwargs(w.tf)['fasta_file'],
dnase_file = lambda w: get_dl_kwargs(w.tf)['dnase_file'],
output:
json_fname = DATA + "processed/tfbinding/eval/runtimes/{tf}/{model}.{batch_size}.json"
params:
model = lambda w: SINGLE_TASK_MODELS[w.tf][w.model],
dl_kwargs = lambda w: get_dl_kwargs(w.tf)
threads: 16
shell:
"""
source activate {wildcards.model}
python src/tf-binding/execution_time.py \
--model={params.model} \
--model_group={wildcards.model} \
--dl_kwargs="{params.dl_kwargs}" \
--batch_size={wildcards.batch_size} \
--num_workers={threads} \
--num_runs=10 \
--tf={wildcards.tf} \
--output={output.json_fname}
"""
rule evaluate_models:
"""Gather model predictions and compute auPRC"""
input:
preds = [DATA + "processed/tfbinding/eval/preds/{tf}/{model}.h5".format(tf=tf, model=model)
for tf, value in SINGLE_TASK_MODELS.items()
for model, model_name in value.items()
if model_name is not None]
output:
csv = "data/processed/tfbinding/eval/metrics/all_models.chr8.csv"
run:
import pandas as pd
from joblib import Parallel, delayed
from m_kipoi.exp.tfbinding.eval import eval_model
from tqdm import tqdm
from m_kipoi.metrics import classification_metrics
MODELS = ['pwm_HOCOMOCO', 'DeepBind', 'lsgkm-SVM', 'DeepSEA', 'FactorNet']
df = pd.DataFrame(Parallel(n_jobs=32)(delayed(eval_model)(tf, model, classification_metrics,
filter_dnase=filter_dnase)
for model in tqdm(MODELS)
for tf in tqdm(TFS)
for filter_dnase in [True, False]))
# Make a nice column description
df['dataset'] = "Chromosome wide (chr8))"
df['dataset'][df.filter_dnase == True] = "Only accessible regions (chr8))"
df.to_csv(output.csv)
rule plot:
"""Create the plot"""
input:
csv = "data/processed/tfbinding/eval/metrics/all_models.chr8.csv"
output:
files_all = expand("data/processed/tfbinding/eval/plots/all_models.chr8.{fmt}",
fmt=['pdf', 'png']),
files_dnase = expand("data/processed/tfbinding/eval/plots/all_models.chr8.DNASE-peaks.{fmt}",
fmt=['pdf', 'png'])
run:
import matplotlib
matplotlib.use('Agg')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
import pandas as pd
import seaborn as sns
sns.set_context("talk")
df = pd.read_csv(input.csv)
pallete = ['#8dd3c7', '#fdb462', '#bebada', '#fb8072', '#80b1d3']
df = df.rename(columns={"model": "Model"})
# filter_dnase=False
g = sns.factorplot(x="tf", y="auPR", hue="Model", data=df[~df.filter_dnase],
size=5, kind="bar", palette=pallete)
g.set_xlabels("Transcription Factor")
g.set_ylabels("auPRC")
for fname in output.files_all:
g.savefig(fname)
# filter_dnase=True
g = sns.factorplot(x="tf", y="auPR", hue="Model", data=df[df.filter_dnase],
size=5, kind="bar", palette=pallete)
g.set_xlabels("Transcription Factor")
g.set_ylabels("auPRC")
for fname in output.files_dnase:
g.savefig(fname)
You can’t perform that action at this time.