In [None]:
from pathlib import Path

ROOT = Path("../")
DATA = ROOT / "data"

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
import shutil
from functools import partial

import numpy as np
import torch
from box import ConfigBox
from dvclive import Live
from dvclive.fastai import DVCLiveCallback
from fastai.data.all import Normalize, get_files
from fastai.metrics import DiceMulti
from fastai.vision.all import (Resize, SegmentationDataLoaders,
                               imagenet_stats, models, unet_learner)
from ruamel.yaml import YAML
from PIL import Image

### Load data and split it into train/test

We have some [data in DVC](https://dvc.org/doc/start/data-management/data-versioning) that we can pull. 

This data includes:
* satellite images
* masks of the swimming pools in each satellite image

DVC can help connect your data to your repo, but it isn't necessary to have your data in DVC to start tracking experiments with DVC and DVCLive.

In [None]:
!dvc pull

In [None]:
test_regions = ["REGION_1-"]

img_fpaths = get_files(DATA / "pool_data" / "images", extensions=".jpg")

train_data_dir = DATA / "train_data"
train_data_dir.mkdir(exist_ok=True)
test_data_dir = DATA / "test_data"
test_data_dir.mkdir(exist_ok=True)
for img_path in img_fpaths:
    msk_path = DATA / "pool_data" / "masks" / f"{img_path.stem}.png"
    if any(region in str(img_path) for region in test_regions):
        shutil.copy(img_path, test_data_dir)
        shutil.copy(msk_path, test_data_dir)
    else:
        shutil.copy(img_path, train_data_dir)
        shutil.copy(msk_path, train_data_dir)

### Create a data loader

Load and prepare the images and masks by creating a data loader.

In [None]:
def get_mask_path(x, train_data_dir):
    return Path(train_data_dir) / f"{Path(x).stem}.png"

In [None]:
bs = 8
valid_pct = 0.20
img_size = 256

data_loader = SegmentationDataLoaders.from_label_func(
        path=train_data_dir,
        fnames=get_files(train_data_dir, extensions=".jpg"),
        label_func=partial(get_mask_path, train_data_dir=train_data_dir),
        codes=["not-pool", "pool"],
        bs=bs,
        valid_pct=valid_pct,
        item_tfms=Resize(img_size),
        batch_tfms=[
            Normalize.from_stats(*imagenet_stats),
        ],
    )

### Review a sample batch of data

Below are some examples of the images overlaid with their masks.

In [None]:
data_loader.show_batch(alpha=0.7)

### Train multiple models with different learning rates using `DVCLiveCallback`

Set up model training, using DVCLive to capture the results of each experiment.

In [None]:
def dice(mask_pred, mask_true, classes=[0, 1], eps=1e-6):
    dice_list = []
    for c in classes:
        y_true = mask_true == c
        y_pred = mask_pred == c
        intersection = 2.0 * np.sum(y_true * y_pred)
        dice = intersection / (np.sum(y_true) + np.sum(y_pred) + eps)
        dice_list.append(dice)
    return np.mean(dice_list)

def evaluate(learn):
    test_img_fpaths = sorted(get_files(DATA / "test_data", extensions=".jpg"))
    test_dl = learn.dls.test_dl(test_img_fpaths)
    preds, _ = learn.get_preds(dl=test_dl)
    masks_pred = np.array(preds[:, 1, :] > 0.5, dtype=np.uint8)
    test_mask_fpaths = [
        get_mask_path(fpath, DATA / "test_data") for fpath in test_img_fpaths
    ]
    masks_true = [Image.open(mask_path) for mask_path in test_mask_fpaths]
    dice_multi = 0.0
    for ii in range(len(masks_true)):
        mask_pred, mask_true = masks_pred[ii], masks_true[ii]
        width, height = mask_true.shape[1], mask_true.shape[0]
        mask_pred = np.array(
            Image.fromarray(mask_pred).resize((width, height)),
            dtype=int
        )
        mask_true = np.array(mask_true, dtype=int)
        dice_multi += dice(mask_true, mask_pred) / len(masks_true)
    return dice_multi

In [None]:
train_arch = 'shufflenet_v2_x2_0'
models_dir = ROOT / "models"
models_dir.mkdir(exist_ok=True)
results_dir = ROOT / "results" / "train"

for base_lr in [0.001, 0.005, 0.01]:
    # initialize dvclive, optionally provide output path, and save results as a dvc experiment
    with Live(str(results_dir), save_dvc_exp=True, report="notebook") as live:
        # log a parameter
        live.log_param("train_arch", train_arch)
        fine_tune_args = {
            'epochs': 8,
            'base_lr': base_lr
        }
        # log a dict of parameters
        live.log_params(fine_tune_args)

        learn = unet_learner(data_loader, 
                            arch=getattr(models, train_arch), 
                            metrics=DiceMulti)
        # train model and automatically capture metrics with DVCLiveCallback
        learn.fine_tune(
            **fine_tune_args,
            cbs=[DVCLiveCallback(live=live)])

        learn.export(fname=(models_dir / "model.pkl").absolute())

        # add additional post-training summary metrics
        live.summary["evaluate/dice_multi"] = evaluate(learn)

        # save model artifact to dvc
        live.log_artifact(
            str(models_dir / "model.pkl"),
            type="model",
            name="pool-segmentation",
            desc="This is a Computer Vision (CV) model that's segmenting out swimming pools from satellite images.",
            labels=["cv", "segmentation", "satellite-images", "unet"],
        )


In [None]:
# Compare experiments
!dvc exp show --only-changed

### Review sample preditions vs ground truth

Below are some example of the predicted masks.

In [None]:
learn.show_results(max_n=6, alpha=0.7)