In [1]:
import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
import os
import multiprocessing as mp

from code.models import basicunet, resnetunet
from code.datasets import TGSAugDataset
from code.configs import *
from code.train import *
from code.losses import FocalRobustLoss
from code.metrics import *
from code.augmentations import *
from code.utils import *
from torch.utils.data import DataLoader
from IPython.display import clear_output
from code.inference import *

%matplotlib inline

In [2]:
def train_augment(image, mask):
    if np.random.rand() < 0.5:
        image, mask = do_horizontal_flip2(image, mask)

    if mask.sum() == 0:
        if np.random.rand() < 0.5:
            image, mask = do_elastic_transform2(image, mask, grid=10,
                                                distort=np.random.uniform(0, 0.15))
        if np.random.rand() < 0.5:
            image, mask = do_random_shift_scale_crop_pad2(image, mask, 0.2)
        if np.random.rand() < 0.5:
            angle = np.random.uniform(0, 15)
            scale = compute_scale_from_angle(angle * np.pi / 180)
            image, mask = do_shift_scale_rotate2(image, mask, dx=0, dy=0, scale=scale,
                                                 angle=angle)
        if np.random.rand() < 0.5:
            image, mask = do_random_perspective2(image, mask, 0.3)
    else:
        c = np.random.choice(4)
        if c == 0:
            image, mask = do_elastic_transform2(image, mask, grid=10,
                                                distort=np.random.uniform(0, 0.15))
        elif c == 1:
            image, mask = do_random_shift_scale_crop_pad2(image, mask, 0.2)
        elif c == 2:
            angle = np.random.uniform(0, 10)
            scale = compute_scale_from_angle(angle * np.pi / 180)
            image, mask = do_shift_scale_rotate2(image, mask, dx=0, dy=0, scale=scale,
                                                 angle=angle)
        elif c == 3:
            image, mask = do_random_perspective2(image, mask, 0.3)
            
    if np.random.rand() < 0.3:
        c = np.random.choice(3)
        if c == 0:
            image = do_brightness_shift(image, np.random.uniform(-0.1, 0.1))  # 0.05
        elif c == 1:
            image = do_brightness_multiply(image, np.random.uniform(1 - 0.08, 1 + 0.08))  # 0.05
        elif c == 2:
            image = do_gamma(image, np.random.uniform(1 - 0.08, 1 + 0.08))  # 0.05
    
    image, mask = do_resize2(image, mask, 202, 202)
    image, mask = do_center_pad_to_factor2(image, mask, factor=64)
    
    return image, mask


def flip_augment(image, mask):
    if np.random.rand() < 0.5:
        image, mask = do_horizontal_flip2(image, mask)
    image, mask = do_resize2(image, mask, 202, 202)
    image, mask = do_center_pad_to_factor2(image, mask, factor=64)
    return image, mask


def test_augment(image, mask):
    image, mask = do_resize2(image, mask, 202, 202)
    image, mask = do_center_pad_to_factor2(image, mask, factor=64)
    return image, mask

In [3]:
train_ds = TGSAugDataset(augmenter=train_augment, path=os.path.join(PATH_TO_SALT_CV, "fold-1/train"), 
                         path_to_depths=PATH_TO_DEPTHS, progress_bar=True)
valid_ds = TGSAugDataset(augmenter=test_augment, path=os.path.join(PATH_TO_SALT_CV, "fold-1/valid"), 
                         path_to_depths=PATH_TO_DEPTHS, progress_bar=True)

HBox(children=(IntProgress(value=0, max=3136), HTML(value='')))




HBox(children=(IntProgress(value=0, max=784), HTML(value='')))




In [4]:
train_dl = DataLoader(train_ds, batch_size=16, num_workers=4)
valid_dl = DataLoader(valid_ds, batch_size=16, num_workers=4)

In [5]:
class UNetResNet34Wrapped(resnetunet.UNetResNet34):
    
    def __init__(self, device):
        super().__init__()
        self.device = device
        self.to(device)
        
    def __call__(self, image, **kwargs):
        logits = super().__call__(image)
        return {"logits": logits[:,0]}


class BasicUNetWrapped(basicunet.BasicUNet):
    
    def __init__(self, device):
        super().__init__()
        self.device = device
        self.to(device)
        
    def __call__(self, image, **kwargs):
        image = image.unsqueeze(1)
        logits = super().__call__(image)
        return {"logits": logits[:,0]}


class FocalLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.loss_fn = FocalRobustLoss(gamma=2.0, alpha=0.05)
        
    def forward(self, logits, mask, **kwargs):
        logits = logits[:,26:-26,26:-26].contiguous()
        mask = mask[:,26:-26,26:-26].contiguous()
        loss = self.loss_fn(logits, mask)
        return loss
    
    
