In [None]:
import sys
import os
sys.path.append("../") 

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
from pathlib import Path
import re

from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *

In [None]:
from config import (IMAGE_DATA_DIR, GT_DIR, IMAGE_DATA_TILES_DIR, GT_TILES_DIR, 
                    GT_ADJ_TILES_DIR, TILES_DIR, LABELS, RED, BLACK, N1, N2, N_validation,
                    CODES
                   )
from loss_custom import WeakCrossEntropy
from metrics_custom import acc_satellite, acc_weakly

In [None]:
# https://docs.fast.ai/dev/test.html#getting-reproducible-results

seed = 42

# python RNG
import random
random.seed(seed)

# pytorch RNGs
import torch
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)

# numpy RNG
import numpy as np
np.random.seed(seed)

In [None]:
BASE_DIR = Path('').absolute().parent; BASE_DIR

# Fully-supervised (FS) Semantic Segmentation

In [None]:
gt_tiles_dir = GT_ADJ_TILES_DIR

In [None]:
image_tiles_fnames = os.listdir(BASE_DIR / IMAGE_DATA_TILES_DIR)

In [None]:
fname = image_tiles_fnames[0]

In [None]:
open_image( BASE_DIR / IMAGE_DATA_TILES_DIR / fname)

In [None]:
open_image(BASE_DIR / gt_tiles_dir / fname)

In [None]:
mask = open_mask(BASE_DIR / gt_tiles_dir / fname)
# mask.show(figsize=(5,5), alpha=1)
mask

In [None]:
src_size = np.array(mask.shape[1:])
src_size,mask.data

### Load data

In [None]:
def get_y_fn(x):
    return BASE_DIR / gt_tiles_dir / x.name

In [None]:
free = gpu_mem_get_free_no_cache(); free

In [None]:
base_path = "top_mosaic_09cm_area"
prog = re.compile(fr"{base_path}(?P<area_id>\d+)_tile(?P<tile_id>\d+).tif")

def is_in_set(x, N):
    fname = x.name  # e.g.: top_mosaic_09cm_area30_tile120.tif'

    match_result = prog.search(fname)
    area_id = match_result.group('area_id')
    tile_id = match_result.group('tile_id')
    image_fname = f"{base_path}{area_id}.tif"  # e.g.: top_mosaic_09cm_area30.tif'
    return image_fname in N

is_in_set_n1 = partial(is_in_set, N=N1)
is_in_set_n2 = partial(is_in_set, N=N2)
is_in_set_nvalidation = partial(is_in_set, N=N_validation)
is_in_set_n1_or_nvalidation = partial(is_in_set, N=N1+N_validation)

codes = LABELS+[RED, BLACK]

src_size = np.array(mask.shape[1:])
src_size,mask.data
size = src_size // 2  # TODO

fs_item_list = (SegmentationItemList.from_folder(BASE_DIR / IMAGE_DATA_TILES_DIR)  #returns SegmentationItemList
             .filter_by_func(is_in_set_n1_or_nvalidation)  #returns SegmentationItemList
             .split_by_valid_func(is_in_set_nvalidation)  #returns ItemLists(SegmentationItemList, SegmentationItemList)
             .label_from_func(get_y_fn, classes=codes)  #returns LabelLists(LabelList, SegmentationItemList)
             .transform(get_transforms(), size=size, tfm_y=True)
            )

In [None]:
bs = 64
fs_data = fs_item_list.databunch(bs=bs).normalize(imagenet_stats)

## Model

Train and compare semantic segmentation networks, using the following data: Task (i) N1 pixel level labels

In [None]:
wd=0.1

In [None]:
# lr_find(learn)
# learn.recorder.plot()

In [None]:
fs_learn = unet_learner(fs_data, models.resnet18, metrics=acc_satellite, wd=wd)

In [None]:
lr=3e-4

In [None]:
fs_learn.fit_one_cycle(20, slice(lr), pct_start=0.9)

In [None]:
fs_learn.save('mixed-stage-1')

In [None]:
fs_learn.load('mixed-stage-1');

In [None]:
fs_learn.recorder.plot_losses()
fs_learn.recorder.plot_metrics()

# Weakly-supervised (WS) Semantic Segmentation

In [None]:
image_tiles_fnames = os.listdir(BASE_DIR / TILES_DIR)

