In [1]:
import fastmri
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import pathlib
import pytorch_lightning as pl

import os
from torch.utils.data import DataLoader
from src.subsample import create_mask_for_mask_type, RandomMaskFunc
from src import transforms as T
from src.mri_data import fetch_dir
from src.data_module import FastMriDataModule, AnnotatedSliceDataset
from src.unet.unet_module import UnetModule
from argparse import ArgumentParser

from pytorch_lightning import loggers as pl_loggers
tensorboard = pl_loggers.TensorBoardLogger('./')

In [2]:
path_config = pathlib.Path("fastmri_dirs.yaml")

version_name = "unet_roi" # one of "unet", "unet_attn", "unet_ssim", "unet_attn_ssim"

if version_name == "unet":
    attn_layer = False
    metric = "l1"
    use_roi = False
elif version_name == "unet_roi":
    attn_layer = False
    metric = "l1"
    use_roi = True
elif version_name == "unet_attn":
    attn_layer = True
    metric = "l1"
    use_roi = True
elif version_name == "unet_ssim":
    attn_layer = False
    metric = "ssim"
    use_roi = True
elif version_name == "unet_attn_ssim":
    attn_layer = True
    metric = "ssim"
    use_roi = True

configs = dict(
    challenge="singlecoil",
    num_gpus=1,
    backend="mps",
    batch_size=1,
    data_path=fetch_dir("knee_path", path_config),
    default_root_dir=fetch_dir("log_path", path_config) / "unet" / version_name,
    mode="train",  # "train" or "test"
    mask_type="random",  # "random" or "equispaced_fraction"
    center_fractions=[0.08],  # number of center lines to use in the mask
    accelerations=[4],  # acceleration rates to use for the mask
    # model parameters
    in_chans=1,
    out_chans=1,
    chans=32,
    # chans=256,
    num_pool_layers=4,
    drop_prob=0.0,
    lr=0.01,
    lr_step_size=40,
    lr_gamma=0.1,
    weight_decay=0.0,
    max_epochs=10,
    metric=metric,
    roi_weight=0.25,
    attn_layer=attn_layer,
    use_roi=use_roi,
)

pl.seed_everything(42)

# mask for transforming the input data
mask = create_mask_for_mask_type(
    configs['mask_type'], configs['center_fractions'], configs['accelerations']
)

# random masks for train, fixed masks for val
train_transform = T.UnetDataTransform(configs['challenge'], mask_func=mask, use_seed=False)
val_transform = T.UnetDataTransform(configs['challenge'], mask_func=mask)
test_transform = T.UnetDataTransform(configs['challenge'])

# create a data module
data_module = FastMriDataModule(
    data_path=configs['data_path'],
    challenge=configs['challenge'],
    train_transform=train_transform,
    val_transform=val_transform,
    test_transform=test_transform,
    test_path=None,
    batch_size=configs['batch_size'],
    num_workers=10,
)

# create a model
model = UnetModule(
    in_chans=configs['in_chans'],
    out_chans=configs['out_chans'],
    chans=configs['chans'],
    num_pool_layers=configs['num_pool_layers'],
    drop_prob=configs['drop_prob'],
    lr=configs['lr'],
    lr_step_size=configs['lr_step_size'],
    lr_gamma=configs['lr_gamma'],
    weight_decay=configs['weight_decay'],
    metric=configs['metric'],
    roi_weight=configs['roi_weight'],
    attn_layer=configs['attn_layer'],
    use_roi=configs['use_roi'],
)

callbacks = [
    pl.callbacks.ModelCheckpoint(
        # dirpath=configs['default_root_dir'],
        dirpath=configs['default_root_dir'] / "checkpoints",
        # monitor="val_loss",
        monitor="validation_loss",
        mode="min",
        save_top_k=1,
        save_last=True,
        verbose=True,
    ),
    pl.callbacks.LearningRateMonitor(logging_interval="epoch"),
]

# create a trainer
trainer = pl.Trainer(
    devices=configs['num_gpus'],
    max_epochs=configs['max_epochs'],
    default_root_dir=configs['default_root_dir'],
    accelerator=configs['backend'],
    callbacks=callbacks,
    logger=tensorboard,
)

#if args.resume_from_checkpoint is None:
#    ckpt_list = sorted(checkpoint_dir.glob("*.ckpt"), key=os.path.getmtime)
  #  if ckpt_list:
 #       args.resume_from_checkpoint = str(ckpt_list[-1])

Global seed set to 42
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Run the line of code to follow the training process of the model. The training process will be displayed in the TensorBoard:

