diff --git a/example-get-started-cv/code/notebooks/TrainSegModel.ipynb b/example-get-started-cv/code/notebooks/TrainSegModel.ipynb index 01657aa4..dba479c2 100644 --- a/example-get-started-cv/code/notebooks/TrainSegModel.ipynb +++ b/example-get-started-cv/code/notebooks/TrainSegModel.ipynb @@ -39,7 +39,8 @@ "from fastai.metrics import DiceMulti\n", "from fastai.vision.all import (Resize, SegmentationDataLoaders, aug_transforms,\n", " imagenet_stats, models, unet_learner)\n", - "from ruamel.yaml import YAML" + "from ruamel.yaml import YAML\n", + "from PIL import Image" ] }, { @@ -100,7 +101,7 @@ "source": [ "bs = 8\n", "valid_pct = 0.20\n", - "img_size = 512\n", + "img_size = 256\n", "\n", "data_loader = SegmentationDataLoaders.from_label_func(\n", " path=train_data_dir,\n", @@ -142,6 +143,23 @@ "### Train multiple models with different learning rates using `DVCLiveCallback`" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def dice(mask_pred, mask_true, classes=[0, 1], eps=1e-6):\n", + " dice_list = []\n", + " for c in classes:\n", + " y_true = mask_true == c\n", + " y_pred = mask_pred == c\n", + " intersection = 2.0 * np.sum(y_true * y_pred)\n", + " dice = intersection / (np.sum(y_true) + np.sum(y_pred) + eps)\n", + " dice_list.append(dice)\n", + " return np.mean(dice_list)" + ] + }, { "cell_type": "code", "execution_count": null, @@ -154,7 +172,8 @@ " live = Live(dir=os.path.join('results', 'train'), \n", " report=\"md\", \n", " save_dvc_exp=True)\n", - " live.log_param(\"base_lr\", base_lr)\n", + " live.summary[\"base_lr\"] =base_lr\n", + " live.make_summary()\n", " learn = unet_learner(data_loader, \n", " arch=getattr(models, train_arch), \n", " metrics=DiceMulti)\n", @@ -164,7 +183,25 @@ " }\n", " learn.fine_tune(\n", " **fine_tune_args,\n", - " cbs=[DVCLiveCallback(live=live)])" + " cbs=[DVCLiveCallback(live=live)])\n", + " \n", + " test_img_fpaths = get_files(Path(\"data\") / \"test_data\", extensions=\".jpg\")\n", + " test_dl = learn.dls.test_dl(test_img_fpaths)\n", + " preds, _ = learn.get_preds(dl=test_dl)\n", + " masks_pred = np.array(preds[:, 1, :] > 0.5, dtype=int)\n", + " test_mask_fpaths = [\n", + " get_mask_path(fpath, Path(\"data\") / \"test_data\") for fpath in test_img_fpaths\n", + " ]\n", + " masks_true = [Image.open(mask_path) for mask_path in test_mask_fpaths]\n", + " masks_true = [\n", + " np.array(img.resize((img_size, img_size)), dtype=int) for img in masks_true\n", + " ]\n", + " with Live(\"results/evaluate\", report=\"md\") as live:\n", + " dice_multi = 0.0\n", + " for ii in range(len(masks_true)):\n", + " mask_pred, mask_true = masks_pred[ii], masks_true[ii]\n", + " dice_multi += dice(mask_true, mask_pred) / len(masks_true)\n", + " live.summary[\"dice_multi\"] = dice_multi" ] }, { @@ -183,8 +220,11 @@ "metadata": {}, "outputs": [], "source": [ + "%%bash\n", "# Apply best performing experiment to the workspace\n", - "!EXP=$(dvc exp show --csv --sort-by dice_multi | tail -n 1 | cut -d , -f 1) && dvc exp apply $EXP" + "BEST_EXP_NAME=$(dvc exp show --csv --sort-by dice_multi | tail -n 1 | cut -d , -f 1)\n", + "echo \"Applying $BEST_EXP_NAME\"\n", + "dvc exp apply \"$BEST_EXP_NAME\"" ] }, {