In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch 
from torchvision import transforms
from torch.utils.data import Dataset 
import numpy as np
from PIL import Image
import os
import nibabel as nib

In [3]:
np.random.seed(12345)

## Cuda Check

In [4]:
print('avaliable:\t', torch.cuda.is_available())
print('current:\t', torch.cuda.current_device())
print('device: \t', torch.cuda.device(0))
print('count:\t\t', torch.cuda.device_count())
print('name:\t\t', torch.cuda.get_device_name(0))

avaliable:	 True
current:	 0
device: 	 <torch.cuda.device object at 0x000002203F5EECC8>
count:		 1
name:		 NVIDIA GeForce GTX 1060 6GB


## Input Data

###### Jimi's Path

In [5]:
path = 'D:/BT_Classification/backup/train/'

## Data Augmenting

In [6]:
def load_image(path):
    """
    Parameters:
        path: path to file
    Return:
        img: image tensor
    """
    img_nii = nib.load(path)
    
    img_np = np.squeeze(img_nii.get_fdata(dtype=np.float32))
    
    return torch.from_numpy(img_np)

In [7]:
def check_crop_dim(full_size, crop_size):

    if full_size[0] < crop_size[0] or full_size[1] < crop_size[1] or full_size[2] < crop_size[2]:
        return True

    return False

In [8]:
def get_valid_crop(image_tensor, crop_size):
    print(image_tensor.shape)
    mask = np.where(image_tensor < torch.mean(image_tensor), 0, image_tensor)
    mask[mask > 0] = 1
    print(mask.shape)
    c_s, c_w, c_h = crop_size
    f_s, f_w, f_h = image_tensor.shape
    
    total = np.sum(mask)
    
    search = True
    while search:
        if f_s == c_s:
            x = 0
        else:
            x = np.random.randint(f_s - c_s)

        if f_w == c_w:
            y = 0
        else:
            y = np.random.randint(f_w - c_w)

        if f_h == c_h:
            z = 0
        else:
            z = np.random.randint(f_h - c_h)

        cropped = mask[x:x + c_s, y:y + c_w, z:z + c_h]
        
        c_total = np.sum(cropped)
        
        if (c_total / total) > 0.1:
            search = False
    
    return image_tensor[x:x + c_s, y:y + c_w, z:z + c_h]
        
            
                

In [9]:
def get_random_crops(path, crop_size, count):
    image_tensor = load_image(path)
    
    if check_crop_dim(image_tensor.shape, crop_size):
        print('Crop size too large. Image:{} \t crop:{}'.format(image_tensor.shape, crop_size))
        return
    
    cropped = []
    for i in range(count):
        cropped.append(get_valid_crop(image_tensor, crop_size))
        
    return cropped
    
    

In [10]:
crops = get_random_crops(path=path+'00000/FLAIR.nii', crop_size=(70, 70, 70), count=2)

torch.Size([359, 99, 288])
(359, 99, 288)
torch.Size([359, 99, 288])
(359, 99, 288)


### View Croppings

In [11]:
import matplotlib.pyplot as plt
import mpl_interactions.ipyplot as iplt

%matplotlib ipympl

In [12]:
def view(image):
    def axial(ax_slice):
        return image[ax_slice,:,:]

    def cor(cor_slice):
        return image[:,cor_slice,:]

    def sag(sag_slice):
        return image[:,:,sag_slice]

    # define layout
    fig, ax = plt.subplots(3,1)

    for a in ax:
        a.xaxis.set_visible(False)
        a.yaxis.set_visible(False)

    ctrl1 = iplt.imshow(axial, ax_slice=np.arange(image.shape[2]), aspect='auto', ax=ax[0], vmin=0, vmax=1, cmap ='gray')
    ctrl2 = iplt.imshow(sag, sag_slice=np.arange(image.shape[0]), aspect='auto', ax=ax[1], vmin=0, vmax=1, cmap ='gray')
    crtl3 = iplt.imshow(cor, cor_slice=np.arange(image.shape[1]), aspect='auto', ax=ax[2], vmin=0, vmax=1, cmap ='gray' )
    fig.tight_layout()


In [13]:
view(crops[0])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

