# Data processing

In [1]:
#|code-fold: true
#|output: false
from IPython.display import display, HTML
display(HTML("<style>.container { width:80% !important; }</style>"))

import numpy as np
from pathlib import Path
import torch
import torch.optim as optim
import torch.nn as nn
import platform
from PIL import Image
import datetime
import matplotlib.pyplot as plt
from matplotlib import cm
from torch.utils.data import DataLoader, Dataset, random_split, WeightedRandomSampler, SubsetRandomSampler
from torchvision.transforms import Compose, ToTensor, Normalize, ToPILImage, RandomHorizontalFlip, Resize

from step_by_step import StepByStep
plt.style.use('fivethirtyeight')

%load_ext autoreload
%autoreload 2

We have accumulated labels in the `data\all_labels` folder. We need to:
- load all of them into NCWH tensors
- split them into train and valid (test labels will be separate), 
- create temporary Datasets and normalizer
- create real Datasets and DataLoaders. 

## Generate tensors

Let's first create a `x_tensor/y_tensor` from all_labels:

In [85]:
proj_dir = Path('.').resolve().parent
data_dir = proj_dir / 'data'
train_imgs_dir = data_dir / 'all_labels/type_1/imgs'
train_masks_dir = data_dir / 'all_labels/type_1/masks'
image_paths = sorted(list(train_imgs_dir.glob('*.png')))
mask_paths = sorted(list(train_masks_dir.glob('*.png')))

In [86]:
len(image_paths)

162

In [87]:
a = set([x.name for x in image_paths])
b = set([x.name for x in mask_paths])
assert len(a ^ b) == 0

to go from a path to tensor use `torchvision.transforms.ToTensor()`:

In [88]:
image_tensor = ToTensor()(Image.open(image_paths[0]))
image_tensor.shape

torch.Size([3, 256, 256])

