# Training
=>reference to TORCHVISION OBJECT DETECTION FINETUNING TUTORIAL.

In [1]:
import os
import numpy as np
import torch
import torchvision
from PIL import Image
import transforms as T
from engine import train_one_epoch, evaluate
# import natsort
import utils

In [2]:
torch.cuda.is_available()

True

In [3]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3"

In [4]:
#import data using extension
EXTENSIONS_LABEL = ['.npy']
EXTENSIONS_IMAGE = ['.png']

def is_label(filename):
    return any(filename.endswith(ext) for ext in EXTENSIONS_LABEL)

def is_image(filename):
    return any(filename.endswith(ext) for ext in EXTENSIONS_IMAGE)

In [5]:
class TrainDataset(object):
    def __init__(self,imgpath,labelpath,transforms):
        
        self.transforms = transforms
        # placeholder for filenames
        self.imgs = []
        self.masks = []

        # get paths for each
        image_path = imgpath
        label_path = labelpath

        # get files
        image_files = [os.path.join(dp, f) for dp, dn, fn in os.walk(
            os.path.expanduser(image_path)) for f in fn if is_image(f)]

        label_files = [os.path.join(dp, f) for dp, dn, fn in os.walk(
            os.path.expanduser(label_path)) for f in fn if is_label(f)]


        # sort 1,10,11,...
        # natsort 1,2,3,...
        label_files.sort()
        image_files.sort()

        self.imgs.extend(image_files)
        self.masks.extend(label_files)

    def __getitem__(self, idx):
        # read image & label
        img_path = self.imgs[idx]
        mask_path = self.masks[idx]
        img = Image.open(img_path).convert("RGB")
        mask = np.load(mask_path, allow_pickle=True)

        # 1st dim is background.
        obj_ids = np.unique(mask)        
        obj_ids = obj_ids[1:]
        
        # split the color-encoded mask into a set
        # of binary masks
        masks = mask == obj_ids[:, None, None]

        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])

            boxes.append([xmin, ymin, xmax, ymax])
            
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        # there is only one class for easy to follow
        labels = torch.ones((num_objs,), dtype=torch.int64)
        masks = torch.as_tensor(masks, dtype=torch.uint8)
        name = self.imgs[idx].split('/')[-1].split('.')[0]
        image_id = torch.tensor([idx])
        image_name = name
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)

        return img, target

    def __len__(self):
        return len(self.imgs)

In [6]:
def get_transform(train):
    transforms = []
    transforms.append(T.ToTensor())
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [7]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

In [8]:
def get_model_instance_segmentation(num_classes):
    # load a model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # get number of input features for the classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # replace the pre-trained head with a weighted mask r-cnn head
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model

In [9]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [10]:
model = get_model_instance_segmentation(2)
#cpu or gpu
model = model.to(device)

In [11]:
imgpath = "/mnt/nas3/ssm/algae_dataset/algae_total/train/resized_image"
labpath = "/mnt/nas3/ssm/algae_dataset/algae_total/train/npy"

dataset = TrainDataset(imgpath,labpath,get_transform(train=True))
dataset_test = TrainDataset(imgpath,labpath,get_transform(train=False))

indices = torch.randperm(len(dataset)).tolist()

dataset = torch.utils.data.Subset(dataset, indices[:-150])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-150:])

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=1, shuffle=True, num_workers=8,
    collate_fn=utils.collate_fn)
data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=1, shuffle=False, num_workers=8,
        collate_fn=utils.collate_fn)

In [12]:
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

In [13]:
num_epochs = 5

for epoch in range(num_epochs):
    # 1 에포크동안 학습하고, 10회 마다 출력합니다
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    # 학습률을 업데이트 합니다
    lr_scheduler.step()
    # 테스트 데이터셋에서 평가를 합니다
    # evaluate(model, data_loader, device=device)
    evaluate(model, data_loader_test, device=device)



Epoch: [0]  [  0/319]  eta: 0:03:43  lr: 0.000021  loss: 3.1239 (3.1239)  loss_classifier: 0.6607 (0.6607)  loss_box_reg: 0.0841 (0.0841)  loss_mask: 2.3467 (2.3467)  loss_objectness: 0.0288 (0.0288)  loss_rpn_box_reg: 0.0035 (0.0035)  time: 0.6994  data: 0.4588  max mem: 966
Epoch: [0]  [ 10/319]  eta: 0:01:33  lr: 0.000178  loss: 2.1744 (2.9177)  loss_classifier: 0.5131 (0.5248)  loss_box_reg: 0.2317 (0.2662)  loss_mask: 1.2537 (1.9761)  loss_objectness: 0.0527 (0.1372)  loss_rpn_box_reg: 0.0056 (0.0134)  time: 0.3038  data: 0.0725  max mem: 1303
Epoch: [0]  [ 20/319]  eta: 0:01:20  lr: 0.000335  loss: 1.8445 (2.1124)  loss_classifier: 0.3454 (0.3863)  loss_box_reg: 0.2046 (0.2804)  loss_mask: 1.0416 (1.3383)  loss_objectness: 0.0249 (0.0961)  loss_rpn_box_reg: 0.0054 (0.0113)  time: 0.2467  data: 0.0180  max mem: 1504
Epoch: [0]  [ 30/319]  eta: 0:01:12  lr: 0.000492  loss: 0.8008 (1.7099)  loss_classifier: 0.1852 (0.3246)  loss_box_reg: 0.1945 (0.2786)  loss_mask: 0.3396 (1.0257)  

Epoch: [0]  [300/319]  eta: 0:00:04  lr: 0.004733  loss: 0.2695 (0.6408)  loss_classifier: 0.0524 (0.1311)  loss_box_reg: 0.0696 (0.1740)  loss_mask: 0.1426 (0.3015)  loss_objectness: 0.0057 (0.0283)  loss_rpn_box_reg: 0.0007 (0.0059)  time: 0.2164  data: 0.0035  max mem: 1507
Epoch: [0]  [310/319]  eta: 0:00:01  lr: 0.004890  loss: 0.3660 (0.6366)  loss_classifier: 0.0566 (0.1300)  loss_box_reg: 0.0785 (0.1717)  loss_mask: 0.1498 (0.2975)  loss_objectness: 0.0307 (0.0315)  loss_rpn_box_reg: 0.0021 (0.0059)  time: 0.2152  data: 0.0036  max mem: 1507
Epoch: [0]  [318/319]  eta: 0:00:00  lr: 0.005000  loss: 0.3654 (0.6266)  loss_classifier: 0.0463 (0.1276)  loss_box_reg: 0.0517 (0.1686)  loss_mask: 0.1382 (0.2936)  loss_objectness: 0.0290 (0.0310)  loss_rpn_box_reg: 0.0020 (0.0058)  time: 0.1971  data: 0.0034  max mem: 1507
Epoch: [0] Total time: 0:01:09 (0.2176 s / it)
creating index...
index created!
Test:  [  0/150]  eta: 0:01:21  model_time: 0.0814 (0.0814)  evaluator_time: 0.0188 (0

Epoch: [1]  [170/319]  eta: 0:00:32  lr: 0.005000  loss: 0.2522 (0.4030)  loss_classifier: 0.0491 (0.0810)  loss_box_reg: 0.0541 (0.1038)  loss_mask: 0.1541 (0.1987)  loss_objectness: 0.0077 (0.0142)  loss_rpn_box_reg: 0.0020 (0.0053)  time: 0.2107  data: 0.0035  max mem: 1507
Epoch: [1]  [180/319]  eta: 0:00:29  lr: 0.005000  loss: 0.2844 (0.4017)  loss_classifier: 0.0587 (0.0807)  loss_box_reg: 0.0556 (0.1037)  loss_mask: 0.1576 (0.1977)  loss_objectness: 0.0066 (0.0142)  loss_rpn_box_reg: 0.0020 (0.0053)  time: 0.2054  data: 0.0034  max mem: 1507
Epoch: [1]  [190/319]  eta: 0:00:27  lr: 0.005000  loss: 0.2844 (0.3967)  loss_classifier: 0.0587 (0.0797)  loss_box_reg: 0.0669 (0.1028)  loss_mask: 0.1479 (0.1952)  loss_objectness: 0.0042 (0.0138)  loss_rpn_box_reg: 0.0020 (0.0052)  time: 0.2049  data: 0.0034  max mem: 1507
Epoch: [1]  [200/319]  eta: 0:00:25  lr: 0.005000  loss: 0.2495 (0.3924)  loss_classifier: 0.0432 (0.0784)  loss_box_reg: 0.0487 (0.1014)  loss_mask: 0.1479 (0.1939) 

Epoch: [2]  [ 40/319]  eta: 0:01:03  lr: 0.005000  loss: 0.3204 (0.3619)  loss_classifier: 0.0450 (0.0739)  loss_box_reg: 0.0622 (0.1024)  loss_mask: 0.1515 (0.1708)  loss_objectness: 0.0023 (0.0106)  loss_rpn_box_reg: 0.0019 (0.0042)  time: 0.2076  data: 0.0035  max mem: 1507
Epoch: [2]  [ 50/319]  eta: 0:01:00  lr: 0.005000  loss: 0.1955 (0.3514)  loss_classifier: 0.0266 (0.0713)  loss_box_reg: 0.0398 (0.0994)  loss_mask: 0.1033 (0.1662)  loss_objectness: 0.0023 (0.0103)  loss_rpn_box_reg: 0.0021 (0.0042)  time: 0.1996  data: 0.0037  max mem: 1507
Epoch: [2]  [ 60/319]  eta: 0:00:57  lr: 0.005000  loss: 0.2041 (0.3454)  loss_classifier: 0.0313 (0.0710)  loss_box_reg: 0.0485 (0.0980)  loss_mask: 0.0980 (0.1628)  loss_objectness: 0.0026 (0.0095)  loss_rpn_box_reg: 0.0016 (0.0042)  time: 0.2118  data: 0.0037  max mem: 1507
Epoch: [2]  [ 70/319]  eta: 0:00:56  lr: 0.005000  loss: 0.2306 (0.3530)  loss_classifier: 0.0390 (0.0721)  loss_box_reg: 0.0527 (0.1035)  loss_mask: 0.1366 (0.1645) 

Test:  [100/150]  eta: 0:00:05  model_time: 0.0598 (0.0720)  evaluator_time: 0.0132 (0.0206)  time: 0.0979  data: 0.0042  max mem: 1507
Test:  [149/150]  eta: 0:00:00  model_time: 0.0601 (0.0744)  evaluator_time: 0.0125 (0.0229)  time: 0.0979  data: 0.0045  max mem: 1507
Test: Total time: 0:00:16 (0.1078 s / it)
Averaged stats: model_time: 0.0601 (0.0744)  evaluator_time: 0.0125 (0.0229)
Accumulating evaluation results...
DONE (t=0.02s).
Accumulating evaluation results...
DONE (t=0.02s).
IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.545
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.777
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.641
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.000
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.156
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.602
 Average R

Epoch: [3]  [210/319]  eta: 0:00:24  lr: 0.000500  loss: 0.2194 (0.2831)  loss_classifier: 0.0407 (0.0562)  loss_box_reg: 0.0369 (0.0716)  loss_mask: 0.1338 (0.1471)  loss_objectness: 0.0032 (0.0052)  loss_rpn_box_reg: 0.0024 (0.0030)  time: 0.2213  data: 0.0033  max mem: 1507
Epoch: [3]  [220/319]  eta: 0:00:21  lr: 0.000500  loss: 0.1982 (0.2799)  loss_classifier: 0.0287 (0.0552)  loss_box_reg: 0.0369 (0.0702)  loss_mask: 0.1159 (0.1464)  loss_objectness: 0.0025 (0.0051)  loss_rpn_box_reg: 0.0019 (0.0030)  time: 0.2123  data: 0.0034  max mem: 1507
Epoch: [3]  [230/319]  eta: 0:00:19  lr: 0.000500  loss: 0.1603 (0.2749)  loss_classifier: 0.0267 (0.0537)  loss_box_reg: 0.0315 (0.0683)  loss_mask: 0.0978 (0.1449)  loss_objectness: 0.0009 (0.0050)  loss_rpn_box_reg: 0.0010 (0.0029)  time: 0.1988  data: 0.0036  max mem: 1507
Epoch: [3]  [240/319]  eta: 0:00:17  lr: 0.000500  loss: 0.1534 (0.2720)  loss_classifier: 0.0142 (0.0531)  loss_box_reg: 0.0233 (0.0673)  loss_mask: 0.0977 (0.1438) 

Epoch: [4]  [ 80/319]  eta: 0:00:52  lr: 0.000500  loss: 0.1433 (0.2195)  loss_classifier: 0.0187 (0.0401)  loss_box_reg: 0.0311 (0.0481)  loss_mask: 0.1016 (0.1249)  loss_objectness: 0.0014 (0.0041)  loss_rpn_box_reg: 0.0007 (0.0022)  time: 0.2203  data: 0.0036  max mem: 1507
Epoch: [4]  [ 90/319]  eta: 0:00:50  lr: 0.000500  loss: 0.1343 (0.2172)  loss_classifier: 0.0120 (0.0396)  loss_box_reg: 0.0302 (0.0480)  loss_mask: 0.0935 (0.1238)  loss_objectness: 0.0009 (0.0037)  loss_rpn_box_reg: 0.0005 (0.0021)  time: 0.2255  data: 0.0038  max mem: 1507
Epoch: [4]  [100/319]  eta: 0:00:47  lr: 0.000500  loss: 0.1597 (0.2209)  loss_classifier: 0.0195 (0.0392)  loss_box_reg: 0.0312 (0.0492)  loss_mask: 0.1043 (0.1269)  loss_objectness: 0.0005 (0.0034)  loss_rpn_box_reg: 0.0007 (0.0021)  time: 0.2177  data: 0.0040  max mem: 1507
Epoch: [4]  [110/319]  eta: 0:00:45  lr: 0.000500  loss: 0.2385 (0.2263)  loss_classifier: 0.0359 (0.0399)  loss_box_reg: 0.0437 (0.0522)  loss_mask: 0.1407 (0.1287) 

In [14]:
#to save directory
torch.save(model.state_dict(), 'checkpoints.pth')

# Inference

In [15]:
model.load_state_dict(torch.load('checkpoints.pth'))
model.eval()

MaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256)
          (relu): ReLU(inplace=True)
          (downsample): 

In [19]:
img, label = dataset_test[100]

with torch.no_grad():
    prediction = model([img.to(device)])

In [21]:
pred=torch.sum(prediction[0]['masks'],dim=0).squeeze(0)
pred[pred>=0.5]=1
pred[pred<0.5]=0
plt.imshow(pred.mul(255).cpu())
plt.show()

In [22]:
import matplotlib.pyplot as plt

plt.imshow(img.permute(1, 2, 0))
plt.show()