# TorchVision Instance Segmentation Finetuning Tutorial

For this tutorial, we will be finetuning a pre-trained [Mask R-CNN](https://arxiv.org/abs/1703.06870) model in the [*Penn-Fudan Database for Pedestrian Detection and Segmentation*](https://www.cis.upenn.edu/~jshi/ped_html/). It contains 170 images with 345 instances of pedestrians, and we will use it to illustrate how to use the new features in torchvision in order to train an instance segmentation model on a custom dataset.

First, we need to install `pycocotools`. This library will be used for computing the evaluation metrics following the COCO metric for intersection over union.

In [1]:
# !pip install cython
# # Install pycocotools, the version by default in Colab
# # has a bug fixed in https://github.com/cocodataset/cocoapi/pull/354
# !pip install git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI

## Defining the Dataset

The [torchvision reference scripts for training object detection, instance segmentation and person keypoint detection](https://github.com/pytorch/vision/tree/v0.3.0/references/detection) allows for easily supporting adding new custom datasets.
The dataset should inherit from the standard `torch.utils.data.Dataset` class, and implement `__len__` and `__getitem__`.

The only specificity that we require is that the dataset `__getitem__` should return:

* image: a PIL Image of size (H, W)
* target: a dict containing the following fields
    * `boxes` (`FloatTensor[N, 4]`): the coordinates of the `N` bounding boxes in `[x0, y0, x1, y1]` format, ranging from `0` to `W` and `0` to `H`
    * `labels` (`Int64Tensor[N]`): the label for each bounding box
    * `image_id` (`Int64Tensor[1]`): an image identifier. It should be unique between all the images in the dataset, and is used during evaluation
    * `area` (`Tensor[N]`): The area of the bounding box. This is used during evaluation with the COCO metric, to separate the metric scores between small, medium and large boxes.
    * `iscrowd` (`UInt8Tensor[N]`): instances with `iscrowd=True` will be ignored during evaluation.
    * (optionally) `masks` (`UInt8Tensor[N, H, W]`): The segmentation masks for each one of the objects
    * (optionally) `keypoints` (`FloatTensor[N, K, 3]`): For each one of the `N` objects, it contains the `K` keypoints in `[x, y, visibility]` format, defining the object. `visibility=0` means that the keypoint is not visible. Note that for data augmentation, the notion of flipping a keypoint is dependent on the data representation, and you should probably adapt `references/detection/transforms.py` for your new keypoint representation

If your model returns the above methods, they will make it work for both training and evaluation, and will use the evaluation scripts from pycocotools.


One note on the labels. The model considers class 0 as background. If your dataset does not contain the background class, you should not have 0 in your labels. For example, assuming you have just two classes, cat and dog, you can define 1 (not 0) to represent cats and 2 to represent dogs. So, for instance, if one of the images has both classes, your labels tensor should look like [1,2].

Additionally, if you want to use aspect ratio grouping during training (so that each batch only contains images with similar aspect ratio), then it is recommended to also implement a `get_height_and_width` method, which returns the height and the width of the image. If this method is not provided, we query all elements of the dataset via `__getitem__` , which loads the image in memory and is slower than if a custom method is provided.


### Writing a custom dataset for Penn-Fudan

Let's write a dataset for the Penn-Fudan dataset.

First, let's download and extract the data, present in a zip file at https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip

In [2]:
import os
import glob
import numpy as np
import torch
import torch.utils.data
from PIL import Image

class SMDDataset(torch.utils.data.Dataset):
    def __init__(self, yolo_dir, transform=None, target_transform=None):
        self.yolo_dir = yolo_dir
        self.images_dir = glob.glob(yolo_dir + "*.jpg")
        self.labels_dir = glob.glob(yolo_dir + "*.txt")
        
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images_dir)
    
    def read_yolo(self, label_path, height, width):
        with open(label_path, 'r') as f:
            data = f.readlines()
        data = [dat.split() for dat in data]
        data = np.array(data, dtype=float)
        labels = data[:, 1].astype('int64') # 0 - moving; 1 - category; 2 - distance
#         labels = data[:, :3].astype('int64') # 0 - moving; 1 - category; 2 - distance
        bboxes = data[:, -4:]
        bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] / 2
        bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] / 2
        bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2]
        bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3]
        
        return torch.from_numpy(labels), torch.from_numpy(bboxes)

    def __getitem__(self, idx):
        img_path = self.images_dir[idx]
        image = Image.open(img_path).convert("RGB")
        height, width = image.size
        labels, bboxes = self.read_yolo(self.labels_dir[idx], height, width)

        target = dict()
        target["labels"] = labels
        target["boxes"] = bboxes

        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            target = self.target_transform(target)

        return image, target


