In [None]:
!pip install git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
import random
import time
import warnings
warnings.simplefilter("ignore")

import os
from albumentations import *
from albumentations.pytorch import ToTensor
import cv2
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import segmentation_models_pytorch as smp
from sklearn.model_selection import KFold
import tifffile as tiff
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset, sampler
from tqdm import tqdm_notebook as tqdm
import gc

In [None]:
BATCH_SIZE = 16
DEVICE = ('cuda' if torch.cuda.is_available() else 'cpu')
NUM_WORKERS = 4
NUM_EPOCHS = 50

DATA = '../input/hubmap-kidney-segmentation/test/'
MASKS = '../input/hubmap-256x256/masks'
TRAIN = '../input/hubmap-256x256/train'

In [None]:
mean = np.array([0.65459856,0.48386562,0.69428385])
std = np.array([0.15167958,0.23584107,0.13146145])

def img2tensor(img, dtype:np.dtype=np.float32):
    if img.ndim==2: 
        img=np.expand_dims(img, 2)
    img=np.transpose(img, (2, 0, 1))
    return torch.from_numpy(img.astype(dtype, copy=False))

class HuBMAPDataset(Dataset):
    def __init__(self, img_filenames, train=True, tfms=None):
        self.img_filenames = img_filenames
        self.train = train
        self.tfms = tfms
    
    def __len__(self):
        return len(self.img_filenames)
    
    def __getitem__(self, idx):
        fname = self.img_filenames[idx]
        imgs=cv2.cvtColor(cv2.imread(os.path.join(TRAIN, fname)), cv2.COLOR_BGR2RGB)
        masks=cv2.imread(os.path.join(MASKS, fname), cv2.IMREAD_GRAYSCALE)
        if self.tfms is not None:
            augmented=self.tfms(image=imgs, mask=masks)
            imgs, masks=augmented['image'], augmented['mask']
        return img2tensor((imgs/255.0-mean)/std), img2tensor(masks)


In [None]:
def get_augmentation_train(p=1.0):
    return Compose([
        HorizontalFlip(),
        VerticalFlip(),
    ], p=p)

In [None]:
def dice_loss(pred, target, smooth = 1.):
    pred = pred.contiguous()
    target = target.contiguous()
    inter = (pred*target).sum(dim=2).sum(dim=2)

    loss = (1-((2.0*inter+smooth) / (pred.sum(dim=2).sum(dim=2) + target.sum(dim=2).sum(dim=2) + smooth))) 
    return loss.mean()

In [None]:
def calc_loss(pred, target, bce_weight=0.5):
    bce = F.binary_cross_entropy_with_logits(pred, target)
    pred = F.sigmoid(pred)

    dice = dice_loss(pred, target)

    loss = bce * bce_weight + dice * (1 - bce_weight)
    
    return loss.to(DEVICE)

In [None]:
def UnetDenseNet():
    return smp.Unet(
    encoder_name='densenet201',
    encoder_weights='imagenet',
    in_channels=3,
    classes=1)

In [None]:
dataset_filenames = os.listdir(TRAIN)

In [None]:
len(dataset_filenames)

In [None]:
dataset_filenames[0]

In [None]:
random.seed(42)
random.shuffle(dataset_filenames)

train_filenames = dataset_filenames[:7664]
valid_filenames = dataset_filenames[7664:]

In [None]:
def train_one_epoch(model, dataloader_train, dataloader_valid, optimizer):
    #training phase
    model.train()
    train_loss = 0
    for i, (imgs, masks) in enumerate(dataloader_train):
        imgs = imgs.to(DEVICE)
        masks = masks.to(DEVICE)
        #forward pass
        outputs = model(imgs)
        #cal loss and backward
        loss = calc_loss(outputs, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(dataloader_train)
    
    #validating phase
    model.eval()
    valid_loss = 0
    with torch.no_grad():
        for i, (imgs, masks) in enumerate(dataloader_valid):
            imgs = imgs.to(DEVICE)
            masks = masks.to(DEVICE)
            outputs = model(imgs)
            loss = calc_loss(outputs, masks)
            valid_loss += loss.item()
    valid_loss /=len(dataloader_valid)
    print(f'EPOCH: {epoch + 1} - train loss: {train_loss} -  valid_loss: {valid_loss}')
    return train_loss, valid_loss

In [None]:
best_valid_loss = 0
ds_t = HuBMAPDataset(train_filenames, train=True, tfms=get_augmentation_train())
ds_v = HuBMAPDataset(valid_filenames, train=False)
dataloader_t = torch.utils.data.DataLoader(ds_t, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
dataloader_v = torch.utils.data.DataLoader(ds_v, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)
model =  UnetDenseNet().to(DEVICE)
optimizer = torch.optim.Adam([
    {'params': model.parameters(), 'lr': 1e-3},
])

train_loss = 0
valid_loss = 0

for epoch in tqdm(range(NUM_EPOCHS)):
    train_loss, valid_loss = train_one_epoch(model, dataloader_t, dataloader_v, optimizer)
    if best_valid_loss == 0:
        best_valid_loss = valid_loss
    if best_valid_loss >= valid_loss:
        best_valid_loss = valid_loss
        torch.save(model, 'best_unet_model.pth')

    gc.collect()