VBox(children=(HBox(children=(IntSlider(value=0, description='ax_slice', max=69, readout=False), Label(value='…

VBox(children=(HBox(children=(IntSlider(value=0, description='sag_slice', max=69, readout=False), Label(value=…

VBox(children=(HBox(children=(IntSlider(value=0, description='cor_slice', max=69, readout=False), Label(value=…

In [14]:
view(crops[1])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

VBox(children=(HBox(children=(IntSlider(value=0, description='ax_slice', max=69, readout=False), Label(value='…

VBox(children=(HBox(children=(IntSlider(value=0, description='sag_slice', max=69, readout=False), Label(value=…

VBox(children=(HBox(children=(IntSlider(value=0, description='cor_slice', max=69, readout=False), Label(value=…

## Building the 3D CNN and MLP

In [16]:
import torch.nn as nn

In [29]:
image_tensor = load_image(path+'00000/FLAIR.nii')

In [37]:
image_tensor.shape

torch.Size([359, 99, 288])

In [38]:
print(image_tensor[:2])

tensor([[[0.0278, 0.0278, 0.0278,  ..., 0.0278, 0.0278, 0.0278],
         [0.0278, 0.0278, 0.0278,  ..., 0.0278, 0.0278, 0.0278],
         [0.0278, 0.0278, 0.0278,  ..., 0.0278, 0.0278, 0.0278],
         ...,
         [0.0278, 0.0278, 0.0278,  ..., 0.0278, 0.0278, 0.0278],
         [0.0278, 0.0278, 0.0278,  ..., 0.0278, 0.0278, 0.0278],
         [0.0278, 0.0278, 0.0278,  ..., 0.0278, 0.0278, 0.0278]],

        [[0.0278, 0.0278, 0.0278,  ..., 0.0278, 0.0278, 0.0278],
         [0.0278, 0.0278, 0.0278,  ..., 0.0278, 0.0278, 0.0278],
         [0.0278, 0.0278, 0.0278,  ..., 0.0278, 0.0278, 0.0278],
         ...,
         [0.0278, 0.0278, 0.0278,  ..., 0.0278, 0.0278, 0.0278],
         [0.0278, 0.0278, 0.0278,  ..., 0.0278, 0.0278, 0.0278],
         [0.0278, 0.0278, 0.0278,  ..., 0.0278, 0.0278, 0.0278]]])


In [42]:
tensor = torch.unsqueeze(image_tensor, -1)
print(tensor.shape)
print(tensor[:2])

torch.Size([359, 99, 288, 1])
tensor([[[[0.0278],
          [0.0278],
          [0.0278],
          ...,
          [0.0278],
          [0.0278],
          [0.0278]],

         [[0.0278],
          [0.0278],
          [0.0278],
          ...,
          [0.0278],
          [0.0278],
          [0.0278]],

         [[0.0278],
          [0.0278],
          [0.0278],
          ...,
          [0.0278],
          [0.0278],
          [0.0278]],

         ...,

         [[0.0278],
          [0.0278],
          [0.0278],
          ...,
          [0.0278],
          [0.0278],
          [0.0278]],

         [[0.0278],
          [0.0278],
          [0.0278],
          ...,
          [0.0278],
          [0.0278],
          [0.0278]],

         [[0.0278],
          [0.0278],
          [0.0278],
          ...,
          [0.0278],
          [0.0278],
          [0.0278]]],


        [[[0.0278],
          [0.0278],
          [0.0278],
          ...,
          [0.0278],
          [0.0278],
          [0.027

In [46]:
test = nn.Conv3d(1, 5, 5, bias=False)

In [49]:
out = test(image_tensor.unsqueeze(0).unsqueeze(0))
print(out.shape)

torch.Size([1, 5, 355, 95, 284])


In [50]:
def make_nn_model():
    in_channel = 1
    base_out = 4
    return nn.Sequential(nn.Conv3d(in_channels=in_channel, out_channels=base_out, kernel_size=3, bias=False),
                         nn.LeakyReLU(),
                         nn.Dropout3d(p=0.3),
                         nn.Conv3d(base_out, base_out*2, 3, bias=False),
                         nn.InstanceNorm3d(base_out*2),
                         nn.LeakyReLU(),
                         nn.Dropout3d(p=0.3),
                         nn.Conv3d(base_out*2, base_out*4, 3, bias=False),
                         nn.InstanceNorm3d(base_out*4),
                         nn.LeakyReLU(),
                         nn.Conv3d(base_out*4, base_out*8, 3, bias=False),
                         nn.MaxPool3d(2),
                         nn.Dropout3d(p=0.3),
                         
                         
                         nn.Conv3d(base_out*8, base_out*8, 3, bias=False),
                         nn.InstanceNorm3d(base_out*8),
                         nn.LeakyReLU(),
                         nn.Dropout3d(p=0.3),
                         nn.Conv3d(base_out*8, base_out*16, 3, bias=False),
                         nn.LeakyReLU(),
                         nn.MaxPool3d(2),
                         nn.Dropout3d(p=0.3),
                         
                         
                         nn.Conv3d(base_out*16, base_out*16, 3, bias=False),
                         nn.InstanceNorm3d(base_out*16),
                         nn.LeakyReLU(),
                         nn.Dropout3d(p=0.3),
                         nn.Conv3d(base_out*16, base_out*32, 3, bias=False),
                         nn.InstanceNorm3d(256),
                         nn.LeakyReLU(),
                         nn.Dropout3d(p=0.3),
                        )

In [65]:
def make_lin_model():
    return nn.Sequential(nn.Linear(10616832, 64),
#                          nn.Dropout2d(p=0.2),
#                          nn.Linear(2048, 64),
                         nn.Dropout2d(p=0.2),
                         nn.Linear(64, 8),
                         nn.Dropout2d(p=0.2),
                         nn.Linear(8, 2)
                        )

In [51]:
nn_model = make_nn_model()

In [68]:
out = nn_model(image_tensor.unsqueeze(0).unsqueeze(0))

In [53]:
print(out.shape)

torch.Size([1, 128, 81, 16, 64])


In [69]:
out = torch.flatten(out,1)

In [55]:
print(out.shape)

torch.Size([1, 10616832])


In [66]:
lin_model = make_lin_model()

In [70]:
out = lin_model(out)

In [71]:
print(out.shape)

torch.Size([1, 2])


In [72]:
print(out)

tensor([[ 0.0144, -0.3824]], grad_fn=<AddmmBackward>)
