In [11]:
import os
from functools import reduce
import logging
from pathlib import Path


from torch.utils.data import Dataset, DataLoader
from easydict import EasyDict
import torch
from torchvision import models
import numpy as np

import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from datetime import datetime
from dotenv import load_dotenv

from src.dataset import CellDataset
from src.postprocessing import postprocess_predictions
from src.iou_metric import iou_map
logging.basicConfig(filename='train.log', level=logging.INFO, format='%(asctime)s %(message)s')
np.random.seed(0)

load_dotenv()

True

In [12]:
current_dir = Path(".") # In my case, it is sartorius_instance_segmentation
current_dir.absolute()

PosixPath('/home/shamil/PycharmProjects/sartorius_instance_segmentation')

## Links to tutorials
- Torchvision maskrcnn inputting: [pytorch.org](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
- Kaggle baseline notebook: [kaggle.com](https://www.kaggle.com/julian3833/sartorius-starter-torch-mask-r-cnn-lb-0-273)


In [13]:
# Global config of dataset, not tunable parameters
config = EasyDict(
    dataset_path=Path(os.environ["dataset_path"]),
    device="cuda:0",
    val_size=0.2,
    batch_size=6,
    num_workers=30,
    max_epochs=40,
    mask_threshold=0.5,
    score_threshold=0.2,
    nms_threshold=None,

)

## This is for Shamil's local running
# config = EasyDict(
#     dataset_path=Path(os.environ["dataset_path"]),
#     device="cpu",
#     val_size=0.2,
#     batch_size=1,
#     num_workers=2,
#     max_epochs=1,
# )


In [14]:
# pascal_voc - inputting bbox coord in format (xmin, ymin, xmax, ymax)
valid_transform = A.Compose([
    A.Normalize(mean=(0.485,), std=(0.229,)),
    # A.ShiftScaleRotate(shift_limit=0.8, border_mode=cv2.BORDER_CONSTANT),
    ToTensorV2(),
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['category_ids']))

test_transform = A.Compose([
    A.Normalize(mean=(0.485,), std=(0.229,)),
    ToTensorV2(),
])

In [15]:
dataset = CellDataset(cfg=config, mode='train', transform=valid_transform)
image, data = dataset[1]

In [16]:
# Do not change collate function - it was takes from torchvision tutorials
train_dataloader = DataLoader(
    dataset=CellDataset(cfg=config, mode='train', transform=valid_transform),
    num_workers=config.num_workers,
    batch_size=config.batch_size,
    collate_fn=lambda x: tuple(zip(*x))
)

val_dataloader = DataLoader(
    dataset=CellDataset(cfg=config, mode='val', transform=valid_transform),
    num_workers=config.num_workers,
    batch_size=config.batch_size,
    collate_fn=lambda x: tuple(zip(*x))
)

In [17]:
device = config.device

In [18]:
def train(model, optimizer):
    for epoch in range(config.max_epochs):
        # train
        losses = []
        mask_losses = []
        model.train()
        for batch_idx, (images, targets) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            break
            # output has keys 'loss classifier', 'loss_box_reg', 'loss_mask', 'loss_objectness', 'loss_rpn_box_reg'
            # second one is Faster R-CNN bounding box prediction loss
            # last one is Region Proposal Network loss, RPN proposes candidate object BBoxes for Faster R-CNN
            # 4th - I don't know what is it
            optimizer.zero_grad()
            images = list(image.to(device) for image in images)
            targets = [{key: value.to(device) for key, value in target.items()} for target in targets]

            output = model(images, targets)
            loss = sum(single_loss for single_loss in output.values())

            loss.backward()
            optimizer.step()
            
            losses.append(loss.item())
            mask_losses.append(output['loss_mask'].item())
        logging.info(f"Epoch {epoch}: Mean train epoch loss is {np.mean(losses)}, mask loss is {np.mean(mask_losses)}")
    
        # Calculating loss metrics on validation
        losses = []
        mask_losses = []
        model.train()
        for batch_idx, (images, targets) in tqdm(enumerate(val_dataloader), total=len(val_dataloader)):
            break
            images = list(image.to(device) for image in images)
            targets = [{key: value.to(device) for key, value in target.items()} for target in targets]

            # Calculating metrics on validation

            with torch.no_grad():
                output = model(images, targets)

            loss = sum(single_loss for single_loss in output.values())

            losses.append(loss.item())
            mask_losses.append(output['loss_mask'].item())

        logging.info(f"Epoch {epoch}: Mean validation  loss is {np.mean(losses)}, mask loss is {np.mean(mask_losses)}")
        
        # val visualization
        model.eval()
        iou_scores = []
        for batch_idx, (images, targets) in tqdm(enumerate(val_dataloader), total=len(val_dataloader), desc="Calculating map of validation"):
            images = list(image.to(device) for image in images)
            targets = [{key: value.to(device) for key, value in target.items()} for target in targets]

            with torch.no_grad():
                outputs = model(images)

            outputs = postprocess_predictions(
                outputs,
                mask_threshold=config.mask_threshold,
                score_threshold=config.score_threshold,
                nms_threshold=config.nms_threshold
            )

            # Iterating through each image in batch
            for output, ground_truth in zip(outputs, targets):
                pred_masks = output['masks']
                true_masks = ground_truth['masks'].cpu().numpy()
                score = iou_map(true_masks=true_masks, pred_masks=pred_masks)
                print(f"map score={score}")

                iou_scores.append(score)

        logging.info(f"Epoch: {epoch}: map score: {np.mean(iou_scores)}")


    weights_dir = current_dir / "weights"
    weights_dir.mkdir(exist_ok=True)
    torch.save(model.state_dict(), weights_dir / f"maskrcnn-{datetime.now().__str__()}.ckpt")
    print(f"saved the weights in weights/maskrcnn-{datetime.now().__str__()} folder!")

In [19]:
# 2 classes: 0 - background, 1 - cell
model = models.detection.maskrcnn_resnet50_fpn(num_classes=2, progress=False)
model.to(device)
print()




In [20]:
optimizer = torch.optim.Adam(model.parameters())
train(model=model, optimizer=optimizer)

  0%|          | 0/484 [00:00<?, ?it/s]
  0%|          | 0/122 [00:07<?, ?it/s]
Calculating map of validation:   0%|          | 0/122 [04:33<?, ?it/s]


KeyboardInterrupt: 