In [1]:
%load_ext autoreload
%autoreload 2

import pathlib
import os

import flax
from flax import jax_utils
import jax
import jax.numpy as jnp
import numpy as np
import optax
import torch
import tensorflow as tf
import tqdm

import init2winit
import fastmri

import i2w

from fastmri.models import unet as t_unet
from fastmri.pl_modules import data_module
from fastmri.pl_modules import unet_module
from fastmri.data.transforms import UnetDataTransform
from fastmri.data.subsample import create_mask_for_mask_type

from init2winit.model_lib import unet as f_unet
from init2winit.dataset_lib import fastmri_dataset
from init2winit.dataset_lib import data_utils
from init2winit.optimizer_lib import optimizers

jax.devices()

[GpuDevice(id=0, process_index=0),
 GpuDevice(id=1, process_index=0),
 GpuDevice(id=2, process_index=0),
 GpuDevice(id=3, process_index=0),
 GpuDevice(id=4, process_index=0),
 GpuDevice(id=5, process_index=0),
 GpuDevice(id=6, process_index=0),
 GpuDevice(id=7, process_index=0)]

# Forward pass

In [2]:
data_path = pathlib.Path('../../')
challenge = 'singlecoil'
center_fractions = [0.08]
accelerations = [4]
mask_type = None  # Should be `random`, but tying out without
test_split = 'test'
batch_size = 8

In [3]:
def create_batch(t_batch):
    batch_size = len(t_batch.slice_num)
    inputs = [i2w.get_slice(os.path.join('../../singlecoil_train', t_batch.fname[i]), t_batch.slice_num[i]) for i in range(batch_size)]
    processed = [fastmri_dataset._process_example(*input, tf.cast(jax.random.PRNGKey(0), tf.int64)) for input in inputs]
    batched = [{key: tf.expand_dims(value, 0) for key, value in f.items()} for f in processed]

    f_batch = {}
    for key in batched[0].keys():
        f_batch[key] = tf.concat([batched[i][key] for i in range(batch_size)], axis=0).numpy()
        
    return f_batch

In [4]:
# mask = create_mask_for_mask_type(mask_type, center_fractions, accelerations)
mask = None

train_transform = UnetDataTransform(challenge, mask_func=mask, use_seed=False)
val_transform = UnetDataTransform(challenge, mask_func=mask)
test_transform = UnetDataTransform(challenge)

dm8 = data_module.FastMriDataModule(
    data_path=data_path,
    challenge=challenge,
    train_transform=train_transform,
    val_transform=val_transform,
    test_transform=test_transform,
    test_split=test_split,
    batch_size=batch_size,
    num_workers=4,
)

dl8 = iter(dm8.train_dataloader())

## `fastmri`

In [5]:
t_model = unet_module.UnetModule()

## `init2winit`

In [6]:
f_model = f_unet.UNetModel(f_unet.DEFAULT_HPARAMS, {}, 'mean_absolute_error', 'image_reconstruction_metrics')
f_params = i2w.convert_params(t_model.unet, f_model.flax_module)

Abs sum diff: 0.104931116


In [7]:
count = 0
for i, t_batch in enumerate(dl8):
    if count == 10:
        break
    t_loss = t_model.training_step(t_batch, i).detach().numpy()
    f_batch = data_utils.shard(create_batch(t_batch))
    f_out = jax_utils.unreplicate(jax.pmap(f_model.evaluate_batch, axis_name='batch')(
        jax_utils.replicate(f_params['params']),
        jax_utils.replicate(jax.tree_map(jnp.zeros_like, f_params['params'])),
        f_batch)).compute()
    print(t_loss - f_out['l1_loss'])
    count += 1

-3.874302e-06
-3.8146973e-06
-3.6358833e-06
-3.695488e-06
-3.5762787e-06
-3.993511e-06
-3.7550926e-06
-3.8146973e-06
-3.874302e-06
-4.2915344e-06
