In [5]:
%load_ext autoreload
%autoreload 2

import pathlib
import os

import flax
from flax import jax_utils
import jax
import jax.numpy as jnp
import h5py
import matplotlib.pyplot as plt
import numpy as np
import optax
import torch
import tensorflow as tf
import tqdm

import init2winit
import fastmri

import i2w

from fastmri.models import unet as t_unet
from fastmri.pl_modules import data_module
from fastmri.pl_modules import unet_module
from fastmri.data.transforms import UnetDataTransform
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri.data.subsample import RandomMaskFunc

from init2winit.model_lib import unet as f_unet
from init2winit.dataset_lib import fastmri_dataset
from init2winit.dataset_lib import data_utils
from init2winit.optimizer_lib import optimizers
from init2winit.optimizer_lib import transform

jax.devices()

Extension horovod.torch has not been built: /opt/conda/lib/python3.7/site-packages/horovod/torch/mpi_lib/_mpi_lib.cpython-37m-x86_64-linux-gnu.so not found
If this is not expected, reinstall Horovod with HOROVOD_WITH_PYTORCH=1 to debug the build error.


[GpuDevice(id=0, process_index=0),
 GpuDevice(id=1, process_index=0),
 GpuDevice(id=2, process_index=0),
 GpuDevice(id=3, process_index=0),
 GpuDevice(id=4, process_index=0),
 GpuDevice(id=5, process_index=0),
 GpuDevice(id=6, process_index=0),
 GpuDevice(id=7, process_index=0)]

In [69]:
def get_prob(kspace, target, seed=None):
    num_cols = kspace.shape[0]
    num_low_frequencies = num_cols * 0.8
    prob = (num_cols / 4.0 - num_low_frequencies) / (
      num_cols - num_low_frequencies
    )
    return prob

In [70]:
def tf_acceleration_mask(kspace, target, seed=None):
    kspace_shape = kspace.shape
    target_shape = target.shape

    # sample_mask
    num_cols = kspace_shape[0]
    num_cols_float = tf.cast(num_cols, dtype=tf.float32)

    # choose_acceleration
    center_fraction = tf.convert_to_tensor(0.8, dtype=tf.float32)
    acceleration = tf.convert_to_tensor(4.0, dtype=tf.float32)

    num_low_frequencies = tf.cast(
      num_cols_float * center_fraction, dtype=tf.int32)

    # calculate_center_mask
    mask = tf.zeros(num_cols, dtype=tf.float32)
    pad = (num_cols - num_low_frequencies + 1) // 2
    mask = tf.tensor_scatter_nd_update(
      mask, tf.reshape(tf.range(pad, pad + num_low_frequencies), (-1, 1)),
      tf.ones(num_low_frequencies))

    # reshape_mask
    center_mask = tf.reshape(mask, (num_cols, 1))

    # calculate_acceleration_mask
    num_low_frequencies_float = tf.cast(num_low_frequencies, dtype=tf.float32)
    prob = (num_cols_float / acceleration - num_low_frequencies_float) / (
      num_cols_float - num_low_frequencies_float
    )
    
    mask = tf.cast(
      tf.random.stateless_uniform((num_cols,), seed) < prob,
      dtype=tf.float32)
    acceleration_mask = tf.reshape(mask, (num_cols, 1))
    
    return center_mask, acceleration_mask, num_low_frequencies

In [71]:
def pytorch_acceleration_mask(kspace, target, seed=None):
    data = kspace
    shape = (1,) * len(data.shape[:-3]) + tuple(data.shape[-3:])
    mask_fn = RandomMaskFunc([0.8], [4.0], seed=seed)
    
    num_cols = shape[-2]
    
    num_low_frequencies = round(num_cols * 0.8)
    
    center_mask = mask_fn.reshape_mask(
        mask_fn.calculate_center_mask(shape, num_low_frequencies), shape
    )
    acceleration_mask = mask_fn.reshape_mask(
        mask_fn.calculate_acceleration_mask(
            num_cols, 4.0, None, num_low_frequencies
        ),
        shape,
    )
    return center_mask, acceleration_mask, num_low_frequencies

In [72]:
directory = '/home/dsuo/singlecoil_train'
files = tf.io.gfile.listdir(directory)
paths = [os.path.join(directory, file) for file in files]

In [74]:
for path in tqdm.tqdm(paths):
    with tf.io.gfile.GFile(path, 'rb') as gf:
        path = gf
    with h5py.File(path, 'r') as hf:
        volume_max = hf.attrs.get('max', 0.0)
        
        for slice_idx in range(hf['kspace'].shape[0]):
            kspace = hf['kspace'][slice_idx]
            target = hf['reconstruction_esc'][slice_idx]
            
            t_c, t_a, t_n = tf_acceleration_mask(kspace, target, tf.cast(jax.random.PRNGKey(0), tf.int64))
            p_c, p_a, p_n = pytorch_acceleration_mask(kspace, target, 0)

            c_close = np.allclose(t_c.numpy(), p_c.detach().numpy())
            a_close = np.allclose(t_a.numpy(), p_a.detach().numpy())

            if not c_close or not a_close:
                print(os.path.basename(path), i, get_prob(kspace, target), c_close, a_close)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 973/973 [03:28<00:00,  4.67it/s]
