GitHub  
https://github.com/facebookresearch/omnivore  
論文  
https://arxiv.org/abs/2201.08377v1    

<a href="https://colab.research.google.com/github/kaz12tech/ai_demos/blob/master/omnivore_demo.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 実行前準備
「ランタイム」→「ランタイムのタイプを変更」→「ハードウェアアクセラレータ」をGPUに変更
## 実行方法
「ランタイム」→「すべてのセルを実行」を選択

## 環境セットアップ
pytorchのバージョン変更やライブラリのインストール

In [None]:
import torch
!python -c "import torch; print(torch.__version__); print(torch.cuda.is_available()); print(torch.cuda.device_count());"

In [None]:
# Pytorch 1.9.0をインストール
#!pip uninstall torch -y
#!pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
!pip install einops timm moviepy

# インストールしたPytorchの動作確認(バージョン、CUDAの認識可否、GPU数確認)
import torch
!python -c "import torch; print(torch.__version__); print(torch.cuda.is_available()); print(torch.cuda.device_count());"

In [None]:
!pip install "git+https://github.com/facebookresearch/pytorchvideo.git"
!git clone https://github.com/facebookresearch/omnivore.git

## モジュールのインポート

In [None]:
%cd /content/omnivore/
import json
from typing import List

import torch
import torch.nn.functional as F
import torchvision.transforms as T
from PIL import Image
from pytorchvideo.data.encoded_video import EncodedVideo
from torchvision.transforms._transforms_video import NormalizeVideo
from transforms import SpatialCrop, TemporalCrop, DepthNorm

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    ShortSideScale,
    UniformTemporalSubsample,
)

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import imageio
import matplotlib.animation as animation
from IPython.display import HTML


## モデルのロード
### Model Zoo 
今回はOmnivore Swin Bを使用


