# Omnivore: A Single Model for Many Visual Modalities

Omnivore is a model introduced in [this paper](https://arxiv.org/abs/2201.08377), and it is a classification model that able to accept different visual modalities, standard image (RGB), video, or depth image (RGBD), as the input. This model uses the video swin transformer as the encoder and it has multiple heads corresponding to each visual modality.

In this notebook, we want to demonstrate how to use the omnivore model loaded with its pretrained weight to classify each of the visual modalities.


In [None]:
import os
import json
from PIL import Image
import numpy as np

import torch
from torchvision.io import read_video, _read_video_from_file
import torchvision.transforms as T

import torchmultimodal.models.omnivore as omnivore
from examples.omnivore.data import presets

from IPython.display import Video
from matplotlib import pyplot as plt

In [None]:
# Get model with pretrained weight
model = omnivore.omnivore_swin_t(pretrained=True)
model = model.eval()

## Inference on Image

In [None]:
# Downloading assets
os.makedirs("assets", exist_ok=True)
!wget "https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/imagenet_val_ringlet_butterfly_001.JPEG" -P "assets/"
!wget "https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/imagenet_class.json" -P "assets/"

In [None]:
# Read the class list and image
with open("assets/imagenet_class.json", "r") as fin:
    imagenet_classes = json.load(fin)
pil_img = Image.open("assets/imagenet_val_ringlet_butterfly_001.JPEG")

In [None]:
# Show image
pil_img

In [None]:
# Apply transforms
img_val_presets = presets.ImageNetClassificationPresetEval(crop_size=224)
input_img = img_val_presets(pil_img)

# Add batch dimension
input_img = input_img.unsqueeze(0)

In [None]:
# Get top5 labels
preds = model(input_img, "image")
top5_values, top5_indices = preds[0].topk(5)
top5_labels = [imagenet_classes[index] for index in top5_indices.tolist()]
top5_labels

# The correct label is a ringlet butterfly, see: https://en.wikipedia.org/wiki/Ringlet

## Inference on Video

In [None]:
# Downloading assets
os.makedirs("assets", exist_ok=True)
!wget "https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/kinetics400_val_snowboarding_001.mp4" -P "assets/"
!wget "https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/kinetics400_class.json" -P "assets/"


In [None]:
# Read class list and video
with open("assets/kinetics400_class.json", "r") as fin:
    kinetics400_classes = json.load(fin)
video, audio, info = read_video("assets/kinetics400_val_snowboarding_001.mp4", output_format="TCHW")

# Since we sampled at 16 fps for training, and the input video is 30 fps
# we resample every 2 frames so it become 15 fps and closer to training fps
video = video[::2]

# Use first 50 frames
video = video[:50]


In [None]:
# Show video
Video("assets/kinetics400_val_snowboarding_001.mp4", width=512)

In [None]:
# Apply transforms
video_val_presets = presets.VideoClassificationPresetEval(crop_size=224, resize_size=224)
input_video = video_val_presets(video)
# Add batch dimension
input_video = input_video.unsqueeze(0)

In [None]:
# Get top5 labels
preds = model(input_video, "video")
top5_values, top5_indices = preds[0].topk(5)
top5_labels = [kinetics400_classes[index] for index in top5_indices.tolist()]
top5_labels

# The correct label is snowboarding

## Inference on depth image

In [None]:
# Downloading assets
os.makedirs("assets", exist_ok=True)
!wget "https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/sunrgbd_val_kitchen_depth_001.png" -P "assets/"
!wget "https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/sunrgbd_val_kitchen_image_001.jpg" -P "assets/"
!wget "https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/sunrgbd_val_kitchen_intrinsics_001.txt" -P "assets/"
!wget "https://download.pytorch.org/torchmultimodal/examples/omnivore/assets/sunrgbd_class.json" -P "assets/"


In [None]:
# Read class list
with open("assets/sunrgbd_class.json", "r") as fin:
    sunrgbd_classes = json.load(fin)

In [None]:
# Read depth image
with open("assets/sunrgbd_val_kitchen_intrinsics_001.txt", "r") as fin:
    lines = fin.readlines()
    focal_length = float(lines[0].strip().split()[0])
    
# Baseline of kv2 sensor of sunrgbd (where this depth image come from)
baseline = 0.075

img_depth = Image.open("assets/sunrgbd_val_kitchen_depth_001.png")
_to_tensor = T.ToTensor()
tensor_depth = _to_tensor(img_depth)
tensor_disparity = baseline * focal_length / (tensor_depth / 1000.0)

img_rgb = Image.open("assets/sunrgbd_val_kitchen_image_001.jpg")
tensor_rgb = _to_tensor(img_rgb)

tensor_rgbd = torch.cat((tensor_rgb, tensor_disparity), dim=0)

In [None]:
# Show depth image
fig = plt.figure(figsize=(16, 16))

fig.add_subplot(1, 2, 1)
plt.imshow(np.asarray(img_rgb))

fig.add_subplot(1, 2, 2)
plt.imshow(np.asarray(img_depth), cmap="jet")

In [None]:
# Apply transforms
depth_val_presets = presets.DepthClassificationPresetEval(crop_size=224, resize_size=224)
input_depth = depth_val_presets(tensor_rgbd)
# Add batch dimension
input_depth = input_depth.unsqueeze(0)

In [None]:
# Get top5 predictions
preds = model(input_depth, "rgbd")
top5_values, top5_indices = preds[0].topk(5)
top5_labels = [sunrgbd_classes[index] for index in top5_indices.tolist()]
top5_labels

# The correct label is kitchen