In [13]:
import os
from functools import reduce
import logging
from pathlib import Path


from torch.utils.data import Dataset, DataLoader
from easydict import EasyDict
import torch
from torchvision import models
import numpy as np
from time import time
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from datetime import datetime
from dotenv import load_dotenv

from src.dataset import CellDataset
from src.postprocessing import postprocess_predictions
from src.iou_metric import iou_map
logging.basicConfig(filename='train.log', level=logging.INFO, format='%(asctime)s %(message)s')
np.random.seed(0)

load_dotenv()

True

In [14]:
current_dir = Path(".") # In my case, it is sartorius_instance_segmentation
current_dir.absolute()

PosixPath('/home/shamil/PycharmProjects/sartorius_instance_segmentation')

In [15]:
import wandb
# wandb.init(project="sartorius_instance_segmentation", entity="implausible_denyability", name="first_baseline")

## Links to tutorials
- Torchvision maskrcnn inputting: [pytorch.org](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
- Kaggle baseline notebook: [kaggle.com](https://www.kaggle.com/julian3833/sartorius-starter-torch-mask-r-cnn-lb-0-273)


In [16]:
# Global config of dataset, not tunable parameters
config = EasyDict(
    dataset_path=Path(os.environ["dataset_path"]),
    device="cuda:0",
    val_size=0.2,
    batch_size=4,
    num_workers=30,
    max_epochs=40,
    mask_threshold=0.5,
    score_threshold=0.2,
    nms_threshold=None,
)

# config = EasyDict(
#     dataset_path=Path(os.environ["dataset_path"]),
#     device="cpu",
#     val_size=0.2,
#     batch_size=2,
#     num_workers=2,
#     max_epochs=40,
#     mask_threshold=0.5,
#     score_threshold=0.2,
#     nms_threshold=None,
# )

In [17]:
from src.augmentations import eval_transform
train_transform = A.Compose([
    A.Normalize(mean=(0.485,), std=(0.229,)),
    # A.ShiftScaleRotate(shift_limit=0.8, border_mode=cv2.BORDER_CONSTANT),
    ToTensorV2(),
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['category_ids']))

In [18]:
# Do not change collate function - it was takes from torchvision tutorials
train_dataloader = DataLoader(
    dataset=CellDataset(cfg=config, mode='train', transform=train_transform),
    num_workers=config.num_workers,
    batch_size=config.batch_size,
    shuffle=True,
    collate_fn=lambda x: tuple(zip(*x))
)

val_dataloader = DataLoader(
    dataset=CellDataset(cfg=config, mode='val', transform=train_transform),
    num_workers=config.num_workers,
    batch_size=config.batch_size,
    shuffle=True,
    collate_fn=lambda x: tuple(zip(*x))
)

In [19]:
device = config.device

In [20]:
def images_to_device(images):
    images = list(image.to(device) for image in images)
    return images

def targets_to_device(targets):
    targets = [{key: value.to(device) for key, value in target.items()} for target in targets]
    return targets

In [21]:
def train_batch(model, images, targets, optimizer):
    optimizer.zero_grad()
    images, targets = images_to_device(images), targets_to_device(targets)
    output = model(images, targets)
    loss = sum(single_loss for single_loss in output.values())
    loss.backward()
    optimizer.step()
    return loss.item(), output['loss_mask'].item()

In [22]:
def eval_batch(model, images, targets):
    images = images_to_device(images)
    with torch.no_grad():
        outputs = model(images)
    outputs = postprocess_predictions(
        outputs,
        mask_threshold=config.mask_threshold,
        score_threshold=config.score_threshold,
        nms_threshold=config.nms_threshold
    )
    iou_scores = []
    for output, target in zip(outputs, targets):
        pred_masks = output['masks']
        true_masks = target['masks'].numpy()
        score = iou_map(true_masks=true_masks, pred_masks=pred_masks)
        iou_scores.append(score)
    return np.mean(iou_scores)

In [23]:
def train(model, optimizer, scheduler):
    for epoch in range(config.max_epochs):
        scheduler.step()

        model.train()
        for batch_idx, (images, targets) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            loss, mask_loss = train_batch(model, images, targets, optimizer)
            wandb.log({"loss/train":loss, "mask_loss/train":mask_loss, "lr": scheduler.get_last_lr()})
        # model.eval()
        # iou_scores = []
        # for batch_idx, (images, targets) in tqdm(enumerate(val_dataloader), total=len(val_dataloader)):
        #     iou_score = eval_batch(model, images, targets)
        #     iou_scores.append(iou_score)

    weights_dir = current_dir / "weights"
    weights_dir.mkdir(exist_ok=True)
    torch.save(model.state_dict(), weights_dir / f"maskrcnn-{datetime.now().__str__()}.ckpt")
    print(f"saved the weights in weights/maskrcnn-{datetime.now().__str__()} folder!")

In [24]:
model = models.detection.maskrcnn_resnet50_fpn(num_classes=2, progress=False)
model.to(device)

optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=config.max_epochs)
train(model=model, optimizer=optimizer, scheduler=scheduler)

  1%|          | 3/242 [01:05<1:26:59, 21.84s/it]


KeyboardInterrupt: 