Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Models can now be loaded with ```python import torch.hub repo = 'epic-kitchens/action-models' class_counts = (125, 352) segment_count = 8 base_model = 'resnet50' tsn = torch.hub.load(repo, 'TSN', class_counts, segment_count, "RGB", base_model=base_model, pretrained='epic-kitchens') trn = torch.hub.load(repo, 'TRN', class_counts, segment_count, "RGB", base_model=base_model, pretrained='epic-kitchens') mtrn = torch.hub.load(repo, 'MTRN', class_counts, segment_count, "RGB", base_model=base_model, pretrained='epic-kitchens') tsm = torch.hub.load(repo, 'TSM', class_counts, segment_count, "RGB", base_model=base_model, pretrained='epic-kitchens') ``` All classes have help docstrings describing their constructor args. There are now `TRN` and `MTRN` classes that preconfigure their parent `TSN` class with the correct consensus function. The `download.py` script has been removed in favour of people using torch.hub. Otherwise users are expected to manually download checkpoints.
- Loading branch information
Showing
13 changed files
with
575 additions
and
174 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -164,3 +164,6 @@ fabric.properties | |
.idea | ||
|
||
**/__pycache__ | ||
|
||
/checkpoints* | ||
/results* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Empty file.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
dependencies = ['torch', 'torchvision', 'pretrainedmodels'] | ||
|
||
from tsn import TSN | ||
from tsn import TSN, TRN, MTRN | ||
from tsm import TSM |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from collections import namedtuple | ||
|
||
__all__ = ["urls", "ModelConfig"] | ||
|
||
|
||
ModelConfig = namedtuple( | ||
"ModelConfig", | ||
["variant", "base_model", "modality", "num_segments", "consensus_type"], | ||
) | ||
_epic_url_base = ( | ||
"https://wp-research-public.s3-eu-west-1.amazonaws.com/epic-models-checkpoints/" | ||
) | ||
|
||
|
||
urls = { | ||
"epic-kitchens": { | ||
ModelConfig("TSN", "resnet50", "RGB", 8, "avg"): _epic_url_base | ||
+ "TSN_arch=resnet50_modality=RGB_segments=8-3ecf904f.pth.tar", | ||
ModelConfig("TSN", "resnet50", "Flow", 8, "avg"): _epic_url_base | ||
+ "TSN_arch=resnet50_modality=Flow_segments=8-4317bc4a.pth.tar", | ||
ModelConfig("TSN", "BNInception", "RGB", 8, "avg"): _epic_url_base | ||
+ "TSN_arch=BNInception_modality=RGB_segments=8-efb96e64.pth.tar", | ||
ModelConfig("TSN", "BNInception", "Flow", 8, "avg"): _epic_url_base | ||
+ "TSN_arch=BNInception_modality=Flow_segments=8-4c720ee3.pth.tar", | ||
ModelConfig("TRN", "BNInception", "RGB", 8, "TRN"): _epic_url_base | ||
+ "TRN_arch=BNInception_modality=RGB_segments=8-a770bfbd.pth.tar", | ||
ModelConfig("TRN", "BNInception", "Flow", 8, "TRN"): _epic_url_base | ||
+ "TRN_arch=BNInception_modality=Flow_segments=8-4f84b178.pth.tar", | ||
ModelConfig("TRN", "resnet50", "RGB", 8, "TRN"): _epic_url_base | ||
+ "TRN_arch=resnet50_modality=RGB_segments=8-c8176b38.pth.tar", | ||
ModelConfig("TRN", "resnet50", "Flow", 8, "TRN"): _epic_url_base | ||
+ "TRN_arch=resnet50_modality=Flow_segments=8-c0a2821c.pth.tar", | ||
ModelConfig("MTRN", "BNInception", "RGB", 8, "TRNMultiscale"): _epic_url_base | ||
+ "MTRN_arch=BNInception_modality=RGB_segments=8-8933f99e.pth.tar", | ||
ModelConfig("MTRN", "BNInception", "Flow", 8, "TRNMultiscale"): _epic_url_base | ||
+ "MTRN_arch=BNInception_modality=Flow_segments=8-c0cea7e1.pth.tar", | ||
ModelConfig("MTRN", "resnet50", "RGB", 8, "TRNMultiscale"): _epic_url_base | ||
+ "MTRN_arch=resnet50_modality=RGB_segments=8-46337796.pth.tar", | ||
ModelConfig("MTRN", "resnet50", "Flow", 8, "TRNMultiscale"): _epic_url_base | ||
+ "MTRN_arch=resnet50_modality=Flow_segments=8-6667f285.pth.tar", | ||
ModelConfig("TSM", "resnet50", "RGB", 8, "avg"): _epic_url_base | ||
+ "TSM_arch=resnet50_modality=RGB_segments=8-cfc93918.pth.tar", | ||
ModelConfig("TSM", "resnet50", "Flow", 8, "avg"): _epic_url_base | ||
+ "TSM_arch=resnet50_modality=Flow_segments=8-e09c2d3a.pth.tar", | ||
} | ||
} | ||
|
||
|
||
class InvalidPretrainError(Exception): | ||
pass |
Oops, something went wrong.