In [18]:
from fastmri.data.mri_data import fetch_dir
from fastmri.data.subsample import create_mask_for_mask_type
from fastmri.data.transforms import UnetDataTransform
from fastmri.pl_modules import FastMriDataModule, UnetModule
from fastmri.data import transforms, mri_data
import pathlib
import torch
import matplotlib.pyplot as plt
import numpy as np


CHALLENGE = 'singlecoil'
MASK_TYPE = 'random'
center_fractions = [0.08]
accelerations = [4]


mask = create_mask_for_mask_type(
        MASK_TYPE, center_fractions, accelerations
    )

train_transform = UnetDataTransform(CHALLENGE, mask_func=mask, use_seed=False)

dataset = mri_data.SliceDataset(
    root=pathlib.Path(
      './fastmri_data/singlecoil_val'
    ),
    transform=train_transform,
    challenge='singlecoil'
)

dataloader = torch.utils.data.DataLoader(dataset, batch_size=10)
image, target, _, _, _, _, _ = next(iter(dataloader))

In [1]:

import InvertCnnConverter
import torch
import Unet

In [32]:
import importlib
importlib.reload(InvertCnnConverter)

<module 'InvertCnnConverter' from '/home/hsyang/workspace/20210411_SYS/memcnn-unet/InvertCnnConverter.py'>

In [43]:
plain_model = Unet.UNet(n_channels=1, n_classes=1)

In [44]:
plain_model

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU()
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_s

In [45]:
#InvertCnnConverter.top_forward_to_checkpoint(plain_model, last_module_name='outc')
InvertCnnConverter.convert_module(plain_model, last_module_name='outc', inplace=True, option=['invertible', 'checkpoint'])
invert_model = plain_model

device = 'cuda:0'

invert_model = invert_model.to(device)
data = image.view(-1,1,320,320)
target = target.to(device)

print(invert_model)

UNet(
  (inc): CheckpointModule(
    (module): DoubleConv(
      (double_conv): Sequential(
        (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU()
      )
    )
  )
  (down1): CheckpointModule(
    (module): Down(
      (maxpool_conv): Sequential(
        (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (1): DoubleConv(
          (double_conv): Sequential(
            (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
            (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1),

In [20]:
criterion = torch.nn.MSELoss()
optim = torch.optim.Adam(invert_model.parameters(), lr=1e-3)
with torch.autograd.set_detect_anomaly(True) : 
    for epoch in range(10) : 
        result = invert_model(data)
        loss = criterion(result, target)
        optim.zero_grad()
        loss.backward()
        optim.step()
        print(epoch, loss.item())

0 0.9515864253044128
1 0.9432873725891113
2 0.9359803199768066
3 0.9295905828475952
4 0.9240349531173706
5 0.9191995859146118
6 0.9149433374404907
7 0.9111204743385315
8 0.9075976610183716
9 0.9042671918869019


In [37]:
device=torch.device("cuda:0")
torch.nn.Conv2d(1,1,3,1,1).to(device)(torch.ones(1,1,10,10).to(device))

def debug_memory(title, v=False):
    alloc = torch.cuda.memory_allocated(device)/1024/1024
    max_alloc = torch.cuda.max_memory_allocated(device)/1024/1024
    reserved = torch.cuda.memory_reserved(device)/1024/1024
    max_reserved = torch.cuda.max_memory_reserved(device)/1024/1024
    if v:
        print(f'[{title:>10s}] alloc={alloc:.0f} / {max_alloc:.0f} MB, reserved={reserved:.0f} / {max_reserved:.0f} MB ')
    return alloc, max_alloc

class Memlog:
    def __init__(self, name):
        self.name = name

    def __call__(self, fn):
        def wrap(*args, **kwargs):
            torch.cuda.empty_cache()
            torch.cuda.reset_max_memory_allocated(device)
            torch.cuda.reset_max_memory_cached(device)    
            a0, ma0 = debug_memory("")
            fn(*args, **kwargs)
            a1, ma1 = debug_memory("")
            torch.cuda.empty_cache()
            torch.cuda.reset_max_memory_allocated(device)
            torch.cuda.reset_max_memory_cached(device)
            # print(f'[{self.name:>10s}]: {ma1-ma0:.0f}')
            #print(f'{ma1-ma0:.0f}')
            return ma1-ma0
        return wrap

In [40]:
@Memlog("original")
def original(size, channel, device):   
    input = torch.ones(1,channel,size,size, device=device, requires_grad=True)
    plain_model = Unet.UNet(n_channels=channel, n_classes=1).to(device)
    out = plain_model(input)
    loss = out.sum()
    loss.backward()
    
@Memlog("converted")
def converted(size, channel, device, option):   
    input = torch.ones(1,channel,size,size, requires_grad=True)
    if 'stitchable' not in option : 
        input = input.to(device)
    plain_model = Unet.UNet(n_channels=channel, n_classes=1)
    InvertCnnConverter.convert_module(plain_model, last_module_name='outc', inplace=True, option=option)
    invert_model = plain_model.to(device)
    out = invert_model(input)
    loss = out.sum()
    loss.backward()

In [24]:
original(320,1,device)



767


767.029296875

In [48]:
for size in range(100, 10001, 100) : 
    for ch in [1] + list(range(10, 101, 10)) : 
        original_size = original(size,ch,device)
        convert_all_size = converted(size,ch,device, option=['checkpoint', 'invertible', 'stitchable'])
        convert_not_stitch_size = converted(size,ch,device, option=['checkpoint', 'invertible'])
        print("size:%dx%d, channel:%d, original:%.2f, convert all:%.2f, convert all but not stitch:%.2f" \
              % (size, size, ch, original_size, convert_all_size, convert_not_stitch_size))
        

size:100x100, channel:1, original:217.02, convert all:184.71, convert all but not stitch:185
size:100x100, channel:10, original:217.39, convert all:184.71, convert all but not stitch:185
size:100x100, channel:20, original:217.79, convert all:184.71, convert all but not stitch:186
size:100x100, channel:30, original:217.17, convert all:184.71, convert all but not stitch:186
size:100x100, channel:40, original:217.20, convert all:184.71, convert all but not stitch:186
size:100x100, channel:50, original:217.22, convert all:184.71, convert all but not stitch:186
size:100x100, channel:60, original:218.80, convert all:184.71, convert all but not stitch:187
size:100x100, channel:70, original:218.82, convert all:184.71, convert all but not stitch:187
size:100x100, channel:80, original:218.84, convert all:184.71, convert all but not stitch:188
size:100x100, channel:90, original:219.62, convert all:184.71, convert all but not stitch:189
size:100x100, channel:100, original:219.64, convert all:184.7

KeyboardInterrupt: 