Skip to content

Commit

Permalink
demo script for using pretrained model
Browse files Browse the repository at this point in the history
  • Loading branch information
zhoubolei committed Jan 9, 2018
1 parent 36e7c3b commit aa309c3
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 40 deletions.
22 changes: 14 additions & 8 deletions README.md
Expand Up @@ -45,7 +45,7 @@ python test_models.py something RGB model/TRN_something_RGB_BNInception_TRNmulti
--arch BNInception --crop_fusion_type TRNmultiscale --test_segments 8
```

### Pretrained models (working on it)
### Pretrained models and demo code

* Download pretrained models on [Something-Something](https://www.twentybn.com/datasets/something-something), [Jester](https://www.twentybn.com/datasets/jester), and [Moments in Time](http://moments.csail.mit.edu/)

Expand All @@ -54,28 +54,34 @@ cd pretrain
./download_models.sh
```

* Download sample video and extracted frames
* Download sample video and extracted frames. There will be mp4 video file and a folder containing the RGB frames for that video.

```bash
cd sample_data
./download_sample_data.sh
```

* Test pretrained model on mp4 video file
* Test pretrained model trained on Something-Something

```bash
python test_video.py --video_file sample_data/juggling.mp4 --rendered_output sample_data/predicted_video.mp4
python test_video.py --frame_folder sample_data/juggling_frames --weight pretrain/TRN_something_RGB_BNInception_TRNmultiscale_segment8_best.pth.tar --arch BNInception --dataset something
```

The command above uses `ffmpeg` to extract frames from the supplied video `--video_file` and optionally generates a new video `--rendered_output` from the frames used to make the prediction with the predicted category in the top-left corner.
* Test pretrained model trained on [Moments in Time](http://moments.csail.mit.edu/)

* Alternatively, if you wish to extract video frames yourself, you can test a pretrained model using a text file specifying the path to the individual frames of the video.
```bash
python test_video.py --frame_folder sample_data/juggling_frames --weight pretrain/TRN_moments_RGB_InceptionV3_TRNmultiscale_segment8_best.pth.tar --arch InceptionV3 --dataset moments
```

* Test pretrained model on mp4 video file

```bash
# Make prediction using list of extracted video frames.
python test_video.py --frame_list sample_data/juggling_frame_list.txt
#python test_video.py --video_file sample_data/juggling.mp4 --rendered_output sample_data/predicted_video.mp4 --weight pretrain/TRN_moments_RGB_InceptionV3_TRNmultiscale_segment8_best.pth.tar --arch InceptionV3 --dataset moments
```

The command above uses `ffmpeg` to extract frames from the supplied video `--video_file` and optionally generates a new video `--rendered_output` from the frames used to make the prediction with the predicted category in the top-left corner.


### TODO

* TODO: Web-cam demo script
Expand Down
8 changes: 3 additions & 5 deletions TRNmodule.py
Expand Up @@ -11,7 +11,7 @@ class RelationModule(torch.nn.Module):
# this is the naive implementation of the n-frame relation module, as num_frames == num_frames_relation
def __init__(self, img_feature_dim, num_frames, num_class):
super(RelationModule, self).__init__()
self.num_frames = num_frames
self.num_frames = num_frames
self.num_class = num_class
self.img_feature_dim = img_feature_dim
self.classifier = self.fc_fusion()
Expand All @@ -21,7 +21,7 @@ def fc_fusion(self):
classifier = nn.Sequential(
nn.ReLU(),
nn.Linear(self.num_frames * self.img_feature_dim, num_bottleneck),
nn.ReLU(),
nn.ReLU(),
nn.Linear(num_bottleneck,self.num_class),
)
return classifier
Expand Down Expand Up @@ -61,9 +61,7 @@ def __init__(self, img_feature_dim, num_frames, num_class):

self.fc_fusion_scales += [fc_fusion]

# maybe we put another fc layer after the summed up results???
print('Multi-Scale Temporal Relation Network Module in use')
print(['%d-frame relation' % i for i in self.scales])
print('Multi-Scale Temporal Relation Network Module in use', ['%d-frame relation' % i for i in self.scales])

def forward(self, input):
# the first one is the largest scale
Expand Down
24 changes: 12 additions & 12 deletions models.py
Expand Up @@ -11,7 +11,7 @@ def __init__(self, num_class, num_segments, modality,
base_model='resnet101', new_length=None,
consensus_type='avg', before_softmax=True,
dropout=0.8,img_feature_dim=256,
crop_num=1, partial_bn=True):
crop_num=1, partial_bn=True, print_spec=True):
super(TSN, self).__init__()
self.modality = modality
self.num_segments = num_segments
Expand All @@ -28,17 +28,17 @@ def __init__(self, num_class, num_segments, modality,
self.new_length = 1 if modality == "RGB" else 5
else:
self.new_length = new_length

print(("""
Initializing TSN with base model: {}.
TSN Configurations:
input_modality: {}
num_segments: {}
new_length: {}
consensus_module: {}
dropout_ratio: {}
img_feature_dim: {}
""".format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout, self.img_feature_dim)))
if print_spec == True:
print(("""
Initializing TSN with base model: {}.
TSN Configurations:
input_modality: {}
num_segments: {}
new_length: {}
consensus_module: {}
dropout_ratio: {}
img_feature_dim: {}
""".format(base_model, self.modality, self.num_segments, self.new_length, consensus_type, self.dropout, self.img_feature_dim)))