That's all for the dataset. Let's see how the outputs are structured for this dataset

In [3]:
dataset = SMDDataset('SMD/NIR/data_yolo/')
dataset[0]

(<PIL.Image.Image image mode=RGB size=1920x1080 at 0x1A656143E88>,
 {'labels': tensor([3, 3, 3, 3, 3, 3, 3, 3, 3]),
  'boxes': tensor([[5.2083e-04, 2.5926e-01, 2.4479e-02, 3.5556e-01],
          [1.4063e-02, 2.7315e-01, 5.9375e-02, 3.2870e-01],
          [9.7656e-02, 2.8148e-01, 2.1901e-01, 3.1944e-01],
          [7.8906e-02, 2.9259e-01, 1.0911e-01, 3.1759e-01],
          [2.0651e-01, 2.6944e-01, 2.4870e-01, 3.3519e-01],
          [3.4818e-01, 2.7083e-01, 4.1432e-01, 3.3657e-01],
          [3.8229e-01, 2.7593e-01, 4.0833e-01, 3.3148e-01],
          [3.9714e-01, 2.6991e-01, 5.7734e-01, 3.3843e-01],
          [2.1406e-01, 2.4815e-01, 3.8385e-01, 4.6481e-01]], dtype=torch.float64)})

So we can see that by default, the dataset returns a `PIL.Image` and a dictionary
containing several fields, including `boxes`, `labels` and `masks`.

## Defining your model