In [None]:
fname = image_tiles_fnames[0]; fname

In [None]:
img = open_image(BASE_DIR / TILES_DIR / fname)
img

In [None]:
src_size = np.array(img.shape[1:]); src_size

### Load data

In [None]:
base_path = "top_mosaic_09cm_area"
prog_label_vector = re.compile(base_path + r"\d+_tile\d+_(?P<label_vector>\d{5}).tif")

def get_y_fn(x):
    fname = x.name
    match_result = prog_label_vector.search(fname)
    label_vector = match_result.group('label_vector')
    label_vector_arr = torch.tensor(list(map(int,label_vector))) # NEW

    indexes = torch.where(label_vector_arr == 1)[0]
    colors = [LABELS[idx] for idx in indexes]
    
    assert 0<len(indexes)<6, (len(indexes), x)
    
    return colors

def has_a_valid_color(x):
    fname = x.name
    match_result = prog_label_vector.search(fname)
    label_vector = match_result.group('label_vector')
    label_vector_arr = torch.tensor(list(map(int,label_vector))) # NEW

    indexes = torch.where(label_vector_arr == 1)[0]
    if not (0<len(indexes)<6):
        print("not valid color", len(indexes), x)
        return False
    return True


In [None]:
# # # DEBUG
# # # Example: top_mosaic_09cm_area27_tile154_11100.tif
# fpath = BASE_DIR / TILES_DIR / fname
# result = get_y_fn(fpath)
# type(result), result

In [None]:
free = gpu_mem_get_free_no_cache(); free

In [None]:
base_path = "top_mosaic_09cm_area"
prog_with_label_vector = re.compile(base_path + r"(?P<area_id>\d+)_tile(?P<tile_id>\d+)_(?P<label_vector>\d{5}).tif")

def is_in_set(x, N):
    fname = x.name  # e.g.: top_mosaic_09cm_area30_tile120.tif'

    match_result = prog_with_label_vector.search(fname)
    area_id = match_result.group('area_id')
    tile_id = match_result.group('tile_id')
    image_fname = f"{base_path}{area_id}.tif"  # e.g.: top_mosaic_09cm_area30.tif'
    return image_fname in N

is_in_set_n1 = partial(is_in_set, N=N1)
is_in_set_n2 = partial(is_in_set, N=N2)
is_in_set_nvalidation = partial(is_in_set, N=N_validation)
is_in_set_n1_or_nvalidation = partial(is_in_set, N=N1+N_validation)
is_in_set_n2_or_nvalidation = partial(is_in_set, N=N2+N_validation)

src_size = np.array(img.shape[1:])
src_size,img.data
size = src_size // 2  # TODO

# codes = CODES
# codes = LABELS+[RED, BLACK]

ws_item_list = (ImageList.from_folder(BASE_DIR / TILES_DIR)  #returns ImageList
             .filter_by_func(is_in_set_n2_or_nvalidation)  #returns ImageList
             .filter_by_func(has_a_valid_color)            #returns ImageList
             .split_by_valid_func(is_in_set_nvalidation)  #returns ItemLists(ImageList, ImageList)
             .label_from_func(get_y_fn, classes=LABELS)  #returns LabelLists(ImageList, MultiCategoryList)
             .transform(get_transforms(), size=size)
            )

In [None]:
bs = 64
ws_data = ws_item_list.databunch(bs=bs).normalize(imagenet_stats)
ws_data.c = 5

## Model

Train and compare semantic segmentation networks, using the following data: Task (i) N1 pixel level labels

In [None]:
# wd=0.1 # TODO
wd=0.01

In [None]:
ws_learn = unet_learner(ws_data, 
                     models.resnet18, 
                     loss_func=WeakCrossEntropy(CODES, axis=1),
                     metrics=acc_weakly, 
                     wd=wd,
                    )

In [None]:
lr_find(ws_learn)
ws_learn.recorder.plot()

In [None]:
lr=1e-4

In [None]:
ws_learn.fit_one_cycle(20, slice(lr), pct_start=0.9)

In [None]:
ws_learn.recorder.plot_losses(skip_start=50, show_grid=True)
ws_learn.recorder.plot_metrics(skip_start=50, show_grid=True)

In [None]:
ws_learn.save('mixed-stage-2')

In [None]:
ws_learn.load('mixed-stage-2');