self._prepare_base_model(base_model)

Expand Down
1 change: 0 additions & 1 deletion sample_data/download_sample_data.sh
@@ -1,4 +1,3 @@
echo 'Downloading sample test video and extracted frames'
wget http://relation.csail.mit.edu/data/juggling.mp4
wget -r -nH --cut-dirs=1 --no-parent --reject="index.html*" http://relation.csail.mit.edu/data/juggling_frames/
ls -1 $PWD/juggling_frames/* > juggling_frame_list.txt
24 changes: 14 additions & 10 deletions test_video.py
Expand Up @@ -74,7 +74,7 @@ def render_frames(frames, prediction):
parser = argparse.ArgumentParser(description="test TRN on a single video")
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--video_file', type=str, default=None)
group.add_argument('--frame_list', type=str, default=None)
group.add_argument('--frame_folder', type=str, default=None)
parser.add_argument('--modality', type=str, default='RGB',
choices=['RGB', 'Flow', 'RGBDiff'], )
parser.add_argument('--dataset', type=str, default='moments',
Expand All @@ -84,6 +84,9 @@ def render_frames(frames, prediction):
parser.add_argument('--input_size', type=int, default=224)
parser.add_argument('--test_segments', type=int, default=8)
parser.add_argument('--img_feature_dim', type=int, default=256)
parser.add_argument('--consensus_type', type=str, default='TRNmultiscale')
parser.add_argument('--weight', type=str)

args = parser.parse_args()

# Get dataset categories.
Expand All @@ -96,13 +99,12 @@ def render_frames(frames, prediction):
args.test_segments,
args.modality,
base_model=args.arch,
consensus_type='TRNmultiscale',
img_feature_dim=args.img_feature_dim)
consensus_type=args.consensus_type,
img_feature_dim=args.img_feature_dim, print_spec=False)

weights = 'pretrain/TRN_{}_RGB_{}_TRNmultiscale_segment8_best.pth.tar'.format(
args.dataset, args.arch)
weights = args.weight
checkpoint = torch.load(weights)
print("model epoch {} best prec@1: {}".format(checkpoint['epoch'], checkpoint['best_prec1']))
#print("model epoch {} best prec@1: {}".format(checkpoint['epoch'], checkpoint['best_prec1']))

base_dict = {'.'.join(k.split('.')[1:]): v for k, v in list(checkpoint['state_dict'].items())}
net.load_state_dict(base_dict)
Expand All @@ -118,9 +120,11 @@ def render_frames(frames, prediction):
])

# Obtain video frames
if args.frame_list is not None:
print('Loading frames listed in text file...')
frame_paths = [line.rstrip() for line in open(args.frame_list, 'r').readlines()]
if args.frame_folder is not None:
print('Loading frames in %s'%args.frame_folder)
import glob
# here make sure after sorting the frame paths have the correct temporal order
frame_paths = sorted(glob.glob(os.path.join(args.frame_folder, '*.jpg')))
frames = load_frames(frame_paths)
else:
print('Extracting frames using ffmpeg...')
Expand All @@ -136,7 +140,7 @@ def render_frames(frames, prediction):
probs, idx = h_x.sort(0, True)

# Output the prediction.
video_name = args.frame_list if args.frame_list is not None else args.video_file
video_name = args.frame_folder if args.frame_folder is not None else args.video_file
print('RESULT ON ' + video_name)
for i in range(0, 5):
print('{:.3f} -> {}'.format(probs[i], categories[idx[i]]))
Expand Down
8 changes: 4 additions & 4 deletions test_video.sh
@@ -1,5 +1,5 @@
# Make prediction from mp4 video file.
python test_video.py --video_file sample_data/juggling.mp4 --rendered_output sample_data/predicted_video.mp4
# Make prediction from mp4 video file (ffmpeg is required)
#python test_video.py --video_file sample_data/juggling.mp4 --rendered_output sample_data/predicted_video.mp4 --weight pretrain/TRN_moments_RGB_InceptionV3_TRNmultiscale_segment8_best.pth.tar --arch InceptionV3 --dataset moments

# Make prediction using list of extracted video frames.
python test_video.py --frame_list sample_data/juggling_frame_list.txt
# Make prediction with input a a folder name with RGB frames
python test_video.py --frame_folder sample_data/juggling_frames --weight pretrain/TRN_moments_RGB_InceptionV3_TRNmultiscale_segment8_best.pth.tar --arch InceptionV3 --dataset moments

0 comments on commit aa309c3

Please sign in to comment.