class LovaszLoss(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.loss_fn = lovasz_elu
        
    def forward(self, logits, mask, **kwargs):
        logits = logits[:,26:-26,26:-26].contiguous()
        mask = mask[:,26:-26,26:-26].contiguous()
        loss = self.loss_fn(logits, mask)
        return loss
    
    
def lossWrapped(logits, mask, **kwargs):
    logits = logits[:,27:-27,27:-27].contiguous()
    mask = mask[:,27:-27,27:-27].contiguous()
    loss = lovasz_elu(logits, mask)
    return loss.item()
    
    
def sigmoid(logits):
    return 1. / (1 + np.exp(-logits))
    

def meanAPWrapped(logits, mask, treashold, **kwargs):
    logits_cpu = logits[:,27:-27,27:-27].cpu().detach().numpy()
    mask_cpu = mask[:,27:-27,27:-27].cpu().detach().numpy()
    return meanAP2d(sigmoid(logits_cpu), mask_cpu, treashold)


def meanIoUWrapped(logits, mask, treashold, **kwargs):
    logits_cpu = logits[:,27:-27,27:-27].cpu().detach().numpy()
    mask_cpu = mask[:,27:-27,27:-27].cpu().detach().numpy()
    return meanIoU2d(sigmoid(logits_cpu), mask_cpu, treashold)

###################################################################

def meanSoftIoUWrapped(logits, mask, **kwargs):
    logits_cpu = logits[:,13:-14,13:-14].cpu().detach().numpy()
    mask_cpu = mask[:,13:-14,13:-14].cpu().detach().numpy()
    return meanSoftIoU2d(sigmoid(logits_cpu), mask_cpu)


def meanAccuracyWrapped(logits, mask, **kwargs):
    logits_cpu = logits[:,13:-14,13:-14].cpu().detach().numpy()
    mask_cpu = mask[:,13:-14,13:-14].cpu().detach().numpy()
    return meanAccuracy2d(sigmoid(logits_cpu), mask_cpu)

In [6]:
from torch.nn import functional as F
from torch.autograd import Variable


def lovasz_hinge(logits, labels, per_image=True, ignore=None):
    """
    Binary Lovasz hinge loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      per_image: compute the loss per image instead of per batch
      ignore: void class id
    """
    if per_image:
        loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
                          for log, lab in zip(logits, labels))
    else:
        loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore))
    return loss


def lovasz_elu(logits, labels, per_image=True, ignore=None):
    """
    Binary Lovasz hinge loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      per_image: compute the loss per image instead of per batch
      ignore: void class id
    """
    if per_image:
        loss = mean(lovasz_elu_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
                          for log, lab in zip(logits, labels))
    else:
        loss = lovasz_elu_flat(*flatten_binary_scores(logits, labels, ignore))
    return loss


def lovasz_sigmoid(logits, labels, per_image=True, ignore=None):
    """
    Binary Lovasz sigmoid loss
      logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty)
      labels: [B, H, W] Tensor, binary ground truth masks (0 or 1)
      per_image: compute the loss per image instead of per batch
      ignore: void class id
    """
    if per_image:
        loss = mean(lovasz_sigmoid_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore))
                          for log, lab in zip(logits, labels))
    else:
        loss = lovasz_sigmoid_flat(*flatten_binary_scores(logits, labels, ignore))
    return loss


def lovasz_hinge_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    if len(labels) == 0:
        # only void pixels, the gradients should be 0
        return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * signs)
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]    
    grad = lovasz_grad(gt_sorted)    
    loss = torch.dot(F.relu(errors_sorted), grad)
    return loss


def lovasz_elu_flat(logits, labels):
    """
    Binary Lovasz hinge loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    #if len(labels) == 0:
        # only void pixels, the gradients should be 0
    #    return logits.sum() * 0.
    signs = 2. * labels.float() - 1.
    errors = (1. - logits * Variable(signs))
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]    
    grad = lovasz_grad(gt_sorted)    
    loss = torch.dot(F.elu(errors_sorted) + 1, Variable(grad))
    return loss


def lovasz_sigmoid_flat(logits, labels):
    """
    Binary Lovasz sigmoid loss
      logits: [P] Variable, logits at each prediction (between -\infty and +\infty)
      labels: [P] Tensor, binary ground truth labels (0 or 1)
      ignore: label to ignore
    """
    errors = (labels - F.sigmoid(logits)).abs()
    errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
    perm = perm.data
    gt_sorted = labels[perm]    
    grad = lovasz_grad(gt_sorted)    
    loss = torch.dot(errors_sorted, Variable(grad))
    return loss


def flatten_binary_scores(scores, labels, ignore=None):
    """
    Flattens predictions in the batch (binary case)
    Remove labels equal to 'ignore'
    """
    scores = scores.view(-1)
    labels = labels.view(-1)
    if ignore is None:
        return scores, labels
    valid = (labels != ignore)
    vscores = scores[valid]
    vlabels = labels[valid]
    return vscores, vlabels


def lovasz_grad(gt_sorted):
    """
    Computes gradient of the Lovasz extension w.r.t sorted errors
    See Alg. 1 in paper
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1. - intersection / union
    if p > 1: # cover 1-pixel case
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard


