In [1]:
from torch.utils.data import DataLoader
import numpy as np
from pathlib import Path
import pandas as pd
import torch

from data_loader import BraTSDataset

In [2]:
data_folder = Path('/home/anvar/work/data/brats_slices/')
df = pd.read_csv(data_folder / 'meta.csv', index_col=0)
data = BraTSDataset(df, data_folder)

In [3]:
len(data)

50844

# 2. Unet modification

1. Sum as aggregator
2. Upsample in `forward`

https://arxiv.org/abs/1505.04597

![title](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)

In [4]:
import torch.nn as nn
from torch.nn import functional

In [5]:
def conv_3x3(in_c, out_c):
    return nn.Sequential(
                nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
                nn.ReLU(inplace=True)
    )

def up_conv(in_c, out_c):
    return nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)

class Unet(nn.Module):
    
    def __init__(self, ):
        super().__init__()
        
        self.max_pool2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.down_conv_1 = conv_3x3(1, 16)
        self.down_conv_2 = conv_3x3(16, 32)
        self.down_conv_3 = conv_3x3(32, 64)
        self.bottleneck_conv = conv_3x3(64, 64)
        
        self.upsample_1 = up_conv(64, 64)
        self.up_conv_1 = conv_3x3(64, 32)
        self.upsample_2 = up_conv(32, 32)
        self.up_conv_2 = conv_3x3(32, 16)
        self.upsample_3 = up_conv(16, 16)
        self.up_conv_3 = conv_3x3(16, 8)
        
        self.segm = nn.Sequential(
            nn.Conv2d(8, 8, kernel_size=3, padding=1),
            nn.Conv2d(8, 1, kernel_size=1)
        )
        
        
    def forward(self, x):
        
        # down/contracting
        x1 = self.down_conv_1(x) #
        x2 = self.max_pool2x2(x1)
        
        x3 = self.down_conv_2(x2)#
        x4 = self.max_pool2x2(x3)
        
        x5 = self.down_conv_3(x4)#
        x6 = self.max_pool2x2(x5)
        
        x7 = self.bottleneck_conv(x6)
        
        # up/expansive
        x = self.upsample_1(functional.interpolate(x7, x5.shape[2:], mode='bilinear'))
        x = self.up_conv_1(x.add(x5))
        
        x = self.upsample_2(functional.interpolate(x, x3.shape[2:], mode='bilinear'))
        x = self.up_conv_2(x.add(x3))
        
        x = self.upsample_3(functional.interpolate(x, x1.shape[2:], mode='bilinear'))
        x = self.up_conv_3(x.add(x1))

        # segm
        x = self.segm(x)
        
        return x

In [6]:
net = Unet()

In [7]:
x = torch.rand(1,1,240,240)

In [8]:
data_folder = Path('/home/anvar/work/data/brats_slices/')
df = pd.read_csv(data_folder / 'meta.csv', index_col=0)
dataset = BraTSDataset(df, data_folder)

In [9]:
dataset_loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=4, shuffle=True,
                                             num_workers=8)

In [10]:
import torch.optim as optim

In [11]:
device = ("cuda" if torch.cuda.is_available() else 'cpu')
model = Unet().to(device)

criterion = nn.BCEWithLogitsLoss() # 

optimizer = optim.Adam(model.parameters(), lr=0.01)

loader = dataset_loader

In [12]:
from tqdm.notebook import tqdm

In [13]:
for epoch in range(5):
    epoch_loss = 0
    for X_batch, y_batch in tqdm(loader):
        
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        optimizer.zero_grad()
        y_pred = model(X_batch.float())
        
        loss = criterion(y_pred, y_batch.float().unsqueeze(1))        
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()        

    print(f'Epoch {epoch+0:03}: | Loss: {epoch_loss/len(dataset_loader):.5f}')

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=12711.0), HTML(value='')))






KeyboardInterrupt: 