# Weakly-supervised Semantic Segmentation

In [None]:
import sys
import os
sys.path.append("../") 

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import os
from pathlib import Path
import re

from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *

In [None]:
from config import (IMAGE_DATA_DIR, TILES_DIR, 
                    LABELS, RED, BLACK, N1, N2, N_validation, CODES
                   )
from loss_custom import WeakCrossEntropy
from metrics_custom import acc_weakly, acc_satellite
from util import set_seed
set_seed(seed=42)

In [None]:
BASE_DIR = Path('').absolute().parent; BASE_DIR

In [None]:
image_tiles_fnames = os.listdir(BASE_DIR / TILES_DIR)

In [None]:
fname = image_tiles_fnames[0]; fname

In [None]:
img = open_image(BASE_DIR / TILES_DIR / fname)
img

In [None]:
src_size = np.array(img.shape[1:]); src_size

### Load data

In [None]:
base_path = "top_mosaic_09cm_area"
prog_label_vector = re.compile(base_path + r"\d+_tile\d+_(?P<label_vector>\d{5}).tif")

def get_y_fn(x):
    fname = x.name
    match_result = prog_label_vector.search(fname)
    label_vector = match_result.group('label_vector')
    label_vector_arr = torch.tensor(list(map(int,label_vector))) # NEW

    indexes = torch.where(label_vector_arr == 1)[0]
    colors = [LABELS[idx] for idx in indexes]
    
    assert 0<len(indexes)<6, (len(indexes), x)
    
    return colors

def has_a_valid_color(x):
    fname = x.name
    match_result = prog_label_vector.search(fname)
    label_vector = match_result.group('label_vector')
    label_vector_arr = torch.tensor(list(map(int,label_vector))) # NEW

    indexes = torch.where(label_vector_arr == 1)[0]
    if not (0<len(indexes)<6):
#         print("not valid color", len(indexes), x)
        return False
    return True


In [None]:
# # # DEBUG
# # # Example: top_mosaic_09cm_area27_tile154_11100.tif
# fpath = BASE_DIR / TILES_DIR / fname
# result = get_y_fn(fpath)
# type(result), result

In [None]:
free = gpu_mem_get_free_no_cache(); free

In [None]:
base_path = "top_mosaic_09cm_area"
prog_with_label_vector = re.compile(base_path + r"(?P<area_id>\d+)_tile(?P<tile_id>\d+)_(?P<label_vector>\d{5}).tif")

def is_in_set(x, N):
    fname = x.name  # e.g.: top_mosaic_09cm_area30_tile120.tif'

    match_result = prog_with_label_vector.search(fname)
    area_id = match_result.group('area_id')
    tile_id = match_result.group('tile_id')
    image_fname = f"{base_path}{area_id}.tif"  # e.g.: top_mosaic_09cm_area30.tif'
    return image_fname in N

is_in_set_n1 = partial(is_in_set, N=N1)
is_in_set_n2 = partial(is_in_set, N=N2)
is_in_set_nvalidation = partial(is_in_set, N=N_validation)
is_in_set_n1_or_nvalidation = partial(is_in_set, N=N1+N_validation)
is_in_set_n2_or_nvalidation = partial(is_in_set, N=N2+N_validation)

src_size = np.array(img.shape[1:])
src_size,img.data
size = src_size // 2  # TODO

codes = LABELS
# codes = LABELS+[RED, BLACK]

item_list = (ImageList.from_folder(BASE_DIR / TILES_DIR)  #returns ImageList
             .filter_by_func(is_in_set_n2_or_nvalidation)  #returns ImageList
             .filter_by_func(has_a_valid_color)            #returns ImageList
             .split_by_valid_func(is_in_set_nvalidation)  #returns ItemLists(ImageList, ImageList)
             .label_from_func(get_y_fn, classes=codes)  #returns LabelLists(ImageList, MultiCategoryList)
             .transform(get_transforms(), size=size)
            )

In [None]:
item_list

In [None]:
# DEBUG
item_list.train.y[0]

In [None]:
# item_list.train.items[0]
item_list.train.c

In [None]:
bs = 64
data = item_list.databunch(bs=bs).normalize(imagenet_stats)
data.c = 5

In [None]:
data.classes[:10], len(data.classes)

In [None]:
data

In [None]:
data.show_batch(2, figsize=(10,7))

In [None]:
data.show_batch(2, figsize=(10,7), ds_type=DatasetType.Valid)

In [None]:
# item_list

# all tiles: 4497
# 935 / 4497 = 20%
# 390 / 4497 = 8.6%

# This seems to be the desired split

## Model

Train and compare semantic segmentation networks, using the following data: Task (i) N1 pixel level labels

In [None]:
# wd=0.1 # TODO
wd=0.01