In this tutorial, we will be using [Mask R-CNN](https://arxiv.org/abs/1703.06870), which is based on top of [Faster R-CNN](https://arxiv.org/abs/1506.01497). Faster R-CNN is a model that predicts both bounding boxes and class scores for potential objects in the image.

![Faster R-CNN](https://raw.githubusercontent.com/pytorch/vision/temp-tutorial/tutorials/tv_image03.png)

Mask R-CNN adds an extra branch into Faster R-CNN, which also predicts segmentation masks for each instance.

![Mask R-CNN](https://raw.githubusercontent.com/pytorch/vision/temp-tutorial/tutorials/tv_image04.png)

There are two common situations where one might want to modify one of the available models in torchvision modelzoo.
The first is when we want to start from a pre-trained model, and just finetune the last layer. The other is when we want to replace the backbone of the model with a different one (for faster predictions, for example).

Let's go see how we would do one or another in the following sections.


### 1 - Finetuning from a pretrained model

Let's suppose that you want to start from a model pre-trained on COCO and want to finetune it for your particular classes. Here is a possible way of doing it:
```
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# load a model pre-trained pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# replace the classifier with a new one, that has
# num_classes which is user-defined
num_classes = 2  # 1 class (person) + background
# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) 
```

### 2 - Modifying the model to add a different backbone

Another common situation arises when the user wants to replace the backbone of a detection
model with a different one. For example, the current default backbone (ResNet-50) might be too big for some applications, and smaller models might be necessary.

Here is how we would go into leveraging the functions provided by torchvision to modify a backbone.

```
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator

# load a pre-trained model for classification and return
# only the features
backbone = torchvision.models.mobilenet_v2(pretrained=True).features
# FasterRCNN needs to know the number of
# output channels in a backbone. For mobilenet_v2, it's 1280
# so we need to add it here
backbone.out_channels = 1280

# let's make the RPN generate 5 x 3 anchors per spatial
# location, with 5 different sizes and 3 different aspect
# ratios. We have a Tuple[Tuple[int]] because each feature
# map could potentially have different sizes and
# aspect ratios 
anchor_generator = AnchorGenerator(sizes=((32, 64, 128, 256, 512),),
                                   aspect_ratios=((0.5, 1.0, 2.0),))

# let's define what are the feature maps that we will
# use to perform the region of interest cropping, as well as
# the size of the crop after rescaling.
# if your backbone returns a Tensor, featmap_names is expected to
# be [0]. More generally, the backbone should return an
# OrderedDict[Tensor], and in featmap_names you can choose which
# feature maps to use.
roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=[0],
                                                output_size=7,
                                                sampling_ratio=2)

# put the pieces together inside a FasterRCNN model
model = FasterRCNN(backbone,
                   num_classes=2,
                   rpn_anchor_generator=anchor_generator,
                   box_roi_pool=roi_pooler)
```

### An Instance segmentation model for PennFudan Dataset

In our case, we want to fine-tune from a pre-trained model, given that our dataset is very small. So we will be following approach number 1.

Here we want to also compute the instance segmentation masks, so we will be using Mask R-CNN:

That's it, this will make model be ready to be trained and evaluated on our custom dataset.

## Training and evaluation functions

In `references/detection/,` we have a number of helper functions to simplify training and evaluating detection models.
Here, we will use `references/detection/engine.py`, `references/detection/utils.py` and `references/detection/transforms.py`.

Let's copy those files (and their dependencies) in here so that they are available in the notebook



Let's write some helper functions for data augmentation / transformation, which leverages the functions in `refereces/detection` that we have just copied:


In [4]:
from engine import train_one_epoch, evaluate
import utils
from torchvision import transforms as T


def get_transform(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    transforms.append(T.ToTensor())
    if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

In [5]:
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

def get_frcnn_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
    return model

#### Testing forward() method 

Before iterating over the dataset, it’s good to see what the model expects during training and inference time on sample data.


In [6]:
import torchvision
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
dataset = SMDDataset('SMD/NIR/data_yolo/', get_transform(train=True))
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=0,
    collate_fn=utils.collate_fn
)
# For Training
images,targets = next(iter(data_loader))
images = list(image for image in images)
targets = [{k: v for k, v in t.items()} for t in targets]
output = model(images,targets)   # Returns losses and detections
# For inference
model.eval()
x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
predictions = model(x)           # Returns predictions

### Putting everything together

We now have the dataset class, the models and the data transforms. Let's instantiate them

In [7]:
# use our dataset and defined transformations

data = SMDDataset("SMD/NIR/data_yolo/", transform=get_transform(True))

train_size = int(0.8 * len(data))
test_size = len(data) - train_size

dataset, dataset_test = torch.utils.data.random_split(data, [train_size, test_size])

# split the dataset in train and test set
torch.manual_seed(1)
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

# define training and validation data loaders
data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=2, shuffle=True, num_workers=0,
    collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, shuffle=False, num_workers=0,
    collate_fn=utils.collate_fn)

Now let's instantiate the model and the optimizer

In [8]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# our dataset has two classes only - background and person
num_classes = 11

# get the model using our helper function
model = get_frcnn_model(num_classes)
# move model to the right device
model.to(device)

# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005,
                            momentum=0.9, weight_decay=0.0005)

# and a learning rate scheduler which decreases the learning rate by
# 10x every 3 epochs
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=3,
                                               gamma=0.1)

And now let's train the model for 10 epochs, evaluating at the end of every epoch.

In [10]:
# let's train it for 10 epochs
from torch.optim.lr_scheduler import StepLR
num_epochs = 1

for epoch in range(num_epochs):
    # train for one epoch, printing every 10 iterations
    train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10)
    # update the learning rate
    lr_scheduler.step()
    # evaluate on the test dataset
    evaluate(model, data_loader_test, device=device)

