In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from dataloader import MRIDataset
from residual3dunet.model import ResidualUNet3D, UNet3D
# from unet3d.model import UNet3D
import matplotlib.pyplot as plt
from utils import get_loaders
import torchvision.transforms.functional as F
import torchvision.transforms as T
import random
import h5py
import numpy as np
from ipywidgets import interact
from elastic_transform import RandomElastic
import nibabel as nib
#from residual3dunet.modelorig import ResidualUNet3D

import matplotlib.pyplot as plt

In [None]:
dataset = MRIDataset(train=True)
train, val = torch.utils.data.random_split(dataset, [40, 10])

dataloader = DataLoader(dataset = train, batch_size= 1, shuffle= True, num_workers=2)

In [None]:
# dataset = MRIDataset(train=True, transform=True)

train_kwargs = {'batch_size': 10}
cuda_kwargs = {'num_workers': 1, 'pin_memory': True,'shuffle': True}
train_kwargs.update(cuda_kwargs)

#train, val = random_split(dataset, [40, 10])
# dataloader = DataLoader(dataset = val, **train_kwargs)

In [None]:
dataloader, valloader = get_loaders(train=True, transform=True, **train_kwargs)

In [None]:
dataiter = iter(dataloader)
data = dataiter.next()
features, labels = data
print(features.shape)
print(labels.shape)

In [2]:
x = torch.rand((1,1,14,240,240))

In [16]:
model1 = ResidualUNet3D(in_channels=1, out_channels=1, f_maps=64, num_levels=4)
model2 = UNet3D(in_channels=1, out_channels=1, f_maps=64, num_levels=4)
# print(features.shape)

In [None]:
output = model2(x)

In [19]:
import torch.nn.functional as F

class UNet(nn.Module):
    def __init__(self, in_channel=1, out_channel=1, training=True):
        super(UNet, self).__init__()
        self.training = training
        self.encoder1 = nn.Sequential(
                nn.GroupNorm(num_groups=1, num_channels=in_channel),
                nn.Conv3d(in_channel, 32, 3, padding=1, bias=False),
                nn.ELU(inplace=True),
                nn.GroupNorm(num_groups=8, num_channels=32),
                nn.Conv3d(32, 64, 3, padding=1, bias=False),
                nn.ELU(inplace=True)
            )
        self.encoder2 = nn.Sequential(
                nn.GroupNorm(num_groups=8, num_channels=64),
                nn.Conv3d(64, 64, 3, padding=1, bias=False),
                nn.ELU(inplace=True),
                nn.GroupNorm(num_groups=8, num_channels=64),
                nn.Conv3d(64, 128, 3, padding=1, bias=False),
                nn.ELU(inplace=True)
            )
        self.encoder3 = nn.Sequential(
                nn.GroupNorm(num_groups=8, num_channels=128),
                nn.Conv3d(128, 128, 3, padding=1, bias=False),
                nn.ELU(inplace=True),
                nn.GroupNorm(num_groups=8, num_channels=128),
                nn.Conv3d(128, 256, 3, padding=1, bias=False),
                nn.ELU(inplace=True)
            )
        self.encoder4 = nn.Sequential(
                nn.GroupNorm(num_groups=8, num_channels=256),
                nn.Conv3d(256, 256, 3, padding=1, bias=False),
                # nn.ELU(inplace=True),
                nn.ELU(inplace=True),
                nn.GroupNorm(num_groups=8, num_channels=256),
                nn.Conv3d(256, 512, 3, padding=1, bias=False),
                # nn.ELU(inplace=True),
                nn.ELU(inplace=True)
            )
        
        self.decoder2 = nn.Sequential(
                nn.GroupNorm(num_groups=8, num_channels=768),
                nn.Conv3d(768, 256, 3, padding=1, bias=False),
                nn.ELU(inplace=True),
                nn.GroupNorm(num_groups=8, num_channels=256),
                nn.Conv3d(256, 256, 3, padding=1, bias=False),
                nn.ELU(inplace=True)
            )

        self.decoder3 = nn.Sequential(
                nn.GroupNorm(num_groups=8, num_channels=384),
                nn.Conv3d(384, 128, 3, padding=1, bias=False),
                nn.ELU(inplace=True),
                nn.GroupNorm(num_groups=8, num_channels=128),
                nn.Conv3d(128, 128, 3, padding=1, bias=False),
                nn.ELU(inplace=True)
            )

        self.decoder4 = nn.Sequential(
                nn.GroupNorm(num_groups=8, num_channels=192),
                nn.Conv3d(192, 64, 3, padding=1, bias=False),
                nn.ELU(inplace=True),
                nn.GroupNorm(num_groups=8, num_channels=64),
                nn.Conv3d(64, 64, 3, padding=1, bias=False),
                nn.ELU(inplace=True)
            )

        self.decoder5 = nn.Conv3d(64, 1, 1)
        self.maxpool = nn.MaxPool3d(kernel_size=(1,2,2))
        self.upsampling = InterpolateUpsampling('trilinear')
        
    
    def forward(self, x):

        print(x.shape)
        x1 = self.encoder1(x) # 1,64,14,240,240
        print(x1.shape)

        x2 = self.encoder2(self.maxpool(x1)) # 1,128,14,120,120
        print(x2.shape)

        x3 = self.encoder3(self.maxpool(x2)) # 1,256,14,60,60
        print(x3.shape)

        x4 = self.encoder4(self.maxpool(x3)) # 1,512,14,30,30
        print(x4.shape)   

        out = x3.size()[2:]
        x5 = torch.cat((x3, self.upsampling(x4, out)),dim=1)
        x5 = self.decoder2(x5)
        print(x5.shape)

        out = x2.size()[2:]
        x6 = torch.cat((x2, self.upsampling(x5, out)),dim=1)
        x6 = self.decoder3(x6)
        print(x6.shape)

        out = x1.size()[2:]
        x7 = torch.cat((x1, self.upsampling(x6, out)),dim=1)
        x7 = self.decoder4(x7)
        print(x7.shape)

        out = self.decoder5(x7)
        print(out)

        return out


