In [1]:
%load_ext autoreload
%autoreload 2

import argparse
import pathlib
import os
from typing import Dict, Tuple

import flax
from flax import jax_utils
import jax
import jax.numpy as jnp
import h5py
import matplotlib.pyplot as plt
import numpy as np
import optax
import torch
import tensorflow as tf
import tqdm

import init2winit
import fastmri

import i2w

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback

from fastmri.models import unet as t_unet
from fastmri.pl_modules import data_module
from fastmri.pl_modules import unet_module
from fastmri.data.mri_data import SliceDataset
from fastmri.data.transforms import *
from fastmri.data.subsample import *

from init2winit.model_lib import unet as f_unet
from init2winit.dataset_lib import fastmri_dataset
from init2winit.dataset_lib import data_utils
from init2winit.optimizer_lib import optimizers
from init2winit.optimizer_lib import transform
from init2winit.model_lib import metrics

jax.devices()

Extension horovod.torch has not been built: /opt/conda/lib/python3.7/site-packages/horovod/torch/mpi_lib/_mpi_lib.cpython-37m-x86_64-linux-gnu.so not found
If this is not expected, reinstall Horovod with HOROVOD_WITH_PYTORCH=1 to debug the build error.


[CpuDevice(id=0)]

In [2]:
# Grabbed from `train_unet_demo.build_args()`.

args = argparse.Namespace(
    accelerations=[4],
    accelerator='gpu',  # Should be `ddp`, but not available in interactive mode
    accumulate_grad_batches=None,
    amp_backend='native',
    amp_level=None,
    auto_lr_find=False,
    auto_scale_batch_size=False,
    auto_select_gpus=False,
    batch_size=8,
    benchmark=False,
#     callbacks=[<pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint object at 0x7f409795e090>],
    center_fractions=[0.08],
    challenge='singlecoil',
    chans=32,
    check_val_every_n_epoch=1,
    checkpoint_callback=None,
    combine_train_val=False,
    data_path=pathlib.PosixPath('/home/dsuo'),
    default_root_dir=pathlib.PosixPath('unet/unet_demo'),
    detect_anomaly=False, deterministic=True,
    devices=None, drop_prob=0.0,
    enable_checkpointing=True,
    enable_model_summary=True,
    enable_progress_bar=True,
    fast_dev_run=False,
    flush_logs_every_n_steps=None,
    gpus=8,
    gradient_clip_algorithm=None,
    gradient_clip_val=None,
    in_chans=1,
    ipus=None,
    limit_predict_batches=1.0,
    limit_test_batches=1.0,
    limit_train_batches=1.0,
    limit_val_batches=1.0,
    log_every_n_steps=50,
    log_gpu_memory=None,
    logger=True,
    lr=0.001,
    lr_gamma=0.1,
    lr_step_size=40,
    mask_type='random',  # Should be `random`, but tying out without
    max_epochs=50,
    max_steps=-1,
    max_time=None,
    min_epochs=None,
    min_steps=None,
    mode='train',
    move_metrics_to_cpu=False,
    multiple_trainloader_mode='max_size_cycle',
    num_log_images=16,
    num_nodes=1,
    num_pool_layers=4,
    num_processes=1,
    num_sanity_val_steps=2,
    num_workers=4,
    out_chans=1,
    overfit_batches=0.0,
    plugins=None,
    precision=32,
    prepare_data_per_node=None,
    process_position=0,
    profiler=None,
    progress_bar_refresh_rate=None,
    reload_dataloaders_every_epoch=False,
    reload_dataloaders_every_n_epochs=0,
    replace_sampler_ddp=False,
    resume_from_checkpoint=None,
    sample_rate=None,
    seed=42,
    stochastic_weight_avg=False,
    strategy='dp',  # This should be None
    sync_batchnorm=False,
    terminate_on_nan=None,
    test_path=None,
    test_sample_rate=None,
    test_split='test',
    test_volume_sample_rate=None,
    tpu_cores=None,
    track_grad_norm=-1,
    use_dataset_cache_file=True,
    val_check_interval=1.0,
    val_sample_rate=None,
    val_volume_sample_rate=None,
    volume_sample_rate=None,
    weight_decay=0.0,
    weights_save_path=None,
    weights_summary='top')

In [3]:
pl.seed_everything(args.seed)

# ------------
# data
# ------------
# this creates a k-space mask for transforming input data
mask = create_mask_for_mask_type(
    args.mask_type, args.center_fractions, args.accelerations
)

mask = None

# use random masks for train transform, fixed masks for val transform
train_transform = UnetDataTransform(args.challenge, mask_func=mask, use_seed=False)
val_transform = UnetDataTransform(args.challenge, mask_func=mask)
test_transform = UnetDataTransform(args.challenge)
# ptl data module - this handles data loaders
dm = data_module.FastMriDataModule(
    data_path=args.data_path,
    challenge=args.challenge,
    train_transform=train_transform,
    val_transform=val_transform,
    test_transform=test_transform,
    test_split=args.test_split,
    test_path=args.test_path,
    sample_rate=args.sample_rate,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    distributed_sampler=(args.accelerator in ("ddp", "ddp_cpu")),
)

# ------------
# model
# ------------
model = unet_module.UnetModule(
    in_chans=args.in_chans,
    out_chans=args.out_chans,
    chans=args.chans,
    num_pool_layers=args.num_pool_layers,
    drop_prob=args.drop_prob,
    lr=args.lr,
    lr_step_size=args.lr_step_size,
    lr_gamma=args.lr_gamma,
    weight_decay=args.weight_decay,
)

Global seed set to 42


