In [None]:
from functools import reduce
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from easydict import EasyDict
import torch
from torchvision import models
import numpy as np
from skimage import io
from pytorch_toolbelt.utils.rle import rle_decode
import albumentations as A
from albumentations.pytorch import ToTensorV2

from scripts.utils import annotation2mask, get_box
from scripts.dataset import CellDataset

np.random.seed(0)

## Links to tutorials
- Torchvision maskrcnn inputting: [pytorch.org](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
- Kaggle baseline notebook: [kaggle.com](https://www.kaggle.com/julian3833/sartorius-starter-torch-mask-r-cnn-lb-0-273)


In [None]:
# Global config of dataset, not tunable parameters
config = EasyDict(
    dataset_path=Path("/data/kaggle_data/"),
    val_size=0.2,
    batch_size=2,
    num_workers=0,
)

In [None]:
# pascal_voc - inputting bbox coord in format (xmin, ymin, xmax, ymax)
valid_transform = A.Compose([
    A.Normalize(mean=(0.485,), std=(0.229,)),
    A.ShiftScaleRotate(shift_limit=0.8),
    ToTensorV2(),
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['category_ids']))

In [None]:
plt.rcParams["figure.figsize"] = (30, 30)

def visualize_dataset(im, boxes, masks):
    """
    Visualization instance of dataset after augmentation
    """
    # torch image preprocessed for plotting with matplotlib
    im = np.transpose(im.numpy(), axes=(1, 2, 0))
    im = (im - im.min()) / (im.max() - im.min())
    im = np.asarray(im * 256., dtype=np.uint8)

    image_mask = reduce(lambda x, y: x + y, masks)

    image_mask[image_mask > 1] = 1
    yellow_mask = np.stack([image_mask, image_mask, np.zeros_like(image_mask)], axis=2)

    image_with_mask = np.array(im + 50 * yellow_mask, dtype=np.uint8)

    # Drawing red rectangle for each instances
    red_color = (255, 0, 0)
    for x1, y1, x2, y2 in boxes:
        image_with_mask = cv2.rectangle(image_with_mask.copy(), pt1=(int(x1), int(y1)), pt2=(int(x2), int(y2)), color=red_color, thickness=2)
    
    plt.imshow(image_with_mask)
    plt.xticks([])
    plt.yticks([])
    plt.show()

In [None]:
dataset = CellDataset(cfg=config, mode='train', transform=valid_transform)

image, data = dataset[1]
visualize_dataset(image, boxes=data['boxes'], masks=data['masks'])

In [None]:
# Do not change collate function - it was takes from torchvision tutorials
dataloader = DataLoader(dataset=CellDataset(cfg=config, mode='train', transform=valid_transform),
                        num_workers=config.num_workers,
                        batch_size=config.batch_size,
                        collate_fn=lambda x: tuple(zip(*x)))

# 2 classes: 0 - background, 1 - cell
model = models.detection.maskrcnn_resnet50_fpn(num_classes=2, progress=False)
model.train()

for image, label in dataloader:
    output = model(image, label)
    print(f"output: {output}")
    break