In [None]:
import numpy as np
from matplotlib import pyplot as plt

import simulations.simpleImages as simulation
import simulations.utils as utils

# Dataset

In [None]:
N = 1
M = 128
input_images, target_masks = simulation.generate_random_data(M, M, count=N)

n_samples, x_size, y_size, n_channels = input_images.shape
n_samples, n_masks, x_size, y_size = target_masks.shape

In [None]:
print(input_images.shape)
print(target_masks.shape)

**input_images**: (1, M, M, C) -> 1 imagen de MxM con C canales

**target_masks**: (1, N, M, M) -> 1 imagen con N mascaras de imagenes de MxM

In [None]:
# Change channel-order and make 3 channels for matplot 
input_images_rgb = list(input_images.astype(np.uint8))

# Map each channel (i.e. class) to each color
target_masks_rgb = [utils.masks_to_colorimg(x) for x in target_masks]

In [None]:
simulation.show_images_samples(input_images_rgb, target_masks_rgb)

# Dataset PyTorch

In [None]:
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets, models

from simulations.simpleImages import generate_random_data

'''
    This module contains pytorch dataset classes that generates 
    simulations data.
'''

class SimDataset(Dataset):
    '''
        Pytorch Dataset of random synthetic simple images. This dataset contains simple 
        images with six figures (filled square, mesh square, circle, triangle...).

        Important: 
            dataset contains (H x W x C) in the range [0, 255], if you want to 
            use it like a tensor you have to convert the shape to C x H x W) in the range [0.0, 1.0].

            The easiest way is to apply torchvision.transform.ToTensor() indicating it in the initializer.
            Ej. SimDataset(X, tranform='torchvision.transform.ToTensor()')
    '''
    
    def __init__(self, count, transform=None):
        self.input_images, self.target_masks = generate_random_data(192, 192, count=count)        
        self.transform = transform
    
    def __len__(self):
        return len(self.input_images)
    
    def __getitem__(self, idx):        
        image = self.input_images[idx]
        mask = self.target_masks[idx]
        if self.transform:
            image = self.transform(image)
        
        return [image, mask]


In [None]:
from torchvision import transforms as torchTransforms
from torch.utils.data import DataLoader

trans = torchTransforms.Compose([
            torchTransforms.ToTensor(),
            torchTransforms.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
        ])

train_set = SimDataset(1200, transform=trans)
val_set = SimDataset(120, transform=trans)

image_datasets = { 'train' : train_set, 'val' : val_set }

batch_size = 6

dataloaders = {
    'train': DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0),
    'val': DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=0)
}

dataset_sizes = {
    x: len(image_datasets[x]) for x in image_datasets.keys()
}

print('Dataset Size: {}'.format(dataset_sizes))

# UNet 

In [None]:
import torch
import torch.nn as nn

## Encoder

In [None]:
class UNetEncode(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()

        self.encode = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2, ceil_mode=True)
        )

        self._result = None

    def forward(self, x):
        self._result = self.encode(x)
        return self._result

    def getResult(self):
        return self._result

## Decoder

In [None]:
class UNetDecode(nn.Module):
    def __init__(self, in_channels, out_channels, upsample=True, kaiming_initialization=False):
        super().__init__()

        self.decode = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )

    def forward(self, x):
        return self.decode(x)

## Unet Implementation

In [None]:
class UNet(nn.Module):
    def __init__(self, output_channels, input_channels=3):
        super().__init__()
        
        #TODO!!
        self.encode_lvl_0 = UNetEncode(input_channels, 64)
        self.encode_lvl_1 = UNetEncode(64, 128)
        self.encode_lvl_2 = UNetEncode(128, 256)
        self.encode_lvl_3 = UNetEncode(256, 512)
        self.encode_lvl_4 = UNetEncode(512, 1024)

        self.decode_lvl_4 = UNetDecode(1024, 1024) #Bottleneck
        self.decode_lvl_3 = UNetDecode(512 + 1024, 512)
        self.decode_lvl_2 = UNetDecode(256 + 512, 256)
        self.decode_lvl_1 = UNetDecode(128 + 256, 128)
        self.decode_lvl_0 = UNetDecode(64 + 128, 64)
        
        self.output = nn.Conv2d(64, output_channels, 1)
        
    def forward(self, x):
        # Encode
        x = self.encode_lvl_0(x)
        x = self.encode_lvl_1(x)
        x = self.encode_lvl_2(x)
        x = self.encode_lvl_3(x)
        x = self.encode_lvl_4(x)

        # Decode
        x = self.decode_lvl_4(x)
        x = torch.cat([x, self.encode_lvl_3.getResult()], dim=1)
        x = self.decode_lvl_3(x)
        x = torch.cat([x, self.encode_lvl_2.getResult()], dim=1)
        x = self.decode_lvl_2(x)
        x = torch.cat([x, self.encode_lvl_1.getResult()], dim=1)
        x = self.decode_lvl_1(x)
        x = torch.cat([x, self.encode_lvl_0.getResult()], dim=1)
        x = self.decode_lvl_0(x)
        
        # Output
        return self.output(x)

In [None]:
#device = torch.device('cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

