# Weakly-supervised Semantic Segmentation

In [None]:
import sys
import os
sys.path.append("../") 

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
from pathlib import Path
import re

from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *

In [None]:
from config import (IMAGE_DATA_DIR, TILES_DIR, 
                    LABELS, RED, BLACK, N1, N2, N_validation, CODES, 
                    BASE_DIR, MODEL_DIR
                   )
from loss_custom import WeakCrossEntropy
from metrics_custom import acc_weakly, acc_satellite
from parameters import IMG_SIZE_RATIO
from util import set_seed, is_in_set_n1_or_nvalidation, is_in_set_nvalidation, REGEX_IMG_FILE_NAME_WITH_LABEL_VECTOR, is_in_set_n2_or_nvalidation, get_y_colors, has_a_valid_color, show_prediction_vs_actual
set_seed(seed=42)
free = gpu_mem_get_free_no_cache(); free

In [None]:
image_tiles_fnames = os.listdir(BASE_DIR / TILES_DIR)

In [None]:
fname = image_tiles_fnames[0]; fname

In [None]:
img = open_image(BASE_DIR / TILES_DIR / fname)
img

### Load data

In [None]:
codes = LABELS
# codes = LABELS+[RED, BLACK]

src_size = np.array(img.shape[1:])
size = (src_size * IMG_SIZE_RATIO).astype(int)

item_list = (ImageList.from_folder(BASE_DIR / TILES_DIR)
             .filter_by_func(partial(is_in_set_n2_or_nvalidation, regex_obj=REGEX_IMG_FILE_NAME_WITH_LABEL_VECTOR))
             .filter_by_func(has_a_valid_color)
             .split_by_valid_func(partial(is_in_set_nvalidation, regex_obj=REGEX_IMG_FILE_NAME_WITH_LABEL_VECTOR))
             .label_from_func(get_y_colors, classes=codes)
             .transform(get_transforms(), size=size)
            )

In [None]:
item_list

In [None]:
# DEBUG
item_list.train.y[0]

In [None]:
# item_list.train.items[0]
item_list.train.c

In [None]:
bs = 64
data = item_list.databunch(bs=bs).normalize(imagenet_stats)
data.c = 5

In [None]:
data.classes[:10], len(data.classes)

In [None]:
data

In [None]:
data.show_batch(2, figsize=(10,7))

In [None]:
data.show_batch(2, figsize=(10,7), ds_type=DatasetType.Valid)

In [None]:
# item_list

# all tiles: 4497
# 935 / 4497 = 20%
# 390 / 4497 = 8.6%

# This seems to be the desired split

## Model

Train and compare semantic segmentation networks, using the following data: Task (i) N1 pixel level labels

In [None]:
# wd=0.1 # TODO
wd=0.01

In [None]:
learn = unet_learner(data, 
                     models.resnet18, 
                     loss_func=WeakCrossEntropy(CODES, axis=1),
                     metrics=[acc_weakly], 
                     wd=wd,
                     model_dir=MODEL_DIR
                    )

In [None]:
lr=1e-3

In [None]:
learn.fit_one_cycle(1, slice(lr), pct_start=0.9)

In [None]:
learn.recorder.plot_losses(skip_start=50, show_grid=True)
learn.recorder.plot_metrics(skip_start=50, show_grid=True)

In [None]:
learn.save('ws-stage-1')

In [None]:
learn.load('ws-stage-1');

## Show Results

In [None]:
# Get all predictions
pred, y = learn.get_preds(ds_type=DatasetType.Valid)

In [None]:
show_prediction_vs_actual(0, pred, y)

In [None]:
show_prediction_vs_actual(1, pred, y)

In [None]:
show_prediction_vs_actual(2, pred, y)

## Calculate acc_satellite

In [None]:
# load N_validation pixel-level labels


# predict for 

In [None]:
from config import GT_ADJ_TILES_DIR, IMAGE_DATA_TILES_DIR
gt_tiles_dir = GT_ADJ_TILES_DIR

def get_y_fn(x):
    return BASE_DIR / gt_tiles_dir / x.name

In [None]:
image_tiles_fnames = os.listdir(BASE_DIR / IMAGE_DATA_TILES_DIR)
fname = image_tiles_fnames[0]
mask = open_mask(BASE_DIR / gt_tiles_dir / fname)
src_size = np.array(mask.shape[1:])

In [None]:
base_path = "top_mosaic_09cm_area"
prog = re.compile(fr"{base_path}(?P<area_id>\d+)_tile(?P<tile_id>\d+).tif")

def is_in_set(x, N):
    fname = x.name  # e.g.: top_mosaic_09cm_area30_tile120.tif'

    match_result = prog.search(fname)
    area_id = match_result.group('area_id')
    tile_id = match_result.group('tile_id')
    image_fname = f"{base_path}{area_id}.tif"  # e.g.: top_mosaic_09cm_area30.tif'
    return image_fname in N

is_in_set_n1 = partial(is_in_set, N=N1)
is_in_set_n2 = partial(is_in_set, N=N2)
is_in_set_nvalidation = partial(is_in_set, N=N_validation)
is_in_set_n1_or_nvalidation = partial(is_in_set, N=N1+N_validation)

codes = LABELS+[RED, BLACK]

src_size = np.array(mask.shape[1:])
src_size,mask.data
# size = src_size // 2  # TODO
size = src_size

fs_item_list = (SegmentationItemList.from_folder(BASE_DIR / IMAGE_DATA_TILES_DIR)  #returns SegmentationItemList
             .filter_by_func(is_in_set_n1_or_nvalidation)  #returns SegmentationItemList
             .split_by_valid_func(is_in_set_nvalidation)  #returns ItemLists(SegmentationItemList, SegmentationItemList)
             .label_from_func(get_y_fn, classes=codes)  #returns LabelLists(LabelList, SegmentationItemList)
             .transform(get_transforms(), size=size, tfm_y=True)
            )

In [None]:
bs = 64
fs_data = fs_item_list.databunch(bs=bs).normalize(imagenet_stats)

wd=0.1

In [None]:
fs_learn = unet_learner(fs_data, 
                        models.resnet18, 
                        metrics=acc_satellite, 
                        wd=wd, 
                        model_dir='/home/jupyter/weakly-supervised-semseg/models'
                       )

In [None]:
fs_learn.load('ws-stage-1');