In [None]:
import sys
import os
sys.path.append("../") 

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
from fastai.vision import *
from fastai.callbacks.hooks import *
from fastai.utils.mem import *

In [None]:
from config import (IMAGE_DATA_DIR, GT_DIR, IMAGE_DATA_TILES_DIR, GT_TILES_DIR, 
                    GT_ADJ_TILES_DIR, TILES_DIR, 
                    LABELS, RED, BLACK, N1, N2, N_validation,
                    CODES, 
                    BASE_DIR, MODEL_DIR
                   )
from loss_custom import WeakCrossEntropy
from metrics_custom import acc_satellite, acc_weakly
from parameters import IMG_SIZE_RATIO, BATCH_SIZE, WEIGHT_DECAY, LEARNING_RATE_FS, LEARNING_RATE_WS
from util import (set_seed, IMG_FILE_PREFIX, is_in_set_n1_or_nvalidation, is_in_set_nvalidation, 
                  REGEX_IMG_FILE_NAME, is_in_set_n1_or_nvalidation, is_in_set_nvalidation, 
                  REGEX_IMG_FILE_NAME_WITH_LABEL_VECTOR, is_in_set_n2_or_nvalidation, 
                  get_y_colors, has_a_valid_color, get_y_fn
                 )
set_seed(seed=42)
free = gpu_mem_get_free_no_cache(); free

# Fully-supervised (FS) Semantic Segmentation

In [None]:
image_tiles_fnames = os.listdir(BASE_DIR / IMAGE_DATA_TILES_DIR)
fname = image_tiles_fnames[0]; fname

In [None]:
open_image( BASE_DIR / IMAGE_DATA_TILES_DIR / fname)

In [None]:
open_image(BASE_DIR / GT_ADJ_TILES_DIR / fname)

In [None]:
mask = open_mask(BASE_DIR / GT_ADJ_TILES_DIR / fname)
# mask.show(figsize=(5,5), alpha=1)
mask

### Load data

In [None]:
codes = LABELS+[RED, BLACK]

src_size = np.array(mask.shape[1:])
size = (src_size * IMG_SIZE_RATIO).astype(int)

fs_item_list = (SegmentationItemList.from_folder(BASE_DIR / IMAGE_DATA_TILES_DIR)
             .filter_by_func(partial(is_in_set_n1_or_nvalidation, regex_obj=REGEX_IMG_FILE_NAME))
             .split_by_valid_func(partial(is_in_set_nvalidation, regex_obj=REGEX_IMG_FILE_NAME))
             .label_from_func(get_y_fn, classes=codes)
             .transform(get_transforms(), size=size, tfm_y=True)
            )

In [None]:
fs_data = fs_item_list.databunch(bs=BATCH_SIZE).normalize(imagenet_stats)

## Model

Train and compare semantic segmentation networks, using the following data: Task (i) N1 pixel level labels

In [None]:
fs_learn = unet_learner(fs_data, 
                        models.resnet18, 
                        metrics=acc_satellite, 
                        wd=WEIGHT_DECAY, 
                        model_dir=model_dir=MODEL_DIR
                       )

In [None]:
# Uncomment for 1st time use

fs_learn.fit_one_cycle(30, slice(LEARNING_RATE_FS), pct_start=0.9)

In [None]:
# Uncomment for 1st time use

fs_learn.save('mixed-stage-1')

# Weakly-supervised (WS) Semantic Segmentation

In [None]:
image_tiles_fnames = os.listdir(BASE_DIR / TILES_DIR)

In [None]:
fname = image_tiles_fnames[0]; fname

In [None]:
img = open_image(BASE_DIR / TILES_DIR / fname)
img

In [None]:
src_size = np.array(img.shape[1:]); src_size

### Load data

In [None]:
base_path = "top_mosaic_09cm_area"
prog_label_vector = re.compile(base_path + r"\d+_tile\d+_(?P<label_vector>\d{5}).tif")

In [None]:
free = gpu_mem_get_free_no_cache(); free

In [None]:
src_size = np.array(img.shape[1:])

size = (src_size * IMG_SIZE_RATIO).astype(int)

codes = LABELS+[RED, BLACK]