tensorboard --logdir ./lightning_logs

annotations generate multiple images: 48450 train samples
annotations generate a single image: 41877 train samples


In [3]:
trainer.fit(model, datamodule=data_module)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name             | Type                 | Params
----------------------------------------------------------
0 | NMSE             | DistributedMetricSum | 0     
1 | SSIM             | DistributedMetricSum | 0     
2 | PSNR             | DistributedMetricSum | 0     
3 | ValLoss          | DistributedMetricSum | 0     
4 | TotExamples      | DistributedMetricSum | 0     
5 | TotSliceExamples | DistributedMetricSum | 0     
6 | unet             | Unet                 | 7.8 M 
----------------------------------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.024    Total estimated model params size (MB)


Epoch 0: 100%|██████████| 41877/41877 [48:13<00:00, 14.47it/s, loss=0.473, v_num=15, validation_loss=0.392]

Epoch 0, global step 34742: 'validation_loss' reached 0.39212 (best 0.39212), saving model to 'logs/unet/unet_roi/checkpoints/epoch=0-step=34742-v1.ckpt' as top 1


Epoch 1: 100%|██████████| 41877/41877 [48:50<00:00, 14.29it/s, loss=0.369, v_num=15, validation_loss=0.401] 

Epoch 1, global step 69484: 'validation_loss' was not in top 1


Epoch 2: 100%|██████████| 41877/41877 [48:52<00:00, 14.28it/s, loss=0.366, v_num=15, validation_loss=0.369] 

Epoch 2, global step 104226: 'validation_loss' reached 0.36913 (best 0.36913), saving model to 'logs/unet/unet_roi/checkpoints/epoch=2-step=104226.ckpt' as top 1


Epoch 3: 100%|██████████| 41877/41877 [48:57<00:00, 14.26it/s, loss=0.318, v_num=15, validation_loss=0.368] 

Epoch 3, global step 138968: 'validation_loss' reached 0.36751 (best 0.36751), saving model to 'logs/unet/unet_roi/checkpoints/epoch=3-step=138968.ckpt' as top 1


Epoch 4: 100%|██████████| 41877/41877 [48:53<00:00, 14.27it/s, loss=0.399, v_num=15, validation_loss=0.364] 

Epoch 4, global step 173710: 'validation_loss' reached 0.36416 (best 0.36416), saving model to 'logs/unet/unet_roi/checkpoints/epoch=4-step=173710.ckpt' as top 1


Epoch 5: 100%|██████████| 41877/41877 [48:52<00:00, 14.28it/s, loss=0.404, v_num=15, validation_loss=0.364] 

Epoch 5, global step 208452: 'validation_loss' reached 0.36398 (best 0.36398), saving model to 'logs/unet/unet_roi/checkpoints/epoch=5-step=208452.ckpt' as top 1


Epoch 6: 100%|██████████| 41877/41877 [48:55<00:00, 14.27it/s, loss=0.459, v_num=15, validation_loss=0.363] 

Epoch 6, global step 243194: 'validation_loss' reached 0.36291 (best 0.36291), saving model to 'logs/unet/unet_roi/checkpoints/epoch=6-step=243194.ckpt' as top 1


Epoch 7: 100%|██████████| 41877/41877 [48:55<00:00, 14.27it/s, loss=0.34, v_num=15, validation_loss=0.362]  

Epoch 7, global step 277936: 'validation_loss' reached 0.36249 (best 0.36249), saving model to 'logs/unet/unet_roi/checkpoints/epoch=7-step=277936.ckpt' as top 1


Epoch 8: 100%|██████████| 41877/41877 [48:55<00:00, 14.26it/s, loss=0.351, v_num=15, validation_loss=0.362] 

Epoch 8, global step 312678: 'validation_loss' reached 0.36216 (best 0.36216), saving model to 'logs/unet/unet_roi/checkpoints/epoch=8-step=312678.ckpt' as top 1


Epoch 9: 100%|██████████| 41877/41877 [49:13<00:00, 14.18it/s, loss=0.264, v_num=15, validation_loss=0.361] 