def mean(l, ignore_nan=False, empty=0):
    """
    nanmean compatible with generators.
    """
    l = iter(l)
    if ignore_nan:
        l = ifilterfalse(np.isnan, l)
    try:
        n = 1
        acc = next(l)
    except StopIteration:
        if empty == 'raise':
            raise ValueError('Empty mean')
        return empty
    for n, v in enumerate(l, 2):
        acc += v
    if n == 1:
        return acc
    return acc / n

In [7]:
device = torch.device("cuda")
model = UNetResNet34Wrapped(device)
model.load_pretrain(PATH_TO_RESNET34)
sum([p.nelement() for p in model.parameters()]) - sum([p.nelement() for p in model.resnet.parameters()])

5422145

In [8]:
optim = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
loss_fn = FocalLoss()

In [9]:
logger = CSVLogger("log1.txt", ["epoch", "train_loss", "valid_loss",
                                "train_mAP_0.5",
                                "train_mIoU_0.5",
                                "valid_mAP_0.5", 
                                "valid_mIoU_0.5"], 
                   log_time=True)

In [10]:
mAP_cp = BestLastCheckpointer("mAP1")
mIoU_cp = BestLastCheckpointer("mIoU1")

In [11]:
freeze(model.resnet)

In [12]:
logger.resetClock()

In [13]:
for epoch in range(5):  # 10
    print("Epoch:", epoch)
    
    loss = train_epoch_fn(model, train_dl, optim, loss_fn, verbose=1, loss_file="focal_loss.txt")
    
    train_metrics = eval_fn(model, train_dl, 
            {
                "train_loss": lossWrapped, 
                "train_mAP_0.5": lambda logits, mask, **kwargs: meanAPWrapped(logits, mask, 0.5), 
                "train_mIoU_0.5": lambda logits, mask, **kwargs: meanIoUWrapped(logits, mask, 0.5), 
            }, 
            verbose=1)
    valid_metrics = eval_fn(model, valid_dl, 
            {
                "valid_loss": lossWrapped, 
                "valid_mAP_0.5": lambda logits, mask, **kwargs: meanAPWrapped(logits, mask, 0.5), 
                "valid_mIoU_0.5": lambda logits, mask, **kwargs: meanIoUWrapped(logits, mask, 0.5), 
            }, 
                            verbose=1)
    
    logger.write(epoch=epoch, **train_metrics, **valid_metrics)
    
    mAP_cp.update(-valid_metrics["valid_mAP_0.5"], 
                    model=model, optim=optim, epoch=epoch)
    mIoU_cp.update(-valid_metrics["valid_mIoU_0.5"],
                    model=model, optim=optim, epoch=epoch)
    
    clear_output()

In [14]:
unfreeze(model)

In [15]:
for epoch in range(10):  # 20
    print("Epoch:", epoch)
    
    loss = train_epoch_fn(model, train_dl, optim, loss_fn, verbose=1, loss_file="focal_loss.txt")
    
    train_metrics = eval_fn(model, train_dl, 
            {
                "train_loss": lossWrapped, 
                "train_mAP_0.5": lambda logits, mask, **kwargs: meanAPWrapped(logits, mask, 0.5), 
                "train_mIoU_0.5": lambda logits, mask, **kwargs: meanIoUWrapped(logits, mask, 0.5), 
            }, 
            verbose=1)
    valid_metrics = eval_fn(model, valid_dl, 
            {
                "valid_loss": lossWrapped, 
                "valid_mAP_0.5": lambda logits, mask, **kwargs: meanAPWrapped(logits, mask, 0.5), 
                "valid_mIoU_0.5": lambda logits, mask, **kwargs: meanIoUWrapped(logits, mask, 0.5), 
            }, 
                            verbose=1)
    
    logger.write(epoch=epoch, **train_metrics, **valid_metrics)
    
    mAP_cp.update(-valid_metrics["valid_mAP_0.5"], 
                    model=model, optim=optim, epoch=epoch)
    mIoU_cp.update(-valid_metrics["valid_mIoU_0.5"],
                    model=model, optim=optim, epoch=epoch)
    
    clear_output()

