Skip to content

Commit

Permalink
[feat] Add pytorchvideo encoder wrapper (#1156)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1156

Add an encoder class that constructs
any pytorchvideo model from config,
and uses this model for its forward pass.

Can load pretrained or random init models,
based on config.

Test Plan:
Tested through unit tests on slowfast50 and mvit.

Will be tested end-to-end when datasets and transformers are available in mmf

```
(torchvideo) ryanjiang@learnfair5083:~/copy/mmf$ pytest tests/models/test_mmf_transformer.py
================================================== test session starts ==================================================
platform linux -- Python 3.7.11, pytest-6.2.5, py-1.10.0, pluggy-1.0.0
rootdir: /private/home/ryanjiang/copy/mmf
plugins: forked-1.3.0, timeout-1.4.2, hydra-core-1.1.1, xdist-2.4.0, dash-2.0.0
collected 15 items

tests/models/test_mmf_transformer.py ...............                                                              [100%]

(torchvideo) ryanjiang@learnfair5083:~/copy/mmf$ pytest tests/modules/test_encoders.py
================================================== test session starts ==================================================
platform linux -- Python 3.7.11, pytest-6.2.5, py-1.10.0, pluggy-1.0.0
rootdir: /private/home/ryanjiang/copy/mmf
plugins: forked-1.3.0, timeout-1.4.2, hydra-core-1.1.1, xdist-2.4.0, dash-2.0.0
collected 12 items

tests/modules/test_encoders.py ............                                                                       [100%]
```

Reviewed By: apsdehal

Differential Revision: D32631207

Pulled By: Ryan-Qiyu-Jiang

fbshipit-source-id: 6b549162f7ae9ccea162563e48ed910618a6da54
  • Loading branch information
Ryan-Qiyu-Jiang authored and facebook-github-bot committed Dec 16, 2021
1 parent ee19bd9 commit 68add70
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 4 deletions.
91 changes: 89 additions & 2 deletions mmf/modules/encoders.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import importlib
import logging
import os
import pickle
import re
from collections import OrderedDict
from copy import deepcopy
from dataclasses import dataclass
from dataclasses import asdict, dataclass
from enum import Enum
from typing import Any

Expand All @@ -25,13 +27,15 @@
from transformers.configuration_auto import AutoConfig
from transformers.modeling_auto import AutoModel


try:
from detectron2.modeling import ShapeSpec, build_resnet_backbone
except ImportError:
pass


logger = logging.getLogger()


class Encoder(nn.Module):
@dataclass
class Config:
Expand Down Expand Up @@ -688,6 +692,89 @@ def forward(self, x: Tensor) -> Tensor:
return out


@registry.register_encoder("pytorchvideo")
class PytorchVideoEncoder(Encoder):
"""A thin wrapper around pytorchvideo models.
This class is responsible for integrating pytorchvideo models as encoders.
THis class attempts to construct a pytorchvideo model from torch hub.
If this fails for a random weight model, and pytorchvideo package is available,
build the model with random weights from pytorchvideo.models.
Config:
name (str): Always 'pytorchvideo' Used for builder_encoder()
random_init (bool): Flag to load pretrained weights
model_name (str): Name of the pytorchvideo model to use
drop_last_n_layers (int):
<=0 value for the number of layers to drop off the end
pooler_name (str): Name of pooler used on model output
Raises:
ImportError:
The constructor raises an ImportError if pytorchvideo is not installed.
"""

@dataclass
class Config(Encoder.Config):
name: str = "pytorchvideo"
random_init: bool = False
model_name: str = "slowfast_r50"
drop_last_n_layers: int = -1
pooler_name: str = "identity"

PYTORCHVIDEO_REPO = "facebookresearch/pytorchvideo:main"

def __init__(self, config: Config):
super().__init__()
config = OmegaConf.create({**asdict(self.Config()), **config})
if config.random_init:
params = dict(**OmegaConf.to_container(config))
params = {
k: v
for k, v in params.items()
if k not in PytorchVideoEncoder.Config().__dict__
}
try:
model = torch.hub.load(
PytorchVideoEncoder.PYTORCHVIDEO_REPO,
model=config.model_name,
pretrained=False,
**params,
)
except BaseException as err:
pytorchvideo_spec = importlib.util.find_spec("pytorchvideo")
if pytorchvideo_spec is None:
raise err
import pytorchvideo.models.hub as hub

model_create_fn = getattr(hub, config.model_name)
model = model_create_fn(pretrained=False, **params)
else:
# load weights from TorchHub
model = torch.hub.load(
PytorchVideoEncoder.PYTORCHVIDEO_REPO,
model=config.model_name,
pretrained=True,
)
encoder_list = []
if config.drop_last_n_layers == 0:
encoder_list += [model]
else:
modules_list = list(model.children())
if len(modules_list) == 1:
modules_list = list(modules_list[0].children())
modules = modules_list[: config.drop_last_n_layers]
encoder_list += modules

pooler = registry.get_pool_class(config.pooler_name)()
encoder_list += [pooler]
self.encoder = nn.Sequential(*encoder_list)

def forward(self, *args, **kwargs):
# pass along input to model
# assumes caller obeys the dynamic model signature
return self.encoder(*args, **kwargs)


@registry.register_encoder("r2plus1d_18")
class R2Plus1D18VideoEncoder(PooledEncoder):
"""
Expand Down
61 changes: 60 additions & 1 deletion tests/models/test_mmf_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from mmf.utils.configuration import Configuration
from mmf.utils.env import setup_imports, teardown_imports
from omegaconf import OmegaConf

from tests.test_utils import (
skip_if_no_pytorchvideo,
)

BERT_VOCAB_SIZE = 30255
ROBERTA_VOCAB_SIZE = 50265
Expand Down Expand Up @@ -444,6 +446,63 @@ def test_preprocessing_with_resnet_encoder(self):
test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]]))
test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long())

@skip_if_no_pytorchvideo
def test_preprocessing_with_mvit_encoder(self):
encoder_config = OmegaConf.create(
{
"name": "pytorchvideo",
"model_name": "mvit_base_32x3",
"random_init": True,
"drop_last_n_layers": 0,
"pooler_name": "cls",
"spatial_size": 224,
"temporal_size": 8,
"head": None,
"embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
"atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
"pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
"pool_kv_stride_adaptive": [1, 8, 8],
"pool_kvq_kernel": [3, 3, 3],
}
)
self._image_modality_config = MMFTransformerModalityConfig(
type="image",
key="image",
embedding_dim=768,
position_dim=1,
segment_id=0,
encoder=encoder_config,
)
modalities_config = [self._image_modality_config, self._text_modality_config]
config = MMFTransformer.Config(modalities=modalities_config, num_labels=2)
mmft = build_model(config)

sample_list = SampleList()
sample_list.image = torch.rand((2, 3, 8, 224, 224))
sample_list.text = torch.randint(0, 512, (2, 128))

transformer_input = mmft.preprocess_sample(sample_list)
input_ids = transformer_input["input_ids"]
self.assertEqual(input_ids["image"].dim(), 3)
self.assertEqual(list(input_ids["image"].size()), [2, 1, 768])

self.assertEqual(input_ids["text"].dim(), 2)
self.assertEqual(list(input_ids["text"].size()), [2, 128])

position_ids = transformer_input["position_ids"]
test_utils.compare_tensors(position_ids["image"], torch.tensor([[0], [0]]))
test_utils.compare_tensors(
position_ids["text"], torch.arange(0, 128).unsqueeze(0).expand((2, 128))
)

masks = transformer_input["masks"]
test_utils.compare_tensors(masks["image"], torch.tensor([[1], [1]]))
test_utils.compare_tensors(masks["text"], torch.ones((2, 128)).long())

segment_ids = transformer_input["segment_ids"]
test_utils.compare_tensors(segment_ids["image"], torch.tensor([[0], [0]]))
test_utils.compare_tensors(segment_ids["text"], torch.ones((2, 128)).long())

def test_tie_mlm_head_weight_to_encoder(self):
self._text_modality_config = MMFTransformerModalityConfig(
type="text",
Expand Down
62 changes: 61 additions & 1 deletion tests/modules/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@
import torch
from mmf.modules import encoders
from omegaconf import OmegaConf
from tests.test_utils import setup_proxy, skip_if_old_transformers
from tests.test_utils import (
setup_proxy,
skip_if_old_transformers,
skip_if_no_pytorchvideo,
)
from torch import nn


Expand Down Expand Up @@ -102,3 +106,59 @@ def test_vit_encoder(self):
x = torch.rand(32, 197, 768)
output, _ = encoder(x)
self.assertEqual(output.size(-1), config.out_dim)

@skip_if_no_pytorchvideo
def test_pytorchvideo_slowfast_r50_encoder(self):
# instantiate video encoder from pytorchvideo
# default model is slowfast_r50
config = OmegaConf.structured(encoders.PytorchVideoEncoder.Config())
encoder = encoders.PytorchVideoEncoder(config)
fast = torch.rand((1, 3, 32, 224, 224))
slow = torch.rand((1, 3, 8, 224, 224))
output = encoder([slow, fast])
# check output tensor is the expected feature dim size
# (bs, feature_dim)
self.assertEqual(output.size(1), 2304)

@skip_if_no_pytorchvideo
def test_mvit_encoder(self):
config = {
"name": "pytorchvideo",
"model_name": "mvit_base_32x3",
"random_init": True,
"drop_last_n_layers": 0,
"pooler_name": "cls",
"spatial_size": 224,
"temporal_size": 8,
"head": None,
"embed_dim_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
"atten_head_mul": [[1, 2.0], [3, 2.0], [14, 2.0]],
"pool_q_stride_size": [[1, 1, 2, 2], [3, 1, 2, 2], [14, 1, 2, 2]],
"pool_kv_stride_adaptive": [1, 8, 8],
"pool_kvq_kernel": [3, 3, 3],
}
# test bert cls pooler
encoder = encoders.PytorchVideoEncoder(OmegaConf.create(config))
x = torch.rand((1, 3, 8, 224, 224))
output = encoder(x)
# check output tensor is the expected feature dim size
# based on pooled attention configs
# for more details consult https://arxiv.org/pdf/2104.11227
# and https://github.com/facebookresearch/pytorchvideo/
# (bs, num_features, feature_dim)
self.assertEqual(output.shape, torch.Size([1, 768]))

# test avg pooler
encoder = encoders.PytorchVideoEncoder(
OmegaConf.create(dict(config, pooler_name="avg"))
)
output = encoder(x)
self.assertEqual(output.shape, torch.Size([1, 768]))

# test no pooling
encoder = encoders.PytorchVideoEncoder(
OmegaConf.create(dict(config, pooler_name="identity"))
)
output = encoder(x)
# (bs, num_features, feature_dim)
self.assertEqual(output.shape, torch.Size([1, 197, 768]))
7 changes: 7 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,13 @@ def wrap(testfn, reason="Requires newer version of transformers"):
return wrap


def skip_if_no_pytorchvideo(testfn, reason="Requires pytorchvideo"):
import importlib

pytorchvideo_spec = importlib.util.find_spec("pytorchvideo")
return unittest.skipIf(pytorchvideo_spec is None, reason)(testfn)


def compare_state_dicts(a, b):
same = True
same = same and (list(a.keys()) == list(b.keys()))
Expand Down

0 comments on commit 68add70

Please sign in to comment.