In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch as t
import torch.nn as nn
from pathlib import Path
# from DiagnosisAI.models.AutoEncoder_BRAIN import Segmenter
import pickle
import numpy as np
import nibabel as nib
from sklearn.model_selection import train_test_split
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from PIL import Image
import cv2 as cv

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def convert_bw_RGB(image):
    temp_img = image.astype(np.float32)
    temp_img *= 1 / temp_img.max() 
    temp_img = temp_img[np.newaxis, :, :]
    # temp_img = cv.cvtColor(temp_img, cv.COLOR_GRAY2RGB)
    temp_img = t.tensor(temp_img)
    # temp_img = temp_img.permute(2, 0, 1)

    return temp_img

def conv_mask(mask):
    return mask[np.newaxis, :, :].astype(np.float32)

In [4]:
file_path = Path("../datasets/brain/Brats2021_training_df/BraTS2021_00000/BraTS2021_00000_flair.nii.gz")
seg_path = Path("../datasets/brain/Brats2021_training_df/BraTS2021_00000/BraTS2021_00000_seg.nii.gz")
file_test = nib.load(file_path).get_fdata()
label_test = nib.load(seg_path).get_fdata()

imgs_train = [convert_bw_RGB(file_test[:, :, 80]) for _ in range(50)]
labels_train = [conv_mask(label_test[:, :, 80]) for _ in range(50)]

imgs_val = [convert_bw_RGB(file_test[:, :, 80]) for _ in range(10)]
labels_val = [conv_mask(label_test[:, :, 80]) for _ in range(10)]

imgs_test = [convert_bw_RGB(file_test[:, :, 80]) for _ in range(15)]
labels_test = [conv_mask(label_test[:, :, 80]) for _ in range(15)]


In [5]:
train_loader = t.utils.data.DataLoader(list(zip(imgs_train, labels_train)), batch_size=1, num_workers=8)
val_loader = t.utils.data.DataLoader(list(zip(imgs_val, labels_val)), batch_size=1, num_workers=8)
test_loader = t.utils.data.DataLoader(list(zip(imgs_test, label_test)), batch_size=1, num_workers=8)

======================== test above

In [6]:
# TODO:
# 1. Unet domslnie wciaga 3 kanalowy obraz RGB, my mamy 1 kanal, ktory sie nie pokazuje wiec czy mozna sztucznie dodac os lub wczytac
# wczytac jako rgb czyli dac to samo na 3 kanaly
# 2. maska do segmentacji czy ma byc binarna dla kanalow tyle ile jest klas czy moze byc w jednym zdjeciu dla roznych wartosci

In [7]:
img, mask = next(iter(train_loader))

In [8]:
img.shape, mask.shape

(torch.Size([1, 1, 240, 240]), torch.Size([1, 1, 240, 240]))

In [9]:
def conv2d_block(input_tensor, in_channels, nfilter, kernel_size = 3, batchnorm = True):
    """Function to add 2 convolutional layers with the parameters passed to it"""
    # first layer
    # input = batch_size, 1, 240, 240
    x = nn.Conv2d(in_channels=in_channels, out_channels=nfilter, kernel_size=kernel_size, padding='same')(input_tensor)
    if batchnorm:
        x = nn.BatchNorm2d(num_features=nfilter)(x)
    x = nn.ReLU()(x)
    
    # second layer
    x = nn.Conv2d(in_channels=nfilter, out_channels=nfilter, kernel_size=kernel_size, padding='same')(x)
    if batchnorm:
        x = nn.BatchNorm2d(num_features=nfilter)(x)
    x = nn.ReLU()(x)
    
    return x

In [10]:
input_img = img
in_channels = 1
n_filters = 16
batchnorm = True
dropout = 0.1

In [26]:
c1 = conv2d_block(input_img, in_channels, n_filters * 1, kernel_size = 3, batchnorm = batchnorm) # (1, 16, 240, 240)
p1 = nn.MaxPool2d((2, 2))(c1) # (1, 16, 120, 120)
p1 = nn.Dropout(dropout)(p1)

c2 = conv2d_block(p1, n_filters * 1, n_filters * 2, kernel_size = 3, batchnorm = batchnorm) # (1, 32, 120, 120)
p2 = nn.MaxPool2d((2, 2))(c2) # (1, 32, 60, 60)
p2 = nn.Dropout(dropout)(p2)

