Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add pytorchvideo encoder wrapper #1156

Closed
Closed
91 changes: 89 additions & 2 deletions mmf/modules/encoders.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import importlib
import logging
import os
import pickle
import re
from collections import OrderedDict
from copy import deepcopy
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Any

Expand All @@ -25,13 +27,15 @@
from transformers.configuration_auto import AutoConfig
from transformers.modeling_auto import AutoModel


try:
from detectron2.modeling import ShapeSpec, build_resnet_backbone
except ImportError:
pass


logger = logging.getLogger()


class Encoder(nn.Module):
@dataclass
class Config:
Expand Down Expand Up @@ -688,6 +692,89 @@ def forward(self, x: Tensor) -> Tensor:
return out


@registry.register_encoder("pytorchvideo")
class PytorchVideoEncoder(Encoder):
"""A thin wrapper around pytorchvideo models.
This class is responsible for integrating pytorchvideo models as encoders.
THis class attempts to construct a pytorchvideo model from torch hub.
If this fails for a random weight model, and pytorchvideo package is available,
build the model with random weights from pytorchvideo.models.

Config:
name (str): Always 'pytorchvideo' Used for builder_encoder()
random_init (bool): Flag to load pretrained weights
model_name (str): Name of the pytorchvideo model to use
drop_last_n_layers (int):
<=0 value for the number of layers to drop off the end
pooler_name (str): Name of pooler used on model output

Raises:
ImportError:
The constructor raises an ImportError if pytorchvideo is not installed.
"""

@dataclass
class Config(Encoder.Config):
name: str = "pytorchvideo"
random_init: bool = False
model_name: str = "slowfast_r50"
drop_last_n_layers: int = -1
pooler_name: str = "identity"

PYTORCHVIDEO_REPO = "facebookresearch/pytorchvideo:main"

def __init__(self, config: Config):
super().__init__()
config = OmegaConf.create({**asdict(self.Config()), **config})
if config.random_init:
params = dict(**OmegaConf.to_container(config))
params = {
k: v
for k, v in params.items()
if k not in PytorchVideoEncoder.Config().__dict__
}
try:
model = torch.hub.load(
PytorchVideoEncoder.PYTORCHVIDEO_REPO,
model=config.model_name,
pretrained=False,
**params,
)
except BaseException as err:
pytorchvideo_spec = importlib.util.find_spec("pytorchvideo")
if pytorchvideo_spec is None:
raise err
import pytorchvideo.models.hub as hub

model_create_fn = getattr(hub, config.model_name)
model = model_create_fn(pretrained=False, **params)
else:
# load weights from TorchHub
model = torch.hub.load(
PytorchVideoEncoder.PYTORCHVIDEO_REPO,
model=config.model_name,
pretrained=True,
)
encoder_list = []
if config.drop_last_n_layers == 0:
encoder_list += [model]
else:
modules_list = list(model.children())
if len(modules_list) == 1:
modules_list = list(modules_list[0].children())
modules = modules_list[: config.drop_last_n_layers]
encoder_list += modules

pooler = registry.get_pool_class(config.pooler_name)()
encoder_list += [pooler]
self.encoder = nn.Sequential(*encoder_list)

def forward(self, *args, **kwargs):
# pass along input to model
# assumes caller obeys the dynamic model signature
return self.encoder(*args, **kwargs)


