# Weakly-supervised Semantic Segmentation

In [None]:
import sys
import os
sys.path.append("../") 

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *

In [None]:
from constants import (IMAGE_DATA_DIR, TILES_DIR, 
                       ALL_CLASSES, N1, N2, N_validation, 
                       BASE_DIR, MODEL_DIR
                      )
from loss_custom import WeakCrossEntropy
from metrics_custom import acc_weakly, acc_satellite
from parameters import IMG_SIZE_RATIO, BATCH_SIZE, WEIGHT_DECAY, LEARNING_RATE_WS, BACKBONE
from util import (set_seed, is_in_set_n1_or_nvalidation, is_in_set_nvalidation, 
                  REGEX_IMG_FILE_NAME_WITH_LABEL_VECTOR, is_in_set_n2_or_nvalidation, 
                  get_y_colors, has_a_valid_color, show_prediction_vs_actual
                 )
set_seed(seed=42)
free = gpu_mem_get_free_no_cache(); free

In [None]:
image_tiles_fnames = os.listdir(BASE_DIR / TILES_DIR)
fname = image_tiles_fnames[0]; fname

In [None]:
img = open_image(BASE_DIR / TILES_DIR / fname)
img

### Load data

In [None]:
src_size = np.array(img.shape[1:])
size = (src_size * IMG_SIZE_RATIO).astype(int)

item_list = (ImageList.from_folder(BASE_DIR / TILES_DIR)
             .filter_by_func(partial(is_in_set_n2_or_nvalidation, regex_obj=REGEX_IMG_FILE_NAME_WITH_LABEL_VECTOR))
             .filter_by_func(has_a_valid_color)
             .split_by_valid_func(partial(is_in_set_nvalidation, regex_obj=REGEX_IMG_FILE_NAME_WITH_LABEL_VECTOR))
             .label_from_func(get_y_colors, classes=ALL_CLASSES)
             .transform(get_transforms(), size=size)
            )

In [None]:
item_list

In [None]:
# DEBUG
item_list.train.y[1]

In [None]:
# item_list.train.items[0]
item_list.train.c

In [None]:
data = item_list.databunch(bs=BATCH_SIZE).normalize(imagenet_stats)

In [None]:
data.classes[:10], len(data.classes)

In [None]:
data

In [None]:
data.show_batch(2, figsize=(10,7))

In [None]:
data.show_batch(2, figsize=(10,7), ds_type=DatasetType.Valid)

## Model

Train and compare semantic segmentation networks, using the following data: Task (ii) N2 pixel level labels

In [None]:
learn = unet_learner(data, 
                     BACKBONE, 
                     loss_func=WeakCrossEntropy(),
                     metrics=[acc_weakly], 
                     wd=WEIGHT_DECAY,
                     model_dir=MODEL_DIR
                    )

In [None]:
learn.fit_one_cycle(20, slice(LEARNING_RATE_WS), pct_start=0.9)

In [None]:
learn.recorder.plot_losses(skip_start=50, show_grid=True)
learn.recorder.plot_metrics(skip_start=50, show_grid=True)

In [None]:
learn.save('ws-stage-1')

In [None]:
learn.load('ws-stage-1');

## Show Results

Here we show the original image and the predicted mask. (The correct mask is not part of the N2 dataset and was therefore not loaded and is not shown here).

In [None]:
show_prediction_vs_actual(0, learn)

In [None]:
show_prediction_vs_actual(1, learn)

In [None]:
show_prediction_vs_actual(2, learn)

In [None]:
show_prediction_vs_actual(3, learn)

In [None]:
show_prediction_vs_actual(4, learn)

## Calculate acc_satellite

In [None]:
from constants import (IMAGE_DATA_DIR, GT_DIR, IMAGE_DATA_TILES_DIR, GT_TILES_DIR)
from util import REGEX_IMG_FILE_NAME, get_y_fn

In [None]:
fs_item_list = (SegmentationItemList.from_folder(BASE_DIR / IMAGE_DATA_TILES_DIR)
                .filter_by_func(partial(is_in_set_n1_or_nvalidation, regex_obj=REGEX_IMG_FILE_NAME))
                .split_by_valid_func(partial(is_in_set_nvalidation, regex_obj=REGEX_IMG_FILE_NAME))
                .label_from_func(get_y_fn, classes=ALL_CLASSES)
                .transform(get_transforms(), size=size, tfm_y=True)
               )
fs_data = fs_item_list.databunch(bs=BATCH_SIZE).normalize(imagenet_stats)
fs_learn = unet_learner(fs_data, 
                        BACKBONE, 
                        metrics=acc_satellite, 
                        wd=WEIGHT_DECAY, 
                        model_dir=MODEL_DIR
                       )
fs_learn.load('ws-stage-1');

In [None]:
predictions, labels = fs_learn.get_preds(); (predictions.shape, predictions.shape)

In [None]:
acc_satellite(predictions, labels)