In [1]:
import fastmri
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pathlib
import pytorch_lightning as pl

import os
from torch.utils.data import DataLoader
from src.subsample import create_mask_for_mask_type, RandomMaskFunc
from src import transforms as T
from src.mri_data import fetch_dir
from src.data_module import FastMriDataModule
from src.unet.unet_module import UnetModule
from argparse import ArgumentParser



In [2]:
path_config = pathlib.Path("fastmri_dirs.yaml")

configs = dict(
    challenge="singlecoil",
    num_gpus=1,
    backend="mps",
    batch_size=1,
    data_path=fetch_dir("knee_path", path_config),
    default_root_dir=fetch_dir("log_path", path_config) / "unet" / "unet_demo",
    mode="train",  # "train" or "test"
    mask_type="random",  # "random" or "equispaced_fraction"
    center_fractions=[0.08],  # number of center lines to use in the mask
    accelerations=[4],  # acceleration rates to use for the mask
    # model parameters
    in_chans=1,
    out_chans=1,
    chans=32,
    num_pool_layers=4,
    drop_prob=0.0,
    lr=0.001,
    lr_step_size=40,
    lr_gamma=0.1,
    weight_decay=0.0,
    max_epochs=50
)

pl.seed_everything(42)

# mask for transforming the input data
mask = create_mask_for_mask_type(
    configs['mask_type'], configs['center_fractions'], configs['accelerations']
)

# random masks for train, fixed masks for val
train_transform = T.UnetDataTransform(configs['challenge'],
                                      mask_func=mask,
                                      use_seed=False)
val_transform = T.UnetDataTransform(configs['challenge'], mask_func=mask)
test_transform = T.UnetDataTransform(configs['challenge'])

# create a data module
data_module = FastMriDataModule(
    data_path=configs['data_path'],
    challenge=configs['challenge'],
    train_transform=train_transform,
    val_transform=val_transform,
    test_transform=test_transform,
    test_path=None,
    batch_size=configs['batch_size'],
    num_workers=4,
)

# create a model
model = UnetModule(
    in_chans=configs['in_chans'],
    out_chans=configs['out_chans'],
    chans=configs['chans'],
    num_pool_layers=configs['num_pool_layers'],
    drop_prob=configs['drop_prob'],
    lr=configs['lr'],
    lr_step_size=configs['lr_step_size'],
    lr_gamma=configs['lr_gamma'],
    weight_decay=configs['weight_decay'],
)

callbacks = [
    pl.callbacks.ModelCheckpoint(
        dirpath=configs['default_root_dir'],
        monitor="val_loss",
        mode="min",
        save_top_k=1,
        save_last=True,
        verbose=True,
    ),
    pl.callbacks.LearningRateMonitor(logging_interval="epoch"),
]

# create a trainer
trainer = pl.Trainer(
    devices=configs['num_gpus'],
    max_epochs=configs['max_epochs'],
    default_root_dir=configs['default_root_dir'],
    accelerator=configs['backend'],
    callbacks=callbacks,
)



#if args.resume_from_checkpoint is None:
#    ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime)
  #  if ckpt_list:
 #       args.resume_from_checkpoint = str(ckpt_list[-1])

Seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/lsantos/Projects/fastMRI/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default


In [4]:
trainer.device_ids

[0]

In [5]:
trainer.accelerator

<pytorch_lightning.accelerators.mps.MPSAccelerator at 0x15fb67b50>

In [6]:
data_module.prepare_data()

In [7]:
data_module.setup(stage="fit")

In [8]:
data_module.transfer_batch_to_device(data_module.train_dataloader(), "mps", dataloader_idx=1)

<torch.utils.data.dataloader.DataLoader at 0x15901d870>

In [3]:
trainer.fit(model, datamodule=data_module)

/Users/lsantos/Projects/fastMRI/.venv/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:652: Checkpoint directory /Users/lsantos/Projects/fastMRI/logs/unet/unet_demo exists and is not empty.

  | Name             | Type                 | Params | Mode 
------------------------------------------------------------------
0 | NMSE             | DistributedMetricSum | 0      | train
1 | SSIM             | DistributedMetricSum | 0      | train
2 | PSNR             | DistributedMetricSum | 0      | train
3 | ValLoss          | DistributedMetricSum | 0      | train
4 | TotExamples      | DistributedMetricSum | 0      | train
5 | TotSliceExamples | DistributedMetricSum | 0      | train
6 | unet             | Unet                 | 7.8 M  | train
------------------------------------------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.024    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/lsantos/Projects/fastMRI/.venv/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.


TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

In [8]:
import pickle
with open("./dataset_cache.pkl", "rb") as f:
    data = pickle.load(f)

In [9]:
data

{PosixPath('data/singlecoil_train'): [FastMRIRawDataSample(fname=PosixPath('data/singlecoil_train/file1000001.h5'), slice_ind=0, metadata={'padding_left': 19, 'padding_right': 354, 'encoding_size': (640, 372, 1), 'recon_size': (320, 320, 1), 'acquisition': 'CORPDFS_FBK', 'max': 0.000851878253624366, 'norm': 0.0596983310320022, 'patient_id': '0beb8905d9b7fad304389b9d4263c57d5b069257ea0fdc5bf7f2675608a47406'}),
  FastMRIRawDataSample(fname=PosixPath('data/singlecoil_train/file1000001.h5'), slice_ind=1, metadata={'padding_left': 19, 'padding_right': 354, 'encoding_size': (640, 372, 1), 'recon_size': (320, 320, 1), 'acquisition': 'CORPDFS_FBK', 'max': 0.000851878253624366, 'norm': 0.0596983310320022, 'patient_id': '0beb8905d9b7fad304389b9d4263c57d5b069257ea0fdc5bf7f2675608a47406'}),
  FastMRIRawDataSample(fname=PosixPath('data/singlecoil_train/file1000001.h5'), slice_ind=2, metadata={'padding_left': 19, 'padding_right': 354, 'encoding_size': (640, 372, 1), 'recon_size': (320, 320, 1), 'acq