| Name      | IN1k Top 1 | Kinetics400 Top 1     | SUN RGBD Top 1     | Model   |
| :---        |    :----   |          :--- | :--- |:--- |
| Omnivore Swin T      | 81.2       | 78.9   |62.3   | [weights](https://dl.fbaipublicfiles.com/omnivore/models/swinT_checkpoint.torch)   
| Omnivore Swin S   | 83.4       | 82.2      |64.6  | [weights](https://dl.fbaipublicfiles.com/omnivore/models/swinS_checkpoint.torch)  |
| Omnivore Swin B      | 84.0       | 83.3   |65.4   | [weights](https://dl.fbaipublicfiles.com/omnivore/models/swinB_checkpoint.torch)   |
| Omnivore Swin B (IN21k)   | 85.3       | 84.0      |67.2   | [weights](https://dl.fbaipublicfiles.com/omnivore/models/swinB_In21k_checkpoint.torch)   |
| Omnivore Swin L (IN21k)      | 86.0       | 84.1   |67.1   | [weights](https://dl.fbaipublicfiles.com/omnivore/models/swinL_In21k_checkpoint.torch) |


In [None]:
# Device on which to run the model
# Set to cuda to load on GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print("set device->", device)

# Pick a pretrained model 
model_name = "omnivore_swinB"
model = torch.hub.load("facebookresearch/omnivore:main", model=model_name)

# Set to eval mode and move to desired device
model = model.to(device)
model = model.eval()

## 画像分類

### ID/ラベルのロード
予測時に導出されるIDと、IDが意味するラベルのマッピングが示されたImagenet1Kデータセットのラベルファイルをロード  


In [None]:
%cd /content/omnivore/
# IDとラベルの対応表(imagenet_class_index.json)をダウンロード
!wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json

with open("/content/omnivore/imagenet_class_index.json", "r") as f:
    imagenet_classnames = json.load(f)

# idとlabelをマッピングした辞書型オブジェクトを生成
imagenet_id_to_classname = {}
for k, v in imagenet_classnames.items():
    imagenet_id_to_classname[k] = v[1] 

### 画像の取得


In [None]:
%cd /content/omnivore/
!mkdir images
%cd /content/omnivore/images/

latent_type='upload_from_pc' #@param ['dawnload_from_web', 'upload_from_pc']
if latent_type=='dawnload_from_web':
  # サンプル画像のダウンロード
  !wget -O library.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/13-11-02-olb-by-RalfR-03.jpg/800px-13-11-02-olb-by-RalfR-03.jpg

  image_path = "/content/omnivore/images/library.jpg"
else:
  # 画像アップロード
  from google.colab import files
  uploaded = files.upload()
  uploaded = list(uploaded.keys())
  file_name = uploaded[0]
  image_path = "/content/omnivore/images/" + file_name

image = Image.open(image_path).convert("RGB")
plt.figure(figsize=(6, 6))
plt.imshow(image)

### 画像のセットアップ
画像をモデルが求めるデータ形式に変換

In [None]:
image_transform = T.Compose(
    [
        T.Resize(224),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
image = image_transform(image)

# The model expects inputs of shape: B x C x T x H x W
image = image[None, :, None, ...]
image = image.to(device)

### 予測
画像のラベルを予測します

In [None]:
prediction = model(image, input_type="image")
prediction = F.softmax(prediction, dim=1)
values, pred_classes = prediction.topk(k=5)

print("上位5番目までの検出ラベル")
for value, pred_class in zip(values[0], pred_classes[0]):
  print('ラベル:', imagenet_id_to_classname[str(pred_class.item())], ', Score:', value.item())

# GPU Memoryがギリギリのため予測が終わったらGPUに配置したimageを削除
del image
torch.cuda.empty_cache()

## 動画分類

### ID/ラベルのロード
予測時に導出されるIDと、IDが意味するラベルのマッピングが示されたKinetics 400データセットのラベルファイルをロード  


In [None]:
%cd /content/omnivore/
# IDとラベルの対応表(imagenet_class_index.json)をダウンロード
!wget https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json

with open("kinetics_classnames.json", "r") as f:
    kinetics_classnames = json.load(f)

# idとlabelをマッピングした辞書型オブジェクトを生成
kinetics_id_to_classname = {}
for k, v in kinetics_classnames.items():
    kinetics_id_to_classname[v] = str(k).replace('"', "")

### 動画の取得

In [None]:
%cd /content/omnivore/
!mkdir videos
%cd /content/omnivore/videos/

latent_type='upload_from_pc' #@param ['dawnload_from_web', 'upload_from_pc']

if latent_type=='dawnload_from_web':
  # サンプル動画のダウンロード
  !wget https://dl.fbaipublicfiles.com/omnivore/example_data/dance.mp4

  # 動画のロード
  video_path = "/content/omnivore/videos/dance.mp4"
else:
  # 動画アップロード
  from google.colab import files
  uploaded = files.upload()
  uploaded = list(uploaded.keys())
  file_name = uploaded[0]
  video_path = "/content/omnivore/videos/" + file_name  

# 動画が大きすぎるとRAMから溢れるためリサイズ
from moviepy.editor import *
clip = VideoFileClip(video_path)

# 高さ80pxにリサイズ
clip_resized = clip.resize(height=80)
clip_resized.write_videofile(video_path)

clip_resized.ipython_display()

### 動画のセットアップ
動画をモデルが求めるデータ形式に変換

In [None]:
num_frames = 160 # default 160
sampling_rate = 1 # default 2
frames_per_second = 160 # default 30 

clip_duration = (num_frames * sampling_rate) / frames_per_second

video_transform = ApplyTransformToKey(
    key="video",
    transform=T.Compose(
        [
            UniformTemporalSubsample(num_frames), 
            T.Lambda(lambda x: x / 255.0),  
            ShortSideScale(size=224), # default 224
            NormalizeVideo(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            TemporalCrop(frames_per_clip=16, stride=40), # default 32, 40
            SpatialCrop(crop_size=224, num_crops=3), # default 224, 3
        ]
    ),
)

In [None]:
# Select the duration of the clip to load by specifying the start and end duration
# The start_sec should correspond to where the action occurs in the video
start_sec = 0
end_sec = start_sec + clip_duration 
print('start_sec:', start_sec, ', end_sec:', end_sec)

# Initialize an EncodedVideo helper class
video = EncodedVideo.from_path(video_path)

# Load the desired clip
video_data = video.get_clip(start_sec=start_sec, end_sec=end_sec)

# Apply a transform to normalize the video input
video_data = video_transform(video_data)

# Move the inputs to the desired device
video_inputs = video_data["video"]

# Take the first clip 
# The model expects inputs of shape: B x C x T x H x W
video_input = video_inputs[0][None, ...]

# GPUメモリに乗り切らないためcpuで実行
video_input = video_input.to('cpu')
model = model.to('cpu')

### 予測
動画のラベルを予測します

In [None]:
# Pass the input clip through the model 
prediction = model(video_input, input_type="video")

# Get the predicted classes 
prediction = F.softmax(prediction, dim=1)
values, pred_classes = prediction.topk(k=5)

print("上位5番目までの検出ラベル")
for value, pred_class in zip(values[0], pred_classes[0]):
  print('ラベル:', kinetics_id_to_classname[int(pred_class)], ', Score:', value.item())

## 深度マップ(RGBD画像)による画像分類

### ID/ラベルのロード
予測時に導出されるIDと、IDが意味するラベルのマッピングが示されたSUN RGBDデータセットのラベルファイルをロード  


In [None]:
%cd /content/omnivore/
!wget https://dl.fbaipublicfiles.com/omnivore/sunrgbd_classnames.json

with open("sunrgbd_classnames.json", "r") as f:
  sunrgbd_id_to_classname = json.load(f)

### RGBD画像の取得

In [None]:
%cd /content/omnivore/
!mkdir rgbd
%cd /content/omnivore/rgbd/

# 画像と深度ファイルのダウンロード
!wget -O store.png https://upload.wikimedia.org/wikipedia/commons/thumb/f/f4/Interior_of_the_IKEA_B%C4%83neasa_33.jpg/791px-Interior_of_the_IKEA_B%C4%83neasa_33.jpg
!wget https://dl.fbaipublicfiles.com/omnivore/example_data/store_disparity.pt

image_path = "/content/omnivore/rgbd/store.png"
depth_path = "/content/omnivore/rgbd/store_disparity.pt"
image = Image.open(image_path).convert("RGB")
depth = torch.load(depth_path)[None, ...]

plt.figure(figsize=(20, 10))
plt.subplot(1, 2, 1)
plt.title("RGB", fontsize=20)
plt.imshow(image)
plt.subplot(1, 2, 2)
plt.imshow(depth.numpy().squeeze())
plt.title("Depth", fontsize=20)

### 画像のセットアップ
画像をモデルが求めるデータ形式に変換

In [None]:
rgbd_transform = T.Compose(
    [
        DepthNorm(max_depth=75.0, clamp_max_before_scale=True),
        T.Resize(224),
        T.CenterCrop(224),
        T.Normalize(
            mean=[0.485, 0.456, 0.406, 0.0418], 
            std=[0.229, 0.224, 0.225, 0.0295]
        ),
    ]
)

In [None]:
# Convert to tensor and transform
image = T.ToTensor()(image)
rgbd = torch.cat([image, depth], dim=0)
rgbd = rgbd_transform(rgbd)

# The model expects inputs of shape: B x C x T x H x W
rgbd_input = rgbd[None, :, None, ...]
rgbd_input = rgbd_input.to(device)
model = model.to(device)

### 予測

In [None]:
prediction = model(rgbd_input, input_type="rgbd")
prediction = F.softmax(prediction, dim=1)
pred_classes = prediction.topk(k=5).indices

print("上位5番目までの検出ラベル")
for value, pred_class in zip(values[0], pred_classes[0]):
  print('ラベル:', sunrgbd_id_to_classname[str(pred_class.item())], ', Score:', value.item())