# Explore coco dataset

## Arguments

In [None]:
# Dataset to explore. It must be in COCO format.
data_path = 'datasets/ws/'

# Configure notebook here
draw_threshold = 0.5
DPI = 220
vert_size = 500  # Set to None for no scaling of images
line_width = 3
workers = 4
draw_labels = True

## Code

In [None]:
import time
from coco_utils import get_coco  # get_coco_kp
from torchvision import transforms
import torchvision
import torchvision.models.detection
import transforms as T
import torch
import utils
from matplotlib.pyplot import figure, imshow, show
import matplotlib
import numpy as np


convert_to_pil = torchvision.transforms.ToPILImage()

def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    return T.Compose(transforms)

# Datasets
dataset_train, num_classes, label_names = get_coco(data_path, image_set='train')
dataset_test, _, _ = get_coco(data_path, image_set='val')

# Samplers
train_sampler = torch.utils.data.SequentialSampler(dataset_train)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)

train_batch_sampler = torch.utils.data.BatchSampler(
    train_sampler, 1, drop_last=True)

# Loaders
data_loader_train = torch.utils.data.DataLoader(
    dataset_train, batch_sampler=train_batch_sampler, num_workers=workers,
    collate_fn=utils.collate_fn)
data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1,
    sampler=test_sampler, num_workers=workers,
    collate_fn=utils.collate_fn)

In [None]:
matplotlib.rcParams['figure.dpi'] = DPI  # This has to be run in a new cell for some reason

def print_dataset_samples(data_loader):
    images_evaluated = 0
    for image, targets in data_loader:
        image = list(img.to('cpu') for img in image)
        targets = [{k: v.to('cpu') for k, v in t.items()} for t in targets]

        boxes = targets[0]['boxes']
        labels = targets[0]['labels']
        image_with_boxes = utils.draw_boxes(
            image[0], boxes, labels, label_names, vert_size=vert_size,
            line_width=line_width, draw_label=draw_labels
        )
        print(f"\nImage number {images_evaluated} | Image size:{image[0].shape}")
        figure()
        imshow(np.asarray(convert_to_pil(image_with_boxes)))
        show()
        images_evaluated += 1


In [None]:
# Run data_loader_train or data_loader_test, but not both together or you will probably run out of GPU memory
print_dataset_samples(data_loader_train)