In [1]:
%load_ext autoreload
%autoreload 2

import functools
import inspect
import pathlib
import os

import argparse
import flax
from flax import jax_utils
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import torch
import tensorflow as tf
import tqdm

import init2winit
import fastmri

import i2w

import pytorch_lightning as pl
from pytorch_lightning import strategies
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus

from fastmri.models import unet as t_unet
from fastmri.pl_modules import data_module
from fastmri.pl_modules import unet_module
from fastmri.data.transforms import UnetDataTransform
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri_examples.unet import train_unet_demo

from init2winit.model_lib import unet as f_unet
from init2winit.dataset_lib import fastmri_dataset
from init2winit.dataset_lib import data_utils
from init2winit.optimizer_lib import optimizers
from init2winit.optimizer_lib import transform

jax.devices()

Extension horovod.torch has not been built: /opt/conda/lib/python3.7/site-packages/horovod/torch/mpi_lib/_mpi_lib.cpython-37m-x86_64-linux-gnu.so not found
If this is not expected, reinstall Horovod with HOROVOD_WITH_PYTORCH=1 to debug the build error.


[CpuDevice(id=0)]

# Train step with lightning

In [2]:
# Grabbed from `train_unet_demo.build_args()`.

args = argparse.Namespace(
    accelerations=[4],
    accelerator='gpu',  # Should be `ddp`, but not available in interactive mode
    accumulate_grad_batches=None,
    amp_backend='native',
    amp_level=None,
    auto_lr_find=False,
    auto_scale_batch_size=False,
    auto_select_gpus=False,
    batch_size=8,
    benchmark=False,
#     callbacks=[<pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint object at 0x7f409795e090>],
    center_fractions=[0.08],
    challenge='singlecoil',
    chans=32,
    check_val_every_n_epoch=1,
    checkpoint_callback=None,
    combine_train_val=False,
    data_path=pathlib.PosixPath('/home/dsuo'),
    default_root_dir=pathlib.PosixPath('unet/unet_demo'),
    detect_anomaly=False, deterministic=True,
    devices=None, drop_prob=0.0,
    enable_checkpointing=True,
    enable_model_summary=True,
    enable_progress_bar=True,
    fast_dev_run=False,
    flush_logs_every_n_steps=None,
    gpus=8,
    gradient_clip_algorithm=None,
    gradient_clip_val=None,
    in_chans=1,
    ipus=None,
    limit_predict_batches=1.0,
    limit_test_batches=1.0,
    limit_train_batches=1.0,
    limit_val_batches=1.0,
    log_every_n_steps=50,
    log_gpu_memory=None,
    logger=True,
    lr=0.001,
    lr_gamma=0.1,
    lr_step_size=40,
    mask_type=None,  # Should be `random`, but tying out without
    max_epochs=50,
    max_steps=-1,
    max_time=None,
    min_epochs=None,
    min_steps=None,
    mode='train',
    move_metrics_to_cpu=False,
    multiple_trainloader_mode='max_size_cycle',
    num_log_images=16,
    num_nodes=1,
    num_pool_layers=4,
    num_processes=1,
    num_sanity_val_steps=2,
    num_workers=4,
    out_chans=1,
    overfit_batches=0.0,
    plugins=None,
    precision=32,
    prepare_data_per_node=None,
    process_position=0,
    profiler=None,
    progress_bar_refresh_rate=None,
    reload_dataloaders_every_epoch=False,
    reload_dataloaders_every_n_epochs=0,
    replace_sampler_ddp=False,
    resume_from_checkpoint=None,
    sample_rate=None,
    seed=42,
    stochastic_weight_avg=False,
    strategy='dp',  # This should be None
    sync_batchnorm=False,
    terminate_on_nan=None,
    test_path=None,
    test_sample_rate=None,
    test_split='test',
    test_volume_sample_rate=None,
    tpu_cores=None,
    track_grad_norm=-1,
    use_dataset_cache_file=True,
    val_check_interval=1.0,
    val_sample_rate=None,
    val_volume_sample_rate=None,
    volume_sample_rate=None,
    weight_decay=0.0,
    weights_save_path=None,
    weights_summary='top')

In [3]:
train_transform = UnetDataTransform(args.challenge, mask_func=args.mask_type, use_seed=False)
val_transform = UnetDataTransform(args.challenge, mask_func=args.mask_type)
test_transform = UnetDataTransform(args.challenge)

dm = data_module.FastMriDataModule(
        data_path=args.data_path,
        challenge=args.challenge,
        train_transform=train_transform,
        val_transform=val_transform,
        test_transform=test_transform,
        test_split=args.test_split,
        test_path=args.test_path,
        sample_rate=args.sample_rate,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        distributed_sampler=(args.accelerator in ("ddp", "ddp_cpu")),
    )