Epoch: [0]  [   0/4489]  eta: 1:46:26  lr: 0.000000  loss: 0.5433 (0.5433)  loss_classifier: 0.1577 (0.1577)  loss_box_reg: 0.0099 (0.0099)  loss_objectness: 0.0904 (0.0904)  loss_rpn_box_reg: 0.2853 (0.2853)  time: 1.4228  data: 0.1416  max mem: 2501
Epoch: [0]  [  10/4489]  eta: 1:39:04  lr: 0.000003  loss: 0.5805 (0.6083)  loss_classifier: 0.1456 (0.1652)  loss_box_reg: 0.0104 (0.0115)  loss_objectness: 0.1170 (0.1093)  loss_rpn_box_reg: 0.3052 (0.3223)  time: 1.3271  data: 0.1579  max mem: 2501
Epoch: [0]  [  20/4489]  eta: 1:36:46  lr: 0.000005  loss: 0.5805 (0.6033)  loss_classifier: 0.1379 (0.1628)  loss_box_reg: 0.0104 (0.0106)  loss_objectness: 0.1126 (0.1112)  loss_rpn_box_reg: 0.3054 (0.3186)  time: 1.2932  data: 0.1389  max mem: 2501
Epoch: [0]  [  30/4489]  eta: 1:35:52  lr: 0.000007  loss: 0.5764 (0.6009)  loss_classifier: 0.1342 (0.1606)  loss_box_reg: 0.0070 (0.0097)  loss_objectness: 0.0945 (0.1113)  loss_rpn_box_reg: 0.3100 (0.3192)  time: 1.2698  data: 0.1192  max me

Epoch: [0]  [ 330/4489]  eta: 2:39:39  lr: 0.000070  loss: 0.2280 (0.3937)  loss_classifier: 0.0494 (0.1003)  loss_box_reg: 0.0013 (0.0044)  loss_objectness: 0.0040 (0.0376)  loss_rpn_box_reg: 0.1673 (0.2514)  time: 3.9377  data: 0.0966  max mem: 2501
Epoch: [0]  [ 340/4489]  eta: 2:42:33  lr: 0.000072  loss: 0.2118 (0.3880)  loss_classifier: 0.0423 (0.0986)  loss_box_reg: 0.0013 (0.0043)  loss_objectness: 0.0033 (0.0366)  loss_rpn_box_reg: 0.1622 (0.2485)  time: 3.9297  data: 0.0959  max mem: 2501
Epoch: [0]  [ 350/4489]  eta: 2:47:19  lr: 0.000074  loss: 0.1983 (0.3832)  loss_classifier: 0.0395 (0.0970)  loss_box_reg: 0.0003 (0.0042)  loss_objectness: 0.0029 (0.0357)  loss_rpn_box_reg: 0.1527 (0.2462)  time: 4.4502  data: 0.0935  max mem: 2501
Epoch: [0]  [ 360/4489]  eta: 2:49:39  lr: 0.000076  loss: 0.1984 (0.3786)  loss_classifier: 0.0357 (0.0954)  loss_box_reg: 0.0001 (0.0042)  loss_objectness: 0.0029 (0.0350)  loss_rpn_box_reg: 0.1534 (0.2441)  time: 4.4152  data: 0.0940  max me

Epoch: [0]  [ 660/4489]  eta: 3:19:01  lr: 0.000139  loss: 0.1950 (0.2954)  loss_classifier: 0.0278 (0.0666)  loss_box_reg: 0.0001 (0.0027)  loss_objectness: 0.0009 (0.0207)  loss_rpn_box_reg: 0.1749 (0.2054)  time: 3.8845  data: 0.0960  max mem: 2501
Epoch: [0]  [ 670/4489]  eta: 3:19:13  lr: 0.000141  loss: 0.1856 (0.2935)  loss_classifier: 0.0219 (0.0659)  loss_box_reg: 0.0001 (0.0026)  loss_objectness: 0.0010 (0.0205)  loss_rpn_box_reg: 0.1455 (0.2044)  time: 3.8707  data: 0.0940  max mem: 2501
Epoch: [0]  [ 680/4489]  eta: 3:19:23  lr: 0.000143  loss: 0.1435 (0.2914)  loss_classifier: 0.0197 (0.0653)  loss_box_reg: 0.0001 (0.0026)  loss_objectness: 0.0010 (0.0202)  loss_rpn_box_reg: 0.1192 (0.2033)  time: 3.8694  data: 0.0960  max mem: 2501
Epoch: [0]  [ 690/4489]  eta: 3:19:32  lr: 0.000145  loss: 0.1490 (0.2897)  loss_classifier: 0.0222 (0.0647)  loss_box_reg: 0.0001 (0.0026)  loss_objectness: 0.0009 (0.0199)  loss_rpn_box_reg: 0.1301 (0.2025)  time: 3.8753  data: 0.0988  max me

