In [None]:
from pathlib import Path

root = Path("../")
data = root / "data"

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import shutil
from functools import partial

import numpy as np
import torch
from box import ConfigBox
from dvclive import Live
from dvclive.fastai import DVCLiveCallback
from fastai.data.all import Normalize, get_files
from fastai.metrics import DiceMulti
from fastai.vision.all import (Resize, SegmentationDataLoaders, aug_transforms,
                               imagenet_stats, models, unet_learner)
from ruamel.yaml import YAML

# Download data

In [None]:
!wget 

### Load data and split it into train/test

In [None]:
test_pct = 0.25

img_fpaths = get_files(data / "pool_data" / "images", extensions=".jpg")

train_data_dir = data / "train_data"
train_data_dir.mkdir(exist_ok=True)
test_data_dir = data / "test_data"
test_data_dir.mkdir(exist_ok=True)
for img_path in img_fpaths:
    msk_path = data / "pool_data" / "masks" / f"{img_path.stem}.png"
    if np.random.uniform() <= test_pct:
        shutil.copy(img_path, test_data_dir)
        shutil.copy(msk_path, test_data_dir)
    else:
        shutil.copy(img_path, train_data_dir)
        shutil.copy(msk_path, train_data_dir)

### Create a data loader

In [None]:
def get_mask_path(x, train_data_dir):
    return Path(train_data_dir) / f"{Path(x).stem}.png"

In [None]:
bs = 8
valid_pct = 0.20
img_size = 512

data_loader = SegmentationDataLoaders.from_label_func(
        path=train_data_dir,
        fnames=get_files(train_data_dir, extensions=".jpg"),
        label_func=partial(get_mask_path, train_data_dir=train_data_dir),
        codes=["not-pool", "pool"],
        bs=bs,
        valid_pct=valid_pct,
        item_tfms=Resize(img_size),
        batch_tfms=[
            *aug_transforms(size=img_size),
            Normalize.from_stats(*imagenet_stats),
        ],
    )

### Review a sample batch of data

In [None]:
data_loader.show_batch(alpha=0.7)

### Train multiple models with different learning rates using `DVCLiveCallback`

In [None]:
train_arch = 'resnet18'

for base_lr in [0.001, 0.005, 0.01]:
    live = Live(dir=str(root / "results" / "train"), 
            report="md", 
            save_dvc_exp=True)
    live.log_param("base_lr", base_lr)
    learn = unet_learner(data_loader, 
                        arch=getattr(models, train_arch), 
                        metrics=DiceMulti)
    fine_tune_args = {
        'epochs': 8,
        'base_lr': base_lr
        }
    learn.fine_tune(
        **fine_tune_args,
        cbs=[DVCLiveCallback(live=live)])

In [None]:
# Compare experiments
!dvc exp show --only-changed

In [None]:
# Apply best performing experiment to the workspace
!EXP=$(dvc exp show --csv --sort-by dice_multi | tail -n 1 | cut -d , -f 1) && dvc exp apply $EXP

### Review sample preditions vs ground truth

In [None]:
learn.show_results(max_n=6, alpha=0.7)

### Review instances where loss function values are the highest (i.e. model is likely to be wrong)

In [None]:
from fastai.vision.all import SegmentationInterpretation

interp = SegmentationInterpretation.from_learner(learn)

In [None]:
interp.plot_top_losses(k=5, alpha=0.7)