Skip to content

Commit

Permalink
Add PyTorch hub support (#2)
Browse files Browse the repository at this point in the history
Models can now be loaded with

```python
import torch.hub
repo = 'epic-kitchens/action-models'

class_counts = (125, 352)
segment_count = 8
base_model = 'resnet50'
tsn = torch.hub.load(repo, 'TSN', class_counts, segment_count, "RGB", base_model=base_model, pretrained='epic-kitchens')
trn = torch.hub.load(repo, 'TRN', class_counts, segment_count, "RGB", base_model=base_model, pretrained='epic-kitchens')
mtrn = torch.hub.load(repo, 'MTRN', class_counts, segment_count, "RGB", base_model=base_model, pretrained='epic-kitchens')
tsm = torch.hub.load(repo, 'TSM', class_counts, segment_count, "RGB", base_model=base_model, pretrained='epic-kitchens')
```

All classes have help docstrings describing their constructor args.

There are now `TRN` and `MTRN` classes that preconfigure their parent `TSN` class with the correct consensus function.

The `download.py` script has been removed in favour of people using torch.hub. Otherwise users are expected to manually download checkpoints.
  • Loading branch information
willprice committed Aug 5, 2019
1 parent ede98a0 commit 4a396f8
Show file tree
Hide file tree
Showing 13 changed files with 575 additions and 174 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Expand Up @@ -164,3 +164,6 @@ fabric.properties
.idea

**/__pycache__

/checkpoints*
/results*
108 changes: 86 additions & 22 deletions README.md
Expand Up @@ -8,7 +8,75 @@ This is a set of models trained for EPIC-Kitchens baselines. We support:

Many thanks to the authors of these repositories.

## Set up
You can use the code provided here in one of two ways:

1. [PyTorch hub](#pytorch-hub) (**recommended**)
1. [Local installation](#local-installation)

## PyTorch Hub

[PyTorch Hub](https://pytorch.org/hub) is a way to easily share models with
others. Using our models via hub is as simple as

```python
import torch.hub
repo = 'epic-kitchens/action-models'

class_counts = (125, 352)
segment_count = 8
base_model = 'resnet50'
tsn = torch.hub.load(repo, 'TSN', class_counts, segment_count, 'RGB',
base_model=base_model,
pretrained='epic-kitchens', force_reload=True)
trn = torch.hub.load(repo, 'TRN', class_counts, segment_count, 'RGB',
base_model=base_model,
pretrained='epic-kitchens')
mtrn = torch.hub.load(repo, 'MTRN', class_counts, segment_count, 'RGB',
base_model=base_model,
pretrained='epic-kitchens')
tsm = torch.hub.load(repo, 'TSM', class_counts, segment_count, 'RGB',
base_model=base_model,
pretrained='epic-kitchens')

# Show all entrypoints and their help strings
for entrypoint in torch.hub.list(repo):
print(entrypoint)
print(torch.hub.help(repo, entrypoint))

batch_size = 1
segment_count = 8
snippet_length = 1 # Number of frames composing the snippet, 1 for RGB, 5 for optical flow
snippet_channels = 3 # Number of channels in a frame, 3 for RGB, 2 for optical flow
height, width = 224, 224

inputs = torch.randn(
[batch_size, segment_count, snippet_length, snippet_channels, height, width]
)
# The segment and snippet length and channel dimensions are collapsed into the channel
# dimension
# Input shape: N x TC x H x W
inputs = inputs.reshape((batch_size, -1, height, width))
for model in [tsn, trn, mtrn, tsm]:
# You can get features out of the models
features = model.features(inputs)
# and then classify those features
verb_logits, noun_logits = model.logits(features)

# or just call the object to classify inputs in a single forward pass
verb_logits, noun_logits = model(inputs)
print(verb_logits.shape, noun_logits.shape)
```

NOTE: We are dependent upon a [fork of Remi Cadene's pretrained
models](https://github.com/wpwei/pretrained-models.pytorch/tree/vision_bug_fix)
that brings `DataParallel` support to PyTorch 1+.
Install this with:

```
$ pip install git+https://github.com/wpwei/pretrained-models.pytorch.git@vision_bug_fix
```

## Local Installation

We provide an `environment.yml` file to create a conda environment. Sadly not all of the
set up can be encapsulated in this file, so you have to perform some steps yourself
Expand Down Expand Up @@ -118,35 +186,31 @@ model = load_checkpoint('path/to/checkpoint.pth.tar')


## Checkpoints
You can download checkpoints using the tool provided at `checkpoints/download.py`,
simply call it with the model variant, modality, and architecture that you wish to
download, e.g. `python checkpoints/download.py mtrn --arch BNInception --modality
Flow`. The checkpoint will be downloaded to the `checkpoints` directory.

The checkpoints accompanying this repository score the following on the test set
when using 10 crop evaluation.

| Checkpoint path | Seen V@1 | Seen N@1 | Seen A@1 | Unseen V@1 | Unseen N@1 | Unseen A@1 |
|----------------------------------------------------------|-----------|-----------|-----------|------------|------------|------------|
| `TSN_arch=BNInception_modality=RGB_segments=8.pth.tar` | 47.97 | 38.85 | 22.39 | 36.46 | 22.64 | 22.39 |
| `TSN_arch=BNInception_modality=Flow_segments=8.pth.tar` | 51.68 | 26.82 | 16.76 | 47.35 | 21.20 | 13.49 |
| `TRN_arch=BNInception_modality=RGB_segments=8.pth.tar` | 58.26 | 36.32 | 25.46 | 47.29 | 22.91 | 15.06 |
| `TRN_arch=BNInception_modality=Flow_segments=8.pth.tar` | 55.20 | 23.95 | 16.03 | 50.32 | 19.02 | 12.77 |
| `MTRN_arch=BNInception_modality=RGB_segments=8.pth.tar` | 55.76 | 37.94 | 26.62 | 45.41 | 23.90 | 15.57 |
| `MTRN_arch=BNInception_modality=Flow_segments=8.pth.tar` | 55.92 | 24.88 | 16.78 | 51.38 | 20.69 | 14.00 |
| `TSN_arch=resnet50_modality=RGB_segments=8.pth.tar` | 49.71 | 39.85 | 23.97 | 36.70 | 23.11 | 12.77 |
| `TSN_arch=resnet50_modality=Flow_segments=8.pth.tar` | 53.14 | 27.76 | 20.28 | 47.56 | 20.28 | 13.11 |
| `TRN_arch=resnet50_modality=RGB_segments=8.pth.tar` | 58.82 | 37.27 | 26.62 | 47.32 | 23.69 | 15.71 |
| `TRN_arch=resnet50_modality=Flow_segments=8.pth.tar` | 55.16 | 23.19 | 15.77 | 50.39 | 18.50 | 12.02 |
| `MTRN_arch=resnet50_modality=RGB_segments=8.pth.tar` | **60.16** | 38.36 | **28.23** | 46.94 | **24.41** | **16.32** |
| `MTRN_arch=resnet50_modality=Flow_segments=8.pth.tar` | 56.79 | 25.00 | 17.24 | 50.36 | 20.28 | 13.42 |
| `TSM_arch=resnet50_modality=RGB_segments=8.pth.tar` | 57.88 | **40.84** | **28.22** | 43.50 | 23.32 | 14.99 |
| `TSM_arch=resnet50_modality=Flow_segments=8.pth.tar` | 58.08 | 27.49 | 19.14 | **52.68** | 20.83 | 14.27 |
| Variant | Arch | Modality | # Segments | Seen V@1 | Seen N@1 | Seen A@1 | Unseen V@1 | Unseen N@1 | Unseen A@1 |
|---------|--------------|----------|------------|-----------|-----------|-----------|------------|------------|------------|
| TSN | BN-Inception | RGB | 8 | 47.97 | 38.85 | 22.39 | 36.46 | 22.64 | 22.39 |
| TSN | BN-Inception | Flow | 8 | 51.68 | 26.82 | 16.76 | 47.35 | 21.20 | 13.49 |
| TRN | BN-Inception | RGB | 8 | 58.26 | 36.32 | 25.46 | 47.29 | 22.91 | 15.06 |
| TRN | BN-Inception | Flow | 8 | 55.20 | 23.95 | 16.03 | 50.32 | 19.02 | 12.77 |
| M-TRN | BN-Inception | RGB | 8 | 55.76 | 37.94 | 26.62 | 45.41 | 23.90 | 15.57 |
| M-TRN | BN-Inception | Flow | 8 | 55.92 | 24.88 | 16.78 | 51.38 | 20.69 | 14.00 |
| TSN | ResNet-50 | RGB | 8 | 49.71 | 39.85 | 23.97 | 36.70 | 23.11 | 12.77 |
| TSN | ResNet-50 | Flow | 8 | 53.14 | 27.76 | 20.28 | 47.56 | 20.28 | 13.11 |
| TRN | ResNet-50 | RGB | 8 | 58.82 | 37.27 | 26.62 | 47.32 | 23.69 | 15.71 |
| TRN | ResNet-50 | Flow | 8 | 55.16 | 23.19 | 15.77 | 50.39 | 18.50 | 12.02 |
| M-TRN | ResNet-50 | RGB | 8 | **60.16** | 38.36 | **28.23** | 46.94 | **24.41** | **16.32** |
| M-TRN | ResNet-50 | Flow | 8 | 56.79 | 25.00 | 17.24 | 50.36 | 20.28 | 13.42 |
| TSM | ResNet-50 | RGB | 8 | 57.88 | **40.84** | **28.22** | 43.50 | 23.32 | 14.99 |
| TSM | ResNet-50 | Flow | 8 | 58.08 | 27.49 | 19.14 | **52.68** | 20.83 | 14.27 |


## Extracting features

Both classes `TSN` and `TSM` include `features` and `logits` methods, mimicking the
Classes include `features` and `logits` methods, mimicking the
[`pretrainedmodels`](https://github.com/Cadene/pretrained-models.pytorch) API. Simply
create a model instance `model = TSN(...)` and call `model.features(input)` to
obtain base-model features. To transform these to logits, call
Expand Down
1 change: 0 additions & 1 deletion checkpoints/.gitignore

This file was deleted.

Empty file removed checkpoints/.gitkeep
Empty file.
75 changes: 0 additions & 75 deletions checkpoints/download.py

This file was deleted.

2 changes: 1 addition & 1 deletion hubconf.py
@@ -1,4 +1,4 @@
dependencies = ['torch', 'torchvision', 'pretrainedmodels']

from tsn import TSN
from tsn import TSN, TRN, MTRN
from tsm import TSM
10 changes: 4 additions & 6 deletions model_loader.py
Expand Up @@ -4,7 +4,7 @@
import torch

from tsm import TSM
from tsn import TSN
from tsn import TSN, TRN, MTRN

verb_class_count, noun_class_count = 125, 352
class_count = (verb_class_count, noun_class_count)
Expand All @@ -25,18 +25,17 @@ def make_tsn(settings):
def make_trn(settings):
model_type = settings["model_type"]
if model_type == "trn":
consensus_type = "TRN"
cls = TRN
elif model_type == "mtrn":
consensus_type = "TRNMultiscale"
cls = MTRN
else:
raise ValueError(f"Unknown model_type '{model_type}' for TRN")
return TSN(
return cls(
class_count,
settings["segment_count"],
settings["modality"],
base_model=settings["arch"],
new_length=settings["flow_length"] if settings["modality"] == "Flow" else 1,
consensus_type=consensus_type,
img_feature_dim=settings["img_feature_dim"],
dropout=settings["dropout"],
)
Expand All @@ -52,7 +51,6 @@ def make_tsm(settings):
new_length=settings["flow_length"] if settings["modality"] == "Flow" else 1,
consensus_type="avg",
dropout=settings["dropout"],
is_shift=True,
shift_div=settings["shift_div"],
shift_place=settings["shift_place"],
temporal_pool=settings["temporal_pool"],
Expand Down
3 changes: 1 addition & 2 deletions ops/non_local.py
Expand Up @@ -223,6 +223,5 @@ def make_non_local(net, n_segment):
img = torch.zeros(*([2, 3] + [20] * dim))
print("img shape: {}".format(img.shape))
net = fn(3, sub_sample=sub_sample, bn_layer=bn_layer)
out = net(img)
out = net(img) # type: ignore
print(out.size())

50 changes: 50 additions & 0 deletions pretrained_settings.py
@@ -0,0 +1,50 @@
from collections import namedtuple

__all__ = ["urls", "ModelConfig"]


ModelConfig = namedtuple(
"ModelConfig",
["variant", "base_model", "modality", "num_segments", "consensus_type"],
)
_epic_url_base = (
"https://wp-research-public.s3-eu-west-1.amazonaws.com/epic-models-checkpoints/"
)


urls = {
"epic-kitchens": {
ModelConfig("TSN", "resnet50", "RGB", 8, "avg"): _epic_url_base
+ "TSN_arch=resnet50_modality=RGB_segments=8-3ecf904f.pth.tar",
ModelConfig("TSN", "resnet50", "Flow", 8, "avg"): _epic_url_base
+ "TSN_arch=resnet50_modality=Flow_segments=8-4317bc4a.pth.tar",
ModelConfig("TSN", "BNInception", "RGB", 8, "avg"): _epic_url_base
+ "TSN_arch=BNInception_modality=RGB_segments=8-efb96e64.pth.tar",
ModelConfig("TSN", "BNInception", "Flow", 8, "avg"): _epic_url_base
+ "TSN_arch=BNInception_modality=Flow_segments=8-4c720ee3.pth.tar",
ModelConfig("TRN", "BNInception", "RGB", 8, "TRN"): _epic_url_base
+ "TRN_arch=BNInception_modality=RGB_segments=8-a770bfbd.pth.tar",
ModelConfig("TRN", "BNInception", "Flow", 8, "TRN"): _epic_url_base
+ "TRN_arch=BNInception_modality=Flow_segments=8-4f84b178.pth.tar",
ModelConfig("TRN", "resnet50", "RGB", 8, "TRN"): _epic_url_base
+ "TRN_arch=resnet50_modality=RGB_segments=8-c8176b38.pth.tar",
ModelConfig("TRN", "resnet50", "Flow", 8, "TRN"): _epic_url_base
+ "TRN_arch=resnet50_modality=Flow_segments=8-c0a2821c.pth.tar",
ModelConfig("MTRN", "BNInception", "RGB", 8, "TRNMultiscale"): _epic_url_base
+ "MTRN_arch=BNInception_modality=RGB_segments=8-8933f99e.pth.tar",
ModelConfig("MTRN", "BNInception", "Flow", 8, "TRNMultiscale"): _epic_url_base
+ "MTRN_arch=BNInception_modality=Flow_segments=8-c0cea7e1.pth.tar",
ModelConfig("MTRN", "resnet50", "RGB", 8, "TRNMultiscale"): _epic_url_base
+ "MTRN_arch=resnet50_modality=RGB_segments=8-46337796.pth.tar",
ModelConfig("MTRN", "resnet50", "Flow", 8, "TRNMultiscale"): _epic_url_base
+ "MTRN_arch=resnet50_modality=Flow_segments=8-6667f285.pth.tar",
ModelConfig("TSM", "resnet50", "RGB", 8, "avg"): _epic_url_base
+ "TSM_arch=resnet50_modality=RGB_segments=8-cfc93918.pth.tar",
ModelConfig("TSM", "resnet50", "Flow", 8, "avg"): _epic_url_base
+ "TSM_arch=resnet50_modality=Flow_segments=8-e09c2d3a.pth.tar",
}
}


class InvalidPretrainError(Exception):
pass

0 comments on commit 4a396f8

Please sign in to comment.