Epoch: [0]  [ 990/4489]  eta: 3:17:16  lr: 0.000208  loss: 0.1672 (0.2543)  loss_classifier: 0.0262 (0.0531)  loss_box_reg: 0.0004 (0.0020)  loss_objectness: 0.0007 (0.0145)  loss_rpn_box_reg: 0.1391 (0.1846)  time: 3.8771  data: 0.0939  max mem: 2501
Epoch: [0]  [1000/4489]  eta: 3:17:00  lr: 0.000210  loss: 0.1463 (0.2532)  loss_classifier: 0.0230 (0.0528)  loss_box_reg: 0.0002 (0.0020)  loss_objectness: 0.0008 (0.0144)  loss_rpn_box_reg: 0.1158 (0.1840)  time: 3.8808  data: 0.0949  max mem: 2501
Epoch: [0]  [1010/4489]  eta: 3:16:43  lr: 0.000210  loss: 0.1541 (0.2525)  loss_classifier: 0.0246 (0.0526)  loss_box_reg: 0.0001 (0.0020)  loss_objectness: 0.0008 (0.0142)  loss_rpn_box_reg: 0.1100 (0.1836)  time: 3.8791  data: 0.0960  max mem: 2501
Epoch: [0]  [1020/4489]  eta: 3:16:25  lr: 0.000210  loss: 0.1564 (0.2514)  loss_classifier: 0.0288 (0.0524)  loss_box_reg: 0.0003 (0.0020)  loss_objectness: 0.0005 (0.0141)  loss_rpn_box_reg: 0.1143 (0.1830)  time: 3.8783  data: 0.0952  max me

Epoch: [0]  [1320/4489]  eta: 3:06:05  lr: 0.000210  loss: 0.1366 (0.2311)  loss_classifier: 0.0258 (0.0463)  loss_box_reg: 0.0001 (0.0016)  loss_objectness: 0.0004 (0.0112)  loss_rpn_box_reg: 0.1080 (0.1719)  time: 3.8796  data: 0.0970  max mem: 2501
Epoch: [0]  [1330/4489]  eta: 3:05:38  lr: 0.000210  loss: 0.1366 (0.2305)  loss_classifier: 0.0255 (0.0461)  loss_box_reg: 0.0001 (0.0016)  loss_objectness: 0.0004 (0.0112)  loss_rpn_box_reg: 0.1124 (0.1715)  time: 3.8791  data: 0.0978  max mem: 2501
Epoch: [0]  [1340/4489]  eta: 3:05:11  lr: 0.000210  loss: 0.1525 (0.2299)  loss_classifier: 0.0230 (0.0459)  loss_box_reg: 0.0001 (0.0016)  loss_objectness: 0.0006 (0.0111)  loss_rpn_box_reg: 0.1266 (0.1713)  time: 3.8760  data: 0.0982  max mem: 2501
Epoch: [0]  [1350/4489]  eta: 3:04:44  lr: 0.000210  loss: 0.1500 (0.2294)  loss_classifier: 0.0230 (0.0458)  loss_box_reg: 0.0001 (0.0016)  loss_objectness: 0.0006 (0.0110)  loss_rpn_box_reg: 0.1296 (0.1710)  time: 3.8754  data: 0.0967  max me