To stack many images, use `torch.stack` (not the most memory efficient but it's ok):

In [90]:
tensorizer = ToTensor()
x_tensor = []
y_tensor = []

for image_path, mask_path in zip(image_paths, mask_paths):
    # load an image into a tensor and store in x_tensor, hopefully not too big:
    image_tensor = tensorizer(Image.open(image_path))
    mask_tensor = tensorizer(Image.open(mask_path).convert('RGB'))
    
    x_tensor.append(image_tensor)
    y_tensor.append(mask_tensor)

x_tensor = torch.stack(x_tensor)
y_tensor = torch.stack(y_tensor)

In [91]:
print(x_tensor.shape)
print(y_tensor.shape)

torch.Size([162, 3, 256, 256])
torch.Size([162, 3, 256, 256])


## Split into train and valid

We'll use `torch.utils.data.random_split`:

In [92]:
torch.manual_seed(13)  # Important for consistency
N = len(x_tensor)
n_train = int(.8*N)
n_val = N - n_train
train_subset, val_subset = random_split(x_tensor, [n_train, n_val])

train_idx = train_subset.indices
val_idx = val_subset.indices

print(train_idx)
print(val_idx)

[22, 101, 76, 11, 44, 97, 18, 51, 86, 123, 125, 59, 0, 63, 92, 111, 114, 41, 95, 27, 67, 36, 110, 83, 62, 10, 127, 144, 69, 145, 143, 133, 117, 55, 58, 42, 89, 7, 94, 50, 38, 150, 70, 153, 137, 105, 155, 96, 1, 102, 138, 140, 93, 131, 4, 5, 141, 71, 21, 91, 35, 149, 124, 30, 151, 147, 85, 16, 160, 57, 32, 103, 115, 156, 74, 104, 77, 87, 6, 129, 80, 139, 60, 119, 75, 90, 61, 78, 46, 98, 134, 116, 128, 25, 154, 148, 108, 2, 15, 24, 79, 84, 135, 126, 107, 120, 37, 82, 52, 100, 68, 53, 54, 19, 118, 9, 132, 40, 31, 29, 146, 73, 161, 72, 152, 136, 48, 49, 8]
[99, 66, 157, 121, 12, 64, 142, 130, 3, 14, 106, 33, 23, 65, 112, 88, 39, 45, 56, 13, 122, 47, 159, 81, 17, 28, 20, 34, 113, 43, 158, 26, 109]


In [97]:
x_train_tensor = x_tensor[train_idx]
y_train_tensor = y_tensor[train_idx]

x_val_tensor = x_tensor[val_idx]
y_val_tensor = y_tensor[val_idx]

In [98]:
print(x_train_tensor.shape)
print(x_val_tensor.shape)

torch.Size([129, 3, 256, 256])
torch.Size([33, 3, 256, 256])


## Temporary Datasets

Our very simple dataset with transform:

In [99]:
class TransformedTensorDataset(Dataset):
    def __init__(self, x, y, transform=None):
        self.x = x
        self.y = y
        self.transform = transform
        
    def __getitem__(self, index):
        x = self.x[index]
        
        if self.transform:
            x = self.transform(x)
        
        return x, self.y[index]
        
    def __len__(self):
        return len(self.x)

Let's first create **temporary** `Dataset` to extract normalization parameters:

In [100]:
temp_dataset = TransformedTensorDataset(x_train_tensor, y_train_tensor)
temp_loader = DataLoader(temp_dataset, batch_size=32)
normalizer = StepByStep.make_normalizer(temp_loader)
normalizer

Normalize(mean=tensor([0.1902, 0.2077, 0.1599]), std=tensor([0.1060, 0.1060, 0.1071]))

## Real Datasets and Loaders

Let's now create **real** `Datasets` and `DataLoaders`:

In [101]:
train_composer = Compose([normalizer])  # train_composer will have augmentations later
val_composer = Compose([normalizer])

train_dataset = TransformedTensorDataset(x_train_tensor, y_train_tensor, transform=train_composer)
val_dataset = TransformedTensorDataset(x_val_tensor, y_val_tensor, transform=val_composer)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# Training prep

Let's now define model, optimizer, and loss.

## Model

In [321]:
model_logistic = nn.Sequential()
model_logistic.add_module('flatten', nn.Flatten())
model_logistic.add_module('output', nn.Linear(25, 1, bias=True))
model_logistic.add_module('sigmoid', nn.Sigmoid())

In [324]:
list(model_logistic.named_parameters())

[('output.weight',
  Parameter containing:
  tensor([[ 0.0078,  0.0652, -0.1866, -0.1366, -0.0493,  0.0314,  0.0079, -0.1319,
           -0.1687, -0.1602,  0.1702,  0.0634,  0.0076, -0.1402,  0.1080,  0.1755,
            0.0316, -0.0058, -0.0456,  0.0117, -0.1726,  0.1206,  0.1004,  0.0990,
            0.1439]], requires_grad=True)),
 ('output.bias',
  Parameter containing:
  tensor([-0.1549], requires_grad=True))]

In [339]:
nn.BatchNorm2d?

In [336]:
def get_padding(block_input, kernel_size, stride):
    """
    Padding of ((output_size-1) * stride - input_size + kernel_size) // 2 
    is supposed to ensure output_size = ceil(input_size/stride) 
    (from https://stackoverflow.com/questions/48491728/what-is-the-behavior-of-same-padding-when-stride-is-greater-than-1)
    but I can't make it work. For now padding padding = (kernel_size - 1) // 2 does the trick. 
    """
#     input_size = block_input.shape[2]
#     output_size = int(np.ceil(input_size / stride))
#     padding = int(np.ceil(((output_size-1) * stride - input_size + kernel_size) // 2))

    padding = (kernel_size - 1) // 2
    return padding


def make_segnet_residual_block(block_input, depth, kernel_size, stride):
    # main branch
    x = nn.Conv2d(block_input.shape[1], depth, kernel_size, stride, 
                  get_padding(block_input, kernel_size, stride))(block_input)
    x = nn.BatchNorm2d(x.shape[1])(x)
    x = nn.ELU()(x)
    x = nn.Conv2d(x.shape[1], depth, kernel_size, stride, 
                 get_padding(x, kernel_size, stride))(x)
    x = nn.BatchNorm2d(x.shape[1])(x)

    # residual branch
    branch = nn.Conv2d(block_input.shape[1], depth, 1, stride,
                      get_padding(block_input, 1, stride))(block_input)  # kernel_size=1 here so padding=0
    branch = nn.BatchNorm2d(branch.shape[1])(branch)

    x = torch.cat([x, branch], 1)
    x = nn.ELU()(x)
    return x
    

class Segnet(nn.Module):
    def __init__(self, n_channels=3, n_classes=2):
        super(Segnet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        
        self.conv_1 = nn.Conv2d(n_channels, 32, 5, 2, get_padding(x, 5, 2))
        self.batch_norm_1 = nn.BatchNorm2d(32)
        self.elu_1 = nn.ELU()
        
    def forward(self, x):
        x = self.conv_1(x)
        x = self.batch_norm_1(x)
        x = self.elu_1(x)
        div2 = make_segnet_residual_block(x, 32, 3, 1)

        x = nn.Conv2d(x.shape[1], 64, 5, 2, get_padding(x, 5, 2))(x)
        x = nn.BatchNorm2d(x.shape[1])(x)
        x = nn.ELU()(x)
        div4 = make_segnet_residual_block(x, 64, 3, 1)
    
        x = nn.Conv2d(x.shape[1], 128, 5, 2, get_padding(x, 5, 2))(x)
        x = nn.BatchNorm2d(x.shape[1])(x)
        x = nn.ELU()(x)
        div8 = make_segnet_residual_block(x, 128, 3, 1)
    
        x = nn.ConvTranspose2d(x.shape[1], 64, 5, 2, get_padding(x, 5, 2))(x)
        x = nn.BatchNorm2d(x.shape[1])(x)
    
        added = torch.cat([x, div4], 1)
        x = nn.ELU()(added)
    
        x = nn.ConvTranspose2d(x.shape[1], 32, 5, 2, get_padding(x, 5, 2))(x)
        x = nn.BatchNorm2d(x.shape[1])(x)
    
        added = torch.cat([x, div2], 1)
        x = nn.ELU()(added)
    
        x = nn.ConvTranspose2d(x.shape[1], self.n_classes, 5, 2, get_padding(x, 5, 2))(x)
    
        return x

In [337]:
model = Segnet()

In [338]:
model

Segnet()

# Appendix


In [157]:
a = torch.Tensor([1])
b = torch.Tensor([2])
c = torch.cat([a,b], dim=0)
print(c)

tensor([1., 2.])


In [None]:
a = torch.Tensor([1])
b = torch.Tensor([2])
c = torch.cat([a,b], dim=0)
print(c)

In [158]:
a = torch.Tensor([1])
b = torch.Tensor([2])
c = torch.stack([a,b], dim=0)
d = torch.stack([a,b], dim=1)
print(c)
print(d)

tensor([[1.],
        [2.]])
tensor([[1., 2.]])


In [None]:
""" Full assembly of the parts to form the complete network """

from .unet_parts import *


class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = (DoubleConv(n_channels, 64))
        self.down1 = (Down(64, 128))
        self.down2 = (Down(128, 256))
        self.down3 = (Down(256, 512))
        factor = 2 if bilinear else 1
        self.down4 = (Down(512, 1024 // factor))
        self.up1 = (Up(1024, 512 // factor, bilinear))
        self.up2 = (Up(512, 256 // factor, bilinear))
        self.up3 = (Up(256, 128 // factor, bilinear))
        self.up4 = (Up(128, 64, bilinear))
        self.outc = (OutConv(64, n_classes))

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

    def use_checkpointing(self):
        self.inc = torch.utils.checkpoint(self.inc)
        self.down1 = torch.utils.checkpoint(self.down1)
        self.down2 = torch.utils.checkpoint(self.down2)
        self.down3 = torch.utils.checkpoint(self.down3)
        self.down4 = torch.utils.checkpoint(self.down4)
        self.up1 = torch.utils.checkpoint(self.up1)
        self.up2 = torch.utils.checkpoint(self.up2)
        self.up3 = torch.utils.checkpoint(self.up3)
        self.up4 = torch.utils.checkpoint(self.up4)
        self.outc = torch.utils.checkpoint(self.outc)