Epoch 9, global step 347420: 'validation_loss' reached 0.36090 (best 0.36090), saving model to 'logs/unet/unet_roi/checkpoints/epoch=9-step=347420.ckpt' as top 1
`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 41877/41877 [49:13<00:00, 14.18it/s, loss=0.264, v_num=15, validation_loss=0.361]


In [10]:
ckpt_list = sorted(configs['default_root_dir'].glob("*.ckpt"), key=os.path.getmtime)
if ckpt_list:
    resume_from_checkpoint = str(ckpt_list[-1])

In [11]:
trainer = pl.Trainer(
    devices=configs['num_gpus'],
    max_epochs=configs['max_epochs'],
    default_root_dir=configs['default_root_dir'],
    accelerator=configs['backend'],
    callbacks=callbacks,
    logger=tensorboard,
    resume_from_checkpoint=resume_from_checkpoint,
)

trainer.fit(model, datamodule=data_module)

  rank_zero_deprecation(
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [4]:
# take one sample from data_module
#data_module.setup()
sample = next(iter(data_module.val_dataloader()))

In [5]:
sample.annotation

{'fname': [''],
 'slice': [''],
 'study_level': [''],
 'x': tensor([-1]),
 'y': tensor([-1]),
 'width': tensor([-1]),
 'height': tensor([-1]),
 'label': ['']}

In [6]:
sample.annotation['x'].device

device(type='cpu')

In [7]:
shape = sample.target.shape
annotation = {'fname': 'file1000001',
    'slice': 15,
    'study_level': 'No',
    'x': 117,
    'y': 146,
    'width': 20,
    'height': 12,
    'label': 'Bone- Subchondral edema'}

In [8]:
mask = torch.ones(shape)
x, y, w, h = annotation['x'], annotation['y'], annotation['width'], annotation['height']
if x >= 0 and y >= 0 and w > 0 and h > 0:
    mask[..., y:y+h, x:x+w] = 2

In [9]:
mask

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])

In [30]:
(y, y+h, x, x+w)

(146, 158, 117, 137)

In [16]:
data_path = fetch_dir("knee_path", path_config) / f"{configs['challenge']}_train"
dataset = AnnotatedSliceDataset(
    root=data_path,
    transform=val_transform,
    challenge=configs['challenge'],
    use_dataset_cache=False,
    raw_sample_filter=None,
    subsplit='knee',
    multiple_annotation_policy='all',
)

NameError: name 'AnnotatedSliceDataset' is not defined

In [14]:
s = dataset[0]

In [27]:
a = dataset.__getitem__(16)

In [28]:
a.annotation

{'fname': 'file1000001',
 'slice': 15,
 'study_level': 'No',
 'x': 117,
 'y': 146,
 'width': 20,
 'height': 12,
 'label': 'Bone- Subchondral edema'}

In [27]:
type(dataset)

src.mri_data.AnnotatedSliceDataset

In [3]:
trainer.fit(model, datamodule=data_module)

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name             | Type                 | Params
----------------------------------------------------------
0 | NMSE             | DistributedMetricSum | 0     
1 | SSIM             | DistributedMetricSum | 0     
2 | PSNR             | DistributedMetricSum | 0     
3 | ValLoss          | DistributedMetricSum | 0     
4 | TotExamples      | DistributedMetricSum | 0     
5 | TotSliceExamples | DistributedMetricSum | 0     
6 | unet             | Unet                 | 7.8 M 
----------------------------------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.024    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [8]:
import pickle
with open("./dataset_cache.pkl", "rb") as f:
    data = pickle.load(f)

In [9]:
data

{PosixPath('data/singlecoil_train'): [FastMRIRawDataSample(fname=PosixPath('data/singlecoil_train/file1000001.h5'), slice_ind=0, metadata={'padding_left': 19, 'padding_right': 354, 'encoding_size': (640, 372, 1), 'recon_size': (320, 320, 1), 'acquisition': 'CORPDFS_FBK', 'max': 0.000851878253624366, 'norm': 0.0596983310320022, 'patient_id': '0beb8905d9b7fad304389b9d4263c57d5b069257ea0fdc5bf7f2675608a47406'}),
  FastMRIRawDataSample(fname=PosixPath('data/singlecoil_train/file1000001.h5'), slice_ind=1, metadata={'padding_left': 19, 'padding_right': 354, 'encoding_size': (640, 372, 1), 'recon_size': (320, 320, 1), 'acquisition': 'CORPDFS_FBK', 'max': 0.000851878253624366, 'norm': 0.0596983310320022, 'patient_id': '0beb8905d9b7fad304389b9d4263c57d5b069257ea0fdc5bf7f2675608a47406'}),
  FastMRIRawDataSample(fname=PosixPath('data/singlecoil_train/file1000001.h5'), slice_ind=2, metadata={'padding_left': 19, 'padding_right': 354, 'encoding_size': (640, 372, 1), 'recon_size': (320, 320, 1), 'acq