# Plan

Задача: бинарная сегментация

На вход двумерное одноканальное (grayscale) изображение, на выходе бинарная маска.

Про отключение dropout/batch normalization для инференса https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615

# 1. Prettify data loader

In [1]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import pandas as pd

In [13]:
class BraTSDataset(Dataset):
    def __init__(self, source_folder: [str, Path], transform=None):
        if isinstance(source_folder, str):
            source_folder = Path(source_folder)
            
        self.images = sorted(list(source_folder.glob('*')))
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, i):
        
        file_name = self.images[i].stem.split('_')[-1]
        slices = list(self.images[i].glob('*'))
        slices = [s.stem for s in slices if 'mask' not in s.stem]
        np.random.shuffle(slices)
        j = slices[0]
        
        image = np.load(self.images[i] / f'{j}.npy', allow_pickle=True)
        mask = np.load(self.images[i] / f'{j}_mask.npy', allow_pickle=True)
        sample = image, mask
        
        if self.transform:
            sample = self.transform(sample)

        return sample

In [14]:
data_folder = Path('/home/anvar/work/data/brats_slices/')
data = BraTSDataset(data_folder)

In [15]:
len(data)

370

In [4]:
%%timeit

data[100]

2.92 ms ± 114 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
import os
from tqdm.notebook import tqdm

## Create a csv file with important metadata

In [6]:
# data_folder = Path('/home/anvar/work/data/brats_slices/')
# df = []

# for path, _, files in tqdm(os.walk(data_folder)):
#     for file in files:
        
#         subject_id = path.split('/')[-1].split('_')[-1]
#         slice_id = file.split('.')[0].split('_')[0]
#         sample_id = f"{subject_id}_{slice_id}" # SubjectID_SliceIndex
#         is_mask = 'mask' in file
#         if is_mask:
#             mask = np.load(Path(path) / file, allow_pickle=True)
#             is_nonzero_mask =  np.any(mask)
#         else:
#             is_nonzero_mask = np.nan
        
#         df.append([Path(Path(path).stem) / file, sample_id, is_mask, subject_id, is_nonzero_mask])
        
# df = pd.DataFrame(df, columns = ['relative_path', 'sample_id', 'is_mask', 'subject_id', 'is_nonzero_mask'])
# print(df.is_nonzero_mask.value_counts())

# df.to_csv(data_folder / 'meta.csv')

> Важное преимущество такого метода в том что вы убираете из класса описывающего ваш датасет
низкоуровневую работу со структурой ваших папок на диске, у вас просто есть таблица, в которой для каждого
файла указан путь к нему, и набор его id полей (например номер слайса и номер пациента, но могут быть и другие поля).

In [7]:
df = pd.read_csv(data_folder / 'meta.csv', index_col=0)

In [8]:
df.head()

Unnamed: 0,relative_path,sample_id,is_mask,subject_id,is_nonzero_mask
0,BraTS20_Training_274/25.npy,274_25,False,274,
1,BraTS20_Training_274/71_mask.npy,274_71,True,274,True
2,BraTS20_Training_274/35_mask.npy,274_35,True,274,False
3,BraTS20_Training_274/46.npy,274_46,False,274,
4,BraTS20_Training_274/3_mask.npy,274_3,True,274,False


In [9]:
df.shape

(101688, 5)

## Edit BratSDataset class

In [16]:
class BraTSDataset(Dataset):
    def __init__(self, meta: pd.DataFrame, source_folder: [str, Path], transform=None):
        if isinstance(source_folder, str):
            source_folder = Path(source_folder)
            
        self.source_folder = source_folder
        self.meta_images = meta.query('is_mask == False').sort_values(by='sample_id').reset_index(drop=True)
        self.meta_masks = meta.query('is_mask == True').sort_values(by='sample_id').reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return self.meta_images.shape[0]

    def __getitem__(self, i):
        image = np.load(self.source_folder / self.meta_images.iloc[i]['relative_path'], allow_pickle=True)
        mask = np.load(self.source_folder / self.meta_masks.iloc[i]['relative_path'], allow_pickle=True)
        sample = image, mask
        
        if self.transform:
            sample = self.transform(sample)

        return sample