@registry.register_encoder("r2plus1d_18")
class R2Plus1D18VideoEncoder(PooledEncoder):
"""
Expand Down
61 changes: 60 additions & 1 deletion tests/models/test_mmf_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from mmf.utils.configuration import Configuration
from mmf.utils.env import setup_imports, teardown_imports
from omegaconf import OmegaConf

from tests.test_utils import (
skip_if_no_pytorchvideo,
)

BERT_VOCAB_SIZE = 30255
ROBERTA_VOCAB_SIZE = 50265
Expand Down Expand Up @@ -444,6 +446,63 @@ def test_preprocessing_with_resnet_encoder(self):
test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]]))
test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long())

@skip_if_no_pytorchvideo
def test_preprocessing_with_mvit_encoder(self):
encoder_config = OmegaConf.create(
{
"name": "pytorchvideo",
"model_name": "mvit_base_32x3",
"random_init": True,
"drop_last_n_layers": 0,
"pooler_name": "cls",
"spatial_size": 224,
"temporal_size": 8,
"head": None,
"embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
"atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
"pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
"pool_kv_stride_adaptive": [1, 8, 8],
"pool_kvq_kernel": [3, 3, 3],
}
)
self._image_modality_config = MMFTransformerModalityConfig(
type="image",
key="image",
embedding_dim=768,
position_dim=1,
segment_id=0,
encoder=encoder_config,
)
modalities_config = [self._image_modality_config, self._text_modality_config]
config = MMFTransformer.Config(modalities=modalities_config, num_labels=2)
mmft = build_model(config)

sample_list = SampleList()
sample_list.image = torch.rand((2, 3, 8, 224, 224))
sample_list.text = torch.randint(0, 512, (2, 128))

transformer_input = mmft.preprocess_sample(sample_list)
input_ids = transformer_input["input_ids"]
self.assertEqual(input_ids["image"].dim(), 3)
self.assertEqual(list(input_ids["image"].size()), [2, 1, 768])

self.assertEqual(input_ids["text"].dim(), 2)
self.assertEqual(list(input_ids["text"].size()), [2, 128])

position_ids = transformer_input["position_ids"]
test_utils.compare_tensors(position_ids["image"], torch.tensor([[0], [0]]))
test_utils.compare_tensors(
position_ids["text"], torch.arange(0, 128).unsqueeze(0).expand((2, 128))
)

masks = transformer_input["masks"]
test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long())

segment_ids = transformer_input["segment_ids"]
test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]]))
test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long())

def test_tie_mlm_head_weight_to_encoder(self):
self._text_modality_config = MMFTransformerModalityConfig(
type="text",
Expand Down
62 changes: 61 additions & 1 deletion tests/modules/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import torch
from mmf.modules import encoders
from omegaconf import OmegaConf
from tests.test_utils import setup_proxy, skip_if_old_transformers
from tests.test_utils import (
setup_proxy,
skip_if_old_transformers,
skip_if_no_pytorchvideo,
)
from torch import nn


Expand Down Expand Up @@ -102,3 +106,59 @@ def test_vit_encoder(self):
x = torch.rand(32, 197, 768)
output, _ = encoder(x)
self.assertEqual(output.size(-1), config.out_dim)

@skip_if_no_pytorchvideo
def test_pytorchvideo_slowfast_r50_encoder(self):
# instantiate video encoder from pytorchvideo
# default model is slowfast_r50
config = OmegaConf.structured(encoders.PytorchVideoEncoder.Config())
encoder = encoders.PytorchVideoEncoder(config)
fast = torch.rand((1, 3, 32, 224, 224))
slow = torch.rand((1, 3, 8, 224, 224))
output = encoder([slow, fast])
# check output tensor is the expected feature dim size
# (bs, feature_dim)
self.assertEqual(output.size(1), 2304)

@skip_if_no_pytorchvideo
def test_mvit_encoder(self):
config = {
"name": "pytorchvideo",
"model_name": "mvit_base_32x3",
"random_init": True,
"drop_last_n_layers": 0,
"pooler_name": "cls",
"spatial_size": 224,
"temporal_size": 8,
"head": None,
"embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
"atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
"pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
"pool_kv_stride_adaptive": [1, 8, 8],
"pool_kvq_kernel": [3, 3, 3],
}
# test bert cls pooler
encoder = encoders.PytorchVideoEncoder(OmegaConf.create(config))
x = torch.rand((1, 3, 8, 224, 224))
output = encoder(x)
# check output tensor is the expected feature dim size
# based on pooled attention configs
# for more details consult https://arxiv.org/pdf/2104.11227
# and https://github.com/facebookresearch/pytorchvideo/
# (bs, num_features, feature_dim)
self.assertEqual(output.shape, torch.Size([1, 768]))

# test avg pooler
encoder = encoders.PytorchVideoEncoder(
OmegaConf.create(dict(config, pooler_name="avg"))
)
output = encoder(x)
self.assertEqual(output.shape, torch.Size([1, 768]))

# test no pooling
encoder = encoders.PytorchVideoEncoder(
OmegaConf.create(dict(config, pooler_name="identity"))
)
output = encoder(x)
# (bs, num_features, feature_dim)
self.assertEqual(output.shape, torch.Size([1, 197, 768]))
7 changes: 7 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ def wrap(testfn, reason="Requires newer version of transformers"):
return wrap


def skip_if_no_pytorchvideo(testfn, reason="Requires pytorchvideo"):
import importlib

pytorchvideo_spec = importlib.util.find_spec("pytorchvideo")
return unittest.skipIf(pytorchvideo_spec is None, reason)(testfn)


def compare_state_dicts(a, b):
same = True
same = same and (list(a.keys()) == list(b.keys()))
Expand Down