# Baseline

In [None]:
import numpy as np
import pandas as pd

from keras.preprocessing.image import load_img
from keras import Model
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras.models import load_model
from keras.optimizers import Adam
from keras.utils.vis_utils import plot_model
from keras.preprocessing.image import ImageDataGenerator
from keras.layers import Input, Conv2D, Conv2DTranspose, MaxPooling2D, concatenate, Dropout

from random import randint

import matplotlib.pyplot as plt
plt.style.use('seaborn-white')
import seaborn as sns
sns.set_style("white")

from sklearn.model_selection import train_test_split
from skimage.transform import resize

from tqdm import tqdm_notebook

from fastai.conv_learner import *
from fastai.dataset import *
#from fastai.models.resnet import vgg_resnet50

import json

from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *

from pathlib import Path


In [None]:
torch.cuda.is_available()

In [None]:
torch.backends.cudnn.enabled

In [None]:
def show_img(im, figsize=None, ax=None, alpha=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, alpha=alpha)
    ax.set_axis_off()
    return ax

In [None]:
PATH = Path("..")
MASKS_FN = "train.csv"
META_FN = "depths.csv"
masks_csv = pd.read_csv("../input/tgs-salt-identification-challenge/train.csv")
meta_csv = pd.read_csv("../input/tgs-salt-identification-challenge/train.csv") #un-used

In [None]:
def show_img(im, figsize=None, ax=None, alpha=None):
    if not ax: fig,ax = plt.subplots(figsize=figsize)
    ax.imshow(im, alpha=alpha)
    ax.set_axis_off()
    return ax

In [None]:
TRAIN_DN = 'input/tgs-salt-identification-challenge/train/images'
MASKS_DN = 'input/tgs-salt-identification-challenge/train/masks'
sz = 128
bs = 64
nw = 16


In [None]:
class MatchedFilesDataset(FilesDataset):
    def __init__(self, fnames, y, transform, path):
        self.y=y
        assert(len(fnames)==len(y))
        super().__init__(fnames, transform, path)
    def get_y(self, i): return open_image(os.path.join(self.path, self.y[i]))
    def get_c(self): return 0

    

In [None]:
x_names = np.array([Path(TRAIN_DN)/f'{o}.png' for o in masks_csv['id']])
y_names = np.array([Path(MASKS_DN)/f'{o}.png' for o in masks_csv['id']])

In [None]:
val_idxs = list(range(1)) # 1008
((val_x,trn_x),(val_y,trn_y)) = split_by_idx(val_idxs, x_names, y_names)

In [None]:
aug_tfms = [RandomRotate(4, tfm_y=TfmType.CLASS),
            RandomFlip(tfm_y=TfmType.CLASS),
            RandomLighting(0.05, 0.05, tfm_y=TfmType.CLASS)]

In [None]:
folder = "../input/tgs-salt-identification-challenge/test/images"
onlyfiles = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]
testfile_name = pd.DataFrame({'img':onlyfiles})
print(len(onlyfiles))

In [None]:
TEST_DN = 'input/tgs-salt-identification-challenge/test/images'
t_names = np.array([Path(TEST_DN)/f'{o}' for o in testfile_name["img"]])
test_name = (t_names ,t_names)

In [None]:
tfms = tfms_from_model(resnet34, sz, crop_type=CropType.NO, tfm_y=TfmType.CLASS, aug_tfms=aug_tfms)
datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), test_name, tfms, test = None, path= PATH)
#datasets = ImageData.get_ds(MatchedFilesDataset, (trn_x,trn_y), (val_x,val_y), tfms, test = test_name, path= PATH)
md = ImageData(PATH, datasets, bs, num_workers=16, classes=None)
md.path = Path('.')
denorm = md.trn_ds.denorm

In [None]:
x,y = next(iter(md.trn_dl))

In [None]:
x.shape,y.shape

In [None]:
f = resnet34
cut,lr_cut = model_meta[f]

def get_base():
    layers = cut_model(f(True), cut)
    return nn.Sequential(*layers)

def dice(pred, targs):
    pred = (pred>0).float()
    return 2. * (pred*targs).sum() / (pred+targs).sum()