In [16]:
optim = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
loss_fn = LovaszLoss()

In [17]:
for epoch in range(60):  # 60
    print("Epoch:", epoch)
    
    loss = train_epoch_fn(model, train_dl, optim, loss_fn, verbose=1, loss_file="lovasz_loss.txt")
    
    train_metrics = eval_fn(model, train_dl, 
            {
                "train_loss": lossWrapped, 
                "train_mAP_0.5": lambda logits, mask, **kwargs: meanAPWrapped(logits, mask, 0.5), 
                "train_mIoU_0.5": lambda logits, mask, **kwargs: meanIoUWrapped(logits, mask, 0.5), 
            }, 
            verbose=1)
    valid_metrics = eval_fn(model, valid_dl, 
            {
                "valid_loss": lossWrapped, 
                "valid_mAP_0.5": lambda logits, mask, **kwargs: meanAPWrapped(logits, mask, 0.5), 
                "valid_mIoU_0.5": lambda logits, mask, **kwargs: meanIoUWrapped(logits, mask, 0.5), 
            }, 
                            verbose=1)
    
    logger.write(epoch=epoch, **train_metrics, **valid_metrics)
    
    mAP_cp.update(-valid_metrics["valid_mAP_0.5"], 
                    model=model, optim=optim, epoch=epoch)
    mIoU_cp.update(-valid_metrics["valid_mIoU_0.5"],
                    model=model, optim=optim, epoch=epoch)
    
    clear_output()

In [18]:
mAP_cp.load("best", model=model, optim=optim)
set_learning_rate(optim, 0.00002)
freeze(model, include=(torch.nn.BatchNorm2d,))

In [19]:
for epoch in range(30):  # 30
    print("Epoch:", epoch)
    
    loss = train_epoch_fn(model, train_dl, optim, loss_fn, verbose=1, loss_file="lovasz_loss.txt")
    
    train_metrics = eval_fn(model, train_dl, 
            {
                "train_loss": lossWrapped, 
                "train_mAP_0.5": lambda logits, mask, **kwargs: meanAPWrapped(logits, mask, 0.5), 
                "train_mIoU_0.5": lambda logits, mask, **kwargs: meanIoUWrapped(logits, mask, 0.5), 
            }, 
            verbose=1)
    valid_metrics = eval_fn(model, valid_dl, 
            {
                "valid_loss": lossWrapped, 
                "valid_mAP_0.5": lambda logits, mask, **kwargs: meanAPWrapped(logits, mask, 0.5), 
                "valid_mIoU_0.5": lambda logits, mask, **kwargs: meanIoUWrapped(logits, mask, 0.5), 
            }, 
                            verbose=1)
    
    logger.write(epoch=epoch, **train_metrics, **valid_metrics)
    
    mAP_cp.update(-valid_metrics["valid_mAP_0.5"], 
                    model=model, optim=optim, epoch=epoch)
    mIoU_cp.update(-valid_metrics["valid_mIoU_0.5"],
                    model=model, optim=optim, epoch=epoch)
    
    clear_output()

In [20]:
train_ds.augmenter = flip_augment
mAP_cp.load("best", model=model, optim=optim)
set_learning_rate(optim, 0.00001)

In [23]:
for epoch in range(10):  # 5
    print("Epoch:", epoch)
    
    loss = train_epoch_fn(model, train_dl, optim, loss_fn, verbose=1, loss_file="lovasz_loss.txt")
    
    train_metrics = eval_fn(model, train_dl, 
            {
                "train_loss": lossWrapped, 
                "train_mAP_0.5": lambda logits, mask, **kwargs: meanAPWrapped(logits, mask, 0.5), 
                "train_mIoU_0.5": lambda logits, mask, **kwargs: meanIoUWrapped(logits, mask, 0.5), 
            }, 
            verbose=1)
    valid_metrics = eval_fn(model, valid_dl, 
            {
                "valid_loss": lossWrapped, 
                "valid_mAP_0.5": lambda logits, mask, **kwargs: meanAPWrapped(logits, mask, 0.5), 
                "valid_mIoU_0.5": lambda logits, mask, **kwargs: meanIoUWrapped(logits, mask, 0.5), 
            }, 
                            verbose=1)
    
    logger.write(epoch=epoch, **train_metrics, **valid_metrics)
    
    mAP_cp.update(-valid_metrics["valid_mAP_0.5"], 
                    model=model, optim=optim, epoch=epoch)
    mIoU_cp.update(-valid_metrics["valid_mIoU_0.5"],
                    model=model, optim=optim, epoch=epoch)
    
    clear_output()