Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[lab01] Update code to tf 2.0, separate demo maskrcnn.
- Loading branch information
Arthur Douillard
committed
Nov 17, 2019
1 parent
66a4aae
commit 134f6c1
Showing
9 changed files
with
265 additions
and
252 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Instance Detection and Segmentation with Mask-RCNN\n", | ||
"\n", | ||
"[Mask RCNN](https://arxiv.org/abs/1703.06870) is a refinement of the [Faster RCNN](https://arxiv.org/abs/1506.01497) **object detection** model to also add support for **instance segmentation**.\n", | ||
"\n", | ||
"The following shows how to use a [Keras based implementation](https://github.com/matterport/Mask_RCNN) provided by matterport.com along with model parameters pretrained on the [COCO object detection dataset](http://cocodataset.org/).\n", | ||
"\n", | ||
"**WARNING**: The following requires to execute the companion `data_download.ipynb` notebook first.\n", | ||
"\n", | ||
"**WARNING**: For this notebook (and only this one), you'll need a tensorflow version under 2.0.0; create a new virtualenv and install requirements_tensorflow1.txt." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import tensorflow\n", | ||
"import keras\n", | ||
"\n", | ||
"tensorflow.__version__, keras.__version__" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from maskrcnn import config\n", | ||
"from maskrcnn import model as modellib\n", | ||
"\n", | ||
"\n", | ||
"class InferenceCocoConfig(config.Config):\n", | ||
" # Give the configuration a recognizable name\n", | ||
" NAME = \"inference_coco\"\n", | ||
"\n", | ||
" # Number of classes (including background)\n", | ||
" NUM_CLASSES = 1 + 80 # COCO has 80 classes\n", | ||
"\n", | ||
" # Set batch size to 1 since we'll be running inference on\n", | ||
" # one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU\n", | ||
" GPU_COUNT = 1\n", | ||
" IMAGES_PER_GPU = 1\n", | ||
"\n", | ||
"\n", | ||
"config = InferenceCocoConfig()\n", | ||
"model = modellib.MaskRCNN(mode=\"inference\", model_dir='maskrcnn/logs', config=config)\n", | ||
"\n", | ||
"# Load weights trained on MS-COCO\n", | ||
"coco_model_file = \"mask_rcnn_coco.h5\"\n", | ||
"model.load_weights(coco_model_file, by_name=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Class Names\n", | ||
"\n", | ||
"Index of the class in the list is its ID. For example, to get ID of the teddy bear class, use: `class_names.index('teddy bear')`\n", | ||
"\n", | ||
"`BG` stands for background." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# COCO Class names\n", | ||
"class_names = ['BG', 'person', 'bicycle', 'car', 'motorcycle', 'airplane',\n", | ||
" 'bus', 'train', 'truck', 'boat', 'traffic light',\n", | ||
" 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird',\n", | ||
" 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear',\n", | ||
" 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie',\n", | ||
" 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',\n", | ||
" 'kite', 'baseball bat', 'baseball glove', 'skateboard',\n", | ||
" 'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup',\n", | ||
" 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',\n", | ||
" 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',\n", | ||
" 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed',\n", | ||
" 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote',\n", | ||
" 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster',\n", | ||
" 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors',\n", | ||
" 'teddy bear', 'hair drier', 'toothbrush']" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Run Object Detection" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from skimage.io import imread\n", | ||
"\n", | ||
"image = imread('webcam_shot.jpeg')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"image.shape" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from maskrcnn import visualize\n", | ||
"import time\n", | ||
"\n", | ||
"# Run detection\n", | ||
"tic = time.time()\n", | ||
"results = model.detect([image], verbose=1)\n", | ||
"toc = time.time()\n", | ||
"print(\"Analyzed image in {:.3f}s\".format(toc - tic))\n", | ||
"\n", | ||
"# Visualize results\n", | ||
"r = results[0]\n", | ||
"for class_id, score in zip(r['class_ids'], r['scores']):\n", | ||
" print(\"{}:\\t{:0.3f}\".format(class_names[class_id], score))\n", | ||
"visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], \n", | ||
" class_names, r['scores']);" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import random\n", | ||
"\n", | ||
"# Load a random image from the images folder\n", | ||
"image_folder = 'maskrcnn/images'\n", | ||
"file_names = next(os.walk(image_folder))[2]\n", | ||
"image = imread(os.path.join(image_folder, random.choice(file_names)))\n", | ||
"\n", | ||
"# Run detection\n", | ||
"results = model.detect([image], verbose=1)\n", | ||
"\n", | ||
"# Visualize results\n", | ||
"r = results[0]\n", | ||
"visualize.display_instances(image, r['rois'], r['masks'], r['class_ids'], \n", | ||
" class_names, r['scores'])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.