In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from fastai.vision import *
from fastai.metrics import dice
from fastai.vision.interpret import SegmentationInterpretation
from torch import cuda as cd
import torch
from torch.nn import functional as F
from functools import partial

# TODO: cleanup accuracy/metrics
def acc_camvid(inputs, target):
    # DONE: write a generic accuracy function (not camvid)
    # type = torch.Tensor
    # shape = [B, C, H, W]
    # B=num_batches,C=num_classes,H=height,W=width
    target = target.squeeze(1)
    mask = target != void_code
    return (inputs.argmax(dim=1)[mask]==target[mask]).float().mean()

# Return Jaccard index, or Intersection over Union (IoU) value
def jaccard_loss(input:Tensor, targs:Tensor, eps:float=1e-8)->Rank0Tensor:
    """Computes the Jaccard loss, a.k.a the IoU loss.
    Note that PyTorch optimizers minimize a loss. In this
    case, we would like to maximize the jaccard loss so we
    return the negated jaccard loss.
    Args:
        targs: a tensor of shape [B, H, W] or [B, 1, H, W].
        input: a tensor of shape [B, C, H, W]. Corresponds to
            the raw output or logits of the model. (prediction)
        eps: added to the denominator for numerical stability.
    Returns:
        jacc_loss: the Jaccard loss.
    """
    num_classes = input.shape[1]
    if num_classes == 1:
        true_1_hot = torch.eye(num_classes + 1)[targs.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        true_1_hot_f = true_1_hot[:, 0:1, :, :]
        true_1_hot_s = true_1_hot[:, 1:2, :, :]
        true_1_hot = torch.cat([true_1_hot_s, true_1_hot_f], dim=1)
        pos_prob = torch.sigmoid(input)
        neg_prob = 1 - pos_prob
        probas = torch.cat([pos_prob, neg_prob], dim=1)
    else:
        true_1_hot = torch.eye(num_classes)[targs.squeeze(1)]
        true_1_hot = true_1_hot.permute(0, 3, 1, 2).float()
        probas = F.softmax(input, dim=1)
    true_1_hot = true_1_hot.type(input.type())
    dims = (0,) + tuple(range(2, targs.ndimension()))
    intersection = torch.sum(probas * true_1_hot, dims)
    cardinality = torch.sum(probas + true_1_hot, dims)
    union = cardinality - intersection
    jacc_loss = (intersection / (union + eps)).mean()
    return jacc_loss

# clear GPU cache
cd.empty_cache()

# get training data
path = untar_data(URLs.CAMVID)
codes = np.loadtxt(path/'codes.txt',dtype=str)
# write codes.txt by hand for desired classes

name2id = {v:k for k,v in enumerate(codes)}
void_code = name2id['Void']
# use 'Void' class for background

path_lbl = path/'labels'
path_img = path/'images'

fnames = get_image_files(path_img)
lbl_names = get_image_files(path_lbl)

# show example image
img_f = fnames[10]
img = open_image(img_f)
img.show(figsize=(5,5))

# function to get label path from image filename
get_y_fn = lambda x: path_lbl/f'{x.stem}_P{x.suffix}'

# show image mask
mask = open_mask(get_y_fn(img_f))
mask.show(figsize=(5,5), alpha=1)

# Training parameters
src_size = np.array(mask.shape[1:])
# torch.Size([1,720,960]) [C,H,W]
# size = src_size // 2 # SMALL round
size = src_size        # BIG round
bs = 4 # SMALL=12 : BIG=4

# Setup Data
src = SegmentationItemList.from_folder(path_img)     # GET images
# fastai.vision.data.SegmentationItemList :: ItemList

src = src.split_by_fname_file(str(path/'valid.txt')) # SPLIT valid/train
# fastai.data_block.ItemLists

src = src.label_from_func(get_y_fn, classes=codes)    # LABEL classes
# fastai.data_block.LabelLists

# TODO: replace split_by_fname_file with split_by_rand_pct
# TODO: augment dataset with apply_tfms

data = (src.transform(get_transforms(),size=size,tfm_y=True)
        .databunch(bs=bs)
        .normalize(imagenet_stats))
# fastai.basic_data.DataBunch

data.show_batch(rows=3, figsize=(12,9))

# Create U-net Learner object
learn = unet_learner(data,models.resnet34,metrics=[jaccard_loss,acc_camvid,dice])
learn.path = Path(".")

# TODO: Custom U-net class to implement Refine-Net
# TODO: Optimize weight decay and percent start params (wd=0.01 and pct_start=0.3)

learn.load('stage-2-big')

# TODO: Visualize unet model architecture

# Train model (iterative uncomment)

# learn.load('stage-1-small')
# learn.load('stage-1-big')

# STAGE 1 : OPTIMZE learning rate
# learn.lr_find()
# learn.recorder.plot() 
#     or fig = learn.recorder.plot(return_fig=True); fig.savefig('file.png')
# lr = 3e-03
# learn.fit_one_cycle(10,slice(lr))
# learn.save('stage-1-small')
# learn.save('stage-1-big')

# STAGE 2 : UNFREEZE
# learn.unfreeze()
# learn.lr_find()
# learn.recorder.plot()
# lower_lr_bound = 4e-05
# learn.fit_one_cycle(12,slice(lower_lr_bound,lr/5))
# learn.save('stage-2-small')
# learn.save('stage-2-big') # FINAL 

# Results
learn.show_results(rows=3,figsize=(8,9))
# learn.recorder.plot_losses() # return_fig=True
# learn.recorder.plot_lr()     # return_fig=True

# Export model
# learn.export() # to 'learn.path/'export.pkl'

# Inference
# learn = load_learner(Path("."))
img = open_image("920x760.jpg")
res = learn.predict(img) 
# res = tuple(ImageSegment,Tensor[1,720,960],Tensor[32,720,960])
#     = tuple(mask image, class pixel values, probabilities)
img.show()
res[0].show()
img.show(y=res[0]) # y=mask

# Stuff that requires Learner to have data loaded
interp = SegmentationInterpretation.from_learner(learn)
classes = [x for x in learn.data.classes]
interp._interp_show(res[0],classes)

# TODO: cleanup color/class output
# TODO: generate this output on export.pkl alone
# TODO: inference on TestBatch

# TODO: gather dataset
# TODO: label semantic segmentation using 3rd party tool

