In [None]:
# default_exp datasets

# Datasets

> The challenge has two parts:
- instance segmentation of dolphins in the photo and
- recognition of an individual dolphin from the photo. 

In [None]:
# export

from pathlib import Path
from typing import *

In [None]:
# exporti


import numpy as np
import shutil
from datetime import datetime
import torch
import torch.utils.data
from torch.hub import download_url_to_file
import torchvision
import PIL
from PIL import Image
from zipfile import ZipFile

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor
from tempfile import TemporaryDirectory

import git

In [None]:
#exporti

# Download TorchVision repo to use some files from
# references/detection

if not Path("coco_utils.py").exists():

    with TemporaryDirectory() as d:
        vision_root = Path(d) / "vision"
        
        git.repo.base.Repo.clone_from(url="https://github.com/pytorch/vision.git", to_path=vision_root)
        assert vision_root.exists() and vision_root.is_dir()
        
        detection_root = vision_root / "references" / "detection"
        py_files = list(detection_root.glob("**/*.py"))
        assert len(py_files) >= 7

        for f_src in py_files:
            f_dst = (Path(".") / f_src.name)
            print(f"Copy: {f_src.resolve()} -> {f_dst.resolve()}")
            shutil.copy(f_src, f_dst)
            assert f_dst.exists() and not f_dst.is_dir()

assert Path("engine.py").exists()
assert Path("transforms.py").exists()
assert Path("coco_utils.py").exists()

# imports
from engine import train_one_epoch, evaluate
import transforms as T
import utils


In [None]:
# hide

for m in [np, torch, torchvision, PIL]:
    print(f"{m.__name__:12}: {m.__version__}")

numpy       : 1.18.5
torch       : 1.7.1
torchvision : 0.8.2
PIL         : 7.2.0


In [None]:
#exporti

dataset_url = "https://s3.eu-central-1.amazonaws.com/ai-league.cisex.org/2020-2021/dolphins-instance-segmentation/dolphins_200.zip"

In [None]:
#exporti

dataset_root = Path("./data/dolphins_200")
dataset_zip = dataset_root.parent / "dolphins_200.zip"

def _download_data_if_needed():

    dataset_zip.parent.mkdir(parents=True, exist_ok=True)

    if not dataset_zip.exists():
        torch.hub.download_url_to_file(
            dataset_url,
            dataset_zip,
            hash_prefix=None,
            progress=True,
        )


    with ZipFile(dataset_zip, 'r') as zip_ref:
        zip_ref.extractall(dataset_root)


In [None]:
# hide

_download_data_if_needed()

assert dataset_root.exists() and dataset_root.is_dir()

images_path = dataset_root / "JPEGImages"
instance_path = dataset_root / "SegmentationObject"
class_path = dataset_root / "SegmentationClass"

image_files = sorted([x for x in images_path.glob("**/*")])
assert len(image_files) >= 200

instance_files = sorted([x for x in instance_path.glob("**/*")])
assert len(instance_files) >= 200

class_files = sorted([x for x in class_path.glob("**/*")])
assert len(class_files) >= 200

In [None]:
# exporti


def _enumerate_colors_for_fname(fname: Path) -> Tuple[int, int, int]:
    """Finds all colors in the image"""
    img = Image.open(fname)
    colors = [y for x, y in img.getcolors()]
    return colors

In [None]:
# hide

test_file = class_path / "070624_6_1_0022.png"

actual = _enumerate_colors_for_fname(test_file)

expected = [(255, 0, 0), (0, 0, 0)]
assert actual == expected, f"{actual} == {expected}"

In [None]:
# exporti


def _enumerate_colors_for_fnames(fnames: List[Path]) -> Dict[Tuple[int, int, int], int]:
    """This function is used to pin (0, 0, 0) color to the front of palette"""
    colors = np.array([_enumerate_colors_for_fname(fname) for fname in fnames]).reshape(
        -1, 3
    )
    colors = set([tuple(x) for x in colors.tolist() if tuple(x) != (0, 0, 0)])
    colors = [(0, 0, 0)] + list(colors)
    return {x: i for i, x in enumerate(colors)}

In [None]:
# hide

actual = _enumerate_colors_for_fnames(class_files)

expected = {(0, 0, 0): 0, (255, 0, 0): 1}

assert actual == expected, f"{actual} == {expected}"

In [None]:
# exporti


def _substitute_values(xs: np.array, x, y):
    """Not sure I understand what this does"""
    ix_x = xs == x
    ix_y = xs == y
    xs[ix_x] = y
    xs[ix_y] = x

In [None]:
# exporti


def _enumerate_image_for_instances(
    im: Image, force_black_to_zero: bool = True, max_colors=16
) -> np.array:
    """convert rgb image mask to enumerated image mask"""
    pallete_mask = im.convert("P", palette=Image.ADAPTIVE, colors=max_colors)

    xs = np.array(pallete_mask)

    if force_black_to_zero:
        _substitute_values(xs, 0, xs.max())

    return xs

In [None]:
# hide

enum_instance = _enumerate_image_for_instances(Image.open(instance_files[0]))
assert enum_instance.shape == (500, 750)

unique_enum_instance = np.unique(np.array(enum_instance))
assert sorted(unique_enum_instance) == [0, 1, 2] # for this particular photo

In [None]:
# exporti


def _enumerate_image_for_classes(
    im: Image,
    colors: Dict[Tuple[int], int] = None,
) -> np.array:
    """Enumerates classes from the rbg format"""
    xs = np.array(im)
    xs = [
        ((xs == color).all(axis=-1)).astype(int) * code
        for color, code in colors.items()
    ]
    xs_sum = xs[0]
    for i in range(1, len(xs)):
        xs_sum = xs_sum + xs[i]
    return xs_sum.astype("uint8")