c3 = conv2d_block(p2, n_filters * 2, n_filters * 4, kernel_size = 3, batchnorm = batchnorm) # (1, 64, 60, 60)
p3 = nn.MaxPool2d((2, 2))(c3) # (1, 64, 30, 30)
p3 = nn.Dropout(dropout)(p3)

c4 = conv2d_block(p3, n_filters * 4, n_filters * 8, kernel_size = 3, batchnorm = batchnorm) # (1, 128, 30, 30)
p4 = nn.MaxPool2d((2, 2))(c4) # (1, 128, 15, 15)
p4 = nn.Dropout(dropout)(p4)

c5 = conv2d_block(p4, n_filters * 8, n_filters * 16, kernel_size = 3, batchnorm = batchnorm) # (1, 256, 15, 15)

In [27]:
c5.shape

torch.Size([1, 256, 15, 15])

In [28]:
u6 = nn.ConvTranspose2d(n_filters * 16, n_filters * 8, kernel_size=3, padding=1)(c5) # (1, 128, 15, 15)
u6 = nn.UpsamplingNearest2d(scale_factor=2)(u6) #blinear # (1, 128, 30, 30) 

To get better precise locations, at every step of the decoder we use skip connections by concatenating the output of the transposed convolution layers with the feature maps from the Encoder at the same level:
u6 = u6 + c4
u7 = u7 + c3
u8 = u8 + c2
u9 = u9 + c1
After every concatenation we again apply two consecutive regular convolutions so that the model can learn to assemble a more precise output

In [23]:
u6 = nn.ConvTranspose2d(n_filters * 16, n_filters * 8, kernel_size=3, padding=1)(c5) # (1, 128, 15, 15)
u6 = nn.UpsamplingNearest2d(scale_factor=2)(u6) #blinear # (1, 128, 30, 30) 
u6 = t.cat([u6, c4], dim=1) # (1, 128, 30, 30) i (1, 128, 30, 30) = (1, 256, 30, 30)
u6 = nn.Dropout(dropout)(u6)
c6 = conv2d_block(u6, n_filters * 16, n_filters * 8, kernel_size = 3, batchnorm = batchnorm) # (1, 128, 30, 30)

u7 = nn.ConvTranspose2d(n_filters * 8, n_filters * 4, kernel_size=3, padding = 1)(c6) # (1, 64, 30, 30)
u7 = nn.UpsamplingNearest2d(scale_factor=2)(u7) # (1, 64, 60, 60)
u7 = t.cat([u7, c3], dim=1) # (1, 64, 60, 60) i (1, 64, 60, 60) = (1, 128, 60, 60)
u7 = nn.Dropout(dropout)(u7)
c7 = conv2d_block(u7, n_filters * 8, n_filters * 4, kernel_size = 3, batchnorm = batchnorm) # (1, 64, 60, 60)

u8 = nn.ConvTranspose2d(n_filters * 4, n_filters * 2, kernel_size= 3, padding = 1)(c7) # (1, 32, 60, 60)
u8 = nn.UpsamplingNearest2d(scale_factor=2)(u8) # (1, 32, 120, 120)
u8 = t.cat([u8, c2], dim=1) # (1, 32, 120, 120) i (1, 32, 120, 120) = (1, 64, 120, 120)
u8 = nn.Dropout(dropout)(u8)
c8 = conv2d_block(u8, n_filters * 4, n_filters * 2, kernel_size = 3, batchnorm = batchnorm) # (1, 32, 120, 120)

u9 = nn.ConvTranspose2d(n_filters * 2, n_filters * 1, kernel_size= 3, padding = 1)(c8) # (1, 16, 120, 120)
u9 = nn.UpsamplingNearest2d(scale_factor=2)(u9) # (1, 16, 240, 240)
u9 = t.cat([u9, c1], dim=1) # (1, 16, 240, 240) I (1, 16, 240, 240) = (1, 32, 240, 240)
u9 = nn.Dropout(dropout)(u9)

c9 = conv2d_block(u9, n_filters * 2, n_filters * 1, kernel_size = 3, batchnorm = batchnorm) # (1, 16, 240, 240)
outputs = nn.Conv2d(n_filters * 1, in_channels, kernel_size=1)(c9)
outputs = nn.Sigmoid()(outputs)

RuntimeError: Given groups=1, weight of size [128, 128, 3, 3], expected input[1, 256, 30, 30] to have 128 channels, but got 256 channels instead

In [22]:
outputs.shape

torch.Size([5, 1, 240, 240])

In [None]:
# value 1 - core
# value 2 - invaded
# value 4 - enhanced