n_classes = 6

model = UNet(n_classes, 3)
model = model.to(device)
model

# Loss Function (DICE)

In [None]:
def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()    

    intersection = (pred * target).sum(dim=2).sum(dim=2)
    
    loss = (1 - ((2. * intersection + smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth)))
    
    return loss.mean()

# UNet Training

In [None]:
import torch.nn.functional as F
from collections import defaultdict

In [None]:
def calc_loss(pred, target, metrics, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)
        
    pred = torch.sigmoid(pred)
    dice = dice_loss(pred, target)
    
    loss = bce * bce_weight + dice * (1 - bce_weight)
    
    metrics['bce'] += bce.data.cpu().numpy() * target.size(0)
    metrics['dice'] += dice.data.cpu().numpy() * target.size(0)
    metrics['loss'] += loss.data.cpu().numpy() * target.size(0)
    
    return loss

def print_metrics(metrics, epoch_samples, phase):    
    outputs = []
    for k in metrics.keys():
        outputs.append("{}: {:4f}".format(k, metrics[k] / epoch_samples))
        
    print("{}: {}".format(phase, ", ".join(outputs)))
    
def train_model(model, optimizer, scheduler, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 1e10

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        since = time.time()

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                for param_group in optimizer.param_groups:
                    print("LR", param_group['lr'])
                    
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            metrics = defaultdict(float)
            epoch_samples = 0
            
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)             

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = calc_loss(outputs, labels, metrics)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                epoch_samples += inputs.size(0)

            print_metrics(metrics, epoch_samples, phase)
            epoch_loss = metrics['loss'] / epoch_samples

            # deep copy the model
            if phase == 'val' and epoch_loss < best_loss:
                print("saving best model")
                best_loss = epoch_loss
                best_model_wts = copy.deepcopy(model.state_dict())

            scheduler.step()

        time_elapsed = time.time() - since
        print('{:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val loss: {:4f}'.format(best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import copy

In [None]:
# Observe that all parameters are being optimized
optimizer_ft = optim.Adam(model.parameters(), lr=1e-4)

exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=25, gamma=0.1)

model = train_model(model, optimizer_ft, exp_lr_scheduler, num_epochs=5)

# Test the model
Prepare a test dataset just for show results

In [None]:
def reverse_transform(inp):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    inp = (inp * 255).astype(np.uint8)

    return inp

def masks_to_colorimg(masks):
    colors = np.asarray([(201, 58, 64), (242, 207, 1), (0, 152, 75), (101, 172, 228),(56, 34, 132), (160, 194, 56)])

    colorimg = np.ones((masks.shape[1], masks.shape[2], 3), dtype=np.float32) * 255
    channels, height, width = masks.shape

    for y in range(height):
        for x in range(width):
            selected_colors = colors[masks[:,y,x] > 0.5]

            if len(selected_colors) > 0:
                colorimg[y,x,:] = np.mean(selected_colors, axis=0)

    return colorimg.astype(np.uint8)

class NormalizeInverse(transforms.Normalize):
    """
    This class is a inmutable implementation that undoes the normalization and 
    returns the reconstructed images in the input domain.
    """

    def __init__(self, mean, std):
        mean = torch.as_tensor(mean)
        std = torch.as_tensor(std)
        std_inv = 1 / (std + 1e-10)
        mean_inv = -mean * std_inv
        super().__init__(mean=mean_inv, std=std_inv)

    def __call__(self, tensor):
        return super().__call__(tensor.clone())

In [None]:
test_set = SimDataset(6, transform=trans)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)

In [None]:
batch_predictions = []
with torch.no_grad():
    for inputs, _ in test_loader:
        inputs = inputs.to(device)
        pred = model(inputs)
        #pred = pred.detach() #Unnecessary inside torch.no_grad() block
        batch_predictions.append(pred.cpu())
        
torch.cuda.empty_cache()

In [None]:
from random import randint
from matplotlib import pyplot as plt

inv_trans = torchTransforms.Compose([
    NormalizeInverse(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]), 
    torchTransforms.ToPILImage()
])

batch_idx = randint(0, len(batch_predictions)-1)
example_idx = randint(0, batch_size-1)

#Remember to comeback the tensor to CPU
example = batch_predictions[batch_idx][example_idx].to('cpu')
inputs, masks = test_set[batch_idx*batch_size + example_idx]

inputs, labels = next(iter(test_loader))
inputs = inputs.to(device)
labels = labels.to(device)

pred = batch_predictions[batch_idx]
pred = torch.sigmoid(pred)

In [None]:
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = [30, 30]
from simulations import utils

input_images_rgb = [reverse_transform(x) for x in inputs.cpu()]
target_masks_rgb = [masks_to_colorimg(x) for x in labels.cpu().numpy()]
pred_rgb = [masks_to_colorimg(x) for x in pred.cpu().detach().numpy()]

nSamples = len(input_images_rgb)
images = [None]*(3*nSamples)
images[::3] = input_images_rgb
images[1::3] = target_masks_rgb
images[2::3] = pred_rgb

utils.show_images(images, nRow=nSamples, nCol=3)