Epoch: [0]  [1650/4489]  eta: 2:50:22  lr: 0.000210  loss: 0.1423 (0.2155)  loss_classifier: 0.0235 (0.0420)  loss_box_reg: 0.0001 (0.0014)  loss_objectness: 0.0004 (0.0092)  loss_rpn_box_reg: 0.1136 (0.1628)  time: 3.8771  data: 0.0924  max mem: 2501
Epoch: [0]  [1660/4489]  eta: 2:49:51  lr: 0.000210  loss: 0.1528 (0.2152)  loss_classifier: 0.0182 (0.0419)  loss_box_reg: 0.0001 (0.0014)  loss_objectness: 0.0005 (0.0092)  loss_rpn_box_reg: 0.1219 (0.1628)  time: 3.8700  data: 0.0941  max mem: 2501
Epoch: [0]  [1670/4489]  eta: 2:49:19  lr: 0.000210  loss: 0.1511 (0.2150)  loss_classifier: 0.0154 (0.0417)  loss_box_reg: 0.0001 (0.0014)  loss_objectness: 0.0005 (0.0091)  loss_rpn_box_reg: 0.1381 (0.1627)  time: 3.8571  data: 0.0956  max mem: 2501
Epoch: [0]  [1680/4489]  eta: 2:48:47  lr: 0.000210  loss: 0.1847 (0.2148)  loss_classifier: 0.0191 (0.0416)  loss_box_reg: 0.0001 (0.0014)  loss_objectness: 0.0005 (0.0091)  loss_rpn_box_reg: 0.1588 (0.1627)  time: 3.8542  data: 0.0965  max me

Epoch: [0]  [1980/4489]  eta: 2:32:54  lr: 0.000210  loss: 0.1465 (0.2051)  loss_classifier: 0.0187 (0.0389)  loss_box_reg: 0.0001 (0.0012)  loss_objectness: 0.0006 (0.0079)  loss_rpn_box_reg: 0.1182 (0.1570)  time: 4.4169  data: 0.0936  max mem: 2501
Epoch: [0]  [1990/4489]  eta: 2:32:21  lr: 0.000210  loss: 0.1235 (0.2047)  loss_classifier: 0.0253 (0.0389)  loss_box_reg: 0.0001 (0.0012)  loss_objectness: 0.0005 (0.0078)  loss_rpn_box_reg: 0.1037 (0.1567)  time: 3.8714  data: 0.0931  max mem: 2501
Epoch: [0]  [2000/4489]  eta: 2:31:47  lr: 0.000210  loss: 0.1054 (0.2042)  loss_classifier: 0.0281 (0.0388)  loss_box_reg: 0.0002 (0.0012)  loss_objectness: 0.0004 (0.0078)  loss_rpn_box_reg: 0.0805 (0.1563)  time: 3.8751  data: 0.0929  max mem: 2501
Epoch: [0]  [2010/4489]  eta: 2:31:13  lr: 0.000210  loss: 0.1054 (0.2038)  loss_classifier: 0.0240 (0.0388)  loss_box_reg: 0.0001 (0.0012)  loss_objectness: 0.0003 (0.0078)  loss_rpn_box_reg: 0.0805 (0.1560)  time: 3.8762  data: 0.0946  max me

Epoch: [0]  [2310/4489]  eta: 2:14:04  lr: 0.000210  loss: 0.1107 (0.1939)  loss_classifier: 0.0223 (0.0371)  loss_box_reg: 0.0002 (0.0011)  loss_objectness: 0.0005 (0.0069)  loss_rpn_box_reg: 0.0821 (0.1488)  time: 3.8629  data: 0.0967  max mem: 2501
Epoch: [0]  [2320/4489]  eta: 2:13:29  lr: 0.000210  loss: 0.1180 (0.1937)  loss_classifier: 0.0223 (0.0370)  loss_box_reg: 0.0001 (0.0011)  loss_objectness: 0.0005 (0.0069)  loss_rpn_box_reg: 0.0898 (0.1486)  time: 3.8616  data: 0.0945  max mem: 2501
Epoch: [0]  [2330/4489]  eta: 2:12:53  lr: 0.000210  loss: 0.1180 (0.1934)  loss_classifier: 0.0232 (0.0370)  loss_box_reg: 0.0001 (0.0011)  loss_objectness: 0.0005 (0.0069)  loss_rpn_box_reg: 0.0945 (0.1484)  time: 3.8603  data: 0.0923  max mem: 2501
Epoch: [0]  [2340/4489]  eta: 2:12:18  lr: 0.000210  loss: 0.1152 (0.1932)  loss_classifier: 0.0215 (0.0369)  loss_box_reg: 0.0002 (0.0011)  loss_objectness: 0.0005 (0.0069)  loss_rpn_box_reg: 0.0889 (0.1482)  time: 3.8643  data: 0.0918  max me