## `fastmri`

In [4]:
t_model = unet_module.UnetModule()

## `init2winit`

In [5]:
f_model = f_unet.UNetModel(f_unet.DEFAULT_HPARAMS, {}, 'mean_absolute_error', 'image_reconstruction_metrics')
f_params = i2w.convert_params(t_model.cpu().unet, f_model.flax_module)

f_opt_init, f_opt_update = f_opt = optax.chain(transform.precondition_by_rms(decay=0.99), optax.scale_by_schedule(optax.piecewise_constant_schedule(init_value=-args.lr, boundaries_and_scales={args.lr_step_size: args.lr_gamma})))
f_opt_state = f_opt_init(f_params)

def loss(params, image, y):
    return jnp.abs(f_model.flax_module.apply(params, image) - y).mean()

def loop(t_batch, f_params, f_opt_state):
    f_batch = i2w.create_batch(t_batch)
    f_loss, grad = jax.value_and_grad(loss)(f_params, f_batch['inputs'].squeeze(), f_batch['targets'].squeeze())
    updates, f_opt_state = f_opt_update(grad, f_opt_state, f_params)
    f_params = optax.apply_updates(f_params, updates)
    f_out = f_model.flax_module.apply(f_params, f_batch['inputs'].squeeze())
    
    return f_loss, f_out, f_params, f_opt_state

Abs sum diff: 0.16495502


## Comparison

In [None]:
trainer = pl.Trainer.from_argparse_args(args)

# `Trainer.fit`
trainer.strategy.model = t_model

# `Trainer._fit_impl`
trainer.state.fn = TrainerFn.FITTING
trainer.state.status = TrainerStatus.RUNNING
trainer.training = True
trainer._last_train_dl_reload_epoch = float("-inf")
trainer._last_val_dl_reload_epoch = float("-inf")

trainer._data_connector.attach_data(t_model, datamodule=dm)

# `Trainer._run`
trainer.strategy.connect(t_model)
trainer._callback_connector._attach_model_callbacks()
trainer._callback_connector._attach_model_logging_functions()

pl.trainer.trainer.verify_loop_configurations(trainer)

trainer._data_connector.prepare_data()
trainer.strategy.setup_environment()
trainer._call_setup_hook()
trainer._call_configure_sharded_model()

# `Trainer._run` TRAIN
trainer.strategy.setup(trainer)

# `Trainer._run_stage`
trainer.strategy.barrier("run-stage")
trainer.strategy.dispatch(trainer)

#  `Trainer._run_train`
trainer._pre_training_routine()
trainer.model.train()
torch.set_grad_enabled(True)

trainer.fit_loop.trainer = trainer

# `Trainer.fit_loop.run`
trainer.fit_loop.reset()
trainer.fit_loop.on_run_start()
trainer.fit_loop.on_advance_start()

# `Trainer.fit_loop.advance`
trainer.fit_loop._data_fetcher.setup(trainer.train_dataloader, batch_to_device=functools.partial(trainer._call_strategy_hook, 'batch_to_device', dataloader_idx=0))

# `Trainer.fit_loop.epoch_loop.run`
trainer.fit_loop.epoch_loop.reset()
trainer.fit_loop.epoch_loop.on_run_start(trainer.fit_loop._data_fetcher)
trainer.fit_loop.epoch_loop.on_advance_start()

# `Trainer.fit_loop.epoch_loop.advance`
trainer.fit_loop.epoch_loop.val_loop.restarting = False
trainer.fit_loop.epoch_loop.batch_progress.is_last_batch = trainer.fit_loop._data_fetcher.done

data = []

