Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d269a33
commit 07cc9cf
Showing
40 changed files
with
4,057 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
env/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*.cover | ||
.hypothesis/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# SageMath parsed files | ||
*.sage.py | ||
|
||
# dotenv | ||
.env | ||
|
||
# virtualenv | ||
.venv | ||
venv/ | ||
ENV/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
.spyproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
# mkdocs documentation | ||
/site | ||
|
||
# experiment bash file | ||
*exp*.sh | ||
|
||
# C++ executive | ||
build | ||
.build_release | ||
|
||
# Caffe | ||
*.caffemodel | ||
|
||
# python temp files | ||
*.npy | ||
*.npz | ||
|
||
# log file | ||
./*.txt | ||
*.log | ||
|
||
# pycharm | ||
.idea/ | ||
|
||
# backup codes | ||
*.bk | ||
|
||
# experiment-related folders | ||
*experiments* | ||
|
||
# JSON files | ||
*.json | ||
|
||
# backup files | ||
*backup* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2019 Min-Hung Chen | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,142 @@ | ||
# TA3N | ||
PyTorch code for CVPRW 2019 paper: Temporal Attentive Alignment for Video Domain Adaptation | ||
# Temporal Attentive Alignment for Video Domain Adaptation | ||
This is the PyTorch implementation of our paper: | ||
|
||
**Temporal Attentive Alignment for Video Domain Adaptation** | ||
Min-Hung Chen, Zsolt Kira, Ghassan AlRegib | ||
CVPR Workshop (Learning from Unlabeled Videos), 2019 | ||
[arXiv (to be updated)]() | ||
|
||
<p align="center"> | ||
<img src="webpage/Overview.png?raw=true" width="60%"> | ||
</p> | ||
|
||
Although various image-based domain adaptation (DA) techniques have been proposed in recent years, domain shift in videos is still not well-explored. Most previous works only evaluate performance on small-scale datasets which are saturated. Therefore, we first propose a larger-scale dataset with larger domain discrepancy: UCF-HMDB_full. Second, we investigate different DA integration methods for videos, and show that simultaneously aligning and learning temporal dynamics achieves effective alignment even without sophisticated DA methods. Finally, we propose Temporal Attentive Adversarial Adaptation Network (TA3N), which explicitly attends to the temporal dynamics using domain discrepancy for more effective domain alignment, achieving state-of-the-art performance on three video DA datasets. | ||
|
||
<p align="center"> | ||
<img src="webpage/SOTA_small.png?raw=true" width="49%"> | ||
<img src="webpage/SOTA_large.png?raw=true" width="50%"> | ||
</p> | ||
|
||
--- | ||
## Contents | ||
* [Requirements](#requirements) | ||
* [Dataset Preparation](#dataset-preparation) | ||
* [Data Structure](#data-structure) | ||
* [File lists for training/validation](#file-lists-for-trainingvalidation) | ||
* [Usage](#usage) | ||
* [Training](#training) | ||
* [Testing](#testing) | ||
* [Video Demo](#video-demo) | ||
* [Options](#options) | ||
* [Domain Adaptation](#domain-adaptation) | ||
* [More options](#more-options) | ||
* [Contact](#contact) | ||
|
||
--- | ||
## Requirements | ||
* support Python 3.6, PyTorch 0.4, CUDA 9.0, CUDNN 7.1.4 | ||
* install all the library with: `pip install -r requirements.txt` | ||
|
||
--- | ||
## Dataset Preparation | ||
### Data Structure | ||
You need to extract frame-level features for each video to run the codes. To extract features, please check [`dataset_preparation/`](dataset_preparation/). | ||
|
||
Folder Structure: | ||
``` | ||
DATA_PATH/ | ||
DATASET/ | ||
list_DATASET_SUFFIX.txt | ||
RGB/ | ||
CLASS_01/ | ||
VIDEO_0001.mp4 | ||
VIDEO_0002.mp4 | ||
... | ||
CLASS_02/ | ||
... | ||
RGB-Feature/ | ||
VIDEO_0001/ | ||
img_00001.t7 | ||
img_00002.t7 | ||
... | ||
VIDEO_0002/ | ||
... | ||
``` | ||
`RGB-Feature/` contains all the feature vectors for training/testing. `RGB/` contains all the raw videos. | ||
|
||
There should be at least two `DATASET` folders: source training set and validation set. If you want to do domain adaption, you need to have another `DATASET`: target training set. | ||
|
||
The pre-trained feature representations will be released soon. ([`TODO`](TODO/)) | ||
|
||
### File lists for training/validation | ||
The file list `list_DATASET_SUFFIX.txt` is required for data feeding. Each line in the list contains the full path of the video folder, video frame number, and video class index. It looks like: | ||
``` | ||
DATA_PATH/DATASET/RGB-Feature/VIDEO_0001/ 100 0 | ||
DATA_PATH/DATASET/RGB-Feature/VIDEO_0002/ 150 1 | ||
...... | ||
``` | ||
To generate the file list, please check [`dataset_preparation/`](dataset_preparation/). | ||
|
||
--- | ||
## Usage | ||
* training/validation: Run `./script_train_val.sh` | ||
<!-- * demo video: Run `./script_demo_video.sh` --> | ||
|
||
All the commonly used variables/parameters have comments in the end of the line. Please check [Options](#options). | ||
|
||
#### Training | ||
All the outputs will be under the directory `exp_path`. | ||
* Outputs: | ||
* model weights: `checkpoint.pth.tar`, `model_best.pth.tar` | ||
* log files: `train.log`, `train_short.log`, `val.log`, `val_short.log` | ||
|
||
#### Testing | ||
You can choose one of model_weights for testing. All the outputs will be under the directory `exp_path`. | ||
|
||
* Outputs: | ||
* score_data: used to check the model output (`scores_XXX.npz`) | ||
* confusion matrix: `confusion_matrix_XXX.png` and `confusion_matrix_XXX-topK.txt` | ||
|
||
<!-- #### Video Demo | ||
`demo_video.py` overlays the predicted categories and confidence values on one video. Please see "Results". --> | ||
|
||
--- | ||
## Options | ||
#### Domain Adaptation | ||
<!-- In both `./script_train_val.sh` and `./script_demo_video.sh`, there are several options related to our Domain Adaptation approaches. --> | ||
In `./script_train_val.sh`, there are several options related to our DA approaches. | ||
* `use_target`: switch on/off the DA mode | ||
* `none`: not use target data (no DA) | ||
* `uSv`/`Sv`: use target data in a unsupervised/supervised way | ||
* options for the DA approaches: | ||
* discrepancy-based: DAN, JAN | ||
* adversarial-based: RevGrad | ||
* Normalization-based: AdaBN | ||
* Ensemble-based: MCD | ||
|
||
#### More options | ||
For more details of all the arguments, please check [opts.py](opts.py). | ||
|
||
#### Notes | ||
The options in the scripts have comments with the following types: | ||
* no comment: user can still change it, but NOT recommend (may need to change the code or have different experimental results) | ||
* comments with choices (e.g. `true | false`): can only choose from choices | ||
* comments as `depend on users`: totally depend on users (mostly related to data path) | ||
|
||
--- | ||
## Citation | ||
If you find this repository useful, please cite our paper: | ||
``` | ||
@article{chen2019taaan, | ||
title={Temporal Attentive Alignment for Video Domain Adaptation}, | ||
author={Chen, Min-Hung and Kira, Zsolt and AlRegib, Ghassan}, | ||
booktitle = {CVPR Workshop on Learning from Unlabeled Videos}, | ||
year={2019} | ||
} | ||
``` | ||
|
||
--- | ||
#### Contact | ||
[Min-Hung Chen](https://www.linkedin.com/in/chensteven) <br> | ||
cmhungsteve AT gatech DOT edu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
from math import ceil | ||
|
||
class RelationModule(torch.nn.Module): | ||
# this is the naive implementation of the n-frame relation module, as num_frames == num_frames_relation | ||
def __init__(self, img_feature_dim, num_bottleneck, num_frames): | ||
super(RelationModule, self).__init__() | ||
self.num_frames = num_frames | ||
self.img_feature_dim = img_feature_dim | ||
self.num_bottleneck = num_bottleneck | ||
self.classifier = self.fc_fusion() | ||
def fc_fusion(self): | ||
# naive concatenate | ||
classifier = nn.Sequential( | ||
nn.ReLU(), | ||
nn.Linear(self.num_frames * self.img_feature_dim, self.num_bottleneck), | ||
nn.ReLU(), | ||
) | ||
return classifier | ||
def forward(self, input): | ||
input = input.view(input.size(0), self.num_frames*self.img_feature_dim) | ||
input = self.classifier(input) | ||
return input | ||
|
||
class RelationModuleMultiScale(torch.nn.Module): | ||
# Temporal Relation module in multiply scale, suming over [2-frame relation, 3-frame relation, ..., n-frame relation] | ||
|
||
def __init__(self, img_feature_dim, num_bottleneck, num_frames): | ||
super(RelationModuleMultiScale, self).__init__() | ||
self.subsample_num = 3 # how many relations selected to sum up | ||
self.img_feature_dim = img_feature_dim | ||
self.scales = [i for i in range(num_frames, 1, -1)] # generate the multiple frame relations | ||
|
||
self.relations_scales = [] | ||
self.subsample_scales = [] | ||
for scale in self.scales: | ||
relations_scale = self.return_relationset(num_frames, scale) | ||
self.relations_scales.append(relations_scale) | ||
self.subsample_scales.append(min(self.subsample_num, len(relations_scale))) # how many samples of relation to select in each forward pass | ||
|
||
# self.num_class = num_class | ||
self.num_frames = num_frames | ||
self.fc_fusion_scales = nn.ModuleList() # high-tech modulelist | ||
for i in range(len(self.scales)): | ||
scale = self.scales[i] | ||
fc_fusion = nn.Sequential( | ||
nn.ReLU(), | ||
nn.Linear(scale * self.img_feature_dim, num_bottleneck), | ||
nn.ReLU(), | ||
) | ||
|
||
self.fc_fusion_scales += [fc_fusion] | ||
|
||
print('Multi-Scale Temporal Relation Network Module in use', ['%d-frame relation' % i for i in self.scales]) | ||
|
||
def forward(self, input): | ||
# the first one is the largest scale | ||
act_scale_1 = input[:, self.relations_scales[0][0] , :] | ||
act_scale_1 = act_scale_1.view(act_scale_1.size(0), self.scales[0] * self.img_feature_dim) | ||
act_scale_1 = self.fc_fusion_scales[0](act_scale_1) | ||
act_scale_1 = act_scale_1.unsqueeze(1) # add one dimension for the later concatenation | ||
act_all = act_scale_1.clone() | ||
|
||
for scaleID in range(1, len(self.scales)): | ||
act_relation_all = torch.zeros_like(act_scale_1) | ||
# iterate over the scales | ||
num_total_relations = len(self.relations_scales[scaleID]) | ||
num_select_relations = self.subsample_scales[scaleID] | ||
idx_relations_evensample = [int(ceil(i * num_total_relations / num_select_relations)) for i in range(num_select_relations)] | ||
|
||
#for idx in idx_relations_randomsample: | ||
for idx in idx_relations_evensample: | ||
act_relation = input[:, self.relations_scales[scaleID][idx], :] | ||
act_relation = act_relation.view(act_relation.size(0), self.scales[scaleID] * self.img_feature_dim) | ||
act_relation = self.fc_fusion_scales[scaleID](act_relation) | ||
act_relation = act_relation.unsqueeze(1) # add one dimension for the later concatenation | ||
act_relation_all += act_relation | ||
|
||
act_all = torch.cat((act_all, act_relation_all), 1) | ||
return act_all | ||
|
||
def return_relationset(self, num_frames, num_frames_relation): | ||
import itertools | ||
return list(itertools.combinations([i for i in range(num_frames)], num_frames_relation)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
0 climb | ||
1 fencing | ||
2 golf | ||
3 kick_ball | ||
4 pullup | ||
5 punch | ||
6 pushup | ||
7 ride_bike | ||
8 ride_horse | ||
9 shoot_ball | ||
10 shoot_bow | ||
11 walk |
Oops, something went wrong.