Epoch: [0]  [2640/4489]  eta: 1:54:33  lr: 0.000210  loss: 0.1080 (0.1855)  loss_classifier: 0.0227 (0.0355)  loss_box_reg: 0.0001 (0.0011)  loss_objectness: 0.0003 (0.0062)  loss_rpn_box_reg: 0.0849 (0.1427)  time: 3.8611  data: 0.0968  max mem: 2501
Epoch: [0]  [2650/4489]  eta: 1:53:57  lr: 0.000210  loss: 0.1206 (0.1853)  loss_classifier: 0.0252 (0.0355)  loss_box_reg: 0.0002 (0.0011)  loss_objectness: 0.0005 (0.0062)  loss_rpn_box_reg: 0.0881 (0.1425)  time: 3.8600  data: 0.0969  max mem: 2501
Epoch: [0]  [2660/4489]  eta: 1:53:21  lr: 0.000210  loss: 0.1011 (0.1850)  loss_classifier: 0.0254 (0.0355)  loss_box_reg: 0.0002 (0.0011)  loss_objectness: 0.0006 (0.0062)  loss_rpn_box_reg: 0.0788 (0.1423)  time: 3.8576  data: 0.0979  max mem: 2501
Epoch: [0]  [2670/4489]  eta: 1:52:52  lr: 0.000210  loss: 0.0968 (0.1847)  loss_classifier: 0.0195 (0.0354)  loss_box_reg: 0.0002 (0.0011)  loss_objectness: 0.0007 (0.0062)  loss_rpn_box_reg: 0.0753 (0.1420)  time: 4.4018  data: 0.0989  max me

KeyboardInterrupt: 

Now that training has finished, let's have a look at what it actually predicts in a test image

In [None]:
# pick one image from the test set
img, _ = dataset_test[0]
# put the model in evaluation mode
model.eval()
with torch.no_grad():
    prediction = model([img.to(device)])

Printing the prediction shows that we have a list of dictionaries. Each element of the list corresponds to a different image. As we have a single image, there is a single dictionary in the list.
The dictionary contains the predictions for the image we passed. In this case, we can see that it contains `boxes`, `labels`, `masks` and `scores` as fields.

In [None]:
prediction

Let's inspect the image and the predicted segmentation masks.

For that, we need to convert the image, which has been rescaled to 0-1 and had the channels flipped so that we have it in `[C, H, W]` format.

In [None]:
Image.fromarray(img.mul(255).permute(1, 2, 0).byte().numpy())

And let's now visualize the top predicted segmentation mask. The masks are predicted as `[N, 1, H, W]`, where `N` is the number of predictions, and are probability maps between 0-1.

In [None]:
Image.fromarray(prediction[0]['masks'][0, 0].mul(255).byte().cpu().numpy())

Looks pretty good!

## Wrapping up

In this tutorial, you have learned how to create your own training pipeline for instance segmentation models, on a custom dataset.
For that, you wrote a `torch.utils.data.Dataset` class that returns the images and the ground truth boxes and segmentation masks. You also leveraged a Mask R-CNN model pre-trained on COCO train2017 in order to perform transfer learning on this new dataset.

For a more complete example, which includes multi-machine / multi-gpu training, check `references/detection/train.py`, which is present in the [torchvision GitHub repo](https://github.com/pytorch/vision/tree/v0.8.2/references/detection). 

