<img src="unet_arch.png" title="U-net arch.">

In [1]:
import os
import natsort
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from datetime import datetime

import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import Compose, Normalize, ToTensor, Resize
from torch.utils.tensorboard import SummaryWriter

import torch.nn as nn
import torch.nn.functional as F

In [2]:
device = 'cpu'#'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using {device} device")

Using cpu device


In [3]:
#parameters
H = 128#256
W = 128#256
size = (H,W)
batch_size = 2
num_epochs = 50
learn_rate = 1.e-4

In [4]:
#Set seed for reproducibility
def seeding(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

In [5]:
seeding(123)

## Load data

In [6]:
class CustomDataset(Dataset):
    def __init__(self, main_dir, img_transform, msk_transform):
        self.main_dir = main_dir
        self.img_transform = img_transform
        self.msk_transform = msk_transform
        self.img_dir = main_dir+'images/'
        self.msk_dir = main_dir+'masks/'
        all_images = os.listdir(self.img_dir)[:20] #OJO!!!!!
        all_masks = os.listdir(self.msk_dir)[:20] #OJO!!!!!
        self.images = natsort.natsorted(all_images)
        self.masks = natsort.natsorted(all_masks)
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        img_loc = os.path.join(self.img_dir, self.images[idx])
        img = Image.open(img_loc).convert('RGB')
        tensor_img = self.img_transform(img)
        msk_loc = os.path.join(self.msk_dir, self.masks[idx])
        msk = Image.open(msk_loc).convert('L') 
        tensor_msk = self.msk_transform(msk)
        tensor_msk /= 255 #Target must be between 0 and 1 
        return tensor_img, tensor_msk

In [7]:
#Define a transform object that takes the data to pytorch tensor form and normalizes it
img_transform = Compose( [Resize(size), ToTensor(), Normalize(mean=(0.5,),std=(0.5,))] );
msk_transform = Compose( [Resize(size), ToTensor()] );

In [8]:
train_set = CustomDataset(main_dir='./data/train/',img_transform=img_transform,msk_transform=msk_transform)
valid_set = CustomDataset(main_dir='./data/test/',img_transform=img_transform,msk_transform=msk_transform)

In [9]:
print("Dataset size:\nTrain: {0} Validation: {1}".format(len(train_set),len(valid_set)))

Dataset size:
Train: 20 Validation: 20


In [10]:
img, msk = train_set[0]
img.shape, msk.shape

(torch.Size([3, 128, 128]), torch.Size([1, 128, 128]))

In [11]:
train_loader = DataLoader(
    dataset = train_set,
    batch_size = batch_size,
    shuffle = True,
    num_workers = 2
)
valid_loader = DataLoader(
    dataset = valid_set,
    batch_size = batch_size,
    shuffle = False,
    num_workers = 2
)

## The model

In [12]:
#Block with 2 succesive convolutions (with same padding in this case to keep the borders)
class ConvBlock(nn.Module):
    #Receives number of input and output channels
    def __init__(self, in_c, out_c):
        super(ConvBlock,self).__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1) #same
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1) #same
        self.bn = nn.BatchNorm2d(out_c) #Normalize each batch (zero mean->no bias)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        #First convolution
        x = self.conv1(x)
        x = self.bn(x)
        x = self.relu(x)
        #Second convolution
        x = self.conv2(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
    

#Block that performs the pool operation reducing the size in half.
#It also returns the original input to do skip-layer connection
# to the decoder.
class EncoderBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super(EncoderBlock,self).__init__()
        self.conv = ConvBlock(in_c, out_c)
        self.pool = nn.MaxPool2d(2)
        
    def forward(self, x):
        x = self.conv(x)
        p = self.pool(x)
        return x, p

    
#Block that performs a transpose convolution to upsample the input.
#Use instead of pre-defined interpolation so that parameter
# learning also takes place.
#It also adds a crop of the corresponding encoder output to
# make the skip connection. In this case no cropping is needed,
# as the layers have the same size because of same padding.
class DecoderBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super(DecoderBlock,self).__init__()
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = ConvBlock(2*out_c, out_c) #in channels are x2 because of the concatenation
        
    def forward(self, x, skip):
        x = self.up(x)
        x = torch.cat([x, skip], axis=1) #Concatenate along the channel dimension
        x = self.conv(x)
        return x
    

#Our network will take an rgb image
class Unet(nn.Module):
    def __init__(self):
        super(Unet,self).__init__()
        #Encoder
        self.e1 = EncoderBlock(3, 64) #Input rgb image and use 64 filters
        self.e2 = EncoderBlock(64, 128) #reduce image size by half and duplicate no. of filters
        self.e3 = EncoderBlock(128, 256) #reduce image size by half and duplicate no. of filters
        self.e4 = EncoderBlock(256, 512) #reduce image size by half and duplicate no. of filters
        
        #Bottleneck
        self.b = ConvBlock(512, 1024) #Convolution without pooling
        
        #Decoder
        self.d1 = DecoderBlock(1024, 512) #upscale the image and reduce by half the no. of filters
        self.d2 = DecoderBlock(512, 256) #upscale the image and reduce by half the no. of filters
        self.d3 = DecoderBlock(256, 128) #upscale the image and reduce by half the no. of filters
        self.d4 = DecoderBlock(128, 64) #upscale the image and reduce by half the no. of filters
        
        #Classifier
        self.out = nn.Conv2d(64, 1, kernel_size=1, padding=0) #output channels is the number of output classes
        
    def forward(self, x):
        s1, x = self.e1(x)
        s2, x = self.e2(x)
        s3, x = self.e3(x)
        s4, x = self.e4(x)
        
        x = self.b(x)
        
        x = self.d1(x, s4)
        x = self.d2(x, s3)
        x = self.d3(x, s2)
        x = self.d4(x, s1)
       
        x = self.out(x)
        return x

In [13]:
model = Unet().to(device)

## Loss function

In [14]:
#From https://github.com/nikhilroxtomar/Retina-Blood-Vessel-Segmentation-in-PyTorch/blob/main/UNET/loss.py
class DiceBCELoss(nn.Module):
    def __init__(self, weight=None, size_average=True):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):

        #comment out if your model contains a sigmoid or equivalent activation layer
        inputs = torch.sigmoid(inputs)

        #flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)

        intersection = (inputs * targets).sum()
        dice_loss = 1 - (2.*intersection + smooth)/(inputs.sum() + targets.sum() + smooth)
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss

        return Dice_BCE

