In [1]:
import os
import torch.nn as nn
import numpy as np
from dloader import genDataLoader 
import sigpy as sp
import sigpy.plot as pl
from fastmri.data import transforms
from collections import Counter
import torch
from icecream import ic
import matplotlib.pyplot as plt

from collections import Counter
import numpy as np

import torch
import torch.nn as nn

from typing import List

import fastmri
from fastmri.models.varnet import NormUnet

import logging

%matplotlib inline

In [2]:
logging.basicConfig(filename="MTL_Loss_iteration.txt",
                level=logging.DEBUG,
                format='%(levelname)s: %(asctime)s %(message)s',
                datefmt='%m/%d/%Y %I:%M:%S')

In [3]:
class VarNetBlockShared(nn.Module):
    """
    Hard-coded for only two contrasts
    """
    def __init__(self, model: nn.Module):
        super().__init__()

        self.model = model
        self.eta1 = nn.Parameter(torch.ones(1)) # for eta1
        self.eta2 = nn.Parameter(torch.ones(1)) # for eta2
        
    def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
        return fastmri.fft2c(fastmri.complex_mul(x, sens_maps)) # F*S operator

    def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
        x = fastmri.ifft2c(x)
        return fastmri.complex_mul(x, fastmri.complex_conj(sens_maps)).sum(
            dim=1, keepdim=True
        ) # S^H * F^H operator

    def forward(
        self,
        current_kspace: torch.Tensor,
        ref_kspace: torch.Tensor,
        mask: torch.Tensor,
        sens_maps: torch.Tensor,
        int_contrast: int
    ) -> torch.Tensor:
        '''
        note that contrast is not str, but rather int index of opt.datasets
        this is implemented in the VarNet portion
        '''

        mask = mask.bool()
        zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
        if int_contrast == 0:
            soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.eta1
        else:
            soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.eta2

        model_term = self.sens_expand(
            self.model(
                self.sens_reduce(current_kspace, sens_maps)), 
                sens_maps
        )

        return current_kspace - soft_dc - model_term
    
    