class InterpolateUpsampling(nn.Module):
    def __init__(self, mode):
        super(InterpolateUpsampling, self).__init__()

        self.mode = mode

    def forward(self, x, size):
        return F.interpolate(x,size=size, mode=self.mode)

In [20]:
model = UNet(1,1)
out = model(x)

torch.Size([1, 1, 14, 240, 240])
torch.Size([1, 64, 14, 240, 240])
torch.Size([1, 128, 14, 120, 120])
torch.Size([1, 256, 14, 60, 60])
torch.Size([1, 512, 14, 30, 30])


  "See the documentation of nn.Upsample for details.".format(mode)


torch.Size([1, 256, 14, 60, 60])
torch.Size([1, 128, 14, 120, 120])


RuntimeError: Unsupported memory format. Supports only ChannelsLast, Contiguous

In [None]:
dataset2 = MRIDataset(train=True, transform=T.Compose([
    T.ToTensor(), 
    T.RandomHorizontalFlip(), 
    T.RandomCrop((240,240), padding=50, pad_if_needed=True),
    ]), elastic=True)


In [None]:

test_loader = DataLoader(dataset = dataset2, batch_size=50, shuffle=False)

# dataiter = iter(test_loader)

# first = next(dataiter)
# second = next(dataiter)

# features1, labels1 = first
# features2, labels2 = second

for data, target in test_loader:
    print(data.shape)
    print(target.shape)



In [None]:
# h5ftrain = h5py.File('dataset/T2train.h5','r')
# h5ftrainmask = h5py.File('dataset/T2trainmask.h5','r')

# data = h5ftrain[f'T2data_2'][:]
# target = h5ftrainmask[f'T2maskdata_2'][:]

# print(data.shape)
# print(target.shape)

# image_path = './dataset/train/T1/MRI2_T1.nii.gz'
# image_obj = nib.load(image_path)
# # print(f'Type of the image {type(image_obj)}')

# # Extract data as numpy array
# image_data = image_obj.get_fdata()
# print(type(image_data))
# print(image_data.shape)



# image_data = np.pad(image_data, ((0,0),(0,0),(0,1)))
# image_data = np.moveaxis(image_data, 2, 0)
# image_data = np.moveaxis(image_data, 2, 1)
# image_data = torch.from_numpy(image_data)
# image_data = torch.unsqueeze(image_data, 0)
# print(image_data.shape)

# preprocess = RandomElastic(alpha=0, sigma=0.06)
# data1, target1 = preprocess(data, target)

# print(data1.shape)
# print(target1.shape)


In [None]:
# index = random.randint(0,49)
def explore_3d_image(layer):

    plt.figure(figsize=(5,10))
    # plt.imshow(data[0,layer,:,:],cmap='gray')
    
    plt.imshow(t1data[layer,:,:],cmap='gray')
    plt.imshow(t2mask[layer,:,:],cmap='gray', alpha=0.3)
    plt.title('Explore Layers of Kidney MRI')
    plt.axis('off')
    return layer

interact(explore_3d_image, layer=(0,13))