In [4]:
f_model = f_unet.UNetModel(f_unet.DEFAULT_HPARAMS, {}, 'mean_absolute_error', 'image_reconstruction_metrics')
f_params = i2w.convert_params(model.cpu().unet, f_model.flax_module)

Abs sum diff: 0.14279632


In [11]:
ssim = {}
our_ssim = {}

class cb(Callback):
    def on_validation_batch_end(self, trainer, module, outputs, batch, batch_idx, dataloader_idx):
        ssim_vals = outputs['ssim_vals']
        output = model(batch.image.cuda()).cpu()
        mean = batch.mean.unsqueeze(1).unsqueeze(2)
        std = batch.std.unsqueeze(1).unsqueeze(2)
        output = output * std + mean
        target = batch.target * std + mean
        for image, target, max_value, file, slice_idx in zip(output, target, batch.max_value, batch.fname, batch.slice_num):
            if file not in our_ssim:
                our_ssim[file] = {}
            our_ssim[file][slice_idx.detach().numpy().item()] = metrics.structural_similarity(image.detach().numpy(), target.detach().numpy(), max_value.detach().numpy().item())
        for file in ssim_vals:
            if file not in ssim:
                ssim[file] = {}
            for slice_idx in ssim_vals[file]:
                ssim[file][slice_idx] = ssim_vals[file][slice_idx]
                print(our_ssim[file][slice_idx], ssim[file][slice_idx])
args.callbacks = [cb()]

In [12]:
trainer = pl.Trainer.from_argparse_args(args)
trainer._data_connector.attach_data(model, datamodule=dm)

  rank_zero_warn("more than one device specific flag has been set")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(limit_test_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(limit_predict_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..


In [13]:
trainer.validate(model, dm)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Validation: 0it [00:00, ?it/s]

0.6171981 tensor([0.6172], dtype=torch.float64)
0.54074556 tensor([0.5407], dtype=torch.float64)
0.48976827 tensor([0.4898], dtype=torch.float64)
0.45737946 tensor([0.4574], dtype=torch.float64)
0.45630378 tensor([0.4563], dtype=torch.float64)
0.44610286 tensor([0.4461], dtype=torch.float64)
0.432169 tensor([0.4322], dtype=torch.float64)
0.39749438 tensor([0.3975], dtype=torch.float64)
0.3542205 tensor([0.3542], dtype=torch.float64)
0.29675758 tensor([0.2968], dtype=torch.float64)
0.27975807 tensor([0.2798], dtype=torch.float64)
0.27084792 tensor([0.2708], dtype=torch.float64)
0.28663078 tensor([0.2866], dtype=torch.float64)
0.28662282 tensor([0.2866], dtype=torch.float64)
0.29736933 tensor([0.2974], dtype=torch.float64)
0.30722082 tensor([0.3072], dtype=torch.float64)
0.3196867 tensor([0.3197], dtype=torch.float64)
0.33244863 tensor([0.3324], dtype=torch.float64)
0.33592173 tensor([0.3359], dtype=torch.float64)
0.33921346 tensor([0.3392], dtype=torch.float64)
0.33476886 tensor([0.3348

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [19]:
for file in ssim:
    path = os.path.join('/home/dsuo/singlecoil_val', file)
    for slice_idx in ssim[file]:
        inputs = i2w.get_slice(path, slice_idx)
        data = fastmri_dataset._process_example(*inputs, seed=tf.cast(jax.random.PRNGKey(0), tf.int64))
        f_out = f_model.flax_module.apply(f_params, data['inputs'].numpy())
        std = data['std'].numpy()
        mean = data['mean'].numpy()
        ss = metrics.structural_similarity(f_out * std + mean, data['targets'].numpy() * std + mean, inputs[-1])
        print(ss, ssim[file][slice_idx])

0.5535585 tensor([0.6172], dtype=torch.float64)
0.5451602 tensor([0.5407], dtype=torch.float64)
0.52114695 tensor([0.4898], dtype=torch.float64)
0.49573553 tensor([0.4574], dtype=torch.float64)
0.48832178 tensor([0.4563], dtype=torch.float64)
0.47254547 tensor([0.4461], dtype=torch.float64)
0.45443764 tensor([0.4322], dtype=torch.float64)
0.41712266 tensor([0.3975], dtype=torch.float64)
0.37166265 tensor([0.3542], dtype=torch.float64)
0.3166698 tensor([0.2968], dtype=torch.float64)
0.2916299 tensor([0.2798], dtype=torch.float64)
0.28306213 tensor([0.2708], dtype=torch.float64)
0.29582402 tensor([0.2866], dtype=torch.float64)
0.29419816 tensor([0.2866], dtype=torch.float64)
0.30037448 tensor([0.2974], dtype=torch.float64)
0.3047076 tensor([0.3072], dtype=torch.float64)
0.31222594 tensor([0.3197], dtype=torch.float64)
0.32618305 tensor([0.3324], dtype=torch.float64)
0.32425913 tensor([0.3359], dtype=torch.float64)
0.3262578 tensor([0.3392], dtype=torch.float64)
0.32289833 tensor([0.3348]

In [None]:
trainer.datamodule

In [None]:
data_module

In [None]:
for file in ssim_vals:
            path = os.path.join('/home/dsuo/singlecoil_val', file)
            for slice_idx in ssim_vals[file]:
                inputs = i2w.get_slice(path, slice_idx)
                data = fastmri_dataset._process_example(*inputs, seed=tf.cast(jax.random.PRNGKey(0), tf.int64))
                f_out = f_model.flax_module.apply(f_params, data['inputs'])