In [None]:
# hide

class_palette = _enumerate_colors_for_fnames(class_files)
assert class_palette == class_palette

img = Image.open(class_files[0])

enum_classes = _enumerate_image_for_classes(img, class_palette)
assert enum_classes.shape == (500, 750)

unique_enum_classes = np.unique(enum_classes)
assert set(unique_enum_classes) == {0, 1}, f"{unique_enum_classes}"

In [None]:
# exporti


class DolphinsInstanceSegmentationDataset(torch.utils.data.Dataset):
    """Instance segmentation dataset
    """
    def __init__(self, root: Path, transforms=None):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.img_paths = sorted((root / "JPEGImages").glob("*.*"))
        self.label_paths = sorted((root / "SegmentationClass").glob("*.*"))
        self.mask_paths = sorted((root / "SegmentationObject").glob("*.*"))

        self.class_colors = _enumerate_colors_for_fnames(self.label_paths)

    def __getitem__(self, idx):
        # load images ad masks
        img_path = self.img_paths[idx]
        label_path = self.label_paths[idx]
        mask_path = self.mask_paths[idx]

        img = Image.open(img_path).convert("RGB")

        # note that we haven't converted the mask to RGB,
        # because each color corresponds to a different instance
        # with 0 being background
        mask_img = Image.open(mask_path)
        mask = _enumerate_image_for_instances(mask_img)

        # instances are encoded as different colors
        obj_ids = np.unique(mask)

        # first id is the background, so remove it
        obj_ids = obj_ids[1:]

        # split the color-encoded mask into a set
        # of binary masks
        masks = (mask == obj_ids[:, None, None])
        
        label_img = Image.open(label_path)
        label_array = _enumerate_image_for_classes(label_img, self.class_colors)
        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        labels = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

            class_mask = label_array * masks[i]
            label, count = np.unique(class_mask, return_counts=True)
            assert label.shape[0] <= 2
            label = max(label)
            labels.append(label)

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there WAS multi class
        # labels = torch.as_tensor(labels, dtype=torch.int64)
        labels = torch.ones((num_objs,), dtype=torch.int64)

        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        # suppose all instances are not crowd
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.img_paths)

In [None]:
# hide

dataset = DolphinsInstanceSegmentationDataset(dataset_root)
dataset[0]

(<PIL.Image.Image image mode=RGB size=750x500 at 0x7F8A491E2BA8>,
 {'boxes': tensor([[539., 236., 734., 320.],
          [301., 248., 554., 339.]]),
  'labels': tensor([1, 1]),
  'masks': tensor([[[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]],
  
          [[0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           ...,
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0],
           [0, 0, 0,  ..., 0, 0, 0]]], dtype=torch.uint8),
  'image_id': tensor([0]),
  'area': tensor([16380., 23023.]),
  'iscrowd': tensor([0, 0])})

In [None]:
# exporti

def _get_instance_segmentation_dataset(
    *,
    get_transform: Callable[[bool], Callable] = (lambda train: None),
    batch_size: int = 4,
    val_split: float = 0.2,
    num_workers: int = 4
) -> Tuple[torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader]:
    """Get dataset for instance segmentation. Make sure you define get_transform function."""

    # get data if needed
    _download_data_if_needed()
    root_path = Path(dataset_root)
    assert root_path.exists()
    assert root_path.is_dir()
    assert len(list(root_path.glob("**/*"))) >= 600

    # use our dataset and defined transformations
    dataset = DolphinsInstanceSegmentationDataset(
        dataset_root, get_transform(train=True)
    )
    dataset_test = DolphinsInstanceSegmentationDataset(
        dataset_root, get_transform(train=False)
    )

    n_val = max(1, round(val_split * len(dataset)))

    # split the dataset in train and test set
    torch.manual_seed(1)
    indices = torch.randperm(len(dataset)).tolist()
    dataset = torch.utils.data.Subset(dataset, indices[:-n_val])
    dataset_test = torch.utils.data.Subset(dataset_test, indices[-n_val:])

    # define training and validation data loaders
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        collate_fn=utils.collate_fn,
    )

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        collate_fn=utils.collate_fn,
    )
    
    return data_loader, data_loader_test

In [None]:
# hide

train_ds, val_ds = _get_instance_segmentation_dataset()
train_ds

<torch.utils.data.dataloader.DataLoader at 0x7f8a491d6588>

## Download and load dataset

In [None]:
# export


def get_dataset(
    name: str,
    *,
    get_transform: Callable[[bool], Callable] = (lambda train: None),
    batch_size: int = 4,
    val_split: float = 0.2,
    num_workers: int = 4,
) -> Tuple[
    torch.utils.data.dataloader.DataLoader, torch.utils.data.dataloader.DataLoader
]:
    """Get one of two datasets available. The parameter `name` can be one of 'segmentation' and 'classification'"""

    assert name in [
        "segmentation",
        "classification",
    ], f"name should be either 'segmentation' or 'classification', but it is '{name}'."

    if name == "segmentation":
        return _get_instance_segmentation_dataset()
    elif name == "classification":
        raise NotImplementedError()

This is how you download the dataset and create dataset loaders for torch:

In [None]:
train_ds, val_ds = get_dataset("segmentation")
train_ds

<torch.utils.data.dataloader.DataLoader at 0x7f8a49189a20>