In [3]:
import torch
import torch.nn as nn
import memcnn

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)


    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [4]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [5]:
unet_model = UNet(n_channels=3, n_classes=10)

In [6]:
from collections import OrderedDict
import copy

def conv2d_to_invertible(block, inplace=True) : 
    replace_modules = copy.deepcopy(block._modules)     
    #print("before for loop", replace_modules)
    for i, (name, module) in enumerate(block.named_modules()) : 
        #print(i, name)
        if '.' not in name and isinstance(module, torch.nn.Conv2d) \
             and not isinstance(module, memcnn.InvertibleModuleWrapper):
            in_c = module.in_channels
            out_c = module.out_channels
            k = module.kernel_size
            s = module.stride
            p = module.padding
            d = module.dilation
            t = module.transposed
            op = module.output_padding
            g = module.groups
            if in_c == out_c : 
                #print(name, module, "\t\t-->")
                fm_input_size = in_c // 2
                gm_input_size = in_c - fm_input_size
                conv2d = memcnn.InvertibleModuleWrapper(fn= \
                             memcnn.AdditiveCoupling(
                                    Fm=torch.nn.Conv2d(fm_input_size, fm_input_size, k, s, p, d, g),
                                    Gm=torch.nn.Conv2d(gm_input_size, gm_input_size, k, s, p, d, g),
                             ), keep_input=False, keep_input_inverse=False)
                replace_modules[name] = conv2d
                #print(conv2d)
            else : 
               #print("input dim and output dim is not matched")
                pass
    #print("after for loop", replace_modules)       
    block._modules = replace_modules

In [7]:
def dfs_conv2d_to_invertible(top_module, inplace=True) : 
    conv2d_to_invertible(top_module, inplace=True)
    for name, module in top_module._modules.items() : 
        if isinstance(module, memcnn.InvertibleModuleWrapper) : 
            continue
        #print(name, len(module._modules))
        if len(module._modules) > 0 : 
            dfs_conv2d_to_invertible(module)

In [8]:
unet_model

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, moment

In [9]:
dfs_conv2d_to_invertible(unet_model, inplace=True)

In [10]:
unet_model

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): InvertibleModuleWrapper(
        (_fn): AdditiveCoupling(
          (Gm): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (Fm): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
      )
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stat

In [1]:
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""

import os
import pathlib
from argparse import ArgumentParser

from fastmri.data.mri_data import fetch_dir
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri.data.transforms import UnetDataTransform
from fastmri.pl_modules import FastMriDataModule, UnetModule

In [1]:
import pathlib
from fastmri.data import subsample
from fastmri.data import transforms, mri_data

# Create a mask function
mask_func = subsample.RandomMaskFunc(
    center_fractions=[0.08, 0.04],
    accelerations=[4, 8]
)

def data_transform(kspace, mask, target, data_attributes, filename, slice_num):
    # Transform the data into appropriate format
    # Here we simply mask the k-space and return the result
    kspace = transforms.to_tensor(kspace)
    masked_kspace, _ = transforms.apply_mask(kspace, mask_func)
    return masked_kspace

dataset = mri_data.SliceDataset(
    root=pathlib.Path(
      './fastmri_data/singlecoil_val'
    ),
    transform=data_transform,
    challenge='singlecoil'
)

for masked_kspace in dataset:
    # Do reconstruction
    pass

In [8]:
from fastmri.data.mri_data import fetch_dir
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri.data.transforms import UnetDataTransform
from fastmri.pl_modules import FastMriDataModule, UnetModule


In [9]:
CHALLENGE = 'singlecoil'
MASK_TYPE = 'random'
center_fractions = [0.08]
accelerations = [4]


mask = create_mask_for_mask_type(
        MASK_TYPE, center_fractions, accelerations
    )

In [10]:
train_transform = UnetDataTransform(CHALLENGE, mask_func=mask, use_seed=False)

In [14]:
dataset = mri_data.SliceDataset(
    root=pathlib.Path(
      './fastmri_data/singlecoil_val'
    ),
    transform=train_transform,
    challenge='singlecoil'
)

In [15]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)

In [22]:
image, target, _, _, _, _, _ = next(iter(dataloader))

In [23]:
image.shape

torch.Size([1, 320, 320])

In [24]:
target.shape

torch.Size([1, 320, 320])

In [3]:
import torch
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1)

In [4]:
data_sample = next(iter(dataloader))

In [6]:
data_sample.shape

torch.Size([1, 640, 368, 2])