# Fully-supervised Semantic Segmentation

In [None]:
import sys
import os
sys.path.append("../") 

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
from pathlib import Path
import re

from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *

In [None]:
from config import (IMAGE_DATA_DIR, GT_DIR, IMAGE_DATA_TILES_DIR, GT_TILES_DIR, 
                    GT_ADJ_TILES_DIR, TILES_DIR,
                    LABELS, RED, BLACK, N1, N2, N_validation, MODEL_DIR
                   )
from util import set_seed
set_seed(seed=42)

In [None]:
BASE_DIR = Path('').absolute().parent; BASE_DIR

In [None]:
image_tiles_fnames = os.listdir(BASE_DIR / IMAGE_DATA_TILES_DIR)

In [None]:
fname = image_tiles_fnames[0]

In [None]:
open_image( BASE_DIR / IMAGE_DATA_TILES_DIR / fname)

In [None]:
open_image(BASE_DIR / GT_ADJ_TILES_DIR / fname)

In [None]:
mask = open_mask(BASE_DIR / GT_ADJ_TILES_DIR / fname)
# mask.show(figsize=(5,5), alpha=1)
mask

In [None]:
src_size = np.array(mask.shape[1:])
src_size,mask.data

### Load data

In [None]:
def get_y_fn(x):
    return BASE_DIR / GT_ADJ_TILES_DIR / x.name

In [None]:
free = gpu_mem_get_free_no_cache(); free

In [None]:
base_path = "top_mosaic_09cm_area"
prog = re.compile(fr"{base_path}(?P<area_id>\d+)_tile(?P<tile_id>\d+).tif")

def is_in_set(x, N):
    fname = x.name  # e.g.: top_mosaic_09cm_area30_tile120.tif'

    match_result = prog.search(fname)
    area_id = match_result.group('area_id')
    tile_id = match_result.group('tile_id')
    image_fname = f"{base_path}{area_id}.tif"  # e.g.: top_mosaic_09cm_area30.tif'
    return image_fname in N

is_in_set_n1 = partial(is_in_set, N=N1)
is_in_set_n2 = partial(is_in_set, N=N2)
is_in_set_nvalidation = partial(is_in_set, N=N_validation)
is_in_set_n1_or_nvalidation = partial(is_in_set, N=N1+N_validation)

codes = LABELS+[RED, BLACK]

src_size = np.array(mask.shape[1:])
src_size,mask.data
size = src_size // 2  # TODO

item_list = (SegmentationItemList.from_folder(BASE_DIR / IMAGE_DATA_TILES_DIR)  #returns SegmentationItemList
             .filter_by_func(is_in_set_n1_or_nvalidation)  #returns SegmentationItemList
             .split_by_valid_func(is_in_set_nvalidation)  #returns ItemLists(SegmentationItemList, SegmentationItemList)
             .label_from_func(get_y_fn, classes=codes)  #returns LabelLists(LabelList, SegmentationItemList)
             .transform(get_transforms(), size=size, tfm_y=True)
            )

In [None]:
bs = 64
data = item_list.databunch(bs=bs).normalize(imagenet_stats)

In [None]:
data

In [None]:
data.classes

In [None]:
data.show_batch(2, figsize=(10,7))

In [None]:
data.show_batch(2, figsize=(10,7), ds_type=DatasetType.Valid)

In [None]:
item_list

# all tiles: 4497
# 935 / 4497 = 20%
# 390 / 4497 = 8.6%

# This seems to be the desired split

## Model

Train and compare semantic segmentation networks, using the following data: Task (i) N1 pixel level labels

In [None]:
name2id = {v:k for k,v in enumerate(codes)}  # {WHITE:0, BLUE:1}
void_codes_red = name2id[RED]
void_codes_black = name2id[BLACK]

def acc_satellite(input, target):
    target = target.squeeze(1)
    mask = target != void_codes_red
    mask = target != void_codes_black
    return (input.argmax(dim=1)[mask]==target[mask]).float().mean()

In [None]:
wd=0.1

In [None]:
# lr_find(learn)
# learn.recorder.plot()

In [None]:
learn = unet_learner(data, models.resnet18, metrics=acc_satellite, wd=wd, 
#                      model_dir=MODEL_DIR # TODO
                    )

In [None]:
# learn

In [None]:
lr=3e-4

In [None]:
learn.fit_one_cycle(2, slice(lr), pct_start=0.9)

In [None]:
learn.recorder.plot_losses()
learn.recorder.plot_metrics()

In [None]:
learn.load('fs-stage-1');

In [None]:
learn.fit_one_cycle(10, slice(lr), pct_start=0.9)

In [None]:
learn.fit_one_cycle(10, slice(lr), pct_start=0.9)

In [None]:
learn.save('fs-stage-1')

In [None]:
learn.load('fs-stage-1');

In [None]:
learn.recorder.plot_losses()
learn.recorder.plot_metrics()

In [None]:
learn.show_results(rows=3, figsize=(8,9))

In [None]:
learn.unfreeze()

In [None]:
lrs = slice(lr/100,lr/1)

In [None]:
learn.fit_one_cycle(12, lrs, pct_start=0.8)

In [None]:
learn.save('fs-stage-2');

In [None]:
learn.load('fs-stage-2');

In [None]:
learn.show_results(rows=3, figsize=(8,9))