# Plan

Задача: бинарная сегментация

На вход двумерное одноканальное (grayscale) изображение, на выходе бинарная маска.

Про отключение dropout/batch normalization для инференса https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615

In [None]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
from pathlib import Path
import pandas as pd
import torch

In [27]:
class BraTSDataset(Dataset):
    def __init__(self, meta: pd.DataFrame, source_folder: [str, Path], transform=None):
        if isinstance(source_folder, str):
            source_folder = Path(source_folder)
            
        self.source_folder = source_folder
        self.meta_images = meta.query('is_mask == False').sort_values(by='sample_id').reset_index(drop=True)
        self.meta_masks = meta.query('is_mask == True').sort_values(by='sample_id').reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return self.meta_images.shape[0]

    def __getitem__(self, i):
        image = np.load(self.source_folder / self.meta_images.iloc[i]['relative_path'], allow_pickle=True)
        mask = np.load(self.source_folder / self.meta_masks.iloc[i]['relative_path'], allow_pickle=True)
        sample = image, mask
        
        if self.transform:
            image, mask = self.transform(sample)

        return torch.from_numpy(image).resize(1, 240, 240), torch.from_numpy(mask)

In [30]:
data_folder = Path('/home/anvar/work/data/brats_slices/')
data = BraTSDataset(df, data_folder)

In [31]:
len(data)

50844

In [33]:
data[10][0].shape

torch.Size([1, 240, 240])

# 2. Define Unet architecture

https://arxiv.org/abs/1505.04597

![title](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)


с одним изменением, вместо кропов будем интерполировать (чтобы в результате на выходе получить маску того же размера что и вход)

In [3]:
import torch.nn as nn

In [70]:
def conv_3x3(in_c, out_c):
    return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
    )

def up_conv(in_c, out_c):
    return nn.Sequential(      
        nn.Upsample(scale_factor=2, mode='bilinear'),
        nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
    )

class Unet(nn.Module):
    
    def __init__(self, ):
        super().__init__()
        
        self.max_pool2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.down_conv_1 = conv_3x3(1, 8)
        self.down_conv_2 = conv_3x3(8, 16)
        self.down_conv_3 = conv_3x3(16, 32)
        self.down_conv_4 = conv_3x3(32, 64)
        self.down_conv_5 = conv_3x3(64, 128)
        
        self.upsample_1 = up_conv(128, 64)
        self.up_conv_1 = conv_3x3(128, 64)
        self.upsample_2 = up_conv(64, 32)
        self.up_conv_2 = conv_3x3(64, 32)
        self.upsample_3 = up_conv(32, 16)
        self.up_conv_3 = conv_3x3(32, 16)
        self.upsample_4 = up_conv(16, 8)
        self.up_conv_4 = conv_3x3(16, 8)
        
        self.segm = nn.Sequential(
            nn.Conv2d(8, 8, kernel_size=3, padding=1),
            nn.Conv2d(8, 1, kernel_size=1)
        )
        
        
    def forward(self, x):
        
        # down/contracting
        x1 = self.down_conv_1(x) #
        x2 = self.max_pool2x2(x1)
        
        x3 = self.down_conv_2(x2)#
        x4 = self.max_pool2x2(x3)
        
        x5 = self.down_conv_3(x4)#
        x6 = self.max_pool2x2(x5)
        
        x7 = self.down_conv_4(x6)#
        print(x7.shape)
        x8 = self.max_pool2x2(x7)

        x9 = self.down_conv_5(x8)
    
        print(x9.shape)
        # up/expansive
        x = self.upsample_1(x9)
        x = self.up_conv_1(torch.cat([x, x7], axis=1))
        
        x = self.upsample_2(x)
        x = self.up_conv_2(torch.cat([x, x5], axis=1))
        
        x = self.upsample_3(x)
        x = self.up_conv_3(torch.cat([x, x3], axis=1))
        
        x = self.upsample_4(x)
        x = self.up_conv_4(torch.cat([x, x1], axis=1))
        
        # segm
        x = self.segm(x)
        
        return x

In [71]:
net = Unet()

In [72]:
x = torch.rand(1,1,240,240)

In [73]:
net(x)

torch.Size([1, 64, 30, 30])
torch.Size([1, 128, 15, 15])


ValueError: either size or scale_factor should be defined

In [69]:
data_folder = Path('/home/anvar/work/data/brats_slices/')
df = pd.read_csv(data_folder / 'meta.csv', index_col=0)
dataset = BraTSDataset(df, data_folder)

NameError: name 'Path' is not defined

In [33]:
dataset_loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=8)

In [34]:
import torch.optim as optim

In [35]:
device = ("cuda" if torch.cuda.is_available() else 'cpu')
model = Unet().to(device)

criterion = nn.BCEWithLogitsLoss() # 

optimizer = optim.Adam(model.parameters(), lr=0.01)

loader = dataset_loader

In [36]:
for epoch in range(5):
    epoch_loss = 0
    for X_batch, y_batch in loader:
        
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        y_pred = model(X_batch.float())
        
        loss = criterion(y_pred, y_batch.float().unsqueeze(1))        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()        

    print(f'Epoch {epoch+0:03}: | Loss: {epoch_loss/len(dataset_loader):.5f}')

KeyboardInterrupt: 