class StdUpsample (nn.Module):
    def __init__(self, nin, nout):
        super().__init__()
        self.conv = nn.ConvTranspose2d(nin, nout, 2, stride=2)
        self.bn = nn.BatchNorm2d(nout)
        
    def forward(self, x): return self.bn(F.relu(self.conv(x)))

class Upsample34(nn.Module):
    def __init__(self, rn):
        super().__init__()
        self.rn = rn
        self.features = nn.Sequential(
            rn, nn.ReLU(),
            StdUpsample(512,256),
            StdUpsample(256,256),
            StdUpsample(256,256),
            StdUpsample(256,256),
            nn.ConvTranspose2d(256, 1, 2, stride=2))
        
    def forward(self,x): return self.features(x)[:,0]


class UpsampleModel ():
    def __init__(self,model,name='upsample'):
        self.model,self.name = model,name

    def get_layer_groups(self, precompute):
        lgs = list(split_by_idxs(children(self.model.rn), [lr_cut]))
        return lgs + [children(self.model.features)[1:]]

In [None]:
cache_dir = os.path.expanduser(os.path.join('~', '.torch'))
if not os.path.exists(cache_dir):
    os.makedirs(cache_dir)
models_dir = os.path.join(cache_dir, 'models')
if not os.path.exists(models_dir):
    os.makedirs(models_dir)
!cp ../input/resnet34/resnet34.pth /tmp/.torch/models/resnet34-333f7ec4.pth

In [None]:
m_base = get_base()

In [None]:
m = to_gpu(Upsample34(m_base))
models = UpsampleModel(m)

In [None]:
#learn = ConvLearner(md, models) # built my model
learn = ConvLearner(md, models) # change the valid set
learn.opt_fn=optim.Adam
learn.crit=nn.BCEWithLogitsLoss()
learn.metrics=[accuracy_thresh(0.5), dice]

In [None]:
#learn.summary()

In [None]:
learn.freeze_to(1)
learn.lr_find()
learn.sched.plot()

In [None]:
lr=4e-2
wd=1e-7
lrs = np.array([lr/100,lr/10,lr])/2

In [None]:
learn.fit(lr,1, wds=wd, cycle_len=1,use_clr=(20,8))

In [None]:
learn.save('tmp')
learn.load('tmp')

In [None]:
learn.unfreeze()
learn.bn_freeze(True)

In [None]:
learn.fit(lrs,1,cycle_len=4,use_clr=(20,8))

In [None]:
preds_test = learn.predict() 

In [None]:
def RLenc(img, order='F', format=True):
    """
    img is binary mask image, shape (r,c)
    order is down-then-right, i.e. Fortran
    format determines if the order needs to be preformatted (according to submission rules) or not

    returns run length as an array or string (if format is True)
    """
    bytes = img.reshape(img.shape[0] * img.shape[1], order=order)
    runs = []  ## list of run lengths
    r = 0  ## the current run length
    pos = 1  ## count starts from 1 per WK
    for c in bytes:
        if (c == 0):
            if r != 0:
                runs.append((pos, r))
                pos += r
                r = 0
            pos += 1
        else:
            r += 1

    # if last run is unsaved (i.e. data ends with 1)
    if r != 0:
        runs.append((pos, r))
        pos += r
        r = 0

    if format:
        z = ''

        for rr in runs:
            z += '{} {} '.format(rr[0], rr[1])
        return z[:-1]
    else:
        return runs

In [None]:
testfile_name = pd.DataFrame({'img':onlyfiles})
testfile_name = pd.DataFrame(testfile_name.img.str.split('.',1).tolist(), columns = ['img','png'])

In [None]:
img_size_ori = 101
img_size_target = 128

def downsample(img):
    if img_size_ori == img_size_target:
        return img
    return resize(img, (img_size_ori, img_size_ori), mode='constant', preserve_range=True)

In [None]:
pred_dict = {idx: RLenc(np.round(downsample(preds_test[i]) > 0.5)) for i, idx in enumerate(tqdm_notebook(testfile_name["img"]))}

In [None]:
sub = pd.DataFrame.from_dict(pred_dict,orient='index')
sub.index.names = ['id']
sub.columns = ['rle_mask']
sub.to_csv('submission.csv')