<a href="https://colab.research.google.com/github/komazawa-deep-learning/komazawa-deep-learning.github.io/blob/master/2022notebooks/2022_0625DETR_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DETR を用いた領域切り出し

ここでは，バウンディングボックス，画像から注目する領域の四角形を切り出しと，パノプティック切り出しを行う。

[DETR](https://arxiv.org/pdf/2005.12872) とは，トランスフォーマーを用いた符号化器-復号化器 (encoder-decoder) モデルである。


<center>
<img src="https://komazawa-deep-learning.github.io/2022assets/2020Carion_DETR_fig2ja.png" width="88%"><br/>
<!-- <img src="https://komazawa-deep-learning.github.io/2022assets/2020Carion_DETR_fig2.svg" width="88%"><br/> -->
</center>



## 準備作業

In [1]:
%config InlineBackend.figure_format = 'retina'
import IPython
isColab = 'google.colab' in str(IPython.get_ipython())

from termcolor import colored
import platform
HOSTNAME = platform.node().split('.')[0]

import os
HOME = os.environ['HOME']

try:
    import ipynbname
except ImportError:
    !pip install ipynbname > /dev/null
import ipynbname
FILEPATH = str(ipynbname.path()).replace(HOME+'/','')

import pwd
USER=pwd.getpwuid(os.geteuid())[0]

from datetime import date
TODAY=date.today()

import torch
TORCH_VERSION = torch.__version__

color = 'green'
print('日付:',colored(f'{TODAY}', color=color, attrs=['bold']))
print('HOSTNAME:',colored(f'{HOSTNAME}', color=color, attrs=['bold']))
print('ユーザ名:',colored(f'{USER}', color=color, attrs=['bold']))
print('HOME:',colored(f'{HOME}', color=color,attrs=['bold']))
print('ファイル名:',colored(f'{FILEPATH}', color=color, attrs=['bold']))
print('torch.__version__:',colored(f'{TORCH_VERSION}', color=color, attrs=['bold']))

日付: [1m[32m2022-06-23[0m
HOSTNAME: [1m[32m71a312c6370d[0m
ユーザ名: [1m[32mroot[0m
HOME: [1m[32m/root[0m
ファイル名: [1m[32m/fileId=1r7e8KEIyinaaO4ADE8IHEIOsFc7fj07t[0m
torch.__version__: [1m[32m1.11.0+cu113[0m


In [3]:
from PIL import Image
import requests
import io
import math
import matplotlib.pyplot as plt
try:
    import japanize_matplotlib
except ImportError:    
    !pip install japanize_matplotlib
import japanize_matplotlib

import torch
from torch import nn
#from torchvision.models import resnet50
import torchvision.transforms as T
import numpy
torch.set_grad_enabled(False);

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting japanize_matplotlib
  Downloading japanize-matplotlib-1.1.3.tar.gz (4.1 MB)
[K     |████████████████████████████████| 4.1 MB 8.5 MB/s 
Building wheels for collected packages: japanize-matplotlib
  Building wheel for japanize-matplotlib (setup.py) ... [?25l[?25hdone
  Created wheel for japanize-matplotlib: filename=japanize_matplotlib-1.1.3-py3-none-any.whl size=4120275 sha256=33e49a20c4bde2d919430ca734b0f25c742d726128d70a6b387341b5616e3f48
  Stored in directory: /root/.cache/pip/wheels/83/97/6b/e9e0cde099cc40f972b8dd23367308f7705ae06cd6d4714658
Successfully built japanize-matplotlib
Installing collected packages: japanize-matplotlib
Successfully installed japanize-matplotlib-1.1.3


パノプティック 切り出しのための API をインストール

In [4]:
try:
    import panopticapi
except ImportError:
    !pip install git+https://github.com/cocodataset/panopticapi.git --upgrade

import panopticapi
from panopticapi.utils import id2rgb, rgb2id

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/cocodataset/panopticapi.git
  Cloning https://github.com/cocodataset/panopticapi.git to /tmp/pip-req-build-jemrtiv4
  Running command git clone -q https://github.com/cocodataset/panopticapi.git /tmp/pip-req-build-jemrtiv4
Building wheels for collected packages: panopticapi
  Building wheel for panopticapi (setup.py) ... [?25l[?25hdone
  Created wheel for panopticapi: filename=panopticapi-0.1-py3-none-any.whl size=8306 sha256=54d4f44f2ca074899a3f809a80faf1ac7e10bdef1d7cf570a2839e8cafdae7ce
  Stored in directory: /tmp/pip-ephem-wheel-cache-e0tjmuxx/wheels/ad/89/b8/b66cce9246af3d71d65d72c85ab993fd28e7578e1b0ed197f1
Successfully built panopticapi
Installing collected packages: panopticapi
Successfully installed panopticapi-0.1


In [5]:
# MS COCO のクラス
CLASSES = [
    'N/A', '人', '自転車', '車', 'バイク', '飛行機', 'バス',
    '電車', 'トラック', 'ボート', '信号機', '消火栓', 'N/A',
    '停止サイン', '駐車メータ', 'ベンチ', '鳥', 'ネコ', 'イヌ', '馬',
    '羊', '牛', '象', '熊', 'シマウマ', 'キリン', 'N/A', 'バックパック',
    '傘', 'N/A', 'N/A', 'ハンドバッグ', 'ネクタイ', 'スーツケース', 'フリスビー', 'スキー',
    'スノーボード', 'sports ball', '凧', '野球のバット', '野球のグラブ',
    'スケートボード', 'サーフボード', 'テニスラケット', 'ボトル', 'N/A', 'ワイングラス',
    'カップ', 'フォーク', 'ナイフ', 'スプーン', 'ボウル', 'バナナ', 'りんご', 'サンドウィッチ',
    'オレンジ', 'ブロッコリ', 'ニンジン', 'ホットドッグ', 'ピザ', 'ドーナッツ', 'ケーキ',
    'イス', 'ソファ', '鉢植え', 'ベッド', 'N/A', 'ダイニングテーブル', 'N/A',
    'N/A', 'トイレ', 'N/A', 'テレビ', 'ラップトップ', 'マウス', 'リモコン', 'キーボード',
    '携帯電話', '電子レンジ', 'コンロ', 'トースター', '洗面台', '冷蔵庫', 'N/A',
    '本', '時計', '花瓶', 'ハサミ', 'テディベア', 'ドライヤー',
    '歯ブラシ'
]


# 視覚化のための色定義
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

# 平均と標準偏差を用いて入力画像を正規化
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


# バウンディングボックスの前処理。
# バウンディングボックスの中心座標と幅から，左，上，右，下座標を計算する
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)


def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b


def plot_results(pil_img, prob, boxes):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        cl = p.argmax()
        text = f'{CLASSES[cl]}: {p[cl]:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.show()
    
# Detectron2 と クラスの定義が異なるため，その差異を吸収するため
coco2d2 = {}
count = 0
for i, c in enumerate(CLASSES):
    if c != "N/A":
        coco2d2[i] = count
    count+=1

# 唯一の前処理 平均を引いて標準偏差で割る標準化
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

## モデルハブから訓練済の結合係数を取得


In [6]:
detr_resnet = torch.hub.load(repo_or_dir='facebookresearch/detr', 
                             model='detr_resnet50', 
                             pretrained=True)

detr_resnet_panoptic, postprocessor0 = torch.hub.load(
    repo_or_dir='facebookresearch/detr', 
    model='detr_resnet50_panoptic', 
    pretrained=True,
    return_postprocessor=True,
    num_classes=250)


Downloading: "https://github.com/facebookresearch/detr/archive/main.zip" to /root/.cache/torch/hub/main.zip
Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

Downloading: "https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth" to /root/.cache/torch/hub/checkpoints/detr-r50-e632da11.pth


  0%|          | 0.00/159M [00:00<?, ?B/s]

Using cache found in /root/.cache/torch/hub/facebookresearch_detr_main
Downloading: "https://dl.fbaipublicfiles.com/detr/detr-r50-panoptic-00ce5173.pth" to /root/.cache/torch/hub/checkpoints/detr-r50-panoptic-00ce5173.pth


  0%|          | 0.00/164M [00:00<?, ?B/s]

COCO 検証用データセットから画像を取得

In [None]:
if not os.path.exists('000000039769.jpg'):
    !wget http://images.cocodataset.org/val2017/000000039769.jpg
im = Image.open('000000039769.jpg')
IPython.display.Image('000000039769.jpg')

* 前処理を，各データに施し，予測をフィルタリング
* クラスの信頼度が しきい値 (0.9) よりも高い物体のみを保存する
 (非オブジェクトの予測は除外)。
* より多くの予測を得たい場合は，この閾値を下げる


In [8]:
# 平均を引いて，標準偏差で割る正規化をバッチサイズ 1 の入力画像に対して行う
img = transform(im).unsqueeze(0)
outputs = detr_resnet(img) # model に通して出力を得る

# 確信度 0.9 以上の予測値だけを考慮する。
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9

# バウンディングボックス(注目する画像の矩形領域) を縮尺に合わせて拡大
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)

  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)


In [None]:
plot_results(im, probas[keep], bboxes_scaled) # 結果の描画

## パノプティック切り出し

In [None]:
import os
img_fname = '000000281759.jpg'
if not os.path.exists(img_fname):
    !wget http://images.cocodataset.org/val2017/000000281759.jpg -O 000000281759.jpg
im = Image.open(img_fname)
IPython.display.Image(img_fname)

検出の実施

In [11]:
img = transform(im).unsqueeze(0)
out = detr_resnet_panoptic(img)

  dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)


これは各クエリのマスクを返すので、信頼度の高いものを可視化してみましょう。
<!-- This returns a mask for each query, let us visualize the high confidence ones -->

In [None]:
# compute the scores, excluding the "no-object" class (the last one)
scores = out["pred_logits"].softmax(-1)[..., :-1].max(-1)[0]
# threshold the confidence
keep = scores > 0.85

# Plot all the remaining masks
ncols = 5
fig, axs = plt.subplots(ncols=ncols, nrows=math.ceil(keep.sum().item() / ncols), figsize=(18, 10))
for line in axs:
    for a in line:
        a.axis('off')
for i, mask in enumerate(out["pred_masks"][keep]):
    ax = axs[i // ncols, i % ncols]
    ax.imshow(mask, cmap="cividis")
    ax.axis('off')
fig.tight_layout()

個々のマスクが揃ったので、予測を統合して統一されたパノプティックセグメンテーションにすることができます。
そのためにDETRのポストプロセッサーを使用します。
<!-- Now that we have the individual masks, we can merge the predictions into a unified panoptic segmentation. 
We use DETR's postprocessor for that. -->

In [13]:
# the post-processor expects as input the target size of the predictions (which we set here to the image size)
result = postprocessor0(out, torch.as_tensor(img.shape[-2:]).unsqueeze(0))[0]

簡単な結果の表示

In [None]:
import itertools
import seaborn as sns
palette = itertools.cycle(sns.color_palette())

# The segmentation is stored in a special-format png
panoptic_seg = Image.open(io.BytesIO(result['png_string']))
panoptic_seg = numpy.array(panoptic_seg, dtype=numpy.uint8).copy()
# We retrieve the ids corresponding to each mask
panoptic_seg_id = rgb2id(panoptic_seg)

# Finally we color each mask individually
panoptic_seg[:, :, :] = 0
for id in range(panoptic_seg_id.max() + 1):
    panoptic_seg[panoptic_seg_id == id] = numpy.asarray(next(palette)) * 255
plt.figure(figsize=(15,15))
plt.imshow(panoptic_seg)
plt.axis('off')
plt.show()

## Detectron2 を用いた視覚化

In [None]:
try:
    import detectron2
except ImportError:
    !pip install 'git+https://github.com/facebookresearch/detectron2.git'
    #!pip install detectron2==0.1.3 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/torch1.5/index.html

import detectron2
detectron2.__version__

In [16]:
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
if isColab:
    from google.colab.patches import cv2_imshow

結果の表示

In [None]:
from copy import deepcopy

# We extract the segments info and the panoptic result from DETR's prediction
segments_info = deepcopy(result["segments_info"])

# Panoptic predictions are stored in a special format png
panoptic_seg = Image.open(io.BytesIO(result['png_string']))
final_w, final_h = panoptic_seg.size

# We convert the png into an segment id map
panoptic_seg = numpy.array(panoptic_seg, dtype=numpy.uint8)
panoptic_seg = torch.from_numpy(rgb2id(panoptic_seg))

    
# Detectron2 uses a different numbering of coco classes, here we convert the class ids accordingly
meta = MetadataCatalog.get("coco_2017_val_panoptic_separated")
for i in range(len(segments_info)):
    c = segments_info[i]["category_id"]
    segments_info[i]["category_id"] = meta.thing_dataset_id_to_contiguous_id[c] if segments_info[i]["isthing"] else meta.stuff_dataset_id_to_contiguous_id[c]

    
# Finally we visualize the prediction
v = Visualizer(numpy.array(im.copy().resize((final_w, final_h)))[:, :, ::-1], meta, scale=1.0)
v._default_font_size = 20
v = v.draw_panoptic_seg_predictions(panoptic_seg, segments_info, area_threshold=0)
if isColab:
    cv2_imshow(v.get_image())
else:
    plt.figure(figsize=(20,20))
    plt.imshow(v.get_image()[:,:,::-1])
    plt.show()