In [None]:
import os
os.chdir("..")

import json
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torch.utils.data import DataLoader

from glob import glob

from lib import *

import albumentations as albu

import segmentation_models_pytorch as smp
from utils.trainer import Trainer
from utils.data.datasets import CustomDataset
from utils.data.datasets import get_preprocessing
from utils.scorer import DiceLoss

In [None]:
augmentations = [albu.HorizontalFlip(),
                 albu.OneOf([albu.IAAAdditiveGaussianNoise(), 
                             albu.GaussNoise()], p=0.2),
                 albu.OneOf([albu.MotionBlur(p=0.2), 
                             albu.MedianBlur(blur_limit=3, p=0.1), 
                             albu.Blur(blur_limit=3, p=0.1)], p=0.2),
                 albu.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.2, rotate_limit=30, p=0.15),
                 albu.ToGray(p=0.1)]
transforms = albu.Compose(augmentations, p=0.5)

transforms = None

In [None]:
train_path = "./data/train"
val_path = "./data/valid"
test_path = "./data/test"
style_transfer_path = None #val_path

encoder = 'resnet152'
encoder_weights = 'imagenet'
device = 'cuda'
lr = 1e-4
mixup_proba = None #0.1
batch_size = 16
n_epoch = 15
num_workers = 0
model_name = f'unet_{encoder}_trnsfrm={int(transforms is not None)}_mxp={int(mixup_proba is not None)}_stl={int(style_transfer_path is not None)}'

In [None]:
model = smp.Unet(encoder_name=encoder, encoder_weights=encoder_weights)
preprocessing_fn = smp.encoders.get_preprocessing_fn(encoder, encoder_weights)

In [None]:
train_dataset = CustomDataset(data_path=train_path, transforms=transforms,\
                              preprocessing=get_preprocessing(preprocessing_fn), mixup_proba=mixup_proba,
                              style_transfer_path=style_transfer_path)
val_dataset = CustomDataset(data_path=val_path, transforms=transforms,\
                            preprocessing=get_preprocessing(preprocessing_fn), style_transfer_path=None)
    
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [None]:
trainer = Trainer(model, device=device, model_name=model_name)

In [None]:
history = trainer.train(train_dataloader=train_loader, val_dataloader=val_loader , n_epoch=n_epoch,\
              optim=torch.optim.Adam, weight_decay=0.0,\
              schedul=None, loss=DiceLoss(), lr=lr, show_results=True,\
              saved_models_dir=None, verbose=True, early_stopping=True, max_gap=2, gamma=None)

In [None]:
history.pop('loss_train')
history.pop('loss_val')
pd.DataFrame.from_dict(history, orient='index', columns=[model_name]).to_excel(f'./results/{model_name}.xls')

In [None]:
test_dataset = CustomDataset(data_path=test_path, masks=False, transforms=None, preprocessing=get_preprocessing(preprocessing_fn))
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

In [None]:
th = 0.6
model.eval()
files = os.listdir(test_path)
pred_masks = []
pred_masks_rle = []
paths_to_imgs = []
for img_id, src in enumerate(test_loader):
    mask = model(src.to(device))
    mask = mask.detach().cpu().numpy().reshape(320, 320)
    mask = (mask >= th).astype('uint8')
    mask = mask*255
    
    pred_masks.append(mask)
    pred_masks_rle.append(encode_rle(mask))
    paths_to_imgs.append(f"{test_path}/{files[img_id]}")
    
    img = np.array(Image.open(f"{test_path}/{files[img_id]}"))
    mask = cv2.resize(mask, (img.shape[1], img.shape[0]))
    show_img_with_mask(img, mask)

In [None]:
valid_dataset = CustomDataset(data_path=val_path, masks=False, transforms=None, preprocessing=get_preprocessing(preprocessing_fn))
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=0)

In [None]:
th = 0.6
model.eval()
files = os.listdir(val_path)
pred_masks = []
pred_masks_rle = []
paths_to_imgs = []
for img_id, src in enumerate(valid_loader):
    img = np.array(Image.open(f"{val_path}/{files[img_id]}"))
    mask = model(src.to(device))
    mask = mask.detach().cpu().numpy().reshape(320, 320)
    mask = (mask >= th).astype('uint8')
    mask = mask*255
    mask = cv2.resize(mask, (img.shape[1], img.shape[0]))
    mask[mask > 0] = 255
    pred_masks.append(mask.astype('uint8'))
    pred_masks_rle.append(encode_rle(mask))
    paths_to_imgs.append(f"{val_path}/{files[img_id]}")

    show_img_with_mask(img, mask)

In [None]:
ids = [int(x.split('.')[0]) for x in files]
df = pd.DataFrame(ids, columns=['id'])
df['rle_mask'] = pred_masks_rle

In [None]:
df.to_csv('pred_valid.csv', index=False)

In [None]:
_ = get_html(paths_to_imgs, pred_masks, path_to_save="./results/example")