In [None]:
import shutil
from pathlib import Path

import cv2
from ultralytics import YOLO

DATA = Path("datasets")

### Load data and split it into train/test

We have some [data in DVC](https://dvc.org/doc/start/data-management/data-versioning) that we can pull. 

This data includes:
* satellite images
* masks of the swimming pools in each satellite image

DVC can help connect your data to your repo, but it isn't necessary to have your data in DVC to start tracking experiments with DVC and DVCLive.

In [None]:
!dvc pull

### Convert to YOLO Dataset format

https://docs.ultralytics.com/datasets/segment/

In [None]:
def mask_to_yolo_annotation(mask):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    annotation = ""
    for contour in contours:
        single_annotation = "0"
        for row, col in contour.squeeze():
            single_annotation += f" {round(col / mask.shape[1], 3)} {round(row / mask.shape[0], 3)}"
        annotation += f"{single_annotation}\n"
    return annotation

In [None]:
test_regions = ["REGION_1-"]

train_data_dir = DATA / "yolo_dataset" / "train"
train_data_dir.mkdir(exist_ok=True, parents=True)
test_data_dir = DATA / "yolo_dataset" / "val"
test_data_dir.mkdir(exist_ok=True, parents=True)

for img_path in DATA.glob("pool_data/images/*.jpg"):
    yolo_annotation = mask_to_yolo_annotation(
        cv2.imread(
            str(DATA / "pool_data" / "masks" / f"{img_path.stem}.png"),
            cv2.IMREAD_GRAYSCALE
        )
    )

    if any(region in str(img_path) for region in test_regions):
        dst = test_data_dir / img_path.name
    else:
        dst = train_data_dir / img_path.name

    shutil.copy(img_path, dst)
    dst.with_suffix(".txt").write_text(yolo_annotation)

In [None]:
yolo_dataset_yaml = DATA / "yolo_dataset.yaml"
yolo_dataset_yaml.write_text(
    """
path: ./yolo_dataset
train: train
val: val

names:
  0: pool
    """
)

### Train model
Set up model training, using DVCLive to capture the results of each experiment.

In [None]:
imgsz = 512
epochs = 20
model = "yolov8n-seg.pt"

In [None]:
yolo = YOLO(model)

yolo.train(data=yolo_dataset_yaml, epochs=epochs, imgsz=imgsz)