In [17]:
data_folder = Path('/home/anvar/work/data/brats_slices/')
data = BraTSDataset(df, data_folder)

In [18]:
len(data)

50844

In [19]:
%%timeit

data[20000]

1.6 ms ± 66 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [21]:
50844 * 1.3 / 1000 / 60

1.10162

In [22]:
1 * 300

300

# 2. Define Unet architecture

https://arxiv.org/abs/1505.04597

![title](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)


с одним изменением, вместо кропов будем интерполировать (чтобы в результате на выходе получить маску того же размера что и вход)

In [1]:
import torch
import torch.nn as nn

In [7]:
def conv_3x3(in_c, out_c):
    return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, kernel_size=3),
                nn.ReLU(inplace=True)
    )

class Unet(nn.Module):
    
    def __init__(self, ):
        super().__init__()
        
        self.max_pool2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.down_conv_1 = conv_3x3(1, 64)
        self.down_conv_2 = conv_3x3(64, 128)
        self.down_conv_3 = conv_3x3(128, 256)
        self.down_conv_4 = conv_3x3(256, 512)
        self.down_conv_5 = conv_3x3(512, 1024)
        
        self.up_transp_conv_1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.upsample_1 = nn.Upsample(64, mode='bilinear')
        self.up_conv_1 = conv_3x3(1024, 512)
        
        self.up_transp_conv_2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.up_conv_2 = conv_3x3(512, 256)
        self.upsample_2 = nn.Upsample(136, mode='bilinear')
        
        self.up_transp_conv_3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.up_conv_3 = conv_3x3(256, 128)
        self.upsample_3 = nn.Upsample(280, mode='bilinear')
        
        self.up_transp_conv_4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.up_conv_4 = conv_3x3(128, 64)
        self.upsample_4 = nn.Upsample(568, mode='bilinear')
        
        
    def forward(self, x):
        x1 = self.down_conv_1(x) #+
        x2 = self.max_pool2x2(x1) 
        
        x3 = self.down_conv_2(x2) #+
        print(x3.shape)
        x4 = self.max_pool2x2(x3) 
        
        x5 = self.down_conv_3(x4) #+
        x6 = self.max_pool2x2(x5) 
        
        x7 = self.down_conv_4(x6) #+
        x8 = self.max_pool2x2(x7) 

        x9 = self.down_conv_5(x8) 
    
        x = self.up_transp_conv_1(x9)
        x = self.up_conv_1(torch.cat([self.upsample_1(x), x7], axis=1))
        
        x = self.up_transp_conv_2(x)
        x = self.up_conv_2(torch.cat([self.upsample_2(x), x5], axis=1))
        
        x = self.up_transp_conv_3(x)
        x = self.up_conv_3(torch.cat([self.upsample_3(x), x3], axis=1))
        
        x = self.up_transp_conv_4(x)
        x = self.up_conv_4(torch.cat([self.upsample_4(x), x1], axis=1))
        
        print(x.shape)
#         return x5

In [8]:
model = Unet()

In [9]:
x = torch.rand(1, 1, 572, 572)

In [10]:
_ = model(x)

torch.Size([1, 128, 280, 280])




torch.Size([1, 64, 564, 564])


In [43]:
conv = nn.Conv2d(64, 128, kernel_size=3)

In [45]:
conv.weight.shape

torch.Size([128, 64, 3, 3])

In [33]:
ones = torch.ones(1, 2, 5, 5)

In [34]:
conv(ones)

tensor([[[[-0.6093, -0.6093, -0.6093],
          [-0.6093, -0.6093, -0.6093],
          [-0.6093, -0.6093, -0.6093]],

         [[ 0.2575,  0.2575,  0.2575],
          [ 0.2575,  0.2575,  0.2575],
          [ 0.2575,  0.2575,  0.2575]]]], grad_fn=<ThnnConv2DBackward>)

In [37]:
conv.weight.shape

torch.Size([128, 64, 3, 3])

In [30]:
ones2 = torch.ones(1,2,5,5)

In [31]:
conv(ones2)

RuntimeError: Given groups=1, weight of size [2, 1, 3, 3], expected input[1, 2, 5, 5] to have 1 channels, but got 2 channels instead