In [None]:
import os
import pathlib

import pytorch_lightning as pl

from fastmri.data.mri_data import fetch_dir
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri.data.transforms import UnetDataTransform
from fastmri.pl_modules import FastMriDataModule, UnetModule

In [None]:
seed = 42
mask_type = "equispaced" # for brain
center_fractions = [0.08, 0.04]
accelerations = [4, 8]

In [None]:
mask = create_mask_for_mask_type(mask_type, center_fractions, accelerations)

In [None]:
challenge = "multicoil" # CHALLENGE is either singlecoil or multicoil
# resume_from_checkpoint = "MODEL" # where model is path to the model checkpoint

# FastMriDataModule
# Args:
#     data_path: Path to root data directory. For example, if knee/path
#         is the root directory with subdirectories multicoil_train and
#         multicoil_val, you would input knee/path for data_path.
#     challenge: Name of challenge from ('multicoil', 'singlecoil').
#     train_transform: A transform object for the training split.
#     val_transform: A transform object for the validation split.
#     test_transform: A transform object for the test split.
#     combine_train_val: Whether to combine train and val splits into one
#         large train dataset. Use this for leaderboard submission.
#     test_split: Name of test split from ("test", "challenge").
#     test_path: An optional test path. Passing this overwrites data_path
#         and test_split.
#     sample_rate [optional]: Fraction of slices of the training data split to use.
#         Can be set to less than 1.0 for rapid prototyping. If not set,
#         it defaults to 1.0. To subsample the dataset either set
#         sample_rate (sample by slice) or volume_sample_rate (sample by
#         volume), but not both.
#     val_sample_rate: Same as sample_rate, but for val split.
#     test_sample_rate: Same as sample_rate, but for test split.
#     volume_sample_rate: Fraction of volumes of the training data split
#         to use. Can be set to less than 1.0 for rapid prototyping. If
#         not set, it defaults to 1.0. To subsample the dataset either
#         set sample_rate (sample by slice) or volume_sample_rate (sample
#         by volume), but not both.
#     val_volume_sample_rate: Same as volume_sample_rate, but for val
#         split.
#     test_volume_sample_rate: Same as volume_sample_rate, but for val
#         split.
#     train_filter: A callable which takes as input a training example
#         metadata, and returns whether it should be part of the training
#         dataset.
#     val_filter: Same as train_filter, but for val split.
#     test_filter: Same as train_filter, but for test split.
#     use_dataset_cache_file: Whether to cache dataset metadata. This is
#         very useful for large datasets like the brain data.
#     batch_size: Batch size.
#     num_workers: Number of workers for PyTorch dataloader.
#     distributed_sampler: Whether to use a distributed sampler. This
#         should be set to True if training with ddp.

data_path="" # train data path
test_split="" # TESTSPLIT should specify the test split you want to run on - either test or challenge.
test_path="" # test data path
sample_rate="1.0" # define
batch_size="1" #  batch_size = 1 if backend == "ddp" else num_gpus
num_workers="4" # default 4
distributed_sampler="True" # Set to True if training with ddp

In [None]:
train_transform = UnetDataTransform(challenge, mask, use_seed=False)
val_transform = UnetDataTransform(challenge, mask)
test_transform = UnetDataTransform(challenge)

In [None]:
data_module = FastMriDataModule(
    data_path=data_path,
    challenge=challenge,
    train_transform=train_transform,
    val_transform=val_transform,
    test_transform=test_transform,  
    combine_train_val=True,
    test_split=test_split,
    test_path=test_path,
    sample_rate=sample_rate,
    batch_size=batch_size,
    num_workers=num_workers,
    distributed_sampler=distributed_sampler
)

In [None]:
# model = UnetModule(
#     in_chans=1,
#     out_chans=,
#     chans=,
#     num_pool_layers=,
#     drop_prob=,
#     lr=,
#     lr_step_size=,
#     lr_gamma=,
#     weight_decay=,

# )

In [None]:
# Args:
    # gpus: The number of GPUs to use for training. For example, num_gpus is used to set this in the code.
    # strategy: The distributed training strategy, such as "ddp" for Distributed Data Parallel.
    # seed: The random seed for reproducibility.
    # deterministic: If set to True, this makes the training deterministic (but potentially slower).
    # default_root_dir: The default root directory for saving model checkpoints and logs.
    # max_epochs: The maximum number of epochs for training.
    # callbacks: A list of callbacks for various training events, such as model checkpointing.

gpus = 1
strategy = "ddp_notebook"
seed = 42
deterministic = True
default_root_dir = ""
max_epoch = 100 


trainer = pl.Trainer(num_nodes=gpus, strategy=strategy, deterministic=deterministic, default_root_dir=default_root_dir)