class VarNetBlockUnshared(nn.Module):

    def __init__(self, model1: nn.Module, model2:nn.Module):
        super().__init__()

        self.model1 = model1
        self.eta1 = nn.Parameter(torch.ones(1))
        
        self.model2 = model2
        self.eta2 = nn.Parameter(torch.ones(1))
 
    def sens_expand(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
        return fastmri.fft2c(fastmri.complex_mul(x, sens_maps)) # F*S operator

    def sens_reduce(self, x: torch.Tensor, sens_maps: torch.Tensor) -> torch.Tensor:
        x = fastmri.ifft2c(x)
        return fastmri.complex_mul(x, fastmri.complex_conj(sens_maps)).sum(
            dim=1, keepdim=True
        ) # S^H * F^H operator

    def forward(
        self,
        current_kspace: torch.Tensor,
        ref_kspace: torch.Tensor,
        mask: torch.Tensor,
        sens_maps: torch.Tensor,
        int_contrast: int
    ) -> torch.Tensor:
        '''
        note that contrast is not str, but rather int index of opt.datasets
        this is implemented in the VarNet portion
        '''
        assert int_contrast == 0 or int_contrast == 1, 'Only two contrasts are allowed'
    
        mask = mask.bool()
        zero = torch.zeros(1, 1, 1, 1, 1).to(current_kspace)
        
        if int_contrast == 0:
            soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.eta1
            model_term = self.sens_expand(
            self.model1(
                self.sens_reduce(current_kspace, sens_maps)), 
                sens_maps
            )
            
            return current_kspace - soft_dc - model_term
        else:
            soft_dc = torch.where(mask, current_kspace - ref_kspace, zero) * self.eta2
            model_term = self.sens_expand(
            self.model2(
                self.sens_reduce(current_kspace, sens_maps)), 
                sens_maps
            )
            
            return current_kspace - soft_dc - model_term

class VarNet(nn.Module):
    """
    A full variational network model.

    This model applies a combination of soft data consistency with a U-Net
    regularizer. To use non-U-Net regularizers, use VarNetBock.
    """
    def __init__(
        self,
        num_shared: int = 3,  # unshared shared unshared shared unshared shared (this is 6 unrolls)
        num_final_unshared: int = 4,  # multi head blocks
        chans: int = 12,
        pools: int = 4,
    ):
        super().__init__()

        self.shared_cascades = nn.ModuleList()
        self.unshared_cascades = nn.ModuleList()

        for it in range(num_shared):
            module_name = 'Shared' + str(it)
            self.shared_cascades.add_module(
                module_name, VarNetBlockShared(NormUnet(chans, pools)))
            module_name = 'UnShared' + str(it)
            self.shared_cascades.add_module(
                module_name,
                VarNetBlockUnshared(NormUnet(chans, pools),
                                    NormUnet(chans, pools)))

        for it in range(num_final_unshared): # This is final multiple heads
            module_name = 'FinalUnShared' + str(it)
            self.unshared_cascades.add_module(
                module_name,
                VarNetBlockUnshared(NormUnet(chans, pools),
                                    NormUnet(chans, pools)))

    def forward(self, masked_kspace: torch.Tensor, mask: torch.Tensor,
                esp_maps: torch.Tensor, int_contrast:int) -> torch.Tensor:
        kspace_pred = masked_kspace.clone()

        for cascade in self.shared_cascades:
            kspace_pred = cascade(kspace_pred, masked_kspace, mask, esp_maps, int_contrast)
        for cascade in self.unshared_cascades:
            kspace_pred = cascade(kspace_pred, masked_kspace, mask, esp_maps, int_contrast)
        
        im_coil = fastmri.ifft2c(kspace_pred)
        im_comb = fastmri.complex_mul(im_coil, fastmri.complex_conj(esp_maps)).sum(
            dim=1, keepdim=True
        )
        
        return im_comb

In [4]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

In [5]:
datasets = [
    'div_coronal_pd_fs',
    'div_coronal_pd'
]

datadir = '/mnt/dense/vliu/summer_dset/'

In [6]:
basedirs = [
    os.path.join(datadir, dataset)
    for dataset in datasets
]
ic(basedirs)

ic| basedirs: ['/mnt/dense/vliu/summer_dset/div_coronal_pd_fs',
               '/mnt/dense/vliu/summer_dset/div_coronal_pd']


['/mnt/dense/vliu/summer_dset/div_coronal_pd_fs',
 '/mnt/dense/vliu/summer_dset/div_coronal_pd']

In [10]:
train_dloader = genDataLoader(
    ['/mnt/dense/vliu/summer_dset/div_coronal_pd_fs/Train','/mnt/dense/vliu/summer_dset/div_coronal_pd/Val'],
    [0, 0], # Let's not downsample anything for a start 
    center_fractions = [0.05, 0.06, 0.07, 0.08],
    accelerations = [4, 5, 6],
    num_workers= 16,
    shuffle = True,
    stratified = 1, method = 'upsample'
)

In [11]:
varnet = VarNet(3,4,12,4) # very simple varnet for reviewing (3x2 + 3 = 9 unrolls)
varnet = varnet.cuda()
criterion = nn.L1Loss() # simple loss func
optimizer = torch.optim.Adam(varnet.parameters(),lr=0.0002)

In [None]:
for epoch in range(100): 
    # for 100 epochs in test set
    # See if training loss gets smaller or not
    train_dataset = iter(train_dloader[0])
    loss1 = 0
    loss2 = 0
    
    for kspace, mask, esp, im_fs, contrast in train_dataset:
        kspace, mask = kspace.cuda(), mask.cuda()
        esp, im_fs = esp.cuda(), im_fs.cuda()

        if contrast[0] == 'div_coronal_pd_fs': # This is contrast 0
            input_contrast = 0
        else:
            input_contrast = 1

        pred = varnet(kspace, mask, esp, input_contrast)
        pred = transforms.complex_center_crop(pred, tuple(im_fs.shape[2:4]))

        loss = criterion(pred, im_fs)
        loss.backward()
        
        if epoch > 1: # ignore first epoch
            if input_contrast == 0:
                loss1 += loss.item()
            else:
                loss2 += loss.item()
        
        if input_contrast == 1:
            optimizer.step()
            optimizer.zero_grad()
    
    if epoch >1:
        loss1 /= len(train_dataset)
        loss2 /= len(train_dataset)

        logging.info('Epoch {} Loss1: {}'.format(epoch, loss1))
        logging.info('Epoch {} Loss2: {}'.format(epoch, loss2))
        torch.save(varnet.state_dict(), 'MTL_weight.pt')

In [None]:
test_dloader = genDataLoader(
    [f'{basedir}/Test' for basedir in basedirs],
    [0, 0], # Let's not downsample anything for a start 
    center_fractions = [0.06],
    accelerations = [6],
    num_workers= 16,
    shuffle = True,
    stratified = 1, method = 'upsample'
)

In [None]:
testloader = iter(test_dloader[0]) # dloader

In [None]:
kspace1, mask1, esp1, im_fs1, contrast1 = next(testloader)
if contrast1[0] == 'div_coronal_pd_fs': # This is contrast 0
    input_contrast = 0
else:
    input_contrast = 1
ic(input_contrast) # This is coronal_pd_fs (contrast #0)

with torch.no_grad():
    kspace, mask = kspace1.cuda(), mask1.cuda()
    esp, im_fs = esp1.cuda(), im_fs1.cuda()

    pred = varnet(kspace, mask, esp, input_contrast)
    pred = transforms.complex_center_crop(pred, tuple(im_fs.shape[2:4]))
    print(pred.shape)

In [None]:
loss = criterion(pred, im_fs)
print(loss.item())

In [None]:
# now look at the image
pred1_numpy = transforms.tensor_to_complex_np(pred.cpu())
im_fs1_numpy = transforms.tensor_to_complex_np(im_fs1.cpu())

fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(10,10))
ax1.imshow(np.abs(pred1_numpy.squeeze()), cmap='gray')
ax2.imshow(np.abs(im_fs1_numpy.squeeze()), cmap='gray')

fig.tight_layout()