-
Notifications
You must be signed in to change notification settings - Fork 370
/
train_varnet_demo.py
197 lines (168 loc) · 5.91 KB
/
train_varnet_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import os
import pathlib
from argparse import ArgumentParser
import pytorch_lightning as pl
from fastmri.data.mri_data import fetch_dir
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri.data.transforms import VarNetDataTransform
from fastmri.pl_modules import FastMriDataModule, VarNetModule
def cli_main(args):
pl.seed_everything(args.seed)
# ------------
# data
# ------------
# this creates a k-space mask for transforming input data
mask = create_mask_for_mask_type(
args.mask_type, args.center_fractions, args.accelerations
)
# use random masks for train transform, fixed masks for val transform
train_transform = VarNetDataTransform(mask_func=mask, use_seed=False)
val_transform = VarNetDataTransform(mask_func=mask)
test_transform = VarNetDataTransform()
# ptl data module - this handles data loaders
data_module = FastMriDataModule(
data_path=args.data_path,
challenge=args.challenge,
train_transform=train_transform,
val_transform=val_transform,
test_transform=test_transform,
test_split=args.test_split,
test_path=args.test_path,
sample_rate=args.sample_rate,
batch_size=args.batch_size,
num_workers=args.num_workers,
distributed_sampler=(args.accelerator in ("ddp", "ddp_cpu")),
)
# ------------
# model
# ------------
model = VarNetModule(
num_cascades=args.num_cascades,
pools=args.pools,
chans=args.chans,
sens_pools=args.sens_pools,
sens_chans=args.sens_chans,
lr=args.lr,
lr_step_size=args.lr_step_size,
lr_gamma=args.lr_gamma,
weight_decay=args.weight_decay,
)
# ------------
# trainer
# ------------
trainer = pl.Trainer.from_argparse_args(args)
# ------------
# run
# ------------
if args.mode == "train":
trainer.fit(model, datamodule=data_module)
elif args.mode == "test":
trainer.test(model, datamodule=data_module)
else:
raise ValueError(f"unrecognized mode {args.mode}")
def build_args():
parser = ArgumentParser()
# basic args
path_config = pathlib.Path("../../fastmri_dirs.yaml")
backend = "ddp"
num_gpus = 2 if backend == "ddp" else 1
batch_size = 1
# set defaults based on optional directory config
data_path = fetch_dir("knee_path", path_config)
default_root_dir = fetch_dir("log_path", path_config) / "varnet" / "varnet_demo"
# client arguments
parser.add_argument(
"--mode",
default="train",
choices=("train", "test"),
type=str,
help="Operation mode",
)
# data transform params
parser.add_argument(
"--mask_type",
choices=("random", "equispaced_fraction"),
default="equispaced_fraction",
type=str,
help="Type of k-space mask",
)
parser.add_argument(
"--center_fractions",
nargs="+",
default=[0.08],
type=float,
help="Number of center lines to use in mask",
)
parser.add_argument(
"--accelerations",
nargs="+",
default=[4],
type=int,
help="Acceleration rates to use for masks",
)
# data config
parser = FastMriDataModule.add_data_specific_args(parser)
parser.set_defaults(
data_path=data_path, # path to fastMRI data
mask_type="equispaced_fraction", # VarNet uses equispaced mask
challenge="multicoil", # only multicoil implemented for VarNet
batch_size=batch_size, # number of samples per batch
test_path=None, # path for test split, overwrites data_path
)
# module config
parser = VarNetModule.add_model_specific_args(parser)
parser.set_defaults(
num_cascades=8, # number of unrolled iterations
pools=4, # number of pooling layers for U-Net
chans=18, # number of top-level channels for U-Net
sens_pools=4, # number of pooling layers for sense est. U-Net
sens_chans=8, # number of top-level channels for sense est. U-Net
lr=0.001, # Adam learning rate
lr_step_size=40, # epoch at which to decrease learning rate
lr_gamma=0.1, # extent to which to decrease learning rate
weight_decay=0.0, # weight regularization strength
)
# trainer config
parser = pl.Trainer.add_argparse_args(parser)
parser.set_defaults(
gpus=num_gpus, # number of gpus to use
replace_sampler_ddp=False, # this is necessary for volume dispatch during val
strategy=backend, # what distributed version to use
seed=42, # random seed
deterministic=True, # makes things slower, but deterministic
default_root_dir=default_root_dir, # directory for logs and checkpoints
max_epochs=50, # max number of epochs
)
args = parser.parse_args()
# configure checkpointing in checkpoint_dir
checkpoint_dir = args.default_root_dir / "checkpoints"
if not checkpoint_dir.exists():
checkpoint_dir.mkdir(parents=True)
args.callbacks = [
pl.callbacks.ModelCheckpoint(
dirpath=args.default_root_dir / "checkpoints",
save_top_k=True,
verbose=True,
monitor="validation_loss",
mode="min",
)
]
# set default checkpoint if one exists in our checkpoint directory
if args.resume_from_checkpoint is None:
ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime)
if ckpt_list:
args.resume_from_checkpoint = str(ckpt_list[-1])
return args
def run_cli():
args = build_args()
# ---------------------
# RUN TRAINING
# ---------------------
cli_main(args)
if __name__ == "__main__":
run_cli()