# Unet

In [2]:
import os
import glob
import math
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

import torch 
import torch.nn as nn 
from torchmetrics import Accuracy
from torch.utils.data import DataLoader
from torchvision import transforms

import pytorch_lightning as pl

# Model
- https://www.kaggle.com/code/hychim/unet-segmentation-on-carvana-dataset/edit
- https://amaarora.github.io/2020/09/13/unet.html

In [15]:
class Block(pl.LightningModule):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias= False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias= False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    def forward(self, x):
        return self.conv(x)

class UNET(pl.LightningModule):
    def __init__(self, in_channels=3, out_channels=1, features=[64,128,256,512], learning_rate=1e-3):
        super().__init__()
        self.learning_rate = learning_rate
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.features= features
        
        # for logging
        self.train_acc = Accuracy()
        self.valid_acc = Accuracy()
        self.test_acc = Accuracy()
        
        self.down = nn.ModuleList()
        self.up = nn.ModuleList()
        self.pool = nn.MaxPool2d(2, 2)

        for feature in features:
            self.down.append(Block(in_channels, feature))
            in_channels = feature

        self.middle_block = Block(features[-1], features[-1]*2)

        for feature in reversed(features):
            self.up.append(nn.ConvTranspose2d(feature*2, feature, 2, 2))
            self.up.append(Block(feature*2, feature) )                      # x gets concat to 2xchannel
        
        self.final_conv = nn.Conv2d(features[0], out_channels, 1)
    
    def forward(self, x):
        concats = [] 

        for down in self.down:
            x = down(x)
            concats.append(x)
            x = self.pool(x)  # Max pooling (encoding features into lower dimension tensor)
        
        concats = concats[::-1] # reverse the whole concat list for concatenation in the up layer 
        x = self.middle_block(x)
        
        # self up structure
        # (0)upconv
        # (1)Block
        # ...
        for idx in range(len(self.up)):
            x = self.up[idx](x)
            concat = concats[idx]
            if concat != x.shape:
                concat = transforms.functional.resize(concat, size=x.shape[2:])
            x = torch.cat((concat, x), dim=1)
            x = self.up[idx+1](x)

        x = self.final_conv(x)

        return(x)
    
    def trainning_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y) # or use BCEWithLogitsLoss() for soft label?
        preds = torch.argmax(logits, dim=1)
        self.train_acc.update(preds, y)
        self.log("train_loss", loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.valid_acc.update(preds, y)
        self.log("valid_loss", loss, prog_bar=True)
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

In [17]:
pl.utilities.model_summary.summarize(UNET(), -1)

   | Name                | Type            | Params
---------------------------------------------------------
0  | train_acc           | Accuracy        | 0     
1  | valid_acc           | Accuracy        | 0     
2  | test_acc            | Accuracy        | 0     
3  | down                | ModuleList      | 4.7 M 
4  | down.0              | Block           | 38.8 K
5  | down.0.conv         | Sequential      | 38.8 K
6  | down.0.conv.0       | Conv2d          | 1.7 K 
7  | down.0.conv.1       | BatchNorm2d     | 128   
8  | down.0.conv.2       | ReLU            | 0     
9  | down.0.conv.3       | Conv2d          | 36.9 K
10 | down.0.conv.4       | BatchNorm2d     | 128   
11 | down.0.conv.5       | ReLU            | 0     
12 | down.1              | Block           | 221 K 
13 | down.1.conv         | Sequential      | 221 K 
14 | down.1.conv.0       | Conv2d          | 73.7 K
15 | down.1.conv.1       | BatchNorm2d     | 256   
16 | down.1.conv.2       | ReLU            | 0     
17 | d

# Data loading

In [None]:
class SegmentationDataset(torch.utils.data.Dataset):
  def __init__(self, image_path, mask_path, transforms):
    self.images = glob.glob(os.path.join(image_path, '*.jpg'))
    self.image_path = image_path
    self.mask_path = mask_path
    self.transforms = transforms

  def __len__(self):
    return len(self.images)
  
  def __getitem__(self, idx):
    img = np.array(Image.open(self.images[idx]).convert('RGB'))
    mask = np.array(Image.open(os.path.join(self.mask_path, os.path.basename(self.images[idx]).replace('.jpg', '.png')))) 
    mask[mask == 255.0] = 1.0  
    augmentations = self.transforms(image=img, mask=mask)
    image = augmentations["image"]
    mask = augmentations["mask"]
    mask = torch.unsqueeze(mask, 0)
    mask = mask.type(torch.float32)
    return image, mask

In [None]:
class SegmentationDataModule(pl.LightningDataModule):
    
    def __init__(self, image_path, mask_path, transform, train_size=0.90, batch_size: int = 9):
        super().__init__()
        self.image_path = image_path
        self.mask_path = mask_path
        self.batch_size = batch_size
        self.transform = transform
        self.train_size = train_size
        
    def setup(self, stage = None):
        if stage in (None, 'fit'):
            ds = SegmentationDataset(self.image_path, self.mask_path, self.transform)
            train_size = math.floor(len(ds)*self.train_size)
            val_size = len(ds)-train_size
            train_ds, val_ds = torch.utils.data.random_split(ds, [train_size, val_size])
            self.train_dataset = train_ds
            self.val_dataset = val_ds
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, self.batch_size, num_workers=2, shuffle=True, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, self.batch_size, num_workers=2, persistent_workers=True)
    
    def test_dataloader(self):
        return DataLoader(self.val_dataset, self.batch_size)

mnist_dm = SegmentationDataModule()

# Training