In [1]:
import os
from functools import reduce
import logging
from pathlib import Path


from torch.utils.data import Dataset, DataLoader
from easydict import EasyDict
import torch
from torchvision import models
import numpy as np
from time import time
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
from datetime import datetime
from dotenv import load_dotenv

from src.dataset import CellDataset
from src.postprocessing import postprocess_predictions
from src.iou_metric import fast_iou, iou_map
logging.basicConfig(filename='train.log', level=logging.INFO, format='%(asctime)s %(message)s')
np.random.seed(0)

load_dotenv()

True

In [2]:
current_dir = Path(".") # In my case, it is sartorius_instance_segmentation
current_dir.absolute()

PosixPath('/workspaces/sartorius_instance_segmentation')

In [3]:
import wandb
experiment_name = "first_baseline_with_metric"
wandb.init(project="sartorius_instance_segmentation", entity="implausible_denyability", name=experiment_name)

[34m[1mwandb[0m: Currently logged in as: [33mimplausible_denyability[0m (use `wandb login --relogin` to force relogin)


## Links to tutorials
- Torchvision maskrcnn inputting: [pytorch.org](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html)
- Kaggle baseline notebook: [kaggle.com](https://www.kaggle.com/julian3833/sartorius-starter-torch-mask-r-cnn-lb-0-273)


In [4]:
# Global config of dataset, not tunable parameters
config = EasyDict(
    dataset_path=Path(os.environ["dataset_path"]),
    device="cuda:1",
    val_size=0.2,
    batch_size=6,
    num_workers=30,
    max_epochs=40,
    mask_threshold=0.5,
    score_threshold=0.2,
    nms_threshold=None,
)



In [5]:
train_transform = A.Compose([
    A.Normalize(mean=(0.485,), std=(0.229,)),
    # A.ShiftScaleRotate(shift_limit=0.8, border_mode=cv2.BORDER_CONSTANT),
    ToTensorV2(),
], bbox_params=A.BboxParams(format='pascal_voc', label_fields=['category_ids']))

In [6]:
# Do not change collate function - it was takes from torchvision tutorials
train_dataloader = DataLoader(
    dataset=CellDataset(cfg=config, mode='train', transform=train_transform),
    num_workers=config.num_workers,
    batch_size=config.batch_size,
    shuffle=True,
    collate_fn=lambda x: tuple(zip(*x))
)

val_dataloader = DataLoader(
    dataset=CellDataset(cfg=config, mode='val', transform=train_transform),
    num_workers=config.num_workers,
    batch_size=config.batch_size,
    shuffle=True,
    collate_fn=lambda x: tuple(zip(*x))
)

In [7]:
device = config.device

In [8]:
def images_to_device(images):
    images = list(image.to(device) for image in images)
    return images

def targets_to_device(targets):
    targets = [{key: value.to(device) for key, value in target.items()} for target in targets]
    return targets

In [9]:
def train_batch(model, images, targets, optimizer):
    optimizer.zero_grad()
    images, targets = images_to_device(images), targets_to_device(targets)
    output = model(images, targets)
    loss = sum(single_loss for single_loss in output.values())
    loss.backward()
    optimizer.step()
    return loss.item(), output['loss_mask'].item()

In [10]:
def eval_batch(model, images, targets):
    images = images_to_device(images)
    with torch.no_grad():
        outputs = model(images)
    outputs = postprocess_predictions(
        outputs,
        mask_threshold=config.mask_threshold,
        score_threshold=config.score_threshold,
        nms_threshold=config.nms_threshold
    )
    iou_scores = []
    for output, target in zip(outputs, targets):
        pred_masks = output['masks']
        true_masks = target['masks'].numpy()
        score = iou_map(pred_masks, true_masks)
        iou_scores.append(score)
    return np.mean(iou_scores)

In [11]:
def train(model, optimizer, scheduler):
    for epoch in range(config.max_epochs):

        model.train()
        for batch_idx, (images, targets) in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            loss, mask_loss = train_batch(model, images, targets, optimizer)
            wandb.log({"loss/train":loss, "mask_loss/train":mask_loss, "lr": scheduler.get_last_lr()[0]})
            scheduler.step()
    weights_dir = current_dir / "weights"
    weights_dir.mkdir(exist_ok=True)
    torch.save(model.state_dict(), weights_dir / f"maskrcnn-{experiment_name}-{datetime.now().__str__()}.ckpt")
    print(f"saved the weights in weights/maskrcnn-{experiment_name}-{datetime.now().__str__()} folder!")

In [None]:
model = models.detection.maskrcnn_resnet50_fpn(num_classes=2, progress=False)
model.to(device)

optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, epochs=config.max_epochs, steps_per_epoch=len(train_dataloader), max_lr=1e-3)
train(model=model, optimizer=optimizer, scheduler=scheduler)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
 79% 64/81 [01:40<00:23,  1.38s/it]