In [1]:
import os
from functools import reduce
import logging
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from easydict import EasyDict
import torch
from torchvision import models
import numpy as np
from skimage import io
from pytorch_toolbelt.utils.rle import rle_decode
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from datetime import datetime
from dotenv import load_dotenv

from src.utils import annotation2mask, get_box
from src.postprocessing import remove_overlapping_pixels
from src.dataset import CellDataset
from src.visualization import plot_two_masks

np.random.seed(0)

load_dotenv()

True

In [2]:
current_dir = Path(".") # In my case, it is sartorius_instance_segmentation
current_dir.absolute()

PosixPath('/workspaces/sartorius_instance_segmentation')

In [3]:
logging.basicConfig(filename='train.log', level=logging.INFO, format='%(asctime)s %(message)s')

## Links to tutorials
- Torchvision maskrcnn inputting: [pytorch.org](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
- Kaggle baseline notebook: [kaggle.com](https://www.kaggle.com/julian3833/sartorius-starter-torch-mask-r-cnn-lb-0-273)


In [4]:
# Global config of dataset, not tunable parameters
config = EasyDict(
    dataset_path=Path(os.environ["dataset_path"]),
    device="cuda:0",
    val_size=0.2,
    batch_size=6,
    num_workers=30,
    max_epochs=40,
)


In [5]:
# pascal_voc - inputting bbox coord in format (xmin, ymin, xmax, ymax)
valid_transform = A.Compose([
    A.Normalize(mean=(0.485,), std=(0.229,)),
    # A.ShiftScaleRotate(shift_limit=0.8, border_mode=cv2.BORDER_CONSTANT),
    ToTensorV2(),
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['category_ids']))

test_transform = A.Compose([
    A.Normalize(mean=(0.485,), std=(0.229,)),
    ToTensorV2(),
])

In [6]:
dataset = CellDataset(cfg=config, mode='train', transform=valid_transform)
image, data = dataset[1]

  'masks': torch.as_tensor(masks),


In [7]:
# Do not change collate function - it was takes from torchvision tutorials
train_dataloader = DataLoader(
    dataset=CellDataset(cfg=config, mode='train', transform=valid_transform),
    num_workers=config.num_workers,
    batch_size=config.batch_size,
    collate_fn=lambda x: tuple(zip(*x))
)

val_dataloader = DataLoader(
    dataset=CellDataset(cfg=config, mode='val', transform=valid_transform),
    num_workers=config.num_workers,
    batch_size=config.batch_size,
    collate_fn=lambda x: tuple(zip(*x))
)

In [8]:
device = config.device

In [9]:
from torchvision.ops import nms
def visualize_masks_from_loader(model, dataloader, name):
    os.makedirs("imagelogs", exist_ok=True)
    model.eval()
    for batch_idx, (images, targets) in tqdm(enumerate(dataloader)):
        images = list(image.to(device) for image in images)
        output = model(images)
        output_mask = output[0]['masks'].squeeze()
        plt.figure(figsize=(20, 9))
        for i in range(3):
            plt.subplot(1, 3, i+1)
            plt.imshow(output_mask[i].cpu().detach().numpy())
        plt.savefig('imagelogs/' + name + '.jpg', dpi=200)
        return

def predict_masks(image: torch.Tensor, model) -> np.ndarray:
    score_threshold = 0.0  # All predictions would be counted, even with low score
    nms_threshold = 0.1  # Overlapping instances will be dropped, lower - lower overlap is permitted
    mask_threshold = 0.5  # Cut masks by the threshold
    """Predicts masks for the given single image"""
    device = next(model.parameters()).device
    image = image.to(device)
    with torch.no_grad():
        output = model.forward([image])[0]

    scores = output['scores'].detach().cpu()
    masks = output['masks'].squeeze().detach().cpu()
    boxes = output['boxes'].detach().cpu()

    masks = (masks >= mask_threshold).int()

    # Now some masks can be empty (all zeros), we need to exclude them
    # TODO(shamil): this indexing is ugly
    indices = torch.as_tensor([torch.sum(mask) > 0 for mask in masks])
    masks, boxes, scores = masks[indices], boxes[indices], scores[indices]

    indices = scores >= score_threshold
    masks, boxes, scores = masks[indices], boxes[indices], scores[indices]

    indices = nms(boxes, scores, nms_threshold)
    masks, boxes, scores = masks[indices], boxes[indices], scores[indices]
    
    answer_masks = remove_overlapping_pixels(masks.numpy())
    assert np.max(np.sum(answer_masks, axis=0)) <= 1, "Masks overlap"
    return answer_masks
    
def train(model, optimizer):
    for epoch in range(config.max_epochs):
        # train
        losses = []
        mask_losses = []
        model.train()
        for batch_idx, (images, targets) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            # output has keys 'loss classifier', 'loss_box_reg', 'loss_mask', 'loss_objectness', 'loss_rpn_box_reg'
            # second one is Faster R-CNN bounding box prediction loss
            # last one is Region Proposal Network loss, RPN proposes candidate object BBoxes for Faster R-CNN
            # 4th - I don't know what is it
            optimizer.zero_grad()
            images = list(image.to(device) for image in images)
            targets = [{key: value.to(device) for key, value in target.items()} for target in targets]

            output = model(images, targets)
            loss = sum(single_loss for single_loss in output.values())

            loss.backward()
            optimizer.step()
            
            losses.append(loss.item())
            mask_losses.append(output['loss_mask'].item())
        logging.info(f"Epoch {epoch}: Mean train epoch loss is {np.mean(losses)}, mask loss is {np.mean(mask_losses)}")
    
        # val
        model.train()
        losses = []
        mask_losses = []
        with torch.no_grad():
            for batch_idx, (images, targets) in tqdm(enumerate(val_dataloader), total=len(val_dataloader)):
                images = list(image.to(device) for image in images)
                targets = [{key: value.to(device) for key, value in target.items()} for target in targets]

                output = model(images, targets)
                loss = sum(single_loss for single_loss in output.values())

                losses.append(loss.item())
                mask_losses.append(output['loss_mask'].item())
        logging.info(f"Epoch {epoch}: Mean validation  loss is {np.mean(losses)}, mask loss is {np.mean(mask_losses)}")
        
        # val visualization
        model.eval()
        with torch.no_grad():
            images, target = next(iter(val_dataloader))
            image = images[0]
            masks = predict_masks(image, model)
            plot_two_masks(image, masks, target[0]['masks'], filename=f'imagelogs/val{epoch}.jpg')
        model.eval()
    weights_dir = current_dir / "weights"
    weights_dir.mkdir(exist_ok=True)
    torch.save(model.state_dict(), weights_dir / f"maskrcnn-{datetime.now().__str__()}.ckpt")
    print(f"saved the weights in weights/maskrcnn-{datetime.now().__str__()} folder!")

In [10]:
# 2 classes: 0 - background, 1 - cell
model = models.detection.maskrcnn_resnet50_fpn(num_classes=2, progress=False)
model.to(device)
print()




In [None]:
optimizer = torch.optim.Adam(model.parameters())
train(model=model, optimizer=optimizer)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100% 81/81 [03:13<00:00,  2.39s/it] 
100% 21/21 [01:07<00:00,  3.23s/it]
  1% 1/81 [00:38<51:24, 38.56s/it]