# Tooth and Cavity Instance Segmentation

In [40]:
import torch
from torch.utils.data import Dataset, DataLoader

import torchvision
torchvision.disable_beta_transforms_warning()
from torchvision.datasets import (
    CocoDetection,
    wrap_dataset_for_transforms_v2
)
from torchvision.transforms import v2

from types import SimpleNamespace
from pathlib import Path
from pprint import pprint

In [39]:
FPATHS = SimpleNamespace(
    data = Path("../data/")
)

FPATHS.data_train = FPATHS.data / "train"
FPATHS.data_valid = FPATHS.data / "valid"
FPATHS.data_test = FPATHS.data / "test"

FPATHS.data_annotations_fname = "_annotations.coco.json"
FPATHS.data_train_annotations = FPATHS.data_train / FPATHS.data_annotations_fname
FPATHS.data_valid_annotations = FPATHS.data_valid / FPATHS.data_annotations_fname
FPATHS.data_test_annotations = FPATHS.data_test / FPATHS.data_annotations_fname


loading annotations into memory...
Done (t=0.67s)
creating index...
index created!
loading annotations into memory...
Done (t=0.02s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [65]:
transforms = v2.Compose([
    v2.ToImagePIL(),
    v2.RandomPhotometricDistort(p=1),
    v2.SanitizeBoundingBox()
])

datasets = SimpleNamespace(
    train = wrap_dataset_for_transforms_v2(
        CocoDetection(
            FPATHS.data_train,
            FPATHS.data_train_annotations,
            transforms=transforms
        )
    ),
    valid = wrap_dataset_for_transforms_v2(CocoDetection(FPATHS.data_valid, FPATHS.data_valid_annotations)),
    test = wrap_dataset_for_transforms_v2(CocoDetection(FPATHS.data_test, FPATHS.data_test_annotations))
)

loading annotations into memory...
Done (t=0.35s)
creating index...
index created!
loading annotations into memory...
Done (t=0.02s)
creating index...
index created!
loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [66]:
def plot_coco_image(dataset: Dataset, index: int = 0):
    image, target = dataset[index]
    

In [67]:
datasets.train[0]

(<PIL.Image.Image image mode=RGB size=640x640>,
 {'id': [0, 1, 2, 3, 4, 5],
  'image_id': 0,
  'category_id': [4, 1, 2, 4, 1, 2],
  'bbox': [[349, 50, 232.234, 487.087],
   [509, 243, 14.144, 18.162],
   [52, 287, 222.69, 342.896],
   [0, 18, 356.087, 615.413],
   [446, 302, 17.865, 12.73],
   [282, 143, 68.479, 243.354]],
  'area': [113118.174, 256.896, 76359.346, 219140.715, 227.421, 16664.615],
  'segmentation': [[[349.703,
     316.878,
     351.78,
     363.852,
     361.898,
     405.69,
     389.281,
     475.745,
     412.205,
     503.366,
     452.103,
     529.204,
     495.873,
     537.474,
     532.802,
     507.852,
     561.842,
     457.608,
     568.252,
     409.742,
     580.765,
     365.525,
     576.165,
     267.335,
     571.858,
     199.145,
     572.903,
     181.874,
     547.303,
     128.793,
     501.901,
     77.793,
     468.554,
     59.847,
     436.098,
     50.387,
     406.319,
     66.388,
     391.576,
     89.388,
     384.873,
     107.253,
  