In [1]:
import os
from functools import reduce
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from easydict import EasyDict
import torch
from torchvision import models
import numpy as np
from skimage import io
from pytorch_toolbelt.utils.rle import rle_decode
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from datetime import datetime
from src.utils import annotation2mask, get_box
from src.dataset import CellDataset
from src.visualization import show_image
from dotenv import load_dotenv

np.random.seed(0)

load_dotenv()

True

In [2]:
current_dir = Path(".") # In my case, it is sartorius_instance_segmentation
current_dir.absolute()

PosixPath('/home/shamil/PycharmProjects/sartorius_instance_segmentation')

## Links to tutorials
- Torchvision maskrcnn inputting: [pytorch.org](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
- Kaggle baseline notebook: [kaggle.com](https://www.kaggle.com/julian3833/sartorius-starter-torch-mask-r-cnn-lb-0-273)


In [3]:
# Global config of dataset, not tunable parameters
config = EasyDict(
    dataset_path=Path(os.environ["dataset_path"]),
    device="cuda:0",
    val_size=0.2,
    batch_size=2,
    num_workers=8,
    max_epochs=40,
)


In [4]:
# pascal_voc - inputting bbox coord in format (xmin, ymin, xmax, ymax)
valid_transform = A.Compose([
    A.Normalize(mean=(0.485,), std=(0.229,)),
    # A.ShiftScaleRotate(shift_limit=0.8, border_mode=cv2.BORDER_CONSTANT),
    ToTensorV2(),
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['category_ids']))

In [5]:
dataset = CellDataset(cfg=config, mode='train', transform=valid_transform)
image, data = dataset[1]
# show_image(image, boxes=data['boxes'], masks=data['masks'])

torch.Size([1, 520, 704])


In [6]:
# Do not change collate function - it was takes from torchvision tutorials
train_dataloader = DataLoader(
    dataset=CellDataset(cfg=config, mode='train', transform=valid_transform),
    num_workers=config.num_workers,
    batch_size=config.batch_size,
    collate_fn=lambda x: tuple(zip(*x))
)

val_dataloader = DataLoader(
    dataset=CellDataset(cfg=config, mode='val', transform=valid_transform),
    num_workers=config.num_workers,
    batch_size=config.batch_size,
    collate_fn=lambda x: tuple(zip(*x))
)

In [7]:
device = config.device

In [8]:
def train(model, optimizer):
    for epoch in range(config.max_epochs):

        model.train()
        for batch_idx, (images, targets) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            # output has keys 'loss classifier', 'loss_box_reg', 'loss_mask', 'loss_objectness', 'loss_rpn_box_reg'
            # second one is Faster R-CNN bounding box prediction loss
            # last one is Region Proposal Network loss, RPN proposes candidate object BBoxes for Faster R-CNN
            # 4th - I don't know what is it
            optimizer.zero_grad()
            images = list(image.to(device) for image in images)
            targets = [{key: value.to(device) for key, value in target.items()} for target in targets]

            output = model(images, targets)
            loss = output['loss_classifier'] + output['loss_box_reg'] + output['loss_mask'] + output[
                'loss_objectness'] + output['loss_rpn_box_reg']

            print(f"train loss, on iter: {batch_idx} is {loss.item()}")

            loss.backward()
            optimizer.step()

        # model.eval()
        # for batch_idx, (images, targets) in tqdm(enumerate(val_dataloader)):
        #     break

    weights_dir = current_dir / "weights"
    weights_dir.mkdir(exist_ok=True)
    torch.save(model.state_dict(), weights_dir / f"maskrcnn-{datetime.now().__str__()}.ckpt")
    print(f"saved the weights in weights/maskrcnn-{datetime.now().__str__()} folder!")


# 2 classes: 0 - background, 1 - cell
model = models.detection.maskrcnn_resnet50_fpn(num_classes=2, progress=False)
model.to(device)

optimizer = torch.optim.Adam(model.parameters())
train(model=model, optimizer=optimizer)

0it [00:00, ?it/s]

train loss, on iter: 0 is 5.175042152404785


0it [00:25, ?it/s]


Test images, making kaggle submission