ws_item_list = (ImageList.from_folder(BASE_DIR / TILES_DIR)  #returns ImageList
             .filter_by_func(partial(is_in_set_n2_or_nvalidation, regex_obj=REGEX_IMG_FILE_NAME_WITH_LABEL_VECTOR))  #returns ImageList
             .filter_by_func(has_a_valid_color)            #returns ImageList
             .split_by_valid_func(partial(is_in_set_nvalidation, regex_obj=REGEX_IMG_FILE_NAME_WITH_LABEL_VECTOR))  #returns ItemLists(ImageList, ImageList)
             .label_from_func(get_y_colors, classes=codes)  #returns LabelLists(ImageList, MultiCategoryList)
             .transform(get_transforms(), size=size)
            )

In [None]:
ws_data = ws_item_list.databunch(bs=BATCH_SIZE).normalize(imagenet_stats)

## Model

Train and compare semantic segmentation networks, using the following data: Task (i) N1 pixel level labels

In [None]:
# https://forums.fast.ai/t/transfer-learning-twice/43699/5
ws_learn = unet_learner(ws_data, 
                        models.resnet18, 
                        loss_func=WeakCrossEntropy(CODES, axis=1),
                        metrics=acc_weakly, 
                        wd=WEIGHT_DECAY,
                        model_dir=MODEL_DIR,
                       )
ws_learn.load('mixed-stage-1');

In [None]:
ws_learn.fit_one_cycle(20, slice(LEARNING_RATE_WS/10), pct_start=0.9)

In [None]:
ws_learn.recorder.plot_losses(skip_start=50, show_grid=True)
ws_learn.recorder.plot_metrics(skip_start=50, show_grid=True)

In [None]:
ws_learn.save('mixed-stage-2')

In [None]:
# ws_learn.load('mixed-stage-2');

In [None]:
ws_learn.load('mixed-stage-2');
loss, accuracy_weakly = ws_learn.validate(dl=ws_data.valid_dl, callbacks=None, metrics=[acc_weakly])
accuracy_weakly

In [None]:
fs_learn.load('mixed-stage-1');
loss, accuracy_satellite = fs_learn.validate(dl=fs_data.valid_dl, callbacks=None, metrics=[acc_satellite])
accuracy_satellite

In [None]:
fs_learn.load('mixed-stage-2');
loss, accuracy_satellite = fs_learn.validate(dl=fs_data.valid_dl, callbacks=None, metrics=[acc_satellite])
accuracy_satellite

In [None]:
fs_learn.fit_one_cycle(20, slice(LEARNING_RATE_FS), pct_start=0.9)

In [None]:
fs_learn.save('mixed-stage-3');

In [None]:
fs_learn.recorder.plot_losses(skip_start=50, show_grid=True)
fs_learn.recorder.plot_metrics(skip_start=50, show_grid=True)

In [None]:
ws_learn.load('mixed-stage-3');
loss, accuracy_weakly = ws_learn.validate(dl=ws_data.valid_dl, callbacks=None, metrics=[acc_weakly])
accuracy_weakly

In [None]:
ws_lr_slice = slice(LEARNING_RATE_WS / 10)
fs_lr_slice = slice(LEARNING_RATE_FS)

In [None]:
should_train_supervised = False  # alternating flag
last_save_location = 'mixed-stage-loop'

ws_learn.load('mixed-stage-1');
ws_learn.save(last_save_location);

for i in range(10): # even s.t. FS is trained last
    print(i, "should_train_supervised: ", should_train_supervised)
    current_learn = fs_learn if should_train_supervised else ws_learn
    current_lr = fs_lr_slice if should_train_supervised else ws_lr_slice
        
    current_learn.load(last_save_location);
    current_learn.fit_one_cycle(1, current_lr, pct_start=0.9)
    current_learn.save(last_save_location);
    
    should_train_supervised = not should_train_supervised

In [None]:
fs_learn.load(last_save_location);
loss, accuracy_satellite = fs_learn.validate(dl=fs_data.valid_dl, callbacks=None, metrics=[acc_satellite])
accuracy_satellite

In [None]:
fs_learn.fit_one_cycle(20, LEARNING_RATE_FS, pct_start=0.9)

In [None]:
fs_learn.recorder.plot_losses(show_grid=True)
fs_learn.recorder.plot_metrics(show_grid=True)