-
Notifications
You must be signed in to change notification settings - Fork 7
/
Snakefile
139 lines (120 loc) · 5.33 KB
/
Snakefile
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
################################################################################
# SETUP
################################################################################
# Modules
from os.path import join
sys.path.append(os.path.join(os.getcwd(), 'src'))
from constants import *
# Configuration
config['signatures_file'] = config.get('signatures_file',
'data/signatures/cosmic-signatures.tsv') # default: COSMIC
config['active_signatures'] = config.get('active_signatures',
[1,2,3,5,6,8,13,17,18,20,26,30])
config['run_name'] = config.get('run_name', 'ICGC-R22-BRCA') # default:
config['mutations_file'] = config.get('mutations_file',
'data/mutations/ICGC-BRCA-EU.RELEASE_22.SBS.renamed.sigma.json')
config['output_dir'] = OUTPUT_DIR = config.get('output_dir', join('output', config.get('run_name')))
if not ('samples' in config):
import json
with open(config.get('mutations_file'), 'r') as IN:
config['samples'] = json.load(IN).get('samples')
elif type(config['samples']) != type([]):
config['samples'] = [config['samples']]
config['random_seed'] = config.get('random_seed', 94781)
config['max_iter'] = config.get('max_iter', 100)
config['cloud_thresholds'] = config.get('cloud_thresholds', list(range(1000,10001,1000)))
config['chosen_cloud_threshold'] = config.get('chosen_cloud_threshold', 2000)
config['tolerance'] = config.get('tolerance', 1e-3)
if len(config.get('active_signatures')) == 0:
ACTIVE_SIGNATURES_PARAM = ''
else:
ACTIVE_SIGNATURES_PARAM = '-as %s' % ' '.join(map(str, config.get('active_signatures')))
# Directories
DATA_DIR = 'data'
SIGNATURES_DIR = join(DATA_DIR, 'signatures')
MUTATIONS_DIR = join(DATA_DIR, 'mutations')
LOOCV_DIR = join(OUTPUT_DIR, 'loocv')
TRAINED_MODEL_DIR = join(OUTPUT_DIR, 'models')
SRC_DIR = 'src'
# Files
SIGNATURES_FILE = config.get('signatures_file')
MUTATIONS_FILE = config.get('mutations_file')
TRAINED_MODEL_FMT = '%s/%s{threshold}/{model}-{sample}.json' % (TRAINED_MODEL_DIR, SIGMA_NAME)
MMM_LOOCV_MODEL_FMT = '%s/%s/%s-{sample}.json' % (LOOCV_DIR, MMM_NAME, MMM_NAME)
SIGMA_LOOCV_MODEL_FMT = '%s/%s{threshold}/%s-{sample}.json' % (LOOCV_DIR, SIGMA_NAME, SIGMA_NAME)
LOOCV_FIGURE = join(LOOCV_DIR, '%s-%s-loocv-comparison.pdf' % (SIGMA_NAME, MMM_NAME))
# Scripts
TRAIN_AND_PREDICT_PY = join(SRC_DIR, 'train_and_predict.py')
CREATE_FIGURE_PY = join(SRC_DIR, 'create_fig.py')
################################################################################
# RULES
################################################################################
# Train the model
rule train_full:
input:
mutations=config.get('mutations_file'),
signatures=config.get('signatures_file')
params:
max_iter=config.get('max_iter'),
random_seed=config.get('random_seed'),
tolerance=config.get('tolerance'),
active_signatures=ACTIVE_SIGNATURES_PARAM
output:
TRAINED_MODEL_FMT
shell:
'python {TRAIN_AND_PREDICT_PY} -mf {input.mutations} -sf {input.signatures} '\
'-od {TRAINED_MODEL_DIR}/sigma{wildcards.threshold} -mn {wildcards.model} '\
'{params.active_signatures} -sn {wildcards.sample} -mi {params.max_iter} '\
'-ct {wildcards.threshold} -rs {params.random_seed} -tol {params.tolerance}'
# Perform LOOCV for each sample
rule loocv:
input:
expand(SIGMA_LOOCV_MODEL_FMT, sample=config.get('samples'), threshold=config.get('cloud_thresholds')),
expand(MMM_LOOCV_MODEL_FMT, sample=config.get('samples'))
rule sigma_loocv_full:
input:
mutations=config.get('mutations_file'),
signatures=config.get('signatures_file')
params:
max_iter=config.get('max_iter'),
random_seed=config.get('random_seed'),
tolerance=config.get('tolerance'),
active_signatures=ACTIVE_SIGNATURES_PARAM,
output:
SIGMA_LOOCV_MODEL_FMT
shell:
'python {TRAIN_AND_PREDICT_PY} -mf {input.mutations} -sf {input.signatures} '\
'-od {LOOCV_DIR}/sigma{wildcards.threshold} -mn {SIGMA_NAME} -sn {wildcards.sample} '\
'{params.active_signatures} -mi {params.max_iter} -ct {wildcards.threshold} '\
'-rs {params.random_seed} -tol {params.tolerance} --cross-validation-mode'
rule mmm_loocv:
input:
mutations=config.get('mutations_file'),
signatures=config.get('signatures_file')
params:
max_iter=config.get('max_iter'),
random_seed=config.get('random_seed'),
tolerance=config.get('tolerance'),
active_signatures=ACTIVE_SIGNATURES_PARAM,
output:
MMM_LOOCV_MODEL_FMT
shell:
'python {TRAIN_AND_PREDICT_PY} -mf {input.mutations} -sf {input.signatures} '\
'-od {LOOCV_DIR}/mmm -mn {MMM_NAME} -sn {wildcards.sample} '\
'{params.active_signatures} -mi {params.max_iter} -ct 0 '\
'-rs {params.random_seed} -tol {params.tolerance} --cross-validation-mode'
# General rules
rule cv_figure:
input:
rules.loocv.input
output:
LOOCV_FIGURE
shell:
'python {CREATE_FIGURE_PY} -ld {LOOCV_DIR} -of {output}'
rule train:
input:
expand(TRAINED_MODEL_FMT, sample=config.get('samples'), model=MODEL_NAMES, threshold=[config.get('chosen_cloud_threshold')]),
rule all:
input:
rules.train.input,
LOOCV_FIGURE