In [15]:
#Binary cross entropy
loss_fn = DiceBCELoss()

## Optimizer

In [16]:
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, verbose=False)

## Training epoch loop

In [17]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0
    last_loss = 0
    
    #Use an enumeration instead of an iterator to track the batch index and do reporting
    for i, data in enumerate(train_loader):
        #training instances are input + mask pairs
        inputs, masks = data
        inputs, masks = inputs.to(device, dtype=torch.float32), masks.to(device, dtype=torch.float32)
        
        #zero the gradients
        optimizer.zero_grad()
        
        #make predictions for this batch
        outputs = model(inputs)
        
        ##*****************************************
        #for layer in model.children():
        #    if isinstance(layer, nn.Conv2d):
        #        print(torch.isfinite(layer.state_dict()['weight'].grad))
        ##*****************************************
        
        #compute loss and its gradient
        loss = loss_fn(outputs, masks)
        loss.backward()
        
        #adjust the weights
        optimizer.step()
        
        #gather data and report
        running_loss += loss.item()
        if i % 10 == 9:
            last_loss = running_loss / 10 #loss per batch
            print(' batch {} loss: {}'.format(i+1, last_loss))
            tb_x = epoch_index * len(train_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0
        
        del inputs
        del masks
        torch.cuda.empty_cache()
        
    return last_loss

## Main loop

In [18]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/fashion_trainer_{}'.format(timestamp))

In [19]:
epoch_number = 0
best_vloss = 1_000_000.
losslist = []
vlosslist = []

torch.autograd.set_detect_anomaly(True)

for epoch in range(num_epochs):
    print('EPOCH {}'.format(epoch_number + 1))
    
    #Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)
    
    #To do reporting gradients do not need to be on
    model.train(False)
    
    #Pull batches fron validation data to validate
    running_vloss = 0
    for i, vdata in enumerate(valid_loader):
        vinputs, vmasks = vdata
        vinputs, vmasks = vinputs.to(device, dtype=torch.float32), vmasks.to(device, dtype=torch.float32)
        voutputs = model(vinputs)
        vloss = loss_fn(voutputs, vmasks)
        running_vloss += vloss
    
    avg_vloss = running_vloss / (i+1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))
    #losslist.append(avg_loss)
    #vlosslist.append(avg_vloss)
    
    
    #Log the runnig loss averaged per batch for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                      {'Training': avg_loss, 'Validation': avg_vloss},
                      epoch_number + 1)
    writer.flush()
    
    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)
        
    epoch_number += 1
    del vinputs
    del vmasks
    torch.cuda.empty_cache()

EPOCH 1
 batch 10 loss: 1.5884505033493042
LOSS train 1.5884505033493042 valid 1.7120850086212158
EPOCH 2
 batch 10 loss: 1.5169926524162292
LOSS train 1.5169926524162292 valid 2.0070128440856934
EPOCH 3
 batch 10 loss: 1.444407308101654
LOSS train 1.444407308101654 valid 2.1096577644348145
EPOCH 4
 batch 10 loss: 1.385660457611084
LOSS train 1.385660457611084 valid 2.438966751098633
EPOCH 5
 batch 10 loss: 1.3515211343765259
LOSS train 1.3515211343765259 valid 1.5657553672790527
EPOCH 6
 batch 10 loss: 1.3362690091133118
LOSS train 1.3362690091133118 valid 2.3851513862609863
EPOCH 7
 batch 10 loss: 1.3271088123321533
LOSS train 1.3271088123321533 valid 1.6018577814102173
EPOCH 8
 batch 10 loss: 1.3113896608352662
LOSS train 1.3113896608352662 valid 1.5519403219223022
EPOCH 9
 batch 10 loss: 1.3009114861488342
LOSS train 1.3009114861488342 valid 1.5311729907989502
EPOCH 10
 batch 10 loss: 1.2927854776382446
LOSS train 1.2927854776382446 valid 1.5210115909576416
EPOCH 11
 batch 10 loss:

In [20]:
#model.cpu()