In [1]:
%load_ext autoreload
%autoreload 2

import functools
import inspect
import pathlib
import os

import argparse
import flax
from flax import jax_utils
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import torch
import tensorflow as tf
import tqdm

import init2winit
import fastmri

import i2w

import pytorch_lightning as pl
from pytorch_lightning import strategies
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus

from fastmri.models import unet as t_unet
from fastmri.pl_modules import data_module
from fastmri.pl_modules import unet_module
from fastmri.data.transforms import UnetDataTransform
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri_examples.unet import train_unet_demo

from init2winit.model_lib import unet as f_unet
from init2winit.dataset_lib import fastmri_dataset
from init2winit.dataset_lib import data_utils
from init2winit.optimizer_lib import optimizers
from init2winit.optimizer_lib import transform

jax.devices()

Extension horovod.torch has not been built: /opt/conda/lib/python3.7/site-packages/horovod/torch/mpi_lib/_mpi_lib.cpython-37m-x86_64-linux-gnu.so not found
If this is not expected, reinstall Horovod with HOROVOD_WITH_PYTORCH=1 to debug the build error.


[CpuDevice(id=0)]

# Train step with lightning

In [2]:
# Grabbed from `train_unet_demo.build_args()`.

args = argparse.Namespace(
    accelerations=[4],
    accelerator='gpu',  # Should be `ddp`, but not available in interactive mode
    accumulate_grad_batches=None,
    amp_backend='native',
    amp_level=None,
    auto_lr_find=False,
    auto_scale_batch_size=False,
    auto_select_gpus=False,
    batch_size=8,
    benchmark=False,
#     callbacks=[<pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint object at 0x7f409795e090>],
    center_fractions=[0.08],
    challenge='singlecoil',
    chans=32,
    check_val_every_n_epoch=1,
    checkpoint_callback=None,
    combine_train_val=False,
    data_path=pathlib.PosixPath('/home/dsuo'),
    default_root_dir=pathlib.PosixPath('unet/unet_demo'),
    detect_anomaly=False, deterministic=True,
    devices=None, drop_prob=0.0,
    enable_checkpointing=True,
    enable_model_summary=True,
    enable_progress_bar=True,
    fast_dev_run=False,
    flush_logs_every_n_steps=None,
    gpus=8,
    gradient_clip_algorithm=None,
    gradient_clip_val=None,
    in_chans=1,
    ipus=None,
    limit_predict_batches=1.0,
    limit_test_batches=1.0,
    limit_train_batches=1.0,
    limit_val_batches=1.0,
    log_every_n_steps=50,
    log_gpu_memory=None,
    logger=True,
    lr=0.001,
    lr_gamma=0.1,
    lr_step_size=40,
    mask_type=None,  # Should be `random`, but tying out without
    max_epochs=50,
    max_steps=-1,
    max_time=None,
    min_epochs=None,
    min_steps=None,
    mode='train',
    move_metrics_to_cpu=False,
    multiple_trainloader_mode='max_size_cycle',
    num_log_images=16,
    num_nodes=1,
    num_pool_layers=4,
    num_processes=1,
    num_sanity_val_steps=2,
    num_workers=4,
    out_chans=1,
    overfit_batches=0.0,
    plugins=None,
    precision=32,
    prepare_data_per_node=None,
    process_position=0,
    profiler=None,
    progress_bar_refresh_rate=None,
    reload_dataloaders_every_epoch=False,
    reload_dataloaders_every_n_epochs=0,
    replace_sampler_ddp=False,
    resume_from_checkpoint=None,
    sample_rate=None,
    seed=42,
    stochastic_weight_avg=False,
    strategy='dp',  # This should be None
    sync_batchnorm=False,
    terminate_on_nan=None,
    test_path=None,
    test_sample_rate=None,
    test_split='test',
    test_volume_sample_rate=None,
    tpu_cores=None,
    track_grad_norm=-1,
    use_dataset_cache_file=True,
    val_check_interval=1.0,
    val_sample_rate=None,
    val_volume_sample_rate=None,
    volume_sample_rate=None,
    weight_decay=0.0,
    weights_save_path=None,
    weights_summary='top')

In [3]:
train_transform = UnetDataTransform(args.challenge, mask_func=args.mask_type, use_seed=False)
val_transform = UnetDataTransform(args.challenge, mask_func=args.mask_type)
test_transform = UnetDataTransform(args.challenge)

dm = data_module.FastMriDataModule(
        data_path=args.data_path,
        challenge=args.challenge,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        test_split=args.test_split,
        test_path=args.test_path,
        sample_rate=args.sample_rate,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        distributed_sampler=(args.accelerator in ("ddp", "ddp_cpu")),
    )

## `fastmri`

In [4]:
t_model = unet_module.UnetModule()

## Comparison

In [5]:
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(t_model, datamodule=dm)

  rank_zero_warn("more than one device specific flag has been set")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(limit_test_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(limit_predict_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name             | Type                 | Params
----------------------------------------------------------
0 | NMSE             | DistributedMetricSum | 0     
1 | SSIM             | DistributedMetricSum | 0     
2 | PSNR             | DistributedMetricSum 

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [6]:
plt.plot(np.array(data))

NameError: name 'data' is not defined

# Notes
After Epoch 0:
- We see train/valid loss of 0.012/0.026 in this notebook (widget gets deleted)
- We see train/valid loss of 0.258/0.300 with `python train_unet_demo.py --challenge singlecoil --mask_type random`
- We see train/valid loss of 0.197/0.318 with `python train_unet_demo.py --challenge singlecoil --mask_type random --accelerator gpu --batch-size 8`