In [None]:
learn = unet_learner(data, 
                     models.resnet18, 
                     loss_func=WeakCrossEntropy(CODES, axis=1),
                     metrics=[acc_weakly], 
                     wd=wd,
                     model_dir='/home/jupyter/weakly-supervised-semseg/models'
                    )

In [None]:
# lr_find(learn)
# learn.recorder.plot()

In [None]:
lr=1e-3

In [None]:
learn.fit_one_cycle(1, slice(lr), pct_start=0.9)

In [None]:
learn.recorder.plot_losses(skip_start=50, show_grid=True)
learn.recorder.plot_metrics(skip_start=50, show_grid=True)

In [None]:
learn.save('ws-stage-1')

In [None]:
learn.load('ws-stage-1');

## Show Results

In [None]:
# learn.show_results(rows=3, figsize=(8,9)) # throws RuntimeError: bool value of Tensor with more than one value is ambiguous

In [None]:
# item = data.valid_ds.x[0]
# pred = learn.predict(item) # throws RuntimeError: bool value of Tensor with more than one value is ambiguous

In [None]:
pred, y = learn.get_preds(ds_type=DatasetType.Valid)  
pred.shape # shape (935, 5, 100, 100)

In [None]:
sample_idx = 0
sample_pred = pred[sample_idx].argmax(dim=0, keepdim=True)  # shape (1, 100, 100)
pred_image_segment = ImageSegment(sample_pred)
print(y[sample_idx])
pred_image_segment

In [None]:
print(data.valid_ds.y[sample_idx].__repr__())
data.valid_ds.x[sample_idx]

In [None]:
sample_idx = 1
sample_pred = pred[sample_idx].argmax(dim=0, keepdim=True)  # shape (1, 100, 100)
pred_image_segment = ImageSegment(sample_pred)
print(y[sample_idx])
pred_image_segment

In [None]:
print(data.valid_ds.y[sample_idx].__repr__())
data.valid_ds.x[sample_idx]

In [None]:
sample_idx = 2
sample_pred = pred[sample_idx].argmax(dim=0, keepdim=True)  # shape (1, 100, 100)
pred_image_segment = ImageSegment(sample_pred)
print(y[sample_idx])
pred_image_segment

In [None]:
print(data.valid_ds.y[sample_idx].__repr__())
data.valid_ds.x[sample_idx]

## Calculate acc_satellite

In [None]:
# load N_validation pixel-level labels


# predict for 

In [None]:
from config import GT_ADJ_TILES_DIR, IMAGE_DATA_TILES_DIR
gt_tiles_dir = GT_ADJ_TILES_DIR

def get_y_fn(x):
    return BASE_DIR / gt_tiles_dir / x.name

In [None]:
image_tiles_fnames = os.listdir(BASE_DIR / IMAGE_DATA_TILES_DIR)
fname = image_tiles_fnames[0]
mask = open_mask(BASE_DIR / gt_tiles_dir / fname)
src_size = np.array(mask.shape[1:])

In [None]:
base_path = "top_mosaic_09cm_area"
prog = re.compile(fr"{base_path}(?P<area_id>\d+)_tile(?P<tile_id>\d+).tif")

def is_in_set(x, N):
    fname = x.name  # e.g.: top_mosaic_09cm_area30_tile120.tif'

    match_result = prog.search(fname)
    area_id = match_result.group('area_id')
    tile_id = match_result.group('tile_id')
    image_fname = f"{base_path}{area_id}.tif"  # e.g.: top_mosaic_09cm_area30.tif'
    return image_fname in N

is_in_set_n1 = partial(is_in_set, N=N1)
is_in_set_n2 = partial(is_in_set, N=N2)
is_in_set_nvalidation = partial(is_in_set, N=N_validation)
is_in_set_n1_or_nvalidation = partial(is_in_set, N=N1+N_validation)

codes = LABELS+[RED, BLACK]

src_size = np.array(mask.shape[1:])
src_size,mask.data
# size = src_size // 2  # TODO
size = src_size

fs_item_list = (SegmentationItemList.from_folder(BASE_DIR / IMAGE_DATA_TILES_DIR)  #returns SegmentationItemList
             .filter_by_func(is_in_set_n1_or_nvalidation)  #returns SegmentationItemList
             .split_by_valid_func(is_in_set_nvalidation)  #returns ItemLists(SegmentationItemList, SegmentationItemList)
             .label_from_func(get_y_fn, classes=codes)  #returns LabelLists(LabelList, SegmentationItemList)
             .transform(get_transforms(), size=size, tfm_y=True)
            )

In [None]:
bs = 64
fs_data = fs_item_list.databunch(bs=bs).normalize(imagenet_stats)

wd=0.1

In [None]:
fs_learn = unet_learner(fs_data, 
                        models.resnet18, 
                        metrics=acc_satellite, 
                        wd=wd, 
                        model_dir='/home/jupyter/weakly-supervised-semseg/models'
                       )

In [None]:
fs_learn.load('ws-stage-1');