In [None]:
# Need to install einops and timm for original omnivore model, and matplotlib for visualization
! pip install einops timm matplotlib

import torch
import torchvision.transforms as T
import torchmultimodal.models.omnivore as omnivore

from PIL import Image
import collections
import json
import matplotlib.pyplot as plt
import numpy as np


In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def custom_load_state_dict(model, pretrained_state_dict):
    # Convert the pretrained_state_dict so it have the same keys as the model
    # then load the value of the weight into the model
    pretrained_keys = list(pretrained_state_dict.keys())
    model_keys = list(model.state_dict().keys())
    key_mapping = {pretrained_keys[i]: model_keys[i] for i in range(len(model_keys))}
    updated_pretrained_state_dict = collections.OrderedDict({key_mapping[key]: val for key, val in pretrained_state_dict.items()})
    model.load_state_dict(updated_pretrained_state_dict)

In [None]:
# Load model from torch_hub

mhub = torch.hub.load("facebookresearch/omnivore:main", model="omnivore_swinT")
mhub.eval()
print(count_parameters(mhub))

In [None]:
m = omnivore.omnivore_swin_t()

# Check that it have same number of parameter
print(count_parameters(m))

In [None]:
custom_load_state_dict(m, mhub.state_dict())
m = m.eval()


# Inference test

In [None]:
# Download imagenet class and image
# Uncomment to download
!wget https://s3.amazonaws.com/deep-learning-models/image-models/imagenet_class_index.json -O imagenet_class_index.json
with open("imagenet_class_index.json", "r") as f:
    imagenet_classnames = json.load(f)

# Create an id to label name mapping
imagenet_id_to_classname = {}
for k, v in imagenet_classnames.items():
    imagenet_id_to_classname[k] = v[1] 

# Download the example image file
# Uncomment to download
!wget -O library.jpg https://upload.wikimedia.org/wikipedia/commons/thumb/c/c5/13-11-02-olb-by-RalfR-03.jpg/800px-13-11-02-olb-by-RalfR-03.jpg

image_path = "library.jpg"
image_pil = Image.open(image_path).convert("RGB")
plt.figure(figsize=(6, 6))
plt.imshow(image_pil)

In [None]:
image_transform = T.Compose(
    [
        T.Resize(224),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
image = image_transform(image_pil)  # C H W

# Adding batch and time (D) dimension
image = image.unsqueeze(0).unsqueeze(2)  # B C D H W

In [None]:
def infer(model):
    with torch.no_grad():
        prediction = model(image, input_type="image")
        pred_classes = prediction.topk(k=5).indices

    pred_class_names = [imagenet_id_to_classname[str(i.item())] for i in pred_classes[0]]
    print("Top 5 predicted labels: %s" % ", ".join(pred_class_names))

In [None]:
# Test both model to infer the same image and make sure the output classes are the same
infer(m)
infer(mhub)

# Make sure the output of the trunk / encoder are the same

In [None]:
m_feature = m.encoder(image)
mhub_feature = mhub.trunk(image)

In [None]:
# See the first 10 features are the same
m_feature.flatten()[:10], mhub_feature[0].flatten()[:10]

In [None]:
# Make sure all the features are the same
np.all(np.array(m_feature == mhub_feature[0]))

# Test on randomly generated input

In [None]:
mock_video = torch.randn(1, 3, 10, 112, 112)

m_output = m(mock_video, input_type="video")
mhub_output = mhub(mock_video, input_type="video")

np.all(np.array(m_output == mhub_output[0]))

In [None]:
mock_depth = torch.randn(1, 4, 1, 112, 112)

m_output = m(mock_video, input_type="rgbd")
mhub_output = mhub(mock_video, input_type="rgbd")

np.all(np.array(m_output == mhub_output[0]))