for batch_idx, batch in enumerate(trainer.fit_loop._data_fetcher):

    kwargs = {'batch': batch, 'batch_idx': batch_idx}

    trainer.fit_loop.epoch_loop.batch_progress.increment_ready()
    trainer._logger_connector.on_batch_start(batch, batch_idx)

    trainer.fit_loop.epoch_loop.batch_progress.increment_started()

    # `Trainer.fit_loop.epoch_loop.batch_loop.run`
    trainer.fit_loop.epoch_loop.batch_loop.reset()
    trainer.fit_loop.epoch_loop.batch_loop.on_run_start(**kwargs)
    trainer.fit_loop.epoch_loop.batch_loop.on_advance_start()


    # `Trainer.fit_loop.epoch_loop.batch_loop.advance`
    trainer.fit_loop.epoch_loop.batch_loop.split_idx, kwargs["batch"] = trainer.fit_loop.epoch_loop.batch_loop._remaining_splits.pop(0)
    trainer._logger_connector.on_train_split_start(trainer.fit_loop.epoch_loop.batch_loop.split_idx)

    outputs = None

    optimizers = pl.loops.utilities._get_active_optimizers(
                    trainer.optimizers, trainer.optimizer_frequencies, kwargs.get("batch_idx", 0)
                )
    # `Trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.run`
    trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.reset()
    trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.on_run_start(kwargs['batch'], optimizers, kwargs['batch_idx'])
    trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.on_advance_start()

    # `Trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.advance`

    # `Trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._run_optimization`
    split_batch = kwargs['batch']
    batch_idx = kwargs['batch_idx']
    optimizer = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._optimizers[trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer_position]
    opt_idx = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optimizer_idx

    # trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._run_optimization_start(opt_idx, optimizer)
    # closure = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._make_closure(split_batch, batch_idx, opt_idx, optimizer)

    # print(trainer.strategy.handles_gradient_accumulation, trainer.fit_loop._should_accumulate())

    # closure()
    # result = closure.consume_result()

    result = trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._run_optimization(split_batch, batch_idx, optimizer, opt_idx)

    # `Trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.advance`
    if result.loss is not None:
         trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop._outputs[
             trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optimizer_idx] = result.asdict()
    trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.optimizer_position += 1
    
    t_out = trainer.strategy.model.module.module.unet(batch.image.unsqueeze(1).cuda()).cpu().detach().numpy()
    f_loss, f_out, f_params, f_opt_state = loop(batch, f_params, f_opt_state)
    print(batch_idx, result.loss, f_loss, np.mean(np.abs(t_out - f_out)))

# `Trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.run`
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.on_advance_end()
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.on_run_end()

# `Trainer.fit_loop.epoch_loop.batch_loop.advance`
trainer.fit_loop.epoch_loop.batch_loop._outputs.append(outputs)

# `Trainer.fit_loop.epoch_loop.batch_loop.run`
trainer.fit_loop.epoch_loop.batch_loop.on_advance_end()
batch_output = trainer.fit_loop.epoch_loop.batch_loop.on_run_end()

# `Trainer.fit_loop.epoch_loop.advance`
trainer.fit_loop.epoch_loop.batch_progress.increment_processed()

trainer.fit_loop.epoch_loop.update_lr_schedulers("step", update_plateau_schedulers=False)
if trainer.fit_loop.epoch_loop._num_ready_batches_reached():
    trainer.fit_loop.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=False)

batch_end_outputs = trainer.fit_loop.epoch_loop._prepare_outputs_training_batch_end(
    batch_output,
    lightning_module=trainer.lightning_module,
    num_optimizers=len(trainer.optimizers),
)

trainer._logger_connector.on_batch_end()
trainer.fit_loop.epoch_loop.batch_progress.increment_completed()
trainer._logger_connector.update_train_step_metrics()

# `Trainer.epoch_loop.run`
# trainer.fit_loop.epoch_loop.on_advance_end()
# trainer.fit_loop.epoch_loop.on_run_end()

# `Trainer.fit_loop.run`
# trainer.fit_loop.on_advance_end()
# trainer.fit_loop.on_run_end()

  rank_zero_warn("more than one device specific flag has been set")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(limit_test_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(limit_predict_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Training: 0it [00:00, ?it/s]



0 tensor(0.8750, device='cuda:0') 0.8750014 0.6709483
1 tensor(0.7343, device='cuda:0') 0.734457 1.0378088
2 tensor(0.4051, device='cuda:0') 0.40467966 0.6539176
3 tensor(0.5527, device='cuda:0') 0.55281216 0.61515653
4 tensor(0.4809, device='cuda:0') 0.47948065 0.6559525
5 tensor(0.2388, device='cuda:0') 0.23760484 0.40521106
6 tensor(0.2503, device='cuda:0') 0.24525006 0.47057593
7 tensor(0.2891, device='cuda:0') 0.29334444 0.6718571
8 tensor(0.3699, device='cuda:0') 0.38205132 0.77186805
9 tensor(0.2551, device='cuda:0') 0.34926352 0.6612922
10 tensor(0.1113, device='cuda:0') 0.11888986 0.6342341
11 tensor(0.1729, device='cuda:0') 0.13970041 0.5351649
12 tensor(0.2931, device='cuda:0') 0.23028748 0.50846624
13 tensor(0.1233, device='cuda:0') 0.24191312 0.5855674
14 tensor(0.1547, device='cuda:0') 0.07574652 0.67522186
15 tensor(0.1507, device='cuda:0') 0.119001985 0.48846024
16 tensor(0.1584, device='cuda:0') 0.15898554 0.6265027
17 tensor(0.1120, device='cuda:0') 0.12637505 0.49574

In [None]:
plt.plot(np.array(data))