# Mask R-CNN - Train on Pantograph Dataset


This notebook shows how to train Mask R-CNN on your own dataset. To keep things simple we use a synthetic dataset of shapes (squares, triangles, and circles) which enables fast training. You'd still need a GPU, though, because the network backbone is a Resnet101, which would be too slow to train on a CPU. On a GPU, you can start to get okay-ish results in a few minutes, and good results in less than an hour.

In [1]:
import os
import sys
import random
import math
import re
import time
import numpy as np
import cv2
import matplotlib
import matplotlib.pyplot as plt

# Root directory of the project
ROOT_DIR = os.path.abspath("../")

# Add root to path 
if ROOT_DIR not in sys.path:
    sys.path.append(ROOT_DIR)

# Import Mask RCNN
# from mrcnn.config import Config
import mrcnn.model as modellib
from mrcnn import visualize
from mrcnn.model import log

# Import pantogrograph class
from dev import pantograph

# Directory to save logs and trained model
MODEL_DIR = os.path.join(ROOT_DIR, "models")

# Set path to root of images. 
DATA_DIR = os.path.join(ROOT_DIR, "datasets/pantograph")


print("Using Root dir:",ROOT_DIR)
print("Using Model dir:",MODEL_DIR)
print("Using Data dir:",DATA_DIR)

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
Using TensorFlow backend.


I've been imported
Using Root dir: /Users/jessedecker/projects/rail_segmentation
Using Model dir: /Users/jessedecker/projects/rail_segmentation/models
Using Data dir: /Users/jessedecker/projects/rail_segmentation/datasets/pantograph


In [None]:
# import tensorflow as tf
print(tf.__version__)

## Load Configuration

In [2]:
config = pantograph.PantographConfig()
config.display()


Configurations Superlee:
BACKBONE                       resnet101
BACKBONE_SHAPES                [[256 256]
 [128 128]
 [ 64  64]
 [ 32  32]
 [ 16  16]]
BACKBONE_STRIDES               [4, 8, 16, 32, 64]
BATCH_SIZE                     2
BBOX_STD_DEV                   [0.1 0.1 0.2 0.2]
COMPUTE_BACKBONE_SHAPE         None
DETECTION_MAX_INSTANCES        50
DETECTION_MIN_CONFIDENCE       0.9
DETECTION_NMS_THRESHOLD        0.3
GPU_COUNT                      1
IMAGES_PER_GPU                 2
IMAGE_MAX_DIM                  1024
IMAGE_MIN_DIM                  512
IMAGE_MIN_SCALE                0.5
IMAGE_PADDING                  True
IMAGE_RESIZE_MODE              square
IMAGE_SHAPE                    [1024 1024    3]
KEYPOINT_MASK_POOL_SIZE        7
KEYPOINT_MASK_SHAPE            [56, 56]
KEYPOINT_THRESHOLD             0.005
LEARNING_MOMENTUM              0.9
LEARNING_RATE                  0.002
MASK_POOL_SIZE                 14
MASK_SHAPE                     [28, 28]
MAX_GT_INSTANCES        

## Dataset

Create a synthetic dataset

Extend the Dataset class and add a method to load the shapes dataset, `load_shapes()`, and override the following methods:

* load_image()
* load_mask()
* image_reference()

In [3]:
# Load dataset
assert config.NAME == "pantograph"

# Training dataset
train_dataset_keypoints = pantograph.PantographDataset()
train_dataset_keypoints.load_pantograph(DATA_DIR, "train")
train_dataset_keypoints.prepare()

#Validation dataset
val_dataset_keypoints = pantograph.PantographDataset()
val_dataset_keypoints.load_pantograph(DATA_DIR, "val")
val_dataset_keypoints.prepare()

print("Train Keypoints Image Count: {}".format(len(train_dataset_keypoints.image_ids)))
print("Train Keypoints Class Count: {}".format(train_dataset_keypoints.num_classes))
for i, info in enumerate(train_dataset_keypoints.class_info):
    print("{:3}. {:50}".format(i, info['name']))

print("Val Keypoints Image Count: {}".format(len(val_dataset_keypoints.image_ids)))
print("Val Keypoints Class Count: {}".format(val_dataset_keypoints.num_classes))
for i, info in enumerate(val_dataset_keypoints.class_info):
    print("{:3}. {:50}".format(i, info['name']))

loading annotations into memory...
Done (t=0.34s)
creating index...
index created!
Skeleton: (5, 2)
Keypoint names: (6,)
loading annotations into memory...
Done (t=0.04s)
creating index...
index created!
Skeleton: (5, 2)
Keypoint names: (6,)
Train Keypoints Image Count: 100
Train Keypoints Class Count: 4
  0. BG                                                
  1. front_bar                                         
  2. middle_bar                                        
  3. rear_bar                                          
Val Keypoints Image Count: 10
Val Keypoints Class Count: 4
  0. BG                                                
  1. front_bar                                         
  2. middle_bar                                        
  3. rear_bar                                          


## Create Model

In [4]:
# Create model in training mode
model = modellib.MaskRCNN(mode="training", 
                          config=config,
                          model_dir=MODEL_DIR)


Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Instructions for updating:
box_ind is deprecated, use box_indices instead


In [6]:
# Local path to starting weights file
MODEL_PATH = os.path.join(MODEL_DIR, "mask_rcnn_coco.h5")

MODEL_PATH

'/Users/jessedecker/projects/rail_segmentation/models/mask_rcnn_coco.h5'

In [7]:
# Which weights to start with?
init_with = "coco"  # imagenet, coco, or last

if init_with == "imagenet":
    model.load_weights(model.get_imagenet_weights(), by_name=True)
elif init_with == "coco":
    # Load weights trained on MS COCO, but skip layers that
    # are different due to the different number of classes
    # See README for instructions to download the COCO weights
    model.load_weights(MODEL_PATH, by_name=True,
                       exclude=["mrcnn_class_logits", "mrcnn_bbox_fc", 
                                "mrcnn_bbox", "mrcnn_mask"])

    
# print("Loading weights from ", MODEL_PATH)

# model.keras_model.summary()

## Training

Train in two stages:
1. Only the heads. Here we're freezing all the backbone layers and training only the randomly initialized layers (i.e. the ones that we didn't use pre-trained weights from MS COCO). To train only the head layers, pass `layers='heads'` to the `train()` function.

2. Fine-tune all layers. For this simple example it's not necessary, but we're including it to show the process. Simply pass `layers="all` to train all layers.

In [8]:
# Fine tune all layers
# Passing layers="all" trains all layers. You can also 
# pass a regular expression to select which layers to
# train by name pattern.
model.train(train_dataset_keypoints, val_dataset_keypoints, 
            learning_rate=config.LEARNING_RATE,
            epochs=20, 
            layers="all")


Starting at epoch 0. LR=0.002

Checkpoint Path: /Users/jessedecker/projects/rail_segmentation/models/pantograph20200225T1412/mask_rcnn_pantograph_{epoch:04d}.h5
Selecting layers to train
conv1                  (Conv2D)
bn_conv1               (BatchNorm)
res2a_branch2a         (Conv2D)
bn2a_branch2a          (BatchNorm)
res2a_branch2b         (Conv2D)
bn2a_branch2b          (BatchNorm)
res2a_branch2c         (Conv2D)
res2a_branch1          (Conv2D)
bn2a_branch2c          (BatchNorm)
bn2a_branch1           (BatchNorm)
res2b_branch2a         (Conv2D)
bn2b_branch2a          (BatchNorm)
res2b_branch2b         (Conv2D)
bn2b_branch2b          (BatchNorm)
res2b_branch2c         (Conv2D)
bn2b_branch2c          (BatchNorm)
res2c_branch2a         (Conv2D)
bn2c_branch2a          (BatchNorm)
res2c_branch2b         (Conv2D)
bn2c_branch2b          (BatchNorm)
res2c_branch2c         (Conv2D)
bn2c_branch2c          (BatchNorm)
res3a_branch2a         (Conv2D)
bn3a_branch2a          (BatchNorm)
res3a_br

  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "









Epoch 1/20


MaybeEncodingError: Error sending result: '([array([[[[-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         ...,
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9]],

        [[-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         ...,
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9]],

        [[-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         ...,
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9]],

        ...,

        [[-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         ...,
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9]],

        [[-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         ...,
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9]],

        [[-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         ...,
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9]]],


       [[[-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         ...,
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9]],

        [[-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         ...,
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9]],

        [[-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         ...,
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9]],

        ...,

        [[-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         ...,
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9]],

        [[-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         ...,
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9]],

        [[-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         ...,
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9],
         [-123.7, -116.8, -103.9]]]], dtype=float32), array([[  79, 1080, 1920,    3,  224,    0,  800, 1024,    1,    1,    1,
           1],
       [  93, 1080, 1920,    3,  224,    0,  800, 1024,    1,    1,    1,
           1]]), array([[[0],
        [0],
        [0],
        ...,
        [0],
        [0],
        [0]],

       [[0],
        [0],
        [0],
        ...,
        [0],
        [0],
        [0]]], dtype=int32), array([[[ 1.62966016,  2.32019413, -4.0207735 ,  2.73799668],
        [ 1.02199027,  4.10177176, -2.81937433,  3.48587721],
        [ 1.02199027,  2.33400481, -2.81937433,  3.48587721],
        ...,
        [ 0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ]],

       [[ 2.59640771,  1.97492714, -4.4726938 ,  2.46721682],
        [ 1.96111646,  3.78412613, -3.11946438,  3.26799313],
        [ 1.96111646,  2.01635918, -3.11946438,  3.26799313],
        ...,
        [ 0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        ,  0.        ]]]), array([[1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
      dtype=int32), array([[[565, 219, 646, 845],
        [607, 169, 710, 896],
        [623, 253, 695, 813],
        ...,
        [  0,   0,   0,   0],
        [  0,   0,   0,   0],
        [  0,   0,   0,   0]],

       [[522, 223, 596, 816],
        [563, 173, 660, 869],
        [582, 254, 647, 790],
        ...,
        [  0,   0,   0,   0],
        [  0,   0,   0,   0],
        [  0,   0,   0,   0]]], dtype=int32), array([[[[221., 644.,   2.],
         [302., 594.,   2.],
         [339., 589.,   2.],
         [726., 573.,   2.],
         [764., 576.,   2.],
         [841., 618.,   2.]],

        [[170., 712.,   2.],
         [263., 629.,   2.],
         [325., 623.,   2.],
         [742., 611.,   2.],
         [816., 620.,   2.],
         [894., 695.,   2.]],

        [[257., 692.,   2.],
         [321., 648.,   2.],
         [358., 643.,   1.],
         [710., 628.,   1.],
         [744., 632.,   2.],
         [811., 668.,   2.]],

        ...,

        [[  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.]],

        [[  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.]],

        [[  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.]]],


       [[[226., 594.,   2.],
         [299., 545.,   2.],
         [336., 542.,   2.],
         [705., 529.,   2.],
         [743., 532.,   2.],
         [815., 572.,   2.]],

        [[175., 661.,   2.],
         [255., 587.,   2.],
         [324., 578.,   2.],
         [720., 566.,   2.],
         [785., 572.,   2.],
         [868., 644.,   2.]],

        [[256., 645.,   2.],
         [319., 602.,   2.],
         [354., 597.,   1.],
         [689., 585.,   1.],
         [730., 591.,   2.],
         [788., 627.,   2.]],

        ...,

        [[  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.]],

        [[  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.]],

        [[  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.],
         [  0.,   0.,   0.]]]]), array([[[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]],


       [[[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]],

        [[0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         ...,
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.],
         [0., 0., 0., ..., 0., 0., 0.]]]])], [])'. Reason: 'error("'i' format requires -2147483648 <= number <= 2147483647",)'

In [None]:
# Save weights
# Typically not needed because callbacks save after every epoch
# Uncomment to save manually
model_path = os.path.join(MODEL_DIR, "mask_rcnn_shapes_test_1.h5")
model.keras_model.save_weights(model_path)