From bbad466ce7cfe7676ef461630cf693645cffe1de Mon Sep 17 00:00:00 2001 From: gurkirt Date: Sun, 3 Dec 2017 14:05:42 +0000 Subject: [PATCH] initial commit --- .gitignore | 5 + LICENSE | 23 + README.md | 328 +++++++++++++- data/__init__.py | 22 + data/config.py | 54 +++ data/ucf24.py | 234 ++++++++++ layers/__init__.py | 2 + layers/box_utils.py | 244 ++++++++++ layers/functions/__init__.py | 5 + layers/functions/prior_box.py | 96 ++++ layers/modules/__init__.py | 4 + layers/modules/l2norm.py | 23 + layers/modules/multibox_loss.py | 116 +++++ online-tubes/.gitignore | 10 + online-tubes/I01onlineTubes.m | 151 +++++++ online-tubes/I02genFusedTubes.m | 155 +++++++ online-tubes/actionpath/actionPaths.m | 137 ++++++ online-tubes/actionpath/fusedActionPaths.m | 229 ++++++++++ online-tubes/actionpath/incremental_linking.m | 270 +++++++++++ online-tubes/actionpath/nms.m | 74 +++ .../eval/compute_spatio_temporal_iou.m | 92 ++++ online-tubes/eval/get_PR_curve.m | 154 +++++++ online-tubes/eval/xVOCap.m | 10 + online-tubes/frameAp.m | 294 ++++++++++++ online-tubes/gentube/PARactionPathSmoother.m | 131 ++++++ online-tubes/gentube/convert2eval.m | 57 +++ online-tubes/gentube/dpEM_max.m | 93 ++++ online-tubes/gentube/readALLactionPaths.m | 47 ++ online-tubes/utils/createdires.m | 20 + online-tubes/utils/initDatasetOpts.m | 60 +++ online-tubes/utils/initDatasetOptsFused.m | 74 +++ ssd.py | 205 +++++++++ test-ucf24.py | 223 +++++++++ train-ucf24.py | 412 +++++++++++++++++ utils/__init__.py | 16 + utils/augmentations.py | 425 ++++++++++++++++++ utils/evaluation.py | 155 +++++++ 37 files changed, 4648 insertions(+), 2 deletions(-) create mode 100644 .gitignore create mode 100644 LICENSE create mode 100644 data/__init__.py create mode 100644 data/config.py create mode 100644 data/ucf24.py create mode 100644 layers/__init__.py create mode 100644 layers/box_utils.py create mode 100644 layers/functions/__init__.py create mode 100644 layers/functions/prior_box.py create mode 100644 layers/modules/__init__.py create mode 100644 layers/modules/l2norm.py create mode 100644 layers/modules/multibox_loss.py create mode 100644 online-tubes/.gitignore create mode 100644 online-tubes/I01onlineTubes.m create mode 100644 online-tubes/I02genFusedTubes.m create mode 100644 online-tubes/actionpath/actionPaths.m create mode 100644 online-tubes/actionpath/fusedActionPaths.m create mode 100644 online-tubes/actionpath/incremental_linking.m create mode 100644 online-tubes/actionpath/nms.m create mode 100644 online-tubes/eval/compute_spatio_temporal_iou.m create mode 100644 online-tubes/eval/get_PR_curve.m create mode 100644 online-tubes/eval/xVOCap.m create mode 100644 online-tubes/frameAp.m create mode 100644 online-tubes/gentube/PARactionPathSmoother.m create mode 100644 online-tubes/gentube/convert2eval.m create mode 100644 online-tubes/gentube/dpEM_max.m create mode 100644 online-tubes/gentube/readALLactionPaths.m create mode 100644 online-tubes/utils/createdires.m create mode 100644 online-tubes/utils/initDatasetOpts.m create mode 100644 online-tubes/utils/initDatasetOptsFused.m create mode 100644 ssd.py create mode 100644 test-ucf24.py create mode 100644 train-ucf24.py create mode 100644 utils/__init__.py create mode 100644 utils/augmentations.py create mode 100644 utils/evaluation.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c321df8 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ + +*.log +*.pyc +*.pyo +__pycache__/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..ff2856f --- /dev/null +++ b/LICENSE @@ -0,0 +1,23 @@ +MIT License + +Copyright (c) 2017 Gurkirt Singh +This is an adaption of Max deGroot, Ellis Brown originl code of SSD for VOC dataset + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/README.md b/README.md index 7a084ea..6d0a113 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,326 @@ -# realtime-action-detection -This repository will host the code of our work on [Online Real-time Multiple Spatiotemporal Action Localisation and Prediction](https://arxiv.org/abs/1611.08563) +# Real-time online Action Detection: ROAD +An implementation of our work ([Online Real time Multiple Spatiotemporal Action Localisation and Prediction](https://arxiv.org/pdf/1611.08563.pdf) published in ICCV 2017. + +Originaly, we used [Caffe](https://github.com/weiliu89/caffe/tree/ssd) implmentation of [SSD-V2](https://arxiv.org/abs/1512.02325) +for publication. I have forked the version of [SSD-CAFFE](https://github.com/gurkirt/caffe/tree/ssd) which I used to generate results for paper, you try that if you want to use caffe. You can use that repo if like caffe other I would recommend using this version. +This implementations still is bit off from original work. It works slightly, better on lower IoU and higher IoU and vice-versa. +Tube generation part in original implementations as same as this. I found that this implemenation of SSD is slight worse @ IoU greater or equal to 0.5 in context of UCF24 dataset. + +I decided to release the code with [PyTorch](http://pytorch.org/) implementation of SSD, +because, it would be easier to reuse than caffe version (where installation itself could be an big issue). +We build on Pytorch [implementation](https://github.com/amdegroot/ssd.pytorch) of SSD by Max deGroot, Ellis Brown. +We made few changes like (different lr for bias and weights during optimisation) and simplified some parts to +accommodate ucf24 dataset. + +### Table of Contents +- Installation +- Datasets +- Training SSD +- Building Tubes +- Performance +- Extras +- TODO +- Citation +- Reference + +## Installation +- Install [PyTorch](http://pytorch.org/) by selecting your environment on the website and running the appropriate command. +- Please install cv2 as well. I recommend using anaconda 3.6 and it's opnecv package. +- You will also need Matlab. If you have distributed computing license then it would be faster otherwise it should also be fine. +Just replace parfor with simple `for` in matlab scripts. I would be happy to accept a PR for python version of this part. +- Clone this repository. + * Note: We currently only support Python 3+ on Linux system +- We currently only support [UCF24](http://www.thumos.info/download.html) with [revised annotaions](https://github.com/gurkirt/corrected-UCF101-Annots) released with our paper, we will try to add [JHMDB21](http://jhmdb.is.tue.mpg.de/) as soon as possible, but can't promise, you can checkout our [BMVC2016 code](https://bitbucket.org/sahasuman/bmvc2016_code) to get started your experiments on JHMDB21. +- To simulate the same training and evaluation setup we provide extracted `rgb` images from videos along with with optical flow images (both `brox flow` and `real-time flow`) computed for ucf24 dataset. +You can download it from my [google drive link](https://drive.google.com/file/d/1o2l6nYhd-0DDXGP-IPReBP4y1ffVmGSE/view?usp=sharing) +- We also support [Visdom](https://github.com/facebookresearch/visdom) for visualization of loss and frameAP on subset during training! + * To use Visdom in the browser: + ```Shell + # First install Python server and client + pip install visdom + # Start the server (probably in a screen or tmux) + python -m visdom.server --port=8097 + ``` + * Then (during training) navigate to http://localhost:8097/ (see the Training section below for more details). + +## Dataset +To make things easy, we provide extracted `rgb` images from videos along with with optical flow images (both `brox flow` and `real-time flow`) computed for ucf24 dataset, +you can download it from my [google drive link](https://drive.google.com/file/d/1o2l6nYhd-0DDXGP-IPReBP4y1ffVmGSE/view?usp=sharing). +It is almost 6Gb tar ball, download it and extract it wherever you going to store your experiments. + +UCF24DETECTION is a dataset loader Class in `data/ucf24.py` that inherits `torch.utils.data.Dataset` making it fully compatible with the `torchvision.datasets` [API](http://pytorch.org/docs/torchvision/datasets.html). + + +## Training SSD +- Requires fc-reduced [VGG-16](https://arxiv.org/abs/1409.1556) model weights, +weights are laready there in dataset tar ball under train_data subfolder. +- By default, we assume that you have downloaded that dataset. +- To train SSD using the train script simply specify the parameters listed in `train-ucf24.py` as a flag or manually change them. + +Let's assume that you extracted dataset in `/home/user/ucf24/` directory then your train command from the root directory of this repo is going to be: + +```Shell +CUDA_VISIBLE_DEVICES=0 python3 train-ucf24.py --data_root=/home/user/ucf24/ --save_root=/home/user/ucf24/ +--visdom=True --input_type=rgb --stepvalues=70000,90000 --max_iter=120000 +``` + +To train of flow inputs +```Shell +CUDA_VISIBLE_DEVICES=0 python3 train-ucf24.py --data_root=/home/user/ucf24/ --save_root=/home/user/ucf24/ +--visdom=True --input_type=brox --stepvalues=70000,90000 --max_iter=120000 +``` + +Different paramneter in `train-ucf24.py` will result in different performance + +- Note: + * Network occupies almost 9.2GB VRAM on a GPU, we used 1080Ti for training and normal training takes about 32-40 hrs + * For instructions on Visdom usage/installation, see the Installation section. By default it is off. + * If you don't like to use visdom then you always keep track of train using logfile which is saved under save_root directory + * During training checkpoint is saved every 10K iteration also log it's frame-level `frame-mean-ap` on a subset of 22k test images. + * We recommend to train for 120K iterations for all the input types. + +## Building Tubes +To generate the tubes and evaluate them, first, you will need frame-level detection then you can navigate to 'online-tubes' to generate tubes using `I01onlineTubes` and `I02genFusedTubes`. + +##### produce frame-level detection +Once you have trained network then you can use `test-ucf24.py` to generate frame-level detections. +To eval SSD using the test script simply specify the parameters listed in `test-ucf24.py` as a flag or manually change them. for e.g.: +```Shell +CUDA_VISIBLE_DEVICES=0 python3 test-ucf24.py --data_root=/home/user/ucf24/ --save_root=/home/user/ucf24/ +--input_type=rgb --eval_iter=120000 +``` + +To evaluate on optical flow models + +```Shell +CUDA_VISIBLE_DEVICES=0 python3 test-ucf24.py --data_root=/home/user/ucf24/ --save_root=/home/user/ucf24/ +--input_type=brox --eval_iter=120000 +``` + +-Note + * By default it will compute frame-level detections and store them as well as compute frame-mean-AP in models saved at 90k and 120k iteration. + * There is a log file file created for each iteration's frame-level evaluation. + +##### Build tubes +You will need frame-level detections and you will need to navigate to `online-tubes` + +Step-1: you will need to spacify `data_root`, `data_root` and `iteration_num_*` in `I01onlineTubes` and `I02genFusedTubes`; +
+Step 2: run `I01onlineTubes` and `I02genFusedTubes` in matlab this print out video-mean-ap and save the results in a `.mat` file + +Results are saved in `save_root/results.mat`. Additionally,`action-path` and `action-tubes` are also stroed under `save_root\ucf24\*` folders. + +* NOTE: `I01onlineTubes` and `I02genFusedTubes` not only produce video-level mAP; they also produce video-level classification accuracy on 24 classes of UCF24. +##### frame-meanAP +To compute frame-mAP you can use `frameAP.m` script. You will need to specify `data_root`, `data_root`. +Use this script to produce results for your publication not the python one, both are almost identical, +but thier ap computation from precision and recall is slightly different. + +## Performance +##### UCF24 Test +Table below is similiar to [table 1 in our paper](https://arxiv.org/pdf/1611.08563.pdf). It contains more info than +that in paper, mostly about this implemenation. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
IoU Threshold = 0.200.500.750.5:0.95frame-mAP@0.5accuracy(%)
Peng et al [3] RGB+BroxFLOW 73.6732.0700.8507.26 -- --
Saha et al [2] RGB+BroxFLOW 66.5536.3707.9414.37 -- --
Singh et al [4] RGB+FastFLOW 70.2043.0014.1019.20 -- --
Singh et al [4] RGB+BroxFLOW 73.5046.3015.0020.40 -- 91.12
This implentation[4] RGB 71.7139.3614.5717.9564.1288.68
This implentation[4] FastFLOW 73.5067.6303.5711.5646.3385.60
This implentation[4] BroxFLOW 44.6214.4300.1203.4221.9470.55
This implentation[4] RGB+FastFLOW (boost-fusion) 70.6140.1811.4217.0364.4089.01
This implentation[4] RGB+FastFLOW (union-set) 72.8043.2313.1418.5160.7089.89
This implentation[4] RGB+FastFLOW(mean fusion) 74.3444.2713.5018.9660.7091.54
This implentation[4] RGB+BroxFLOW (boost-fusion) 73.5843.7612.6018.6067.6091.10
This implentation[4] RGB+BroxFLOW (union-set) 74.8845.1413.9319.7364.3692.64
This implentation[4] RGB+BroxFLOW(mean fusion) 76.9147.5615.1420.6667.0193.08
Kalogeiton et al. [5] RGB+BroxFLOW (stack of flow images)(mean fusion) 76.5049.2019.7023.4069.50--
+ +##### Disscussion: +`Effect of training iterations:` +There is a effect of learing rate and number of itertaion +the model is trained. +If you train SSD on intial leanring rate for +many iterations then it performs is better on +lower IoU threshold, which is done in this case. +In orignal work using caffe implementation of SSD, +I trained SSD with 0.0005 learning rate for first 30K +iteration and dropped then learning rate by factor of 5 +(divided by 5) and only trained for 45k itrations. +In this implementation all the models are trained for 120K +iterations, intial learninig rate is 0.0005 and learing is dropped by +the fastor of 5 after 70K and 90K iterations. + +`Kalogeiton et al. [5] ` make use mean fusion, so I thought we could try in our pipeline which was very easy to incorporate. +It is evident from above table that mean fusion performs better than other fusion techniques. +Also, their method rely on multiple frames as input in addition to post-processing of +bounding box coordinates at tubelet level. + +##### Real-time aspect: + +This implementation is mainly focused on producing the best number, it can be modified to tun fast. +There few aspect that would need changes: + - NMS is performed once in python then again in matlab; one has to do that on GPU in python + - Most of the time spent during tube generations is taken by disc operations; which can be elimnated completely. + - IoU computation during action path is done multiple time just to keep the code clean that can be handled more smartly + +Contact me if you want to implement real-time version. +Proper real-time version would require converting matlab part into python. +I presented the timing of indivual components in paper, which still holds. + +## Extras +To use pre-trained model download the pretrained weights from the links given below and make changes in `test-ucf24.py` to accept the downloaded weights. + +##### Download pre-trained networks +- Currently, we provide the following PyTorch models: + * SSD300 trained on ucf24 ; available from my [google drive](https://drive.google.com/drive/folders/1Z42S8fQt4Amp1HsqyBOoHBtgVKUzJuJ8?usp=sharing) + - appearence model trained on rgb-images (named `rgb-ssd300_ucf24_120000`) + - accurate flow model trained on brox-images (named `brox-ssd300_ucf24_120000`) + - real-time flow model trained on fastOF-images (named `fastOF-ssd300_ucf24_120000`) +- These models can be used to reproduce above table which is almost identical in our [paper](https://arxiv.org/pdf/1611.08563.pdf) + +## TODO + - Incorporate JHMDB-21 dataset + - Convert matlab part into python + +## Citation +If this work has been helpful in your research please consider citing [1] and [4] + + @inproceedings{singh2016online, + title={Online Real time Multiple Spatiotemporal Action Localisation and Prediction}, + author={Singh, Gurkirt and Saha, Suman and Sapienza, Michael and Torr, Philip and Cuzzolin, Fabio}, + jbooktitle={ICCV}, + year={2017} + } + +## References +- [1] Wei Liu, et al. SSD: Single Shot MultiBox Detector. [ECCV2016]((http://arxiv.org/abs/1512.02325)). +- [2] S. Saha, G. Singh, M. Sapienza, P. H. S. Torr, and F. Cuzzolin, Deep learning for detecting multiple space-time action tubes in videos. BMVC 2016 +- [3] X. Peng and C. Schmid. Multi-region two-stream R-CNN for action detection. ECCV 2016 +- [4] G. Singh, S Saha, M. Sapienza, P. H. S. Torr and F Cuzzolin. Online Real time Multiple Spatiotemporal Action Localisation and Prediction. ICCV, 2017. +- [5] Kalogeiton, V., Weinzaepfel, P., Ferrari, V. and Schmid, C., 2017. Action Tubelet Detector for Spatio-Temporal Action Localization. ICCV, 2017. +- [Original SSD Implementation (CAFFE)](https://github.com/weiliu89/caffe/tree/ssd) +- A huge thank to Max deGroot, Ellis Brown for Pytorch implementation of [SSD](https://github.com/amdegroot/ssd.pytorch) + diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..8a78ef5 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,22 @@ +#from .voc0712 import VOCDetection, AnnotationTransform, detection_collate, VOC_CLASSES +from .ucf24 import UCF24Detection, AnnotationTransform, detection_collate, CLASSES +from .config import * +import cv2 +import numpy as np + + +def base_transform(image, size, mean): + x = cv2.resize(image, (size, size)).astype(np.float32) + # x = cv2.resize(np.array(image), (size, size)).astype(np.float32) + x -= mean + x = x.astype(np.float32) + return x + + +class BaseTransform: + def __init__(self, size, mean): + self.size = size + self.mean = np.array(mean, dtype=np.float32) + + def __call__(self, image, boxes=None, labels=None): + return base_transform(image, self.size, self.mean), boxes, labels diff --git a/data/config.py b/data/config.py new file mode 100644 index 0000000..63e5ec7 --- /dev/null +++ b/data/config.py @@ -0,0 +1,54 @@ +# config.py +""" SSD network configs + +Original author: Ellis Brown, Max deGroot for VOC dataset +https://github.com/amdegroot/ssd.pytorch + +""" + +#SSD300 CONFIGS +# newer version: use additional conv11_2 layer as last layer before multibox layers +v2 = { + 'feature_maps' : [38, 19, 10, 5, 3, 1], + + 'min_dim' : 300, + + 'steps' : [8, 16, 32, 64, 100, 300], + + 'min_sizes' : [30, 60, 111, 162, 213, 264], + + 'max_sizes' : [60, 111, 162, 213, 264, 315], + + # 'aspect_ratios' : [[2, 1/2], [2, 1/2, 3, 1/3], [2, 1/2, 3, 1/3], + # [2, 1/2, 3, 1/3], [2, 1/2], [2, 1/2]], + 'aspect_ratios' : [[2], [2, 3], [2, 3], [2, 3], [2], [2]], + + 'variance' : [0.1, 0.2], + + 'clip' : True, + + 'name' : 'v2', +} + +# use average pooling layer as last layer before multibox layers +v1 = { + 'feature_maps' : [38, 19, 10, 5, 3, 1], + + 'min_dim' : 300, + + 'steps' : [8, 16, 32, 64, 100, 300], + + 'min_sizes' : [30, 60, 114, 168, 222, 276], + + 'max_sizes' : [-1, 114, 168, 222, 276, 330], + + # 'aspect_ratios' : [[2], [2, 3], [2, 3], [2, 3], [2, 3], [2, 3]], + 'aspect_ratios' : [[1,1,2,1/2],[1,1,2,1/2,3,1/3],[1,1,2,1/2,3,1/3], + [1,1,2,1/2,3,1/3],[1,1,2,1/2,3,1/3],[1,1,2,1/2,3,1/3]], + + 'variance' : [0.1, 0.2], + + 'clip' : True, + + 'name' : 'v1', +} diff --git a/data/ucf24.py b/data/ucf24.py new file mode 100644 index 0000000..6156b3e --- /dev/null +++ b/data/ucf24.py @@ -0,0 +1,234 @@ +"""UCF24 Dataset Classes + +Author: Gurkirt Singh for ucf101-24 dataset + +""" + +import os +import os.path +import torch +import torch.utils.data as data +import cv2, pickle +import numpy as np + +CLASSES = ( # always index 0 + 'Basketball', 'BasketballDunk', 'Biking', 'CliffDiving', 'CricketBowling', 'Diving', 'Fencing', + 'FloorGymnastics', 'GolfSwing', 'HorseRiding', 'IceDancing', 'LongJump', 'PoleVault', 'RopeClimbing', + 'SalsaSpin','SkateBoarding', 'Skiing', 'Skijet', 'SoccerJuggling', + 'Surfing', 'TennisSwing', 'TrampolineJumping', 'VolleyballSpiking', 'WalkingWithDog') + + +class AnnotationTransform(object): + """ + Same as original + Transforms a VOC annotation into a Tensor of bbox coords and label index + Initilized with a dictionary lookup of classnames to indexes + Arguments: + class_to_ind (dict, optional): dictionary lookup of classnames -> indexes + (default: alphabetic indexing of UCF24's 24 classes) + keep_difficult (bool, optional): keep difficult instances or not + (default: False) + height (int): height + width (int): width + """ + + def __init__(self, class_to_ind=None, keep_difficult=False): + self.class_to_ind = class_to_ind or dict( + zip(CLASSES, range(len(CLASSES)))) + self.ind_to_class = dict(zip(range(len(CLASSES)),CLASSES)) + + def __call__(self, bboxs, labels, width, height): + res = [] + for t in range(len(labels)): + bbox = bboxs[t,:] + label = labels[t] + '''pts = ['xmin', 'ymin', 'xmax', 'ymax']''' + bndbox = [] + for i in range(4): + cur_pt = max(0,int(bbox[i]) - 1) + scale = width if i % 2 == 0 else height + cur_pt = min(scale, int(bbox[i])) + cur_pt = float(cur_pt) / scale + bndbox.append(cur_pt) + bndbox.append(label) + res += [bndbox] # [xmin, ymin, xmax, ymax, label_ind] + # img_id = target.find('filename').text[:-4] + return res # [[xmin, ymin, xmax, ymax, label_ind], ... ] + + +def readsplitfile(splitfile): + with open(splitfile, 'r') as f: + temptrainvideos = f.readlines() + trainvideos = [] + for vid in temptrainvideos: + vid = vid.rstrip('\n') + trainvideos.append(vid) + return trainvideos + + +def make_lists(rootpath, imgtype, split=1, fulltest=False): + imagesDir = rootpath + imgtype + '/' + splitfile = rootpath + 'splitfiles/trainlist{:02d}.txt'.format(split) + trainvideos = readsplitfile(splitfile) + trainlist = [] + testlist = [] + + with open(rootpath + 'splitfiles/pyannot.pkl','rb') as fff: + database = pickle.load(fff) + + train_action_counts = np.zeros(len(CLASSES), dtype=np.int32) + test_action_counts = np.zeros(len(CLASSES), dtype=np.int32) + + ratios = np.asarray([1.1,0.8,4.7,1.4,0.9,2.6,2.2,3.0,3.0,5.0,6.2,2.7,3.5,3.1,4.3,2.5,4.5,3.4,6.7,3.6,1.6,3.4,0.6,4.3]) + # ratios = np.ones_like(ratios) #TODO:uncomment this line and line 155, 156 to compute new ratios might be useful for JHMDB21 + video_list = [] + for vid, videoname in enumerate(sorted(database.keys())): + video_list.append(videoname) + actidx = database[videoname]['label'] + istrain = True + step = ratios[actidx] + numf = database[videoname]['numf'] + lastf = numf-1 + if videoname not in trainvideos: + istrain = False + step = ratios[actidx]*2.0 + if fulltest: + step = 1 + lastf = numf + + annotations = database[videoname]['annotations'] + num_tubes = len(annotations) + + tube_labels = np.zeros((numf,num_tubes),dtype=np.int16) # check for each tube if present in + tube_boxes = [[[] for _ in range(num_tubes)] for _ in range(numf)] + for tubeid, tube in enumerate(annotations): + # print('numf00', numf, tube['sf'], tube['ef']) + for frame_id, frame_num in enumerate(np.arange(tube['sf'], tube['ef'], 1)): # start of the tube to end frame of the tube + label = tube['label'] + assert actidx == label, 'Tube label and video label should be same' + box = tube['boxes'][frame_id, :] # get the box as an array + box = box.astype(np.float32) + box[2] += box[0] #convert width to xmax + box[3] += box[1] #converst height to ymax + tube_labels[frame_num, tubeid] = label+1 # change label in tube_labels matrix to 1 form 0 + tube_boxes[frame_num][tubeid] = box # put the box in matrix of lists + + possible_frame_nums = np.arange(0, lastf, step) + # print('numf',numf,possible_frame_nums[-1]) + for frame_num in possible_frame_nums: # loop from start to last possible frame which can make a legit sequence + frame_num = np.int32(frame_num) + check_tubes = tube_labels[frame_num,:] + + if np.sum(check_tubes>0)>0: # check if there aren't any semi overlapping tubes + all_boxes = [] + labels = [] + image_name = imagesDir + videoname+'/{:05d}.jpg'.format(frame_num+1) + label_name = rootpath + 'labels/' + videoname + '/{:05d}.txt'.format(frame_num + 1) + + assert os.path.isfile(image_name), 'Image does not exist'+image_name + for tubeid, tube in enumerate(annotations): + if tube_labels[frame_num, tubeid]>0: + box = np.asarray(tube_boxes[frame_num][tubeid]) + all_boxes.append(box) + labels.append(tube_labels[frame_num, tubeid]) + + if istrain: # if it is training video + trainlist.append([vid, frame_num+1, np.asarray(labels)-1, np.asarray(all_boxes)]) + train_action_counts[actidx] += len(labels) + else: # if test video and has micro-tubes with GT + testlist.append([vid, frame_num+1, np.asarray(labels)-1, np.asarray(all_boxes)]) + test_action_counts[actidx] += len(labels) + elif fulltest and not istrain: # if test video with no ground truth and fulltest is trues + testlist.append([vid, frame_num+1, np.asarray([9999]), np.zeros((1,4))]) + + for actidx, act_count in enumerate(train_action_counts): # just to see the distribution of train and test sets + print('train {:05d} test {:05d} action {:02d} {:s}'.format(act_count, test_action_counts[actidx] , int(actidx), CLASSES[actidx])) + + # newratios = train_action_counts/4000 + # print('new ratios', newratios) + # print('older ratios', ratios) + print('Trainlistlen', len(trainlist), ' testlist ', len(testlist)) + + return trainlist, testlist, video_list + + +class UCF24Detection(data.Dataset): + """UCF24 Action Detection Dataset + to access input images and target which is annotation + """ + + def __init__(self, root, image_set, transform=None, target_transform=None, + dataset_name='ucf24', input_type='rgb', full_test=False): + + self.input_type = input_type + input_type = input_type+'-images' + self.root = root + self.CLASSES = CLASSES + self.image_set = image_set + self.transform = transform + self.target_transform = target_transform + self.name = dataset_name + self._annopath = os.path.join(root, 'labels/', '%s.txt') + self._imgpath = os.path.join(root, input_type) + self.ids = list() + + trainlist, testlist, video_list = make_lists(root, input_type, split=1, fulltest=full_test) + self.video_list = video_list + if self.image_set == 'train': + self.ids = trainlist + elif self.image_set == 'test': + self.ids = testlist + else: + print('spacify correct subset ') + + def __getitem__(self, index): + im, gt, img_index = self.pull_item(index) + + return im, gt, img_index + + def __len__(self): + return len(self.ids) + + def pull_item(self, index): + annot_info = self.ids[index] + frame_num = annot_info[1] + video_id = annot_info[0] + videoname = self.video_list[video_id] + img_name = self._imgpath + '/{:s}/{:05d}.jpg'.format(videoname, frame_num) + # print(img_name) + img = cv2.imread(img_name) + height, width, channels = img.shape + + target = self.target_transform(annot_info[3], annot_info[2], width, height) + + + if self.transform is not None: + target = np.array(target) + img, boxes, labels = self.transform(img, target[:, :4], target[:, 4]) + img = img[:, :, (2, 1, 0)] + # img = img.transpose(2, 0, 1) + target = np.hstack((boxes, np.expand_dims(labels, axis=1))) + # print(height, width,target) + return torch.from_numpy(img).permute(2, 0, 1), target, index + # return torch.from_numpy(img), target, height, width + + +def detection_collate(batch): + """Custom collate fn for dealing with batches of images that have a different + number of associated object annotations (bounding boxes). + Arguments: + batch: (tuple) A tuple of tensor images and lists of annotations + Return: + A tuple containing: + 1) (tensor) batch of images stacked on their 0 dim + 2) (list of tensors) annotations for a given image are stacked on 0 dim + """ + + targets = [] + imgs = [] + image_ids = [] + for sample in batch: + imgs.append(sample[0]) + targets.append(torch.FloatTensor(sample[1])) + image_ids.append(sample[2]) + return torch.stack(imgs, 0), targets, image_ids diff --git a/layers/__init__.py b/layers/__init__.py new file mode 100644 index 0000000..53a3f4b --- /dev/null +++ b/layers/__init__.py @@ -0,0 +1,2 @@ +from .functions import * +from .modules import * diff --git a/layers/box_utils.py b/layers/box_utils.py new file mode 100644 index 0000000..be1e922 --- /dev/null +++ b/layers/box_utils.py @@ -0,0 +1,244 @@ +""" Bounding box utilities + +Original author: Ellis Brown, Max deGroot for VOC dataset +https://github.com/amdegroot/ssd.pytorch + +""" + +import torch + +def point_form(boxes): + """ Convert prior_boxes to (xmin, ymin, xmax, ymax) + representation for comparison to point form ground truth data. + Args: + boxes: (tensor) center-size default boxes from priorbox layers. + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin + boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax + + +def center_size(boxes): + """ Convert prior_boxes to (cx, cy, w, h) + representation for comparison to center-size form ground truth data. + Args: + boxes: (tensor) point_form boxes + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy + boxes[:, 2:] - boxes[:, :2], 1) # w, h + + +def intersect(box_a, box_b): + """ We resize both tensors to [A,B,2] without new malloc: + [A,2] -> [A,1,2] -> [A,B,2] + [B,2] -> [1,B,2] -> [A,B,2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: (tensor) bounding boxes, Shape: [A,4]. + box_b: (tensor) bounding boxes, Shape: [B,4]. + Return: + (tensor) intersection area, Shape: [A,B]. + """ + A = box_a.size(0) + B = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), + box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), + box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + inter = torch.clamp((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. Here we operate on + ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] + box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2]-box_a[:, 0]) * + (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] + area_b = ((box_b[:, 2]-box_b[:, 0]) * + (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +def match(threshold, truths, priors, variances, labels, loc_t, conf_t, idx): + """Match each prior box with the ground truth box of the highest jaccard + overlap, encode the bounding boxes, then return the matched indices + corresponding to both confidence and location preds. + Args: + threshold: (float) The overlap threshold used when mathing boxes. + truths: (tensor) Ground truth boxes, Shape: [num_obj, num_priors]. + priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. + variances: (tensor) Variances corresponding to each prior coord, + Shape: [num_priors, 4]. + labels: (tensor) All the class labels for the image, Shape: [num_obj]. + loc_t: (tensor) Tensor to be filled w/ endcoded location targets. + conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. + idx: (int) current batch index + Return: + The matched indices corresponding to 1)location and 2)confidence preds. + """ + # jaccard index + overlaps = jaccard( + truths, + point_form(priors) + ) + # (Bipartite Matching) + # [1,num_objects] best prior for each ground truth + best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) + # [1,num_priors] best ground truth for each prior + best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) + best_truth_idx.squeeze_(0) + best_truth_overlap.squeeze_(0) + best_prior_idx.squeeze_(1) + best_prior_overlap.squeeze_(1) + best_truth_overlap.index_fill_(0, best_prior_idx, 2) # ensure best prior + # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap + for j in range(best_prior_idx.size(0)): + best_truth_idx[best_prior_idx[j]] = j + matches = truths[best_truth_idx] # Shape: [num_priors,4] + conf = labels[best_truth_idx] + 1 # Shape: [num_priors] + conf[best_truth_overlap < threshold] = 0 # label as background + loc = encode(matches, priors, variances) + loc_t[idx] = loc # [num_priors,4] encoded offsets to learn + conf_t[idx] = conf # [num_priors] top class label for each prior + + +def encode(matched, priors, variances): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + Args: + matched: (tensor) Coords of ground truth for each prior in point-form + Shape: [num_priors, 4]. + priors: (tensor) Prior boxes in center-offset form + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + encoded boxes (tensor), Shape: [num_priors, 4] + """ + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - priors[:, :2] + # encode variance + g_cxcy /= (variances[0] * priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:] + g_wh = torch.log(g_wh) / variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + +# Adapted from https://github.com/Hakuyume/chainer-ssd +def decode(loc, priors, variances): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + Args: + loc (tensor): location predictions for loc layers, + Shape: [num_priors,4] + priors (tensor): Prior boxes in center-offset form. + Shape: [num_priors,4]. + variances: (list[float]) Variances of priorboxes + Return: + decoded bounding box predictions + """ + + boxes = torch.cat(( + priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:], + priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def log_sum_exp(x): + """Utility function for computing log_sum_exp while determining + This will be used to determine unaveraged confidence loss across + all examples in a batch. + Args: + x (Variable(tensor)): conf_preds from conf layers + """ + x_max = x.data.max() + return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max + + +# Original author: Francisco Massa: +# https://github.com/fmassa/object-detection.torch +# Ported to PyTorch by Max deGroot (02/01/2017) +def nms(boxes, scores, overlap=0.5, top_k=200): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + + keep = scores.new(scores.size(0)).zero_().long() + if boxes.numel() == 0: + return keep + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + area = torch.mul(x2 - x1, y2 - y1) + v, idx = scores.sort(0) # sort in ascending order + # I = I[v >= 0.01] + idx = idx[-top_k:] # indices of the top-k largest vals + xx1 = boxes.new() + yy1 = boxes.new() + xx2 = boxes.new() + yy2 = boxes.new() + w = boxes.new() + h = boxes.new() + + # keep = torch.Tensor() + count = 0 + while idx.numel() > 0: + i = idx[-1] # index of current largest val + # keep.append(i) + keep[count] = i + count += 1 + if idx.size(0) == 1: + break + idx = idx[:-1] # remove kept element from view + # load bboxes of next highest vals + torch.index_select(x1, 0, idx, out=xx1) + torch.index_select(y1, 0, idx, out=yy1) + torch.index_select(x2, 0, idx, out=xx2) + torch.index_select(y2, 0, idx, out=yy2) + # store element-wise max with next highest score + xx1 = torch.clamp(xx1, min=x1[i]) + yy1 = torch.clamp(yy1, min=y1[i]) + xx2 = torch.clamp(xx2, max=x2[i]) + yy2 = torch.clamp(yy2, max=y2[i]) + w.resize_as_(xx2) + h.resize_as_(yy2) + w = xx2 - xx1 + h = yy2 - yy1 + # check sizes of xx1 and xx2.. after each iteration + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + inter = w*h + # IoU = i / (area(a) + area(b) - i) + rem_areas = torch.index_select(area, 0, idx) # load remaining areas) + union = (rem_areas - inter) + area[i] + IoU = inter/union # store result in iou + # keep only elements with an IoU <= overlap + idx = idx[IoU.le(overlap)] + return keep, count diff --git a/layers/functions/__init__.py b/layers/functions/__init__.py new file mode 100644 index 0000000..a178cb4 --- /dev/null +++ b/layers/functions/__init__.py @@ -0,0 +1,5 @@ + +from .prior_box import PriorBox + + +__all__ = ['PriorBox'] diff --git a/layers/functions/prior_box.py b/layers/functions/prior_box.py new file mode 100644 index 0000000..932c060 --- /dev/null +++ b/layers/functions/prior_box.py @@ -0,0 +1,96 @@ +""" Generates prior boxes for SSD netowrk + +Original author: Ellis Brown, Max deGroot for VOC dataset +https://github.com/amdegroot/ssd.pytorch + +""" + +import torch +from math import sqrt as sqrt +from itertools import product as product + +class PriorBox(object): + """Compute priorbox coordinates in center-offset form for each source + feature map. + Note: + This 'layer' has changed between versions of the original SSD + paper, so we include both versions, but note v2 is the most tested and most + recent version of the paper. + + """ + def __init__(self, cfg): + super(PriorBox, self).__init__() + # self.type = cfg.name + self.image_size = cfg['min_dim'] + # number of priors for feature map location (either 4 or 6) + self.num_priors = len(cfg['aspect_ratios']) + self.variance = cfg['variance'] or [0.1] + self.feature_maps = cfg['feature_maps'] + self.min_sizes = cfg['min_sizes'] + self.max_sizes = cfg['max_sizes'] + self.steps = cfg['steps'] + self.aspect_ratios = cfg['aspect_ratios'] + self.clip = cfg['clip'] + self.version = cfg['name'] + for v in self.variance: + if v <= 0: + raise ValueError('Variances must be greater than 0') + + def forward(self): + mean = [] + # TODO merge these + if self.version == 'v2': + for k, f in enumerate(self.feature_maps): + for i, j in product(range(f), repeat=2): + f_k = self.image_size / self.steps[k] + # unit center x,y + cx = (j + 0.5) / f_k + cy = (i + 0.5) / f_k + + # aspect_ratio: 1 + # rel size: min_size + s_k = self.min_sizes[k]/self.image_size + mean += [cx, cy, s_k, s_k] + + # aspect_ratio: 1 + # rel size: sqrt(s_k * s_(k+1)) + s_k_prime = sqrt(s_k * (self.max_sizes[k]/self.image_size)) + mean += [cx, cy, s_k_prime, s_k_prime] + + # rest of aspect ratios + for ar in self.aspect_ratios[k]: + mean += [cx, cy, s_k*sqrt(ar), s_k/sqrt(ar)] + mean += [cx, cy, s_k/sqrt(ar), s_k*sqrt(ar)] + + else: + # original version generation of prior (default) boxes + for i, k in enumerate(self.feature_maps): + step_x = step_y = self.image_size/k + for h, w in product(range(k), repeat=2): + c_x = ((w+0.5) * step_x) + c_y = ((h+0.5) * step_y) + c_w = c_h = self.min_sizes[i] / 2 + s_k = self.image_size # 300 + # aspect_ratio: 1, + # size: min_size + mean += [(c_x-c_w)/s_k, (c_y-c_h)/s_k, + (c_x+c_w)/s_k, (c_y+c_h)/s_k] + if self.max_sizes[i] > 0: + # aspect_ratio: 1 + # size: sqrt(min_size * max_size)/2 + c_w = c_h = sqrt(self.min_sizes[i] * + self.max_sizes[i])/2 + mean += [(c_x-c_w)/s_k, (c_y-c_h)/s_k, + (c_x+c_w)/s_k, (c_y+c_h)/s_k] + # rest of prior boxes + for ar in self.aspect_ratios[i]: + if not (abs(ar-1) < 1e-6): + c_w = self.min_sizes[i] * sqrt(ar)/2 + c_h = self.min_sizes[i] / sqrt(ar)/2 + mean += [(c_x-c_w)/s_k, (c_y-c_h)/s_k, + (c_x+c_w)/s_k, (c_y+c_h)/s_k] + # back to torch land + output = torch.Tensor(mean).view(-1, 4) + if self.clip: + output.clamp_(max=1, min=0) + return output diff --git a/layers/modules/__init__.py b/layers/modules/__init__.py new file mode 100644 index 0000000..4218da4 --- /dev/null +++ b/layers/modules/__init__.py @@ -0,0 +1,4 @@ +from .l2norm import L2Norm +from .multibox_loss import MultiBoxLoss + +__all__ = ['L2Norm', 'MultiBoxLoss'] diff --git a/layers/modules/l2norm.py b/layers/modules/l2norm.py new file mode 100644 index 0000000..f344064 --- /dev/null +++ b/layers/modules/l2norm.py @@ -0,0 +1,23 @@ +import torch +import torch.nn as nn +from torch.autograd import Function +from torch.autograd import Variable +import torch.nn.init as init + +class L2Norm(nn.Module): + def __init__(self,n_channels, scale): + super(L2Norm,self).__init__() + self.n_channels = n_channels + self.gamma = scale or None + self.eps = 1e-10 + self.weight = nn.Parameter(torch.Tensor(self.n_channels)) + self.reset_parameters() + + def reset_parameters(self): + init.constant(self.weight,self.gamma) + + def forward(self, x): + norm = x.pow(2).sum(dim=1, keepdim=True).sqrt()+self.eps + x /= norm + out = self.weight.unsqueeze(0).unsqueeze(2).unsqueeze(3).expand_as(x) * x + return out diff --git a/layers/modules/multibox_loss.py b/layers/modules/multibox_loss.py new file mode 100644 index 0000000..023b547 --- /dev/null +++ b/layers/modules/multibox_loss.py @@ -0,0 +1,116 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from data import v2 as cfg +from ..box_utils import match, log_sum_exp + +class MultiBoxLoss(nn.Module): + """SSD Weighted Loss Function + Compute Targets: + 1) Produce Confidence Target Indices by matching ground truth boxes + with (default) 'priorboxes' that have jaccard index > threshold parameter + (default threshold: 0.5). + 2) Produce localization target by 'encoding' variance into offsets of ground + truth boxes and their matched 'priorboxes'. + 3) Hard negative mining to filter the excessive number of negative examples + that comes with using a large number of default bounding boxes. + (default negative:positive ratio 3:1) + Objective Loss: + L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N + Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss + weighted by α which is set to 1 by cross val. + Args: + c: class confidences, + l: predicted boxes, + g: ground truth boxes + N: number of matched default boxes + See: https://arxiv.org/pdf/1512.02325.pdf for more details. + """ + + def __init__(self, num_classes, overlap_thresh, prior_for_matching, + bkg_label, neg_mining, neg_pos, neg_overlap, encode_target, + use_gpu=True): + super(MultiBoxLoss, self).__init__() + self.use_gpu = use_gpu + self.num_classes = num_classes + self.threshold = overlap_thresh + self.background_label = bkg_label + self.encode_target = encode_target + self.use_prior_for_matching = prior_for_matching + self.do_neg_mining = neg_mining + self.negpos_ratio = neg_pos + self.neg_overlap = neg_overlap + self.variance = cfg['variance'] + + def forward(self, predictions, targets): + """Multibox Loss + Args: + predictions (tuple): A tuple containing loc preds, conf preds, + and prior boxes from SSD net. + conf shape: torch.size(batch_size,num_priors,num_classes) + loc shape: torch.size(batch_size,num_priors,4) + priors shape: torch.size(num_priors,4) + + ground_truth (tensor): Ground truth boxes and labels for a batch, + shape: [batch_size,num_objs,5] (last idx is the label). + """ + loc_data, conf_data, priors = predictions + num = loc_data.size(0) + priors = priors[:loc_data.size(1), :] + num_priors = (priors.size(0)) + num_classes = self.num_classes + + # match priors (default boxes) and ground truth boxes + loc_t = torch.Tensor(num, num_priors, 4) + conf_t = torch.LongTensor(num, num_priors) + for idx in range(num): + truths = targets[idx][:, :-1].data + labels = targets[idx][:, -1].data + defaults = priors.data + match(self.threshold, truths, defaults, self.variance, labels, + loc_t, conf_t, idx) + if self.use_gpu: + loc_t = loc_t.cuda() + conf_t = conf_t.cuda() + # wrap targets + loc_t = Variable(loc_t, requires_grad=False) + conf_t = Variable(conf_t, requires_grad=False) + + pos = conf_t > 0 + num_pos = pos.sum(keepdim=True) + + # Localization Loss (Smooth L1) + # Shape: [batch,num_priors,4] + pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) + loc_p = loc_data[pos_idx].view(-1, 4) + loc_t = loc_t[pos_idx].view(-1, 4) + loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False) + + # Compute max conf across batch for hard negative mining + batch_conf = conf_data.view(-1, self.num_classes) + + loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) + + # Hard Negative Mining + loss_c[pos] = 0 # filter out pos boxes for now + loss_c = loss_c.view(num, -1) + _, loss_idx = loss_c.sort(1, descending=True) + _, idx_rank = loss_idx.sort(1) + num_pos = pos.long().sum(1, keepdim=True) + num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) + neg = idx_rank < num_neg.expand_as(idx_rank) + + # Confidence Loss Including Positive and Negative Examples + pos_idx = pos.unsqueeze(2).expand_as(conf_data) + neg_idx = neg.unsqueeze(2).expand_as(conf_data) + conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1, self.num_classes) + targets_weighted = conf_t[(pos+neg).gt(0)] + loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False) + + # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N + + N = num_pos.data.sum() + loss_l /= N + loss_c /= N + return loss_l, loss_c diff --git a/online-tubes/.gitignore b/online-tubes/.gitignore new file mode 100644 index 0000000..cad12ce --- /dev/null +++ b/online-tubes/.gitignore @@ -0,0 +1,10 @@ +*.ods# +*.m~ +*.prototxt~ +*.txt~ +*.xml~ +*.log +*.txt +*.txt~ +*~ +/results diff --git a/online-tubes/I01onlineTubes.m b/online-tubes/I01onlineTubes.m new file mode 100644 index 0000000..e3847f4 --- /dev/null +++ b/online-tubes/I01onlineTubes.m @@ -0,0 +1,151 @@ +% --------------------------------------------------------- +% Copyright (c) 2017, Gurkirt Singh +% This code and is available +% under the terms of MIT License provided in LICENSE. +% Please retain this notice and LICENSE if you use +% this file (or any portion of it) in your project. +% --------------------------------------------------------- +%% This is main script to build tubes and evaluate them %% + +function I01onlineTubes() + +data_root = '/mnt/mars-fast/datasets'; +save_root = '/mnt/mars-gamma/ssd-work'; +iteration_num_rgb = [120000]; % you can also evaluate on multiple iertations +iteration_num_flow = [120000]; % you can also evaluate on multiple iertations + +% add subfolder to matlab paths +addpath(genpath('gentube/')); +addpath(genpath('actionpath/')); +addpath(genpath('eval/')); +addpath(genpath('utils/')); +model_type = 'CONV'; + +completeList = {... + {'ucf24','01',{'rgb'},iteration_num_rgb,{'score'}},... + {'ucf24','01',{'brox'},iteration_num_flow,{'score'}}... + {'ucf24','01',{'fastOF'},iteration_num_flow,{'score'}}... + }; + + +alldopts = cell(2,1); +count = 1; +gap=3; + +for setind = 1:length(completeList) + [dataset, listid, imtypes, iteration_nums, costTypes] = enumurateList(completeList{setind}); + for ct = 1:length(costTypes) + costtype = costTypes{ct}; + for imtind = 1:length(imtypes) + imgType = imtypes{imtind}; + for iteration = iteration_nums + for iouthresh=0.1 + %% generate directory sturcture based on the options + dopts = initDatasetOpts(data_root,save_root,dataset,imgType,model_type,listid,iteration,iouthresh,costtype, gap); + if exist(dopts.detDir,'dir') + alldopts{count} = dopts; + count = count+1; + end + end + end + end + end +end + +results = cell(2,1); + +%% For each option type build tubes and evaluate them +for index = 1:count-1 + opts = alldopts{index}; + if exist(opts.detDir,'dir') + fprintf('Video List %02d :: %s\nAnnotFile :: %s\nImage Dir :: %s\nDetection Dir:: %s\nActionpath Dir:: %s\nTube Dir:: %s\n',... + index, opts.vidList, opts.annotFile, opts.imgDir, opts.detDir, opts.actPathDir, opts.tubeDir); + %% Build action paths given frame level detections + actionPaths(opts); + %% Perform temproal labelling and evaluate; results saved in results cell + result_cell = gettubes(opts); + results{index,1} = result_cell; + results{index,2} = opts; + rm = result_cell{1}; + fprintf('\nmAP@0.2:%0.4f mAP@0.5:%0.4f mAP@0.75:%0.4f AVGmAP:%0.4f clsAcc:%0.4f\n',... + rm(1,5),rm(2,5),rm(7,5),mean(rm(2:end,5)),rm(1,6)); + end +end + +%% save results +save_dir = [save_root,'/results/']; +if ~isdir(save_dir) + mkdir(save_dir) +end + +save([save_dir,'online_tubes_results.mat'],'results') + +%% Function to enumrate options +function [dataset,listnum,imtypes,weights,costTypes] = enumurateList(sublist) +dataset = sublist{1}; listnum = sublist{2}; imtypes = sublist{3}; +weights = sublist{4};costTypes = sublist{5}; + +%% Facade function for smoothing tubes and evaluating them +function results = gettubes(dopts) + +numActions = length(dopts.actions); +results = zeros(300,6); +counter=1; +class_aps = cell(2,1); +% save file name to save result for eah option type +saveName = sprintf('%stubes-results.mat',dopts.tubeDir); +if ~exist(saveName,'file') + + annot = load(dopts.annotFile); + annot = annot.annot; + testvideos = getVideoNames(dopts.vidList); + for alpha = 3 + fprintf('alpha %03d ',alpha); + tubesSaveName = sprintf('%stubes-alpha%04d.mat',dopts.tubeDir,uint16(alpha*100)); + if ~exist(tubesSaveName,'file') + % read action paths + actionpaths = readALLactionPaths(dopts.vidList,dopts.actPathDir,1); + %% perform temporal trimming + smoothedtubes = parActionPathSmoother(actionpaths,alpha*ones(numActions,1),numActions); + save(tubesSaveName,'smoothedtubes','-v7.3'); + else + load(tubesSaveName) + end + + min_num_frames = 8; kthresh = 0.0; topk = 40; + xmldata = convert2eval(smoothedtubes, min_num_frames, kthresh*ones(numActions,1), topk,testvideos); + + %% Do the evaluation + for iou_th =[0.2,[0.5:0.05:0.95]] + [tmAP,tmIoU,tacc,AP] = get_PR_curve(annot, xmldata, testvideos, dopts.actions, iou_th); + % pritn outs iou_threshold, meanAp, sm, classifcation accuracy + fprintf('%.2f %0.3f %0.3f N ',iou_th,tmAP, tacc); + results(counter,:) = [iou_th,alpha,alpha,tmIoU,tmAP,tacc]; + class_aps{counter} = AP; + counter = counter+1; + end + fprintf('\n'); + end + + results(counter:end,:) = []; + result = cell(2,1); + result{2} = class_aps; + result{1} = results; + results = result; + fprintf('results saved in %s\n',saveName); + save(saveName,'results'); +else + load(saveName) +end + +function videos = getVideoNames(split_file) +% ------------------------------------------------------------------------- +fid = fopen(split_file,'r'); +data = textscan(fid, '%s'); +videos = cell(1); +count = 0; +for i=1:length(data{1}) + filename = cell2mat(data{1}(i,1)); + count = count +1; + videos{count} = filename; +end diff --git a/online-tubes/I02genFusedTubes.m b/online-tubes/I02genFusedTubes.m new file mode 100644 index 0000000..af71407 --- /dev/null +++ b/online-tubes/I02genFusedTubes.m @@ -0,0 +1,155 @@ + +function I02genFusedTubes() + +data_root = '/mnt/mars-fast/datasets'; +save_root = '/mnt/mars-gamma/ssd-work'; +iteration_num_rgb = 90000; % you can also evaluate on multiple iertations +iteration_num_flow = 120000; % you can also evaluate on multiple iertations + +addpath(genpath('actionpath/')); +addpath(genpath('gentube/')); +addpath(genpath('eval/')); +addpath(genpath('utils/')); + +completeList = {... + {'ucf24','01',{'rgb','brox'},[90000,120000],{'cat','nwsum-plus','mean'}, 0.25},... + {'ucf24','01',{'rgb','brox'},[120000,120000],{'cat','nwsum-plus','mean'}, 0.25},... + {'ucf24','01',{'rgb','fastOF'},[90000,120000],{'cat','nwsum-plus','mean'}, 0.25},... + {'ucf24','01',{'rgb','fastOF'},[120000,120000],{'cat','nwsum-plus','mean'}, 0.25},... + }; +model_type = 'CONV'; +costtype = 'score'; +iouthresh = 0.1; +gap = 3; +alldopts = cell(2,1); +count = 1; +for setind = [2,4] %1:length(completeList) + [dataset,listid,imtypes,iteration_nums,fusiontypes,fuseiouths] = enumurateList(completeList{setind}); + for ff =1:length(fusiontypes) + fusiontype = fusiontypes{ff}; + if strcmp(fusiontype,'cat') || strcmp(fusiontype,'mean') + tempfuseiouths = 0; + else + tempfuseiouths = fuseiouths; + end + for fuseiouth = tempfuseiouths + for iouWeight = 1 + dopts = initDatasetOptsFused(data_root,save_root,dataset,imtypes,model_type, ... + listid,iteration_nums,iouthresh,costtype,gap,fusiontype,fuseiouth); + if exist(dopts.basedetDir,'dir') && exist(dopts.topdetDir,'dir') + alldopts{count} = dopts; + count = count+1; + end + end + end + end +end + +fprintf('\n\n\n\n %d \n\n\n\n',count) + +% sets = {1:12,13:24,25:36,49:64}; +% parpool('local',16); %length(set)); + +results = cell(2,1); +for setid = 1 + for index = 1:count-1 + opts = alldopts{index}; + if exist(opts.basedetDir,'dir') && exist(opts.topdetDir,'dir') + fprintf('Video List :: %s\n \nDetection basedetDir:: %s\nActionpath Dir:: %s\nTube Dir:: %s\n',... + opts.vidList,opts.basedetDir,opts.actPathDir,opts.tubeDir); + + %% Build action paths given frame level detections + fusedActionPaths(opts); + %% Perform temproal labelling and evaluate; results saved in results cell + result_cell = gettubes(opts); + results{index,1} = result_cell; + results{index,2} = opts; + rm = result_cell{1}; + fprintf('\nmAP@0.2:%0.4f mAP@0.5:%0.4f mAP@0.75:%0.4f AVGmAP:%0.4f clsAcc:%0.4f\n',... + rm(1,5),rm(2,5),rm(7,5),mean(rm(2:end,5)),rm(1,6)); + end + end +end + +%% save results +save_dir = [save_root,'/results/']; +if ~isdir(save_dir) + mkdir(save_dir) +end + +save([save_dir,'online_fused_tubes_results.mat'],'results') + + +function [dataset,listnum,imtypes,weights,fusiontypes,fuseiouths] = enumurateList(sublist) + +dataset = sublist{1}; listnum = sublist{2}; imtypes = sublist{3}; +weights = sublist{4}; +fusiontypes = sublist{5}; +fuseiouths = sublist{6}; + +%% Facade function for smoothing tubes and evaluating them +function results = gettubes(dopts) + +numActions = length(dopts.actions); +results = zeros(300,6); +counter=1; +class_aps = cell(2,1); +% save file name to save result for eah option type +saveName = sprintf('%stubes-results.mat',dopts.tubeDir); +if ~exist(saveName,'file') + + annot = load(dopts.annotFile); + annot = annot.annot; + testvideos = getVideoNames(dopts.vidList); + for alpha = 3 + fprintf('alpha %03d ',alpha); + tubesSaveName = sprintf('%stubes-alpha%04d.mat',dopts.tubeDir,uint16(alpha*100)); + if ~exist(tubesSaveName,'file') + % read action paths + actionpaths = readALLactionPaths(dopts.vidList,dopts.actPathDir,1); + %% perform temporal trimming + smoothedtubes = parActionPathSmoother(actionpaths,alpha*ones(numActions,1),numActions); + save(tubesSaveName,'smoothedtubes','-v7.3'); + else + load(tubesSaveName) + end + + min_num_frames = 8; kthresh = 0.0; topk = 40; + % strip off uncessary parts and remove very small actions less than + % 8 frames; not really necessary but for speed at eval time + xmldata = convert2eval(smoothedtubes, min_num_frames, kthresh*ones(numActions,1), topk,testvideos); + + %% Do the evaluation + for iou_th =[0.2,[0.5:0.05:0.95]] + [tmAP,tmIoU,tacc,AP] = get_PR_curve(annot, xmldata, testvideos, dopts.actions, iou_th); + % pritn outs iou_threshold, meanAp, sm, classifcation accuracy + fprintf('%.2f %0.3f %0.3f N ',iou_th,tmAP, tacc); + results(counter,:) = [iou_th,alpha,alpha,tmIoU,tmAP,tacc]; + class_aps{counter} = AP; + counter = counter+1; + end + fprintf('\n'); + end + + results(counter:end,:) = []; + result = cell(2,1); + result{2} = class_aps; + result{1} = results; + results = result; + fprintf('results saved in %s\n',saveName); + save(saveName,'results'); +else + load(saveName) +end + +function videos = getVideoNames(split_file) +% ------------------------------------------------------------------------- +fid = fopen(split_file,'r'); +data = textscan(fid, '%s'); +videos = cell(1); +count = 0; +for i=1:length(data{1}) + filename = cell2mat(data{1}(i,1)); + count = count +1; + videos{count} = filename; +end diff --git a/online-tubes/actionpath/actionPaths.m b/online-tubes/actionpath/actionPaths.m new file mode 100644 index 0000000..c3820b0 --- /dev/null +++ b/online-tubes/actionpath/actionPaths.m @@ -0,0 +1,137 @@ +% --------------------------------------------------------- +function actionPaths(dopts) +% --------------------------------------------------------- +% Copyright (c) 2017, Gurkirt Singh +% This code and is available +% under the terms of MID License provided in LICENSE. +% Please retain this notice and LICENSE if you use +% this file (or any portion of it) in your project. +% --------------------------------------------------------- + +detresultpath = dopts.detDir; +costtype = dopts.costtype; +gap = dopts.gap; +videolist = dopts.vidList; +actions = dopts.actions; +saveName = dopts.actPathDir; +iouth = dopts.iouThresh; +numActions = length(actions); +nms_thresh = 0.45; +videos = getVideoNames(videolist); +NumVideos = length(videos); + +for vid=1:NumVideos + tic; + videoID = videos{vid}; + pathsSaveName = [saveName,videoID,'-actionpaths.mat']; + + videoDetDir = [detresultpath,videoID,'/']; + + if ~exist(pathsSaveName,'file') + fprintf('computing tubes for vide [%d out of %d] video ID = %s\n',vid,NumVideos, videoID); + + %% loop over all the frames of the video + fprintf('Reading detections '); + + frames = readDetections(videoDetDir); + + fprintf('\nDone reading detections\n'); + + fprintf('Gernrating action paths ...........\n'); + + %% parllel loop over all action class and genrate paths for each class + allpaths = cell(1); + parfor a=1:numActions + allpaths{a} = genActionPaths(frames, a, nms_thresh, iouth, costtype,gap); + end + + fprintf('results are being saved in::: %s for %d classes\n',pathsSaveName,length(allpaths)); + save(pathsSaveName,'allpaths'); + fprintf('All Done in %03d Seconds\n',round(toc)); + end + +end + +disp('done computing action paths'); + +end + +function paths = genActionPaths(frames,a,nms_thresh,iouth,costtype,gap) +action_frames = struct(); + +for f=1:length(frames) + [boxes,scores,allscores] = dofilter(frames,a,f,nms_thresh); + action_frames(f).boxes = boxes; + action_frames(f).scores = scores; + action_frames(f).allScores = allscores; +end + +paths = incremental_linking(action_frames,iouth,costtype, gap, gap); + +end + +%-- filter out least likkey detections for actions --- +function [boxes,scores,allscores] = dofilter(frames, a, f, nms_thresh) + scores = frames(f).scores(:,a); + pick = scores>0.001; + scores = scores(pick); + boxes = frames(f).boxes(pick,:); + allscores = frames(f).scores(pick,:); + [~,pick] = sort(scores,'descend'); + to_pick = min(50,size(pick,1)); + pick = pick(1:to_pick); + scores = scores(pick); + boxes = boxes(pick,:); + allscores = allscores(pick,:); + pick = nms([boxes scores], nms_thresh); + pick = pick(1:min(10,length(pick))); + boxes = boxes(pick,:); + scores = scores(pick); + allscores = allscores(pick,:); +end + +%-- list the files in directory and sort them ---------- +function list = sortdirlist(dirname) +list = dir(dirname); +list = sort({list.name}); +end + +% ------------------------------------------------------------------------- +function [videos] = getVideoNames(split_file) +% ------------------------------------------------------------------------- +fprintf('Get both lis is %s\n',split_file); +fid = fopen(split_file,'r'); +data = textscan(fid, '%s'); +videos = cell(1); +count = 0; + +for i=1:length(data{1}) + filename = cell2mat(data{1}(i,1)); + count = count +1; + videos{count} = filename; + % videos(i).vid = str2num(cell2mat(data{1}(i,1))); +end +end + + +function frames = readDetections(detectionDir) + +detectionList = sortdirlist([detectionDir,'*.mat']); +frames = struct([]); +numframes = length(detectionList); +scores = 0; +loc = 0; +for f = 1 : numframes + filename = [detectionDir,detectionList{f}]; + load(filename); % loads loc and scores variable + loc = [loc(:,1)*320, loc(:,2)*240, loc(:,3)*320, loc(:,4)*240]; + loc(loc(:,1)<0,1) = 0; + loc(loc(:,2)<0,2) = 0; + loc(loc(:,3)>319,3) = 319; + loc(loc(:,4)>239,4) = 239; + loc = loc + 1; + frames(f).boxes = loc; + frames(f).scores = [scores(:,2:end),scores(:,1)]; +end + +end diff --git a/online-tubes/actionpath/fusedActionPaths.m b/online-tubes/actionpath/fusedActionPaths.m new file mode 100644 index 0000000..95daac0 --- /dev/null +++ b/online-tubes/actionpath/fusedActionPaths.m @@ -0,0 +1,229 @@ +function fusedActionPaths(dopts) +% AUTORIGHTS +% --------------------------------------------------------- +% Copyright (c) 2016, Gurkirt Singh +% +% This code and is available +% under the terms of the Simplified BSD License provided in +% LICENSE. Please retain this notice and LICENSE if you use +% this file (or any portion of it) in your project. +% --------------------------------------------------------- + +detresultpathBase = dopts.basedetDir; +detresultpathTop = dopts.topdetDir; +videolist = dopts.vidList; +actions = dopts.actions; +saveName = dopts.actPathDir; +iouth = dopts.iouThresh; +numActions = length(actions); +costtype = dopts.costtype; +gap = dopts.gap; +nms_thresh = 0.45; +videos = getVideoNames(videolist); + +NumVideos = length(videos); +timimngs = zeros(NumVideos,1); + +for vid=1:NumVideos + tt = tic; + videoID = videos{vid}; + pathsSaveName = [saveName,videoID,'-actionpaths.mat']; + videoDetDirBase = [detresultpathBase,videoID,'/']; + videoTopDirBase = [detresultpathTop,videoID,'/']; + if ~exist(pathsSaveName,'file') + fprintf('computing tubes for vide [%d out of %d] video ID = %s\n',vid,NumVideos, videoID); + + fprintf('Reading detection files searlially '); + frames = readDetections(videoDetDirBase,videoTopDirBase); + fprintf('\nDone reading detection files \n'); + fprintf('Gernrating action paths ...........\n'); + + %% parllel loop over all action class and genrate paths for each class + thpath = tic; + allpaths = cell(1); + for a=1:numActions + allpaths{a} = genActionPaths(frames,a,nms_thresh,dopts.fuseiouth,dopts.fusiontype,iouth,costtype,gap); + end + timimngs(vid) = toc(thpath); + %% + fprintf('Completed linking \n'); + fprintf('results are being saved in::: %s\n',pathsSaveName); + save(pathsSaveName,'allpaths'); + fprintf('All Done in %03d Seconds\n',round(toc(tt))); + end +end + +% save('ucf101timing.mat','numfs','timimngs') +disp('done computing action paths'); +end + +% --------------------------------------------------------- +% function to gather the detection box and nms them and pass it to linking script +function paths = genActionPaths(frames,a,nms_thresh,fuseiouth,fusiontype,iouth,costtype,gap) +% --------------------------------------------------------- +action_frames = struct(); +for f=1:length(frames) + + baseBoxes = frames(f).baseBoxes; + baseAllScores = frames(f).baseScores; + topBoxes = frames(f).topBoxes; + topAllScores = frames(f).topScores; + meanScores = frames(f).meanScores; + [boxes, allscores] = fuseboxes(baseBoxes,topBoxes,baseAllScores,topAllScores,meanScores,fuseiouth,fusiontype,a,nms_thresh); + + action_frames(f).allScores = allscores; + action_frames(f).boxes = boxes(:,1:4); + action_frames(f).scores = boxes(:,5); +end + +paths = incremental_linking(action_frames,iouth,costtype,gap, gap); +end + +% --------------------------------------------------------- +function [boxes,allscores] = fuseboxes(baseBoxes,topBoxes,baseAllScores,topAllScores,meanScores,fuseiouth,fusiontype,a,nms_thresh) +% --------------------------------------------------------- + +if strcmp(fusiontype,'mean') + [boxes,allscores] = dofilter(baseBoxes,meanScores,a,nms_thresh); +elseif strcmp(fusiontype,'nwsum-plus') + [baseBoxes,baseAllScores] = dofilter(baseBoxes,baseAllScores,a,nms_thresh); + [topBoxes,topAllScores] = dofilter(topBoxes,topAllScores,a,nms_thresh); + [boxes,allscores] = boost_fusion(baseBoxes,topBoxes,baseAllScores,topAllScores,fuseiouth,a); + pick = nms(boxes,nms_thresh); + boxes = boxes(pick(1:min(10,length(pick))),:); + allscores = allscores(pick(1:min(10,length(pick))),:); + +else %% fusion type is cat // union-set fusion + [baseBoxes,baseAllScores] = dofilter(baseBoxes,baseAllScores,a,nms_thresh); + [topBoxes,topAllScores] = dofilter(topBoxes,topAllScores,a,nms_thresh); + boxes = [baseBoxes;topBoxes]; + allscores = [baseAllScores;topAllScores]; + pick = nms(boxes,nms_thresh); + boxes = boxes(pick(1:min(10,length(pick))),:); + allscores = allscores(pick(1:min(10,length(pick))),:); +end + +end + + +function [boxes,allscores] = dofilter(boxes, allscores,a,nms_thresh) + scores = allscores(:,a); + pick = scores>0.001; + scores = scores(pick); + boxes = boxes(pick,:); + allscores = allscores(pick,:); + [~,pick] = sort(scores,'descend'); + to_pick = min(50,size(pick,1)); + pick = pick(1:to_pick); + scores = scores(pick); + boxes = boxes(pick,:); + allscores = allscores(pick,:); + pick = nms([boxes scores], nms_thresh); + pick = pick(1:min(10,length(pick))); + boxes = [boxes(pick,:),scores(pick,:)]; + allscores = allscores(pick,:); +end + +% --------------------------------------------------------- +function [sb,ss] = boost_fusion(sb, fb,ss,fs,fuseiouth,a) % bs - boxes_spatial bf-boxes_flow +% --------------------------------------------------------- + +nb = size(sb,1); % num boxes +box_spatial = [sb(:,1:2) sb(:,3:4)-sb(:,1:2)+1]; +box_flow = [fb(:,1:2) fb(:,3:4)-fb(:,1:2)+1]; +coveredboxes = []; + +for i=1:nb + ovlp = inters_union(box_spatial(i,:), box_flow); % ovlp has 1x5 or 5x1 dim + if ~isempty(ovlp) + [movlp, maxind] = max(ovlp); + + if movlp>=fuseiouth && isempty(ismember(coveredboxes,maxind)) + ms = ss(i,:) + fs(maxind,:)*movlp; + ms = ms/sum(ms); + sb(i,5) = ms(a); + ss(i,:) = ms; + coveredboxes = [coveredboxes;maxind]; + end + end +end + +nb = size(fb,1); + +for i=1:nb + if ~ismember(coveredboxes,i) + sb = [sb;fb(i,:)]; + ss = [ss;fs(i,:)]; + end +end +end + + +function iou = inters_union(bounds1,bounds2) +% ------------------------------------------------------------------------ +inters = rectint(bounds1,bounds2); +ar1 = bounds1(:,3).*bounds1(:,4); +ar2 = bounds2(:,3).*bounds2(:,4); +union = bsxfun(@plus,ar1,ar2')-inters; +iou = inters./(union+0.001); +end + +% ------------------------------------------------------------------------- +function list = sortdirlist(dirname) +list = dir(dirname); +list = sort({list.name}); +end + +% ------------------------------------------------------------------------- +function [videos] = getVideoNames(split_file) +% ------------------------------------------------------------------------- +fprintf('Get both lis %s\n',split_file); +fid = fopen(split_file,'r'); +data = textscan(fid, '%s'); +videos = cell(1); +count = 0; + +for i=1:length(data{1}) + filename = cell2mat(data{1}(i,1)); + count = count +1; + videos{count} = filename; + % videos(i).vid = str2num(cell2mat(data{1}(i,1))); +end + +end + +function frames = readDetections(detectionDir,top_detectionDir ) + +detectionList = sortdirlist([detectionDir,'*.mat']); +frames = struct([]); +numframes = length(detectionList); +scores = 0; +loc = 0; +for f = 1 : numframes + filename = [detectionDir,detectionList{f}]; + load(filename); % load loc and scores variable + loc = [loc(:,1)*320, loc(:,2)*240, loc(:,3)*320, loc(:,4)*240]; + loc(loc(:,1)<0,1) = 0; + loc(loc(:,2)<0,2) = 0; + loc(loc(:,3)>319,3) = 319; + loc(loc(:,4)>239,4) = 239; + loc = loc + 1; + frames(f).baseBoxes = loc; + frames(f).baseScores = [scores(:,2:end),scores(:,1)]; + + filename = [top_detectionDir,detectionList{f}]; + load(filename); % load loc and scores variable + loc = [loc(:,1)*320, loc(:,2)*240, loc(:,3)*320, loc(:,4)*240]; + loc(loc(:,1)<0,1) = 0; + loc(loc(:,2)<0,2) = 0; + loc(loc(:,3)>319,3) = 319; + loc(loc(:,4)>239,4) = 239; + loc = loc + 1; + frames(f).topBoxes = loc; + frames(f).topScores = [scores(:,2:end),scores(:,1)]; + frames(f).meanScores = (frames(f).topScores + frames(f).baseScores)/2.0; +end + +end + + diff --git a/online-tubes/actionpath/incremental_linking.m b/online-tubes/actionpath/incremental_linking.m new file mode 100644 index 0000000..c715e4e --- /dev/null +++ b/online-tubes/actionpath/incremental_linking.m @@ -0,0 +1,270 @@ +% ------------------------------------------------------------------------- +function live_paths = incremental_linking(frames,iouth,costtype,jumpgap,threhgap) +% ------------------------------------------------------------------------- + +num_frames = length(frames); + +%% online path building + +live_paths = struct(); %% Stores live paths +dead_paths = struct(); %% Store the paths that has been terminated +dp_count = 0; +for t = 1:num_frames + num_box = size(frames(t).boxes,1); + if t==1 + for b = 1 : num_box + live_paths(b).boxes = frames(t).boxes(b,:); + live_paths(b).scores = frames(t).scores(b); + live_paths(b).allScores(t,:) = frames(t).allScores(b,:); + live_paths(b).pathScore = frames(t).scores(b); + live_paths(b).foundAT(t) = 1; + live_paths(b).count = 1; + live_paths(b).lastfound = 0; %less than 5 mean yes + end + else + lp_count = getPathCount(live_paths); + + % fprintf(' %d ', t); + edge_scores = zeros(lp_count,num_box); + + for lp = 1 : lp_count + edge_scores(lp,:) = score_of_edge(live_paths(lp),frames(t),iouth,costtype); + end + + + dead_count = 0 ; + coverd_boxes = zeros(1,num_box); + path_order_score = zeros(1,lp_count); + for lp = 1 : lp_count + if live_paths(lp).lastfound < jumpgap %less than 5 mean yes + box_to_lp_score = edge_scores(lp,:); + if sum(box_to_lp_score)>0 %%checking if atleast there is one match + [m_score,maxInd] = max(box_to_lp_score); + live_paths(lp).count = live_paths(lp).count + 1; + lpc = live_paths(lp).count; + live_paths(lp).boxes(lpc,:) = frames(t).boxes(maxInd,:); + live_paths(lp).scores(lpc) = frames(t).scores(maxInd); + live_paths(lp).allScores(lpc,:) = frames(t).allScores(maxInd,:); + live_paths(lp).pathScore = live_paths(lp).pathScore + m_score; + live_paths(lp).foundAT(lpc) = t; + live_paths(lp).lastfound = 0; + edge_scores(:,maxInd) = 0; + coverd_boxes(maxInd) = 1; + else + live_paths(lp).lastfound = live_paths(lp_count).lastfound +1; + end + + scores = sort(live_paths(lp).scores,'ascend'); + num_sc = length(scores); + path_order_score(lp) = mean(scores(max(1,num_sc-jumpgap):num_sc)); + + else + dead_count = dead_count + 1; + end + end + + % Sort the path based on scoe of the boxes and terminate dead path + + [live_paths,dead_paths,dp_count] = sort_live_paths(live_paths,.... + path_order_score,dead_paths,dp_count,jumpgap); + lp_count = getPathCount(live_paths); + % start new paths using boxes that are not assigned + if sum(coverd_boxes)0 + path_order_score = zeros(1,lp_count); + + for lp = 1 : length(live_paths) + scores = sort(live_paths(lp).scores,'descend'); + num_sc = length(scores); + path_order_score(lp) = mean(scores(1:min(20,num_sc))); + end + + [~,ind] = sort(path_order_score,'descend'); + for lpc = 1 : length(live_paths) + olp = ind(lpc); + sorted_live_paths(lpc).start = live_paths(olp).start; + sorted_live_paths(lpc).end = live_paths(olp).end; + sorted_live_paths(lpc).boxes = live_paths(olp).boxes; + sorted_live_paths(lpc).scores = live_paths(olp).scores; + sorted_live_paths(lpc).allScores = live_paths(olp).allScores; + sorted_live_paths(lpc).pathScore = live_paths(olp).pathScore; + sorted_live_paths(lpc).foundAT = live_paths(olp).foundAT; + sorted_live_paths(lpc).count = live_paths(olp).count; + sorted_live_paths(lpc).lastfound = live_paths(olp).lastfound; + end +end + +% ------------------------------------------------------------------------- +function gap_filled_paths = fill_gaps(paths,gap) +% ------------------------------------------------------------------------- +gap_filled_paths = struct(); +if isfield(paths,'boxes') + g_count = 1; + + for lp = 1 : getPathCount(paths) + if length(paths(lp).foundAT)>gap + gap_filled_paths(g_count).start = paths(lp).foundAT(1); + gap_filled_paths(g_count).end = paths(lp).foundAT(end); + gap_filled_paths(g_count).pathScore = paths(lp).pathScore; + gap_filled_paths(g_count).foundAT = paths(lp).foundAT; + gap_filled_paths(g_count).count = paths(lp).count; + gap_filled_paths(g_count).lastfound = paths(lp).lastfound; + count = 1; + i = 1; + while i <= length(paths(lp).scores) + diff_found = paths(lp).foundAT(i)-paths(lp).foundAT(max(i-1,1)); + if count == 1 || diff_found == 1 + gap_filled_paths(g_count).boxes(count,:) = paths(lp).boxes(i,:); + gap_filled_paths(g_count).scores(count) = paths(lp).scores(i); + gap_filled_paths(g_count).allScores(count,:) = paths(lp).allScores(i,:); + i = i + 1; + count = count + 1; + else + for d = 1 : diff_found + gap_filled_paths(g_count).boxes(count,:) = paths(lp).boxes(i,:); + gap_filled_paths(g_count).scores(count) = paths(lp).scores(i); + gap_filled_paths(g_count).allScores(count,:) = paths(lp).allScores(i,:); + count = count + 1; + end + i = i + 1; + end + end + g_count = g_count + 1; + end + end +end + + +% ------------------------------------------------------------------------- +function [sorted_live_paths,dead_paths,dp_count] = sort_live_paths(live_paths,... + path_order_score,dead_paths,dp_count,gap) +% ------------------------------------------------------------------------- + +sorted_live_paths = struct(); +[~,ind] = sort(path_order_score,'descend'); +lpc = 0; +for lp = 1 : getPathCount(live_paths) + olp = ind(lp); + if live_paths(ind(lp)).lastfound < gap + lpc = lpc + 1; + sorted_live_paths(lpc).boxes = live_paths(olp).boxes; + sorted_live_paths(lpc).scores = live_paths(olp).scores; + sorted_live_paths(lpc).allScores = live_paths(olp).allScores; + sorted_live_paths(lpc).pathScore = live_paths(olp).pathScore; + sorted_live_paths(lpc).foundAT = live_paths(olp).foundAT; + sorted_live_paths(lpc).count = live_paths(olp).count; + sorted_live_paths(lpc).lastfound = live_paths(olp).lastfound; + else + dp_count = dp_count + 1; + dead_paths(dp_count).boxes = live_paths(olp).boxes; + dead_paths(dp_count).scores = live_paths(olp).scores; + dead_paths(dp_count).allScores = live_paths(olp).allScores; + dead_paths(dp_count).pathScore = live_paths(olp).pathScore; + dead_paths(dp_count).foundAT = live_paths(olp).foundAT; + dead_paths(dp_count).count = live_paths(olp).count; + dead_paths(dp_count).lastfound = live_paths(olp).lastfound; + + end +end + + + + +% ------------------------------------------------------------------------- +function score = score_of_edge(v1,v2,iouth,costtype) +% ------------------------------------------------------------------------- + +N2 = size(v2.boxes,1); +score = zeros(1,N2); + +% try +bounds1 = [v1.boxes(end,1:2) v1.boxes(end,3:4)-v1.boxes(end,1:2)+1]; +% catch +% fprintf('catch here') +% end +bounds2 = [v2.boxes(:,1:2) v2.boxes(:,3:4)-v2.boxes(:,1:2)+1]; +iou = inters_union(bounds1,bounds2); + +for i = 1 : N2 + + if iou(i)>=iouth + + scores2 = v2.scores(i); + scores1 = v1.scores(end); + score_similarity = sqrt(sum((v1.allScores(end,:) - v2.allScores(i,:)).^2)); + if strcmp(costtype, 'score') + score(i) = scores2; + elseif strcmp(costtype, 'scrSim') + score(i) = 1-score_similarity; + elseif strcmp(costtype, 'scrMinusSim') + score(i) = scores2 + (1 - score_similarity); + end + + end + +end + +% ------------------------------------------------------------------------- +function lp_count = getPathCount(live_paths) +% ------------------------------------------------------------------------- + +if isfield(live_paths,'boxes') + lp_count = length(live_paths); +else + lp_count = 0; +end + +% ------------------------------------------------------------------------- +function iou = inters_union(bounds1,bounds2) +% ------------------------------------------------------------------------- + +inters = rectint(bounds1,bounds2); +ar1 = bounds1(:,3).*bounds1(:,4); +ar2 = bounds2(:,3).*bounds2(:,4); +union = bsxfun(@plus,ar1,ar2')-inters; + +iou = inters./(union+eps); diff --git a/online-tubes/actionpath/nms.m b/online-tubes/actionpath/nms.m new file mode 100644 index 0000000..27c49a3 --- /dev/null +++ b/online-tubes/actionpath/nms.m @@ -0,0 +1,74 @@ +function pick = nms(boxes, overlap) +% Non-maximum suppression. +% pick = nms(boxes, overlap) +% +% Greedily select high-scoring detections and skip detections that are +% significantly covered by a previously selected detection. +% +% Return value +% pick Indices of locally maximal detections +% +% Arguments +% boxes Detection bounding boxes (see pascal_test.m) +% overlap Overlap threshold for suppression +% For a selected box Bi, all boxes Bj that are covered by +% more than overlap are suppressed. Note that 'covered' is +% is |Bi \cap Bj| / |Bj|, not the PASCAL intersection over +% union measure. + +% AUTORIGHTS +% ------------------------------------------------------- +% Copyright (C) 2011-2012 Ross Girshick +% Copyright (C) 2008, 2009, 2010 Pedro Felzenszwalb, Ross Girshick +% Copyright (C) 2007 Pedro Felzenszwalb, Deva Ramanan +% +% This file is part of the voc-releaseX code +% (http://people.cs.uchicago.edu/~rbg/latent/) +% and is available under the terms of an MIT-like license +% provided in COPYING. Please retain this notice and +% COPYING if you use this file (or a portion of it) in +% your project. +% ------------------------------------------------------- + +if isempty(boxes) + pick = []; +else + x1 = boxes(:,1); + y1 = boxes(:,2); + x2 = boxes(:,3); + y2 = boxes(:,4); + s = boxes(:,end); + area = (x2-x1) .* (y2-y1); + %area = (x2-x1+1) .* (y2-y1+1); + + [vals, I] = sort(s); + pick = []; + while ~isempty(I) + last = length(I); + i = I(last); + pick = [pick; i]; + suppress = [last]; + for pos = 1:last-1 + j = I(pos); + xx1 = max(x1(i), x1(j)); + yy1 = max(y1(i), y1(j)); + xx2 = min(x2(i), x2(j)); + yy2 = min(y2(i), y2(j)); + w = xx2-xx1; + h = yy2-yy1; + +% w = xx2-xx1+1; +% h = yy2-yy1+1; + + if w > 0 && h > 0 + % compute overlap + inter = w*h; + o = inter / (area(j) + area(i) - inter); + if o > overlap + suppress = [suppress; pos]; + end + end + end + I(suppress) = []; + end +end diff --git a/online-tubes/eval/compute_spatio_temporal_iou.m b/online-tubes/eval/compute_spatio_temporal_iou.m new file mode 100644 index 0000000..344553e --- /dev/null +++ b/online-tubes/eval/compute_spatio_temporal_iou.m @@ -0,0 +1,92 @@ + +% ###################################################################################################################################################################################### +% We are here talking about spatio-temporal detections, i.e. a set of ground-truth bounding boxes that +% I will denote by g_t, with t between t_g^b and t_g^e (beginning and end time of the ground-truth) +% versus a detection which is also a set of bounding boxes, denoted by d_t, with t between t_d^e et t_d^e. +% +% a) temporal iou = T_i / T_u +% this is the intersection over union between the timing of the the tubes, +% ie mathematically T_i / T_u with +% the intersection T_i = max(0, max(t_g^b,t_d^b)-min(t_d^e,t_g^e) ) +% and the union T_u = min(t_g^b,t_d^b)-max(t_d^e,t_g^e) +% +% b) for each t between max(tgb,tdb)-min(tde,tge), we compute the IoU between g_t and d_t, and average them +% +% Multiplying (a) and (b) is the same as computed the average of the spatial iou over all frames in T_u of the two tubes, with a spatial iou of 0 for frames where only one box exists. +% c) as this is standard in detection problem, if there are multiple detections for the same groundtruth detection, the first one is counted as positive and the other ones as negatives +% ###################################################################################################################################################################################### +%{ +gt_fnr = 1xn doube +gt_bb = nx4 doubld - [x y w h] +dt_fnr = 1xm double +dt_bb = mx4 double - [x y w h] +%} +% ------------------------------------------------------------------------- +function st_iou = compute_spatio_temporal_iou(gt_fnr, gt_bb, dt_fnr, dt_bb) +% ------------------------------------------------------------------------- + +% time gt begin +tgb = gt_fnr(1); +% time gt end +tge = gt_fnr(end); +%time dt begin +tdb = dt_fnr(1); +tde = dt_fnr(end); +% temporal intersection +T_i = double(max(0, min(tge,tde)-max(tgb,tdb))); + +if T_i>0 + T_i = T_i +1; + % temporal union + T_u = double(max(tge,tde) - min(tgb,tdb)+1); + %temporal IoU + T_iou = T_i/T_u; + % intersect frame numbers + int_fnr = max(tgb,tdb):min(tge,tde); + + % find the ind of the intersected frames in the detected frames + [~,int_find_dt] = ismember(int_fnr, dt_fnr); + [~,int_find_gt] = ismember(int_fnr, gt_fnr); + + assert(length(int_find_dt)==length(int_find_gt)); + + iou = zeros(length(int_find_dt),1); + for i=1:length(int_find_dt) + if int_find_gt(i)<1 +% fprintf('error ') + pf = pf; + else + pf = i; + end + + gt_bound = gt_bb(int_find_gt(pf),:); + dt_bound = dt_bb(int_find_dt(pf),:)+1; + + % gt_bound = [gt_bound(:,1:2) gt_bound(:,3:4)-gt_bound(:,1:2)]; + % dt_bound = [dt_bound(:,1:2) dt_bound(:,3:4)-dt_bound(:,1:2)]; + iou(i) = inters_union(double(gt_bound),double(dt_bound)); + end + % finalspatio-temporal IoU threshold + st_iou = T_iou*mean(iou); +else + st_iou =0; +end +% % iou_thresh = 0.2,...,0.6 % 'Learing to track paper' takes 0.2 for UCF101 and 0.5 for JHMDB +% if delta >= iou_thresh +% % consider this tube as valid detection +% end + +end + +% ------------------------------------------------------------------------- +function iou = inters_union(bounds1,bounds2) +% ------------------------------------------------------------------------- + +inters = rectint(bounds1,bounds2); +ar1 = bounds1(:,3).*bounds1(:,4); +ar2 = bounds2(:,3).*bounds2(:,4); +union = bsxfun(@plus,ar1,ar2')-inters; + +iou = inters./(union+eps); + +end diff --git a/online-tubes/eval/get_PR_curve.m b/online-tubes/eval/get_PR_curve.m new file mode 100644 index 0000000..c08e18e --- /dev/null +++ b/online-tubes/eval/get_PR_curve.m @@ -0,0 +1,154 @@ +%%################################################################################################################################################## + +%% Author: Gurkirt Singh +%% Release date: 26th January 2017 +% STEP-1: loop over the videos present in the predicited Tubes +% STEP-2: for each video get the GT Tubes +% STEP-3: Compute the spatio-temporal overlap bwtween GT tube and predicited +% tubes +% STEP-4: then label tp 1 or fp 0 to each predicted tube +% STEP-5: Compute PR and AP for each class using scores, tp and fp in allscore +%################################################################################################################################################## + +function [mAP,mAIoU,acc,AP] = get_PR_curve(annot, xmldata, testlist, actions, iou_th) +% load(xmlfile) +num_vid = length(testlist); +num_actions = length(actions); +AP = zeros(num_actions,1); +averageIoU = zeros(num_actions,1); + +cc = zeros(num_actions,1); +for a=1:num_actions + allscore{a} = zeros(10000,2,'single'); +end + +total_num_gt_tubes = zeros(num_actions,1); +% count all the gt tubes from all the vidoes for label a +% total_num_detection = zeros(num_actions,1); + +preds = zeros(num_vid,1) - 1; +gts = zeros(num_vid,1); +annotNames = {annot.name}; +dtNames = {xmldata.videoName}; +for vid=1:num_vid + maxscore = -10000; + [action,~] = getActionName(testlist{vid}); %%get action name to which this video belongs to + [~,action_id] = find(strcmp(action, actions)); %% process only the videos from current action a + [~,gtVidInd] = find(strcmp(annotNames,testlist{vid})); + [~,dtVidInd] = find(strcmp(dtNames,testlist{vid})); + + dt_tubes = sort_detection(xmldata(dtVidInd)); + gt_tubes = annot(gtVidInd).tubes; + + num_detection = length(dt_tubes.class); + num_gt_tubes = length(gt_tubes); + + % total_num_detection = total_num_detection + num_detection; + for gtind = 1:num_gt_tubes + action_id = gt_tubes(gtind).class; + total_num_gt_tubes(action_id) = total_num_gt_tubes(action_id) + 1; + end + gts(vid) = action_id; + dt_labels = dt_tubes.class; + covered_gt_tubes = zeros(num_gt_tubes,1); + for dtind = 1:num_detection + dt_fnr = dt_tubes.framenr(dtind).fnr; + dt_bb = dt_tubes.boxes(dtind).bxs; + dt_label = dt_labels(dtind); + if dt_tubes.score(dtind)>maxscore + preds(vid) = dt_label; + maxscore = dt_tubes.score(dtind); + end + cc(dt_label) = cc(dt_label) + 1; + + ioumax=-inf;maxgtind=0; + for gtind = 1:num_gt_tubes + action_id = gt_tubes(gtind).class; + if ~covered_gt_tubes(gtind) && dt_label == action_id + gt_fnr = gt_tubes(gtind).sf:gt_tubes(gtind).ef; +% if isempty(gt_fnr) +% continue +% end + gt_bb = gt_tubes(gtind).boxes; + iou = compute_spatio_temporal_iou(gt_fnr, gt_bb, dt_fnr, dt_bb); + if iou>ioumax + ioumax=iou; + maxgtind=gtind; + end + end + end + + if ioumax>iou_th + covered_gt_tubes(maxgtind) = 1; + allscore{dt_label}(cc(dt_label),:) = [dt_tubes.score(dtind),1]; + averageIoU(dt_label) = averageIoU(dt_label) + ioumax; + else + allscore{dt_label}(cc(dt_label),:) = [dt_tubes.score(dtind),0]; + end + + end +end + +for a=1:num_actions + allscore{a} = allscore{a}(1:cc(a),:); + scores = allscore{a}(:,1); + labels = allscore{a}(:,2); + [~, si] = sort(scores,'descend'); + % scores = scores(si); + labels = labels(si); + fp=cumsum(labels==0); + tp=cumsum(labels==1); + cdet =0; + if ~isempty(tp)>0 + cdet = tp(end); + averageIoU(a) = (averageIoU(a)+0.000001)/(tp(end)+0.00001); + end + + recall=tp/total_num_gt_tubes(a); + precision=tp./(fp+tp); + AP(a) = xVOCap(recall,precision); + draw = 0; + if draw + % plot precision/recall + plot(recall,precision,'-'); + grid; + xlabel 'recall' + ylabel 'precision' + title(sprintf('class: %s, AP = %.3f',actions{a},AP(a))); + end + % fprintf('Action %02d AP = %0.5f and AIOU %0.5f GT %03d total det %02d correct det %02d %s\n', a, AP(a),averageIoU(a),total_num_gt_tubes(a),length(tp),cdet,actions{a}); + +end +acc = mean(preds==gts); +AP(isnan(AP)) = 0; +mAP = mean(AP); +averageIoU(isnan(averageIoU)) = 0; +mAIoU = mean(averageIoU); + + +%% ------------------------------------------------------------------------------------------------------------------------------------------------ +function [action,vidID] = getActionName(str) +%------------------------------------------------------------------------------------------------------------------------------------------------ +indx = strsplit(str, '/'); +action = indx{1}; +vidID = indx{2}; +%% +function sorted_tubes = sort_detection(dt_tubes) + +sorted_tubes = dt_tubes; + +if ~isempty(dt_tubes.class) + + num_detection = length(dt_tubes.class); + scores = dt_tubes.score; + [~,indexs] = sort(scores,'descend'); + for dt = 1 : num_detection + dtind = indexs(dt); + sorted_tubes.framenr(dt).fnr = dt_tubes.framenr(dtind).fnr; + sorted_tubes.boxes(dt).bxs = dt_tubes.boxes(dtind).bxs; + sorted_tubes.class(dt) = dt_tubes.class(dtind); + sorted_tubes.score(dt) = dt_tubes.score(dtind); + sorted_tubes.nr(dt) = dt; + end +end +%% diff --git a/online-tubes/eval/xVOCap.m b/online-tubes/eval/xVOCap.m new file mode 100644 index 0000000..bad027d --- /dev/null +++ b/online-tubes/eval/xVOCap.m @@ -0,0 +1,10 @@ +function ap = xVOCap(rec,prec) +% From the PASCAL VOC 2011 devkit + +mrec=[0 ; rec ; 1]; +mpre=[0 ; prec ; 0]; +for i=numel(mpre)-1:-1:1 + mpre(i)=max(mpre(i),mpre(i+1)); +end +i=find(mrec(2:end)~=mrec(1:end-1))+1; +ap=sum((mrec(i)-mrec(i-1)).*mpre(i)); \ No newline at end of file diff --git a/online-tubes/frameAp.m b/online-tubes/frameAp.m new file mode 100644 index 0000000..ca660b6 --- /dev/null +++ b/online-tubes/frameAp.m @@ -0,0 +1,294 @@ +% --------------------------------------------------------- +% Copyright (c) 2017, Gurkirt Singh +% This code and is available +% under the terms of MIT License provided in LICENSE. +% Please retain this notice and LICENSE if you use +% this file (or any portion of it) in your project. +% --------------------------------------------------------- + +%% This is main script to compute frame mean AP %% +%% this code is very new so hasn't been tested a lot +% Input: Detection directory; annotation file path; split file path +% Output: computes frame AP for all the detection directories +% It should produce results almost identical to test_ucf24.py + +function frameAP() + +addpath(genpath('eval/')); +addpath(genpath('utils/')); +addpath(genpath('actionpath/')); +data_root = '/mnt/mars-fast/datasets'; +save_root = '/mnt/mars-fast/ssd-work'; +iou_th = 0.5; +model_type = 'CONV'; +dataset = 'ucf24'; +list_id = '01'; +split_file = sprintf('%s/%s/splitfiles/testlist%s.txt',data_root,dataset,list_id); +annotfile = sprintf('%s/%s/splitfiles/annots.mat',data_root,dataset); +annot = load(annotfile); +annot = annot.annot; +testlist = getVideoNames(split_file); +num_vid = length(testlist); +num_actions = 24; + +logfile = fopen('frameAP.log','w'); % open log file + +imgType = 'rgb'; iteration_num = 120000; +det_dirs1 = sprintf('%s/%s/detections/%s-%s-%s-%06d/',save_root,dataset,model_type,imgType,list_id,iteration_num); +imgType = 'brox'; iteration_num = 120000; +det_dirs2 = sprintf('%s/%s/detections/%s-%s-%s-%06d/',save_root,dataset,model_type,imgType,list_id,iteration_num); +imgType = 'fastOF'; iteration_num = 120000; +det_dirs3 = sprintf('%s/%s/detections/%s-%s-%s-%06d/',save_root,dataset,model_type,imgType,list_id,iteration_num); + +combinations = {{det_dirs1},{det_dirs2},{det_dirs3},... + {det_dirs1,det_dirs3,'boost'},{det_dirs1,det_dirs2,'boost'},... + {det_dirs1,det_dirs3,'cat'},{det_dirs1,det_dirs2,'cat'},... + {det_dirs1,det_dirs3,'mean'},{det_dirs1,det_dirs2,'mean'}}; + +for c=1:length(combinations) + comb = combinations{c}; + line = comb{1}; + if length(comb)>1 + fusion_type = comb{3}; + line = [line,' ',comb{2},' \n\n fusion type: ',fusion_type,'\n\n']; + + else + fusion_type = 'none'; + end + + line = sprintf('Evaluation for %s\n',line); + fprintf('%s',line) + fprintf(logfile,'%s',line); + AP = zeros(num_actions,1); + cc = zeros(num_actions,1); + for a=1:num_actions + allscore{a} = zeros(24*20*160000,2,'single'); + end + + total_num_gt_boxes = zeros(num_actions,1); + annotNames = {annot.name}; + + for vid=1:num_vid + video_name = testlist{vid}; + [~,gtVidInd] = find(strcmp(annotNames, testlist{vid})); + gt_tubes = annot(gtVidInd).tubes; + numf = annot(gtVidInd).num_imgs; + num_gt_tubes = length(gt_tubes); + if mod(vid,5) == 0 + fprintf('Done procesing %d videos out of %d %s\n', vid, num_vid, video_name) + end + for nf = 1:numf + gt_boxes = get_gt_boxes(gt_tubes,nf); + dt_boxes = get_dt_boxes(comb, video_name, nf, num_actions, fusion_type); + num_gt_boxes = size(gt_boxes,1); + for g = 1:num_gt_boxes + total_num_gt_boxes(gt_boxes(g,5)) = total_num_gt_boxes(gt_boxes(g,5)) + 1; + end + covered_gt_boxes = zeros(num_gt_boxes,1); + for d = 1 : size(dt_boxes,1) + dt_score = dt_boxes(d,5); + dt_label = dt_boxes(d,6); + cc(dt_label) = cc(dt_label) + 1; + ioumax=-inf; maxgtind=0; + if num_gt_boxes>0 && any(gt_boxes(:,5) == dt_label) + for g = 1:num_gt_boxes + if ~covered_gt_boxes(g) && any(dt_label == gt_boxes(:,5)) + iou = compute_spatial_iou(gt_boxes(g,1:4), dt_boxes(d,1:4)); + if iou>ioumax + ioumax=iou; + maxgtind=g; + end + end + end + end + + if ioumax>=iou_th + covered_gt_boxes(maxgtind) = 1; + allscore{dt_label}(cc(dt_label),:) = [dt_score,1]; % tp detection + else + allscore{dt_label}(cc(dt_label),:) = [dt_score,0]; % fp detection + end + + end + + end + end + % Sort scores and then reorder tp fp labels in result precision and recall for each action + for a=1:num_actions + allscore{a} = allscore{a}(1:cc(a),:); + scores = allscore{a}(:,1); + labels = allscore{a}(:,2); + [~, si] = sort(scores,'descend'); + % scores = scores(si); + labels = labels(si); + fp=cumsum(labels==0); + tp=cumsum(labels==1); + recall=tp/total_num_gt_boxes(a); + precision=tp./(fp+tp); + AP(a) = xVOCap(recall,precision); + line = sprintf('Action %02d AP = %0.5f \n', a, AP(a)); + fprintf('%s',line); + fprintf(logfile,'%s',line); + end + + AP(isnan(AP)) = 0; + mAP = mean(AP); + line = sprintf('\nMean AP::=> %.5f\n\n',mAP); + fprintf('%s',line); + fprintf(logfile,'%s',line); +end +end + + +% ------------------------------------------------------------------------- +function [videos] = getVideoNames(split_file) +% ------------------------------------------------------------------------- +fprintf('Get both lis is %s\n',split_file); +fid = fopen(split_file,'r'); +data = textscan(fid, '%s'); +videos = cell(1); +count = 0; + +for i=1:length(data{1}) + filename = cell2mat(data{1}(i,1)); + count = count +1; + videos{count} = filename; + % videos(i).vid = str2num(cell2mat(data{1}(i,1))); +end +end + +function gt_boxes = get_gt_boxes(gt_tubes,nf) +gt_boxes = []; +gt_tubes; +for t = 1:length(gt_tubes) + if nf >= gt_tubes(t).sf && nf <= gt_tubes(t).ef + b_ind = nf - gt_tubes(t).sf + 1; + box = [gt_tubes(t).boxes(b_ind,:), gt_tubes(t).class]; + gt_boxes = [gt_boxes;box]; + end +end +end + +function dt_boxes = get_dt_boxes(detection_dir, video_name, nf, num_actions, fusion_type) +dt_boxes = []; +%% apply nms per class +[boxes,scores] = read_detections(detection_dir, video_name, nf); +for a = 1 : num_actions + cls_boxes = get_cls_detection(boxes,scores,a,fusion_type); + dt_boxes = [dt_boxes; cls_boxes]; +end +end + +function cls_boxes = get_cls_detection(boxes,scores,a,fusion_type) + +if strcmp(fusion_type,'none') + cls_boxes = dofilter(boxes(1).b,scores(1).s,a); +elseif strcmp(fusion_type,'mean') + cls_boxes = dofilter(boxes(1).b,(scores(1).s+scores(2).s)/2.0,a); +elseif strcmp(fusion_type,'cat') + cls_boxes_base = dofilter(boxes(1).b,scores(1).s,a); + cls_boxes_top = dofilter(boxes(2).b,scores(2).s,a); + all_boxes = [cls_boxes_base;cls_boxes_top]; + pick = nms(all_boxes(:,1:5),0.45); + cls_boxes = all_boxes(pick,:); +elseif strcmp(fusion_type,'boost') + cls_boxes_base = dofilter(boxes(1).b,scores(1).s,a); + cls_boxes_top = dofilter(boxes(2).b,scores(2).s,a); + all_boxes = boost_boxes(cls_boxes_base,cls_boxes_top); + pick = nms(all_boxes(:,1:5),0.45); + cls_boxes = all_boxes(pick,:); +else + error('Spacify correct fusion technique'); +end + +end + +function cls_boxes_base = boost_boxes(cls_boxes_base,cls_boxes_top) + +box_spatial = [cls_boxes_base(:,1:2) cls_boxes_base(:,3:4)-cls_boxes_base(:,1:2)+1]; +box_flow = [cls_boxes_top(:,1:2) cls_boxes_top(:,3:4)-cls_boxes_top(:,1:2)+1]; +coveredboxes = []; +nb = size(cls_boxes_base,1); % num boxes +for i=1:nb + ovlp = inters_union(box_spatial(i,:), box_flow); % ovlp has 1x5 or 5x1 dim + if ~isempty(ovlp) + [movlp, maxind] = max(ovlp); + if movlp>=0.3 && isempty(ismember(coveredboxes,maxind)) + cls_boxes_base(i,5) = cls_boxes_base(i,5) + cls_boxes_top(maxind,5)*movlp; + coveredboxes = [coveredboxes;maxind]; + end + end +end + +nb = size(cls_boxes_top,1); +for i=1:nb + if ~ismember(coveredboxes,i) + cls_boxes_base = [cls_boxes_base; cls_boxes_top(i,:)]; + end +end + +end + +function [bxs, sc] = read_detections(detection_dir, video_name, nf) +detection_dir1 = detection_dir{1}; +det_file = sprintf('%s%s/%05d.mat', detection_dir1, video_name, nf); +load(det_file); % loads loc and scores variable +boxes = [loc(:,1)*320, loc(:,2)*240, loc(:,3)*320, loc(:,4)*240] + 1; +boxes(boxes(:,1)<1,1) = 1; boxes(boxes(:,2)<1,2) = 1; +boxes(boxes(:,3)>320,3) = 320; boxes(boxes(:,4)>240,4) = 240; +scores = [scores(:,2:end),scores(:,1)]; +bxs = struct(); +sc = struct(); +bxs(1).b = boxes; +sc(1).s = scores; +if length(detection_dir)>1 + detection_dir1 = detection_dir{2}; + det_file = sprintf('%s%s/%05d.mat', detection_dir1, video_name, nf); + load(det_file); % loads loc and scores variable + boxes = [loc(:,1)*320, loc(:,2)*240, loc(:,3)*320, loc(:,4)*240] + 1; + boxes(boxes(:,1)<1,1) = 1; boxes(boxes(:,2)<1,2) = 1; + boxes(boxes(:,3)>320,3) = 320; boxes(boxes(:,4)>240,4) = 240; + scores = [scores(:,2:end),scores(:,1)]; + bxs(2).b = boxes; + sc(2).s = scores; +end + +end + + +function boxes = dofilter(boxes,scores,a) +scores = scores(:,a); +pick = scores>0.01; +scores = scores(pick); +boxes = boxes(pick,:); +[~,pick] = sort(scores,'descend'); +to_pick = min(50,size(pick,1)); +pick = pick(1:to_pick); +scores = scores(pick); +boxes = boxes(pick,:); +pick = nms([boxes scores],0.45); +pick = pick(1:min(20,length(pick))); +boxes = boxes(pick,:); +scores = scores(pick); +cls = scores*0 + a; +boxes = [boxes,scores, cls]; +end + +function iou = inters_union(bounds1,bounds2) +% ------------------------------------------------------------------------ +inters = rectint(bounds1,bounds2); +ar1 = bounds1(:,3).*bounds1(:,4); +ar2 = bounds2(:,3).*bounds2(:,4); +union = bsxfun(@plus,ar1,ar2')-inters; +iou = inters./(union+0.001); +end + + +function iou = compute_spatial_iou(gt_box, dt_box) +dt_box = [dt_box(1:2), dt_box(3:4)-dt_box(1:2)+1]; +inter = rectint(gt_box,dt_box); +ar1 = gt_box(3)*gt_box(4); +ar2 = dt_box(3)*dt_box(4); +union = ar1 + ar2 - inter; +iou = inter/union; +end \ No newline at end of file diff --git a/online-tubes/gentube/PARactionPathSmoother.m b/online-tubes/gentube/PARactionPathSmoother.m new file mode 100644 index 0000000..7647484 --- /dev/null +++ b/online-tubes/gentube/PARactionPathSmoother.m @@ -0,0 +1,131 @@ +% --------------------------------------------------------- +% Copyright (c) 2017, Gurkirt Singh +% This code and is available +% under the terms of MIT License provided in LICENSE. +% Please retain this notice and LICENSE if you use +% this file (or any portion of it) in your project. +% --------------------------------------------------------- + + +function final_tubes = parActionPathSmoother(actionpaths,alpha,num_action) + +% load data +% fprintf('Number of video intest set %d \n', actionpath,alpha,num_action,calpha,useNeg +% alpha = 1; + +final_tubes = struct('starts',[],'ts',[],'te',[],'label',[],'path_total_score',[],... + 'dpActionScore',[],'dpPathScore',[],... + 'path_boxes',cell(1),'path_scores',cell(1),'video_id',cell(1)); + + +alltubes = cell(length(actionpaths),1); + +parfor t = 1 : length(actionpaths) + % fprintf('[%03d/%03d] calpha %04d\n',t,length(tubes),uint16(calpha*100)); + % fprintf('.'); + video_id = actionpaths(t).video_id; + % fprintf('[doing for %s %d out of %d]\n',video_id,t,length(tubes)); + alltubes{t} = actionPathSmoother4oneVideo(actionpaths(t).paths,alpha,num_action,video_id) ; +end + +action_count = 1; +for t = 1 : length(actionpaths) + vid_tubes = alltubes{t}; + for k=1:length(vid_tubes.ts) + final_tubes.starts(action_count) = vid_tubes.starts(k); + final_tubes.ts(action_count) = vid_tubes.ts(k); + final_tubes.video_id{action_count} = vid_tubes.video_id{k}; + final_tubes.te(action_count) = vid_tubes.te(k); + final_tubes.dpActionScore(action_count) = vid_tubes.dpActionScore(k); + final_tubes.label(action_count) = vid_tubes.label(k); + final_tubes.dpPathScore(action_count) = vid_tubes.dpPathScore(k); + final_tubes.path_total_score(action_count) = vid_tubes.path_total_score(k); + final_tubes.path_boxes{action_count} = vid_tubes.path_boxes{k}; + final_tubes.path_scores{action_count} = vid_tubes.path_scores{k}; + action_count = action_count + 1; + end + +end +end + +function final_tubes = actionPathSmoother4oneVideo(video_paths,alpha,num_action,video_id) +action_count =1; +final_tubes = struct('starts',[],'ts',[],'te',[],'label',[],'path_total_score',[],... + 'dpActionScore',[],'dpPathScore',[],'vid',[],... + 'path_boxes',cell(1),'path_scores',cell(1),'video_id',cell(1)); + +if ~isempty(video_paths) + %gt_ind = find(strcmp(video_id,annot.videoName)); + %number_frames = length(video_paths{1}(1).idx); +% alpha = alpha-3.2; + for a = 1 : num_action + action_paths = video_paths{a}; + num_act_paths = getPathCount(action_paths); + for p = 1 : num_act_paths + M = action_paths(p).allScores(:,1:num_action)'; %(:,1:num_action)'; + %M = normM(M); + %M = [M(a,:),1-M(a,:)]; + M = M +20; + + [pred_path,time,D] = dpEM_max(M,alpha(a)); + [ Ts, Te, Scores, Label, DpPathScore] = extract_action(pred_path,time,D,a); + for k = 1 : length(Ts) + final_tubes.starts(action_count) = action_paths(p).start; + final_tubes.ts(action_count) = Ts(k); + final_tubes.video_id{action_count} = video_id; + % final_tubes.vid(action_count) = vid_num; + final_tubes.te(action_count) = Te(k); + final_tubes.dpActionScore(action_count) = Scores(k); + final_tubes.label(action_count) = Label(k); + final_tubes.dpPathScore(action_count) = DpPathScore(k); + final_tubes.path_total_score(action_count) = mean(action_paths(p).scores); + final_tubes.path_boxes{action_count} = action_paths(p).boxes; + final_tubes.path_scores{action_count} = action_paths(p).scores; + action_count = action_count + 1; + end + + end + + end +end +end + +function M = normM(M) +for i = 1: size(M,2) + M(:,i) = M(:,i)/sum(M(:,i)); +end +end +function [ts,te,scores,label,total_score] = extract_action(p,q,D,action) +% p(1:1) = 1; +indexs = find(p==action); + +if isempty(indexs) + ts = []; te = []; scores = []; label = []; total_score = []; + +else + indexs_diff = [indexs,indexs(end)+1] - [indexs(1)-2,indexs]; + ts = find(indexs_diff>1); + + if length(ts)>1 + te = [ts(2:end)-1,length(indexs)]; + else + te = length(indexs); + end + ts = indexs(ts); + te = indexs(te); + scores = (D(action,q(te)) - D(action,q(ts)))./(te-ts); + label = ones(length(ts),1)*action; + total_score = ones(length(ts),1)*D(p(end),q(end))/length(p); +end +end + +% ------------------------------------------------------------------------- +function lp_count = getPathCount(live_paths) +% ------------------------------------------------------------------------- + +if isfield(live_paths,'boxes') + lp_count = length(live_paths); +else + lp_count = 0; +end +end diff --git a/online-tubes/gentube/convert2eval.m b/online-tubes/gentube/convert2eval.m new file mode 100644 index 0000000..643acf6 --- /dev/null +++ b/online-tubes/gentube/convert2eval.m @@ -0,0 +1,57 @@ +% --------------------------------------------------------- +% Copyright (c) 2017, Gurkirt Singh +% This code and is available +% under the terms of MIT License provided in LICENSE. +% Please retain this notice and LICENSE if you use +% this file (or any portion of it) in your project. +% --------------------------------------------------------- +% Input: smoothed tubes +% Output: filtered out tubes with proper scoring + +function xmld = convert2eval(final_tubes,min_num_frames,kthresh,topk,vids) + +xmld = struct([]); +v= 1; + +for vv = 1 : length(vids) + action_indexes = find(strcmp(final_tubes.video_id,vids{vv})); + videoName = vids{vv}; + xmld(v).videoName = videoName; + actionscore = final_tubes.dpActionScore(action_indexes); + path_scores = final_tubes.path_scores(1,action_indexes); + + ts = final_tubes.ts(action_indexes); + starts = final_tubes.starts(action_indexes); + te = final_tubes.te(action_indexes); + act_nr = 1; + + for a = 1 : length(ts) + act_ts = ts(a); + act_te = te(a); +% act_dp_score = actionscore(a); %% only useful on JHMDB + act_path_scores = cell2mat(path_scores(a)); + + %----------------------------------------------------------- + act_scores = sort(act_path_scores(act_ts:act_te),'descend'); + %save('test.mat', 'act_scores'); pause; + + topk_mean = mean(act_scores(1:min(topk,length(act_scores)))); + + bxs = final_tubes.path_boxes{action_indexes(a)}(act_ts:act_te,:); + + bxs = [bxs(:,1:2), bxs(:,3:4)-bxs(:,1:2)]; + + label = final_tubes.label(action_indexes(a)); + + if topk_mean > kthresh(label) && (act_te-act_ts) > min_num_frames + xmld(v).score(act_nr) = topk_mean; + xmld(v).nr(act_nr) = act_nr; + xmld(v).class(act_nr) = label; + xmld(v).framenr(act_nr).fnr = (act_ts:act_te) + starts(a)-1; + xmld(v).boxes(act_nr).bxs = bxs; + act_nr = act_nr+1; + end + end + v = v + 1; + +end diff --git a/online-tubes/gentube/dpEM_max.m b/online-tubes/gentube/dpEM_max.m new file mode 100644 index 0000000..a0001b4 --- /dev/null +++ b/online-tubes/gentube/dpEM_max.m @@ -0,0 +1,93 @@ +% --------------------------------------------------------- +% Original code comes from https://team.inria.fr/perception/research/skeletalquads/ +% Copyright (c) 2014, Georgios Evangelidis and Gurkirt Singh, +% This code and is available +% under the terms of MIT License provided in LICENSE. +% Please retain this notice and LICENSE if you use +% this file (or any portion of it) in your project. +% --------------------------------------------------------- + +% M = <10xnum_frames> +% r = 10 (action labels) +% c = frame indices in a video + +function [p,q,D] = dpEM_max(M,alpha) + +% transition cost for the smoothness term +% V(L1,L2) = 0, if L1=L2 +% V(L1,L2) = alpha, if L1~=L2 + + + +[r,c] = size(M); + + + +% costs +D = zeros(r, c+1); % add an extra column +D(:,1) = 0; % put the maximum cost +D(:, 2:(c+1)) = M; + +v = [1:r]'; + + +%D = M; +phi = zeros(r,c); + +%test = struct([]); +for j = 2:c+1; % c = 1230 + for i = 1:r; % r = 10 + +% test(j).D = D(:, j-1); % fetching prev column 10 rows +% test(j).alpha = alpha*(v~=i); % switching each row for each class +% test(j).D_alpha = [D(:, j-1)-alpha*(v~=i)]; +% test(j).max = max([D(:, j-1)-alpha*(v~=i)]); % for ith class taking the max score + + + [dmax, tb] = max([D(:, j-1)-alpha*(v~=i)]); + %keyboard; + D(i,j) = D(i,j)+dmax; + phi(i,j-1) = tb; + end +end + +% Note: +% the outer loop (j) is to visit one by one each frames +% the inner loop (i) is to get the max score for each action label +% the -alpha*(v~=i) term is to add a penalty by subtracting alpha from the +% data term for all other class labels other than i, for ith class label +% it adds zero penalty; +% (v~=i) will return a logical array consists of 10 elements, in the ith +% location it is 0 (false becuase the condition v~=i is false) and all other locations +% returns 1, thus for ith calss it multiplies 0 +% with alpha and for the rest of the classes multiplies 1; +% for each iteration of ith loop we get a max value which we add to the +% data term d(i,j), in this way the 10 max values for 10 different action +% labels are stored to the jth column (or for the jth frame): D(1,j), D(2,j),...,D(10,j), + +% save('test.mat','r','c','M', 'phi'); +% pause; + +% Traceback from last frame +D = D(:,2:(c+1)); + +% best of the last column +q = c; % frame inidces +[~,p] = max(D(:,c)); + + + +i = p; % index of max element in last column of D, +j = q; % frame indices + +while j>1 % loop over frames in a video + tb = phi(i,j); % i -> index of max element in last column of D, j-> last frame index or last column of D + p = [tb,p]; + q = [j-1,q]; + j = j-1; + i = tb; +end + +% +% phi(i,j) stores all the max indices in the forward pass +% during the backward pass , a predicited path is constructed using these indices values diff --git a/online-tubes/gentube/readALLactionPaths.m b/online-tubes/gentube/readALLactionPaths.m new file mode 100644 index 0000000..d5b8505 --- /dev/null +++ b/online-tubes/gentube/readALLactionPaths.m @@ -0,0 +1,47 @@ +% --------------------------------------------------------- +% Copyright (c) 2017, Gurkirt Singh +% This code and is available +% under the terms of MIT License provided in LICENSE. +% Please retain this notice and LICENSE if you use +% this file (or any portion of it) in your project. +% --------------------------------------------------------- + +function actionpath = readALLactionPaths(videolist,actionPathDir,step) + +videos = getVideoNames(videolist); +NumVideos = length(videos); + +actionpath = struct([]); +fprintf('Loading action paths of %d videos\n',NumVideos); +count = 1; +for vid=1:step:NumVideos + + videoID = videos(vid).video_id; + pathsSaveName = [actionPathDir,videoID,'-actionpaths.mat']; + + if ~exist(pathsSaveName,'file') + error('Action path does not exist please genrate actin path', pathsSaveName) + else +% fprintf('loading vid %d %s \n',vid,pathsSaveName); + load(pathsSaveName); + actionpath(count).video_id = videos(vid).video_id; + actionpath(count).paths = allpaths; + count = count+1; + end +end +end + +function [videos] = getVideoNames(split_file) +% ------------------------------------------------------------------------- +fid = fopen(split_file,'r'); +data = textscan(fid, '%s'); +videos = struct(); +for i=1:length(data{1}) + filename = cell2mat(data{1}(i,1)); + videos(i).video_id = filename; + % videos(i).vid = str2num(cell2mat(data{1}(i,1))); + +end +count = length(data{1}); + +end diff --git a/online-tubes/utils/createdires.m b/online-tubes/utils/createdires.m new file mode 100644 index 0000000..777d79f --- /dev/null +++ b/online-tubes/utils/createdires.m @@ -0,0 +1,20 @@ +% --------------------------------------------------------- +% Copyright (c) 2017, Gurkirt Singh +% This code and is available +% under the terms of MIT License provided in LICENSE. +% Please retain this notice and LICENSE if you use +% this file (or any portion of it) in your project. +% --------------------------------------------------------- + + +function createdires(basedirs,actions) +for s = 1: length(basedirs) + savename = basedirs{s}; + for action = actions + saveNameaction = [savename,action{1}]; + if ~isdir(saveNameaction) + mkdir(saveNameaction); + end + end +end +end \ No newline at end of file diff --git a/online-tubes/utils/initDatasetOpts.m b/online-tubes/utils/initDatasetOpts.m new file mode 100644 index 0000000..fada11f --- /dev/null +++ b/online-tubes/utils/initDatasetOpts.m @@ -0,0 +1,60 @@ +% --------------------------------------------------------- +% Copyright (c) 2017, Gurkirt Singh +% This code and is available +% under the terms of MIT License provided in LICENSE. +% Please retain this notice and LICENSE if you use +% this file (or any portion of it) in your project. +% --------------------------------------------------------- + +function opts = initDatasetOpts(data_root,baseDir,dataset,imgType,model_type,listid,iteration_num,iouthresh,costtype,gap) + +opts = struct(); +opts.imgType = imgType; +opts.costtype = costtype; +opts.gap = gap; +opts.baseDir = baseDir; +opts.imgType = imgType; +opts.dataset = dataset; +opts.iouThresh = iouthresh; +opts.weight = iteration_num; +opts.listid = listid; + +testlist = ['testlist',listid]; +opts.vidList = sprintf('%s/%s/splitfiles/%s.txt',data_root,dataset,testlist); + +if strcmp(dataset,'ucf24') + opts.actions = {'Basketball','BasketballDunk','Biking','CliffDiving','CricketBowling',... + 'Diving','Fencing','FloorGymnastics','GolfSwing','HorseRiding','IceDancing',... + 'LongJump','PoleVault','RopeClimbing','SalsaSpin','SkateBoarding','Skiing',... + 'Skijet','SoccerJuggling','Surfing','TennisSwing','TrampolineJumping',... + 'VolleyballSpiking','WalkingWithDog'}; +elseif strcmp(dataset,'JHMDB') + opts.actions = {'brush_hair','catch','clap','climb_stairs','golf','jump',... + 'kick_ball','pick','pour','pullup','push','run','shoot_ball','shoot_bow',... + 'shoot_gun','sit','stand','swing_baseball','throw','walk','wave'}; +elseif strcmp(dataset,'LIRIS') + opts.actions = {'discussion', 'give_object_to_person','put_take_obj_into_from_box_desk',... + 'enter_leave_room_no_unlocking','try_enter_room_unsuccessfully','unlock_enter_leave_room',... + 'leave_baggage_unattended','handshaking','typing_on_keyboard','telephone_conversation'}; +end + +opts.imgDir = sprintf('%s/%s/%s-images/',data_root,dataset,imgType); + +opts.detDir = sprintf('%s/%s/detections/%s-%s-%s-%06d/',baseDir,dataset,model_type,imgType,listid,iteration_num); +opts.annotFile = sprintf('%s/%s/splitfiles/annots.mat',data_root,dataset); + +opts.actPathDir = sprintf('%s/%s/actionPaths/%s-%s-%s-%06d-%s-%d-%04d/',baseDir,dataset,model_type,imgType,listid,iteration_num,costtype,gap,iouthresh*100); +opts.tubeDir = sprintf('%s/%s/actionTubes/%s-%s-%s-%06d-%s-%d-%04d/',baseDir,dataset,model_type,imgType,listid,iteration_num,costtype,gap,iouthresh*100); + +if exist(opts.detDir,'dir') + if ~isdir(opts.actPathDir) + fprintf('Creating %s\n',opts.actPathDir); + mkdir(opts.actPathDir) + end + if ~isdir(opts.tubeDir) + mkdir(opts.tubeDir) + end + if strcmp(dataset,'ucf24') || strcmp(dataset,'JHMDB') + createdires({opts.actPathDir},opts.actions) + end +end diff --git a/online-tubes/utils/initDatasetOptsFused.m b/online-tubes/utils/initDatasetOptsFused.m new file mode 100644 index 0000000..bf13e49 --- /dev/null +++ b/online-tubes/utils/initDatasetOptsFused.m @@ -0,0 +1,74 @@ +% --------------------------------------------------------- +% Copyright (c) 2017, Gurkirt Singh +% This code and is available +% under the terms of MIT License provided in LICENSE. +% Please retain this notice and LICENSE if you use +% this file (or any portion of it) in your project. +% --------------------------------------------------------- + +function opts = initDatasetOptsFused(data_root,baseDir,dataset,imtypes,model_type, ... + listid,iteration_nums,iouthresh,costtype,gap,fusiontype,fuseiouth) +%% data_root,baseDir,dataset,imgType,model_type,listid,iteration_num,iouthresh,costtype,gap + +opts = struct(); +imgType = [imtypes{1},'-',imtypes{2}]; +opts.imgType = imgType; +opts.costtype = costtype; +opts.gap = gap; +opts.baseDir = baseDir; +opts.imgType = imgType; +opts.dataset = dataset; +opts.iouThresh = iouthresh; +opts.iteration_nums = iteration_nums; +opts.listid = listid; +opts.fusiontype = fusiontype; +opts.fuseiouth = fuseiouth; +testlist = ['testlist',listid]; +opts.data_root = data_root; +opts.vidList = sprintf('%s/%s/splitfiles/%s.txt',data_root,dataset,testlist); + +if strcmp(dataset,'ucf24') + opts.actions = {'Basketball','BasketballDunk','Biking','CliffDiving','CricketBowling',... + 'Diving','Fencing','FloorGymnastics','GolfSwing','HorseRiding','IceDancing',... + 'LongJump','PoleVault','RopeClimbing','SalsaSpin','SkateBoarding','Skiing',... + 'Skijet','SoccerJuggling','Surfing','TennisSwing','TrampolineJumping',... + 'VolleyballSpiking','WalkingWithDog'}; +elseif strcmp(dataset,'JHMDB') + opts.actions = {'brush_hair','catch','clap','climb_stairs','golf','jump',... + 'kick_ball','pick','pour','pullup','push','run','shoot_ball','shoot_bow',... + 'shoot_gun','sit','stand','swing_baseball','throw','walk','wave'}; +elseif strcmp(dataset,'LIRIS') + opts.actions = {'discussion', 'give_object_to_person','put_take_obj_into_from_box_desk',... + 'enter_leave_room_no_unlocking','try_enter_room_unsuccessfully','unlock_enter_leave_room',... + 'leave_baggage_unattended','handshaking','typing_on_keyboard','telephone_conversation'}; +end + +opts.imgDir = sprintf('%s/%s/%s-images/',data_root,dataset,imtypes{1}); + +opts.basedetDir = sprintf('%s/%s/detections/%s-%s-%s-%06d/',baseDir,dataset,model_type,imtypes{1},listid,iteration_nums(1)); +opts.topdetDir = sprintf('%s/%s/detections/%s-%s-%s-%06d/',baseDir,dataset,model_type,imtypes{2},listid,iteration_nums(2)); + +opts.annotFile = sprintf('%s/%s/splitfiles/annots.mat',data_root,dataset); + +opts.actPathDir = sprintf('%s/%s/actionPaths/%s/%s-%s-%s-%s-%d-%d-%s-%d-%04d-fiou%03d/',baseDir,dataset,fusiontype,model_type,imtypes{1},imtypes{2},... + listid,iteration_nums(1),iteration_nums(2),costtype,gap,iouthresh*100,uint16(fuseiouth*100)); +opts.tubeDir = sprintf('%s/%s/actionTubes/%s/%s-%s-%s-%s-%d-%d-%s-%d-%04d-fiou%03d/',baseDir,dataset,fusiontype,model_type,imtypes{1},imtypes{2},... + listid,iteration_nums(1),iteration_nums(2),costtype,gap,iouthresh*100,uint16(fuseiouth*100)); + +if exist(opts.basedetDir,'dir') + if ~isdir(opts.actPathDir) + fprintf('Creating %s\n',opts.actPathDir); + mkdir(opts.actPathDir) + end + + if ~isdir(opts.tubeDir) + mkdir(opts.tubeDir) + end + + if strcmp(dataset,'ucf24') || strcmp(dataset,'JHMDB') + createdires({opts.actPathDir},opts.actions) + end +end + +%fprintf('Video List :: %s\nImage Dir :: %s\nDetection Dir:: %s\nActionpath Dir:: %s\nTube Dir:: %s\n',... + % opts.vidList,opts.imgDir,opts.detDir,opts.actPathDir,opts.tubeDir) diff --git a/ssd.py b/ssd.py new file mode 100644 index 0000000..6bedc0a --- /dev/null +++ b/ssd.py @@ -0,0 +1,205 @@ + +""" SSD network Classes + +Original author: Ellis Brown, Max deGroot for VOC dataset +https://github.com/amdegroot/ssd.pytorch + +Updated by Gurkirt Singh for ucf101-24 dataset +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from layers import * +from data import v2 +import os + + +class SSD(nn.Module): + """Single Shot Multibox Architecture + The network is composed of a base VGG network followed by the + added multibox conv layers. Each multibox layer branches into + 1) conv2d for class conf scores + 2) conv2d for localization predictions + 3) associated priorbox layer to produce default bounding + boxes specific to the layer's feature map size. + See: https://arxiv.org/pdf/1512.02325.pdf for more details. + + Args: + base: VGG16 layers for input, size of either 300 or 500 + extras: extra layers that feed to multibox loc and conf layers + head: "multibox head" consists of loc and conf conv layers + """ + + def __init__(self, base, extras, head, num_classes): + super(SSD, self).__init__() + + self.num_classes = num_classes + # TODO: implement __call__ in PriorBox + self.priorbox = PriorBox(v2) + self.priors = Variable(self.priorbox.forward(), volatile=True) + self.num_priors = self.priors.size(0) + self.size = 300 + + # SSD network + self.vgg = nn.ModuleList(base) + # Layer learns to scale the l2 normalized features from conv4_3 + self.L2Norm = L2Norm(512, 20) + self.extras = nn.ModuleList(extras) + + self.loc = nn.ModuleList(head[0]) + self.conf = nn.ModuleList(head[1]) + + self.softmax = nn.Softmax().cuda() + # self.detect = Detect(num_classes, 0, 200, 0.001, 0.45) + + def forward(self, x): + + """Applies network layers and ops on input image(s) x. + + Args: + x: input image or batch of images. Shape: [batch,3*batch,300,300]. + + Return: + Depending on phase: + test: + Variable(tensor) of output class label predictions, + confidence score, and corresponding location predictions for + each object detected. Shape: [batch,topk,7] + + train: + list of concat outputs from: + 1: confidence layers, Shape: [batch*num_priors,num_classes] + 2: localization layers, Shape: [batch,num_priors*4] + 3: priorbox layers, Shape: [2,num_priors*4] + """ + + sources = list() + loc = list() + conf = list() + + # apply vgg up to conv4_3 relu + for k in range(23): + x = self.vgg[k](x) + + s = self.L2Norm(x) + sources.append(s) + + # apply vgg up to fc7 + for k in range(23, len(self.vgg)): + x = self.vgg[k](x) + sources.append(x) + + # apply extra layers and cache source layer outputs + for k, v in enumerate(self.extras): + x = F.relu(v(x), inplace=True) + if k % 2 == 1: + sources.append(x) + + # apply multibox head to source layers + for (x, l, c) in zip(sources, self.loc, self.conf): + loc.append(l(x).permute(0, 2, 3, 1).contiguous()) + conf.append(c(x).permute(0, 2, 3, 1).contiguous()) + + loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1) + conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1) + output = (loc.view(loc.size(0), -1, 4), + conf.view(conf.size(0), -1, self.num_classes), + self.priors + ) + return output + + def load_weights(self, base_file): + other, ext = os.path.splitext(base_file) + if ext == '.pkl' or '.pth': + print('Loading weights into state dict...') + self.load_state_dict(torch.load(base_file, map_location=lambda storage, loc: storage)) + print('Finished!') + else: + print('Sorry only .pth and .pkl files supported.') + + +# This function is derived from torchvision VGG make_layers() +# https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py +def vgg(cfg, i, batch_norm=False): + layers = [] + in_channels = i + for v in cfg: + if v == 'M': + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + elif v == 'C': + layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)] + else: + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1) + conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6) + conv7 = nn.Conv2d(1024, 1024, kernel_size=1) + layers += [pool5, conv6, + nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)] + return layers + + +def add_extras(cfg, i, batch_norm=False): + # Extra layers added to VGG for feature scaling + layers = [] + in_channels = i + flag = False + for k, v in enumerate(cfg): + if in_channels != 'S': + if v == 'S': + layers += [nn.Conv2d(in_channels, cfg[k + 1], + kernel_size=(1, 3)[flag], stride=2, padding=1)] + else: + layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])] + flag = not flag + in_channels = v + return layers + + +def multibox(vgg, extra_layers, cfg, num_classes): + loc_layers = [] + conf_layers = [] + vgg_source = [24, -2] + for k, v in enumerate(vgg_source): + loc_layers += [nn.Conv2d(vgg[v].out_channels, + cfg[k] * 4, kernel_size=3, padding=1)] + conf_layers += [nn.Conv2d(vgg[v].out_channels, + cfg[k] * num_classes, kernel_size=3, padding=1)] + for k, v in enumerate(extra_layers[1::2], 2): + loc_layers += [nn.Conv2d(v.out_channels, cfg[k] + * 4, kernel_size=3, padding=1)] + conf_layers += [nn.Conv2d(v.out_channels, cfg[k] + * num_classes, kernel_size=3, padding=1)] + return vgg, extra_layers, (loc_layers, conf_layers) + + +base = { + '300': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M', + 512, 512, 512], + '512': [], +} +extras = { + '300': [256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256], + '512': [], +} +mbox = { + '300': [4, 6, 6, 6, 4, 4], # number of boxes per feature map location + '512': [], +} + + +def build_ssd(size=300, num_classes=21): + + if size != 300: + print("Error: Sorry only SSD300 is supported currently!") + return + + return SSD(*multibox(vgg(base[str(size)], 3), + add_extras(extras[str(size)], 1024), + mbox[str(size)], num_classes), num_classes) diff --git a/test-ucf24.py b/test-ucf24.py new file mode 100644 index 0000000..e30dc02 --- /dev/null +++ b/test-ucf24.py @@ -0,0 +1,223 @@ +"""Adapted from: + @longcw faster_rcnn_pytorch: https://github.com/longcw/faster_rcnn_pytorch + @rbgirshick py-faster-rcnn https://github.com/rbgirshick/py-faster-rcnn + Which was adopated by: Ellis Brown, Max deGroot + https://github.com/amdegroot/ssd.pytorch + + Further: + Updated by Gurkirt Singh for ucf101-24 dataset + Licensed under The MIT License [see LICENSE for details] + +""" + +import torch +import torch.backends.cudnn as cudnn +from torch.autograd import Variable +from data import AnnotationTransform, UCF24Detection, BaseTransform, CLASSES, detection_collate, v2 +from ssd import build_ssd +import torch.utils.data as data +from layers.box_utils import decode, nms +from utils.evaluation import evaluate_detections +import os, time +import argparse +import numpy as np +import pickle +import scipy.io as sio # to save detection as mat files +cfg = v2 + +def str2bool(v): + return v.lower() in ("yes", "true", "t", "1") + +parser = argparse.ArgumentParser(description='Single Shot MultiBox Detector Training') +parser.add_argument('--version', default='v2', help='conv11_2(v2) or pool6(v1) as last layer') +parser.add_argument('--basenet', default='vgg16_reducedfc.pth', help='pretrained base model') +parser.add_argument('--dataset', default='ucf24', help='pretrained base model') +parser.add_argument('--ssd_dim', default=300, type=int, help='Input Size for SSD') # only support 300 now +parser.add_argument('--input_type', default='rgb', type=str, help='INput tyep default rgb can take flow as well') +parser.add_argument('--jaccard_threshold', default=0.5, type=float, help='Min Jaccard index for matching') +parser.add_argument('--batch_size', default=32, type=int, help='Batch size for training') +parser.add_argument('--resume', default=None, type=str, help='Resume from checkpoint') +parser.add_argument('--num_workers', default=0, type=int, help='Number of workers used in dataloading') +parser.add_argument('--max_iter', default=90000, type=int, help='Number of training iterations') +parser.add_argument('--eval_iter', default='50000,70000,90000', type=str, help='Number of training iterations') +parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda to train model') +parser.add_argument('--ngpu', default=1, type=str2bool, help='Use cuda to train model') +parser.add_argument('--lr', '--learning-rate', default=0.001, type=float, help='initial learning rate') +parser.add_argument('--visdom', default=False, type=str2bool, help='Use visdom to for loss visualization') +parser.add_argument('--data_root', default='/mnt/mars-fast/datasets/', help='Location of VOC root directory') +parser.add_argument('--save_root', default='/mnt/mars-gamma/ssd-work/', help='Location to save checkpoint models') +parser.add_argument('--iou_thresh', default=0.5, type=float, help='Evaluation threshold') +parser.add_argument('--conf_thresh', default=0.01, type=float, help='Confidence threshold for evaluation') +parser.add_argument('--nms_thresh', default=0.45, type=float, help='NMS threshold') +parser.add_argument('--topk', default=50, type=int, help='topk for evaluation') + +args = parser.parse_args() + +if args.cuda and torch.cuda.is_available(): + torch.set_default_tensor_type('torch.cuda.FloatTensor') +else: + torch.set_default_tensor_type('torch.FloatTensor') + + +def test_net(net, save_root, exp_name, input_type, dataset, iteration, num_classes, thresh=0.5 ): + """ Test a SSD network on an Action image database. """ + + val_data_loader = data.DataLoader(dataset, args.batch_size, num_workers=args.num_workers, + shuffle=False, collate_fn=detection_collate, pin_memory=True) + image_ids = dataset.ids + save_ids = [] + val_step = 250 + num_images = len(dataset) + video_list = dataset.video_list + det_boxes = [[] for _ in range(len(CLASSES))] + gt_boxes = [] + print_time = True + batch_iterator = None + count = 0 + torch.cuda.synchronize() + ts = time.perf_counter() + num_batches = len(val_data_loader) + det_file = save_root + 'cache/' + exp_name + '/detection-'+str(iteration).zfill(6)+'.pkl' + print('Number of images ', len(dataset),' number of batchs', num_batches) + frame_save_dir = save_root+'detections/CONV-'+input_type+'-'+args.listid+'-'+str(iteration).zfill(6)+'/' + print('\n\n\nDetections will be store in ',frame_save_dir,'\n\n') + for val_itr in range(len(val_data_loader)): + if not batch_iterator: + batch_iterator = iter(val_data_loader) + + torch.cuda.synchronize() + t1 = time.perf_counter() + + images, targets, img_indexs = next(batch_iterator) + batch_size = images.size(0) + height, width = images.size(2), images.size(3) + + if args.cuda: + images = Variable(images.cuda(), volatile=True) + output = net(images) + + loc_data = output[0] + conf_preds = output[1] + prior_data = output[2] + + if print_time and val_itr%val_step == 0: + torch.cuda.synchronize() + tf = time.perf_counter() + print('Forward Time {:0.3f}'.format(tf - t1)) + for b in range(batch_size): + gt = targets[b].numpy() + gt[:, 0] *= width + gt[:, 2] *= width + gt[:, 1] *= height + gt[:, 3] *= height + gt_boxes.append(gt) + decoded_boxes = decode(loc_data[b].data, prior_data.data, cfg['variance']).clone() + conf_scores = net.softmax(conf_preds[b]).data.clone() + index = img_indexs[b] + annot_info = image_ids[index] + + frame_num = annot_info[1]; video_id = annot_info[0]; videoname = video_list[video_id] + output_dir = frame_save_dir+videoname + if not os.path.isdir(output_dir): + os.makedirs(output_dir) + + output_file_name = output_dir+'/{:05d}.mat'.format(int(frame_num)) + save_ids.append(output_file_name) + sio.savemat(output_file_name, mdict={'scores':conf_scores.cpu().numpy(),'loc':decoded_boxes.cpu().numpy()}) + + for cl_ind in range(1, num_classes): + scores = conf_scores[:, cl_ind].squeeze() + c_mask = scores.gt(args.conf_thresh) # greater than minmum threshold + scores = scores[c_mask].squeeze() + # print('scores size',scores.size()) + if scores.dim() == 0: + # print(len(''), ' dim ==0 ') + det_boxes[cl_ind - 1].append(np.asarray([])) + continue + boxes = decoded_boxes.clone() + l_mask = c_mask.unsqueeze(1).expand_as(boxes) + boxes = boxes[l_mask].view(-1, 4) + # idx of highest scoring and non-overlapping boxes per class + ids, counts = nms(boxes, scores, args.nms_thresh, args.topk) # idsn - ids after nms + scores = scores[ids[:counts]].cpu().numpy() + boxes = boxes[ids[:counts]].cpu().numpy() + # print('boxes sahpe',boxes.shape) + boxes[:, 0] *= width + boxes[:, 2] *= width + boxes[:, 1] *= height + boxes[:, 3] *= height + + for ik in range(boxes.shape[0]): + boxes[ik, 0] = max(0, boxes[ik, 0]) + boxes[ik, 2] = min(width, boxes[ik, 2]) + boxes[ik, 1] = max(0, boxes[ik, 1]) + boxes[ik, 3] = min(height, boxes[ik, 3]) + + cls_dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=True) + det_boxes[cl_ind - 1].append(cls_dets) + + count += 1 + if val_itr%val_step == 0: + torch.cuda.synchronize() + te = time.perf_counter() + print('im_detect: {:d}/{:d} time taken {:0.3f}'.format(count, num_images, te - ts)) + torch.cuda.synchronize() + ts = time.perf_counter() + if print_time and val_itr%val_step == 0: + torch.cuda.synchronize() + te = time.perf_counter() + print('NMS stuff Time {:0.3f}'.format(te - tf)) + print('Evaluating detections for itration number ', iteration) + + #Save detection after NMS along with GT + with open(det_file, 'wb') as f: + pickle.dump([gt_boxes, det_boxes, save_ids], f, pickle.HIGHEST_PROTOCOL) + + return evaluate_detections(gt_boxes, det_boxes, CLASSES, iou_thresh=thresh) + + +def main(): + + means = (104, 117, 123) # only support voc now + + exp_name = 'CONV-SSD-{}-{}-bs-{}-{}-lr-{:05d}'.format(args.dataset, args.input_type, + args.batch_size, args.basenet[:-14], int(args.lr * 100000)) + + args.save_root += args.dataset+'/' + args.data_root += args.dataset+'/' + args.listid = '01' ## would be usefull in JHMDB-21 + print('Exp name', exp_name, args.listid) + for iteration in [int(itr) for itr in args.eval_iter.split(',')]: + log_file = open(args.save_root + 'cache/' + exp_name + "/testing-{:d}.log".format(iteration), "w", 1) + log_file.write(exp_name + '\n') + trained_model_path = args.save_root + 'cache/' + exp_name + '/ssd300_ucf24_' + repr(iteration) + '.pth' + log_file.write(trained_model_path+'\n') + num_classes = len(CLASSES) + 1 #7 +1 background + net = build_ssd(300, num_classes) # initialize SSD + net.load_state_dict(torch.load(trained_model_path)) + net.eval() + if args.cuda: + net = net.cuda() + cudnn.benchmark = True + print('Finished loading model %d !' % iteration) + # Load dataset + dataset = UCF24Detection(args.data_root, 'test', BaseTransform(args.ssd_dim, means), AnnotationTransform(), + input_type=args.input_type, full_test=True) + # evaluation + torch.cuda.synchronize() + tt0 = time.perf_counter() + log_file.write('Testing net \n') + mAP, ap_all, ap_strs = test_net(net, args.save_root, exp_name, args.input_type, dataset, iteration, num_classes) + for ap_str in ap_strs: + print(ap_str) + log_file.write(ap_str + '\n') + ptr_str = '\nMEANAP:::=>' + str(mAP) + '\n' + print(ptr_str) + log_file.write(ptr_str) + + torch.cuda.synchronize() + print('Complete set time {:0.2f}'.format(time.perf_counter() - tt0)) + log_file.close() + +if __name__ == '__main__': + main() diff --git a/train-ucf24.py b/train-ucf24.py new file mode 100644 index 0000000..6b67758 --- /dev/null +++ b/train-ucf24.py @@ -0,0 +1,412 @@ + +""" Adapted from: + @longcw faster_rcnn_pytorch: https://github.com/longcw/faster_rcnn_pytorch + @rbgirshick py-faster-rcnn https://github.com/rbgirshick/py-faster-rcnn + Which was adopated by: Ellis Brown, Max deGroot + https://github.com/amdegroot/ssd.pytorch + + Further: + Updated by Gurkirt Singh for ucf101-24 dataset + Licensed under The MIT License [see LICENSE for details] +""" + +import os +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.init as init +import argparse +from torch.autograd import Variable +import torch.utils.data as data +from data import v2, UCF24Detection, AnnotationTransform, detection_collate, CLASSES, BaseTransform +from utils.augmentations import SSDAugmentation +from layers.modules import MultiBoxLoss +from ssd import build_ssd +import numpy as np +import time +from utils.evaluation import evaluate_detections +from layers.box_utils import decode, nms +from utils import AverageMeter +from torch.optim.lr_scheduler import MultiStepLR + +def str2bool(v): + return v.lower() in ("yes", "true", "t", "1") + + +parser = argparse.ArgumentParser(description='Single Shot MultiBox Detector Training') +parser.add_argument('--version', default='v2', help='conv11_2(v2) or pool6(v1) as last layer') +parser.add_argument('--basenet', default='vgg16_reducedfc.pth', help='pretrained base model') +parser.add_argument('--dataset', default='ucf24', help='pretrained base model') +parser.add_argument('--ssd_dim', default=300, type=int, help='Input Size for SSD') # only support 300 now +parser.add_argument('--input_type', default='rgb', type=str, help='INput tyep default rgb options are [rgb,brox,fastOF]') +parser.add_argument('--jaccard_threshold', default=0.5, type=float, help='Min Jaccard index for matching') +parser.add_argument('--batch_size', default=32, type=int, help='Batch size for training') +parser.add_argument('--resume', default=None, type=str, help='Resume from checkpoint') +parser.add_argument('--num_workers', default=0, type=int, help='Number of workers used in dataloading') +parser.add_argument('--max_iter', default=90000, type=int, help='Number of training iterations') +parser.add_argument('--man_seed', default=123, type=int, help='manualseed for reproduction') +parser.add_argument('--cuda', default=True, type=str2bool, help='Use cuda to train model') +parser.add_argument('--ngpu', default=1, type=str2bool, help='Use cuda to train model') +parser.add_argument('--lr', '--learning-rate', default=0.0005, type=float, help='initial learning rate') +parser.add_argument('--momentum', default=0.9, type=float, help='momentum') +parser.add_argument('--stepvalues', default='70000,120000', type=str, help='iter number wher elearing rate to be dropped') +parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for SGD') +parser.add_argument('--gamma', default=0.2, type=float, help='Gamma update for SGD') +parser.add_argument('--log_iters', default=True, type=bool, help='Print the loss at each iteration') +parser.add_argument('--visdom', default=False, type=str2bool, help='Use visdom to for loss visualization') +parser.add_argument('--data_root', default='/mnt/mars-fast/datasets/', help='Location of VOC root directory') +parser.add_argument('--save_root', default='/mnt/mars-gamma/ssd-work/', help='Location to save checkpoint models') +parser.add_argument('--iou_thresh', default=0.5, type=float, help='Evaluation threshold') +parser.add_argument('--conf_thresh', default=0.01, type=float, help='Confidence threshold for evaluation') +parser.add_argument('--nms_thresh', default=0.45, type=float, help='NMS threshold') +parser.add_argument('--topk', default=50, type=int, help='topk for evaluation') + +## Parse arguments +args = parser.parse_args() +## set random seeds +np.random.seed(args.man_seed) +torch.manual_seed(args.man_seed) +if args.cuda: + torch.cuda.manual_seed_all(args.man_seed) + +if args.cuda and torch.cuda.is_available(): + torch.set_default_tensor_type('torch.cuda.FloatTensor') +else: + torch.set_default_tensor_type('torch.FloatTensor') + + +def main(): + args.cfg = v2 + args.train_sets = 'train' + args.means = (104, 117, 123) + num_classes = len(CLASSES) + 1 + args.num_classes = num_classes + args.stepvalues = [int(val) for val in args.stepvalues.split(',')] + args.loss_reset_step = 30 + args.eval_step = 10000 + args.print_step = 10 + + ## Define the experiment Name will used to same directory and ENV for visdom + args.exp_name = 'CONV-SSD-{}-{}-bs-{}-{}-lr-{:05d}'.format(args.dataset, + args.input_type, args.batch_size, args.basenet[:-14], int(args.lr*100000)) + + args.save_root += args.dataset+'/' + args.save_root = args.save_root+'cache/'+args.exp_name+'/' + + if not os.path.isdir(args.save_root): + os.makedirs(args.save_root) + + net = build_ssd(300, args.num_classes) + if args.input_type == 'fastOF': + print('Download pretrained brox flow trained model weights and place them at:::=> ',args.data_root + 'ucf24/train_data/brox_wieghts.pth') + pretrained_weights = args.data_root + 'ucf24/train_data/brox_wieghts.pth' + print('Loading base network...') + net.load_state_dict(torch.load(pretrained_weights)) + else: + vgg_weights = torch.load(args.data_root +'ucf24/train_data/' + args.basenet) + print('Loading base network...') + net.vgg.load_state_dict(vgg_weights) + + args.data_root += args.dataset + '/' + + if args.cuda: + net = net.cuda() + + def xavier(param): + init.xavier_uniform(param) + + def weights_init(m): + if isinstance(m, nn.Conv2d): + xavier(m.weight.data) + m.bias.data.zero_() + + + print('Initializing weights for extra layers and HEADs...') + # initialize newly added layers' weights with xavier method + net.extras.apply(weights_init) + net.loc.apply(weights_init) + net.conf.apply(weights_init) + + parameter_dict = dict(net.named_parameters()) # Get parmeter of network in dictionary format wtih name being key + params = [] + + #Set different learning rate to bias layers and set their weight_decay to 0 + for name, param in parameter_dict.items(): + if name.find('bias') > -1: + print(name, 'layer parameters will be trained @ {}'.format(args.lr*2)) + params += [{'params': [param], 'lr': args.lr*2, 'weight_decay': 0}] + else: + print(name, 'layer parameters will be trained @ {}'.format(args.lr)) + params += [{'params':[param], 'lr': args.lr, 'weight_decay':args.weight_decay}] + + optimizer = optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + criterion = MultiBoxLoss(args.num_classes, 0.5, True, 0, True, 3, 0.5, False, args.cuda) + scheduler = MultiStepLR(optimizer, milestones=args.stepvalues, gamma=args.gamma) + train(args, net, optimizer, criterion, scheduler) + + +def train(args, net, optimizer, criterion, scheduler): + log_file = open(args.save_root+"training.log", "w", 1) + log_file.write(args.exp_name+'\n') + for arg in vars(args): + print(arg, getattr(args, arg)) + log_file.write(str(arg)+': '+str(getattr(args, arg))+'\n') + log_file.write(str(net)) + net.train() + + # loss counters + batch_time = AverageMeter() + losses = AverageMeter() + loc_losses = AverageMeter() + cls_losses = AverageMeter() + + print('Loading Dataset...') + train_dataset = UCF24Detection(args.data_root, args.train_sets, SSDAugmentation(args.ssd_dim, args.means), + AnnotationTransform(), input_type=args.input_type) + val_dataset = UCF24Detection(args.data_root, 'test', BaseTransform(args.ssd_dim, args.means), + AnnotationTransform(), input_type=args.input_type, + full_test=False) + epoch_size = len(train_dataset) // args.batch_size + print('Training SSD on', train_dataset.name) + + if args.visdom: + + import visdom + viz = visdom.Visdom() + viz.port = 8097 + viz.env = args.exp_name + # initialize visdom loss plot + lot = viz.line( + X=torch.zeros((1,)).cpu(), + Y=torch.zeros((1, 6)).cpu(), + opts=dict( + xlabel='Iteration', + ylabel='Loss', + title='Current SSD Training Loss', + legend=['REG', 'CLS', 'AVG', 'S-REG', ' S-CLS', ' S-AVG'] + ) + ) + # initialize visdom meanAP and class APs plot + legends = ['meanAP'] + for cls in CLASSES: + legends.append(cls) + val_lot = viz.line( + X=torch.zeros((1,)).cpu(), + Y=torch.zeros((1,args.num_classes)).cpu(), + opts=dict( + xlabel='Iteration', + ylabel='Mean AP', + title='Current SSD Validation mean AP', + legend=legends + ) + ) + + + batch_iterator = None + train_data_loader = data.DataLoader(train_dataset, args.batch_size, num_workers=args.num_workers, + shuffle=True, collate_fn=detection_collate, pin_memory=True) + val_data_loader = data.DataLoader(val_dataset, args.batch_size, num_workers=args.num_workers, + shuffle=False, collate_fn=detection_collate, pin_memory=True) + itr_count = 0 + torch.cuda.synchronize() + t0 = time.perf_counter() + for iteration in range(args.max_iter + 1): + if (not batch_iterator) or (iteration % epoch_size == 0): + # create batch iterator + batch_iterator = iter(train_data_loader) + + # load train data + images, targets, img_indexs = next(batch_iterator) + if args.cuda: + images = Variable(images.cuda()) + targets = [Variable(anno.cuda(), volatile=True) for anno in targets] + else: + images = Variable(images) + targets = [Variable(anno, volatile=True) for anno in targets] + # forward + out = net(images) + # backprop + optimizer.zero_grad() + + loss_l, loss_c = criterion(out, targets) + loss = loss_l + loss_c + loss.backward() + optimizer.step() + scheduler.step() + loc_loss = loss_l.data[0] + conf_loss = loss_c.data[0] + # print('Loss data type ',type(loc_loss)) + loc_losses.update(loc_loss) + cls_losses.update(conf_loss) + losses.update((loc_loss + conf_loss)/2.0) + + + if iteration % args.print_step == 0 and iteration>0: + if args.visdom: + losses_list = [loc_losses.val, cls_losses.val, losses.val, loc_losses.avg, cls_losses.avg, losses.avg] + viz.line(X=torch.ones((1, 6)).cpu() * iteration, + Y=torch.from_numpy(np.asarray(losses_list)).unsqueeze(0).cpu(), + win=lot, + update='append') + + + torch.cuda.synchronize() + t1 = time.perf_counter() + batch_time.update(t1 - t0) + + print_line = 'Itration {:06d}/{:06d} loc-loss {:.3f}({:.3f}) cls-loss {:.3f}({:.3f}) ' \ + 'average-loss {:.3f}({:.3f}) Timer {:0.3f}({:0.3f})'.format( + iteration, args.max_iter, loc_losses.val, loc_losses.avg, cls_losses.val, + cls_losses.avg, losses.val, losses.avg, batch_time.val, batch_time.avg) + + torch.cuda.synchronize() + t0 = time.perf_counter() + log_file.write(print_line+'\n') + print(print_line) + + # if args.visdom and args.send_images_to_visdom: + # random_batch_index = np.random.randint(images.size(0)) + # viz.image(images.data[random_batch_index].cpu().numpy()) + itr_count += 1 + + if itr_count % args.loss_reset_step == 0 and itr_count > 0: + loc_losses.reset() + cls_losses.reset() + losses.reset() + batch_time.reset() + print('Reset accumulators of ', args.exp_name,' at', itr_count*args.print_step) + itr_count = 0 + + if (iteration % args.eval_step == 0 or iteration == 5000) and iteration>0: + torch.cuda.synchronize() + tvs = time.perf_counter() + print('Saving state, iter:', iteration) + torch.save(net.state_dict(), args.save_root+'ssd300_ucf24_' + + repr(iteration) + '.pth') + + net.eval() # switch net to evaluation mode + mAP, ap_all, ap_strs = validate(args, net, val_data_loader, val_dataset, iteration, iou_thresh=args.iou_thresh) + + for ap_str in ap_strs: + print(ap_str) + log_file.write(ap_str+'\n') + ptr_str = '\nMEANAP:::=>'+str(mAP)+'\n' + print(ptr_str) + log_file.write(ptr_str) + + if args.visdom: + aps = [mAP] + for ap in ap_all: + aps.append(ap) + viz.line( + X=torch.ones((1, args.num_classes)).cpu() * iteration, + Y=torch.from_numpy(np.asarray(aps)).unsqueeze(0).cpu(), + win=val_lot, + update='append' + ) + net.train() # Switch net back to training mode + torch.cuda.synchronize() + t0 = time.perf_counter() + prt_str = '\nValidation TIME::: {:0.3f}\n\n'.format(t0-tvs) + print(prt_str) + log_file.write(ptr_str) + + log_file.close() + + +def validate(args, net, val_data_loader, val_dataset, iteration_num, iou_thresh=0.5): + """Test a SSD network on an image database.""" + print('Validating at ', iteration_num) + num_images = len(val_dataset) + num_classes = args.num_classes + + det_boxes = [[] for _ in range(len(CLASSES))] + gt_boxes = [] + print_time = True + batch_iterator = None + val_step = 100 + count = 0 + torch.cuda.synchronize() + ts = time.perf_counter() + + for val_itr in range(len(val_data_loader)): + if not batch_iterator: + batch_iterator = iter(val_data_loader) + + torch.cuda.synchronize() + t1 = time.perf_counter() + + images, targets, img_indexs = next(batch_iterator) + batch_size = images.size(0) + height, width = images.size(2), images.size(3) + + if args.cuda: + images = Variable(images.cuda(), volatile=True) + output = net(images) + + loc_data = output[0] + conf_preds = output[1] + prior_data = output[2] + + if print_time and val_itr%val_step == 0: + torch.cuda.synchronize() + tf = time.perf_counter() + print('Forward Time {:0.3f}'.format(tf-t1)) + for b in range(batch_size): + gt = targets[b].numpy() + gt[:,0] *= width + gt[:,2] *= width + gt[:,1] *= height + gt[:,3] *= height + gt_boxes.append(gt) + decoded_boxes = decode(loc_data[b].data, prior_data.data, args.cfg['variance']).clone() + conf_scores = net.softmax(conf_preds[b]).data.clone() + + for cl_ind in range(1, num_classes): + scores = conf_scores[:, cl_ind].squeeze() + c_mask = scores.gt(args.conf_thresh) # greater than minmum threshold + scores = scores[c_mask].squeeze() + # print('scores size',scores.size()) + if scores.dim() == 0: + # print(len(''), ' dim ==0 ') + det_boxes[cl_ind - 1].append(np.asarray([])) + continue + boxes = decoded_boxes.clone() + l_mask = c_mask.unsqueeze(1).expand_as(boxes) + boxes = boxes[l_mask].view(-1, 4) + # idx of highest scoring and non-overlapping boxes per class + ids, counts = nms(boxes, scores, args.nms_thresh, args.topk) # idsn - ids after nms + scores = scores[ids[:counts]].cpu().numpy() + boxes = boxes[ids[:counts]].cpu().numpy() + # print('boxes sahpe',boxes.shape) + boxes[:,0] *= width + boxes[:,2] *= width + boxes[:,1] *= height + boxes[:,3] *= height + + for ik in range(boxes.shape[0]): + boxes[ik, 0] = max(0, boxes[ik, 0]) + boxes[ik, 2] = min(width, boxes[ik, 2]) + boxes[ik, 1] = max(0, boxes[ik, 1]) + boxes[ik, 3] = min(height, boxes[ik, 3]) + + cls_dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=True) + + det_boxes[cl_ind-1].append(cls_dets) + count += 1 + if val_itr%val_step == 0: + torch.cuda.synchronize() + te = time.perf_counter() + print('im_detect: {:d}/{:d} time taken {:0.3f}'.format(count, num_images, te-ts)) + torch.cuda.synchronize() + ts = time.perf_counter() + if print_time and val_itr%val_step == 0: + torch.cuda.synchronize() + te = time.perf_counter() + print('NMS stuff Time {:0.3f}'.format(te - tf)) + print('Evaluating detections for itration number ', iteration_num) + return evaluate_detections(gt_boxes, det_boxes, CLASSES, iou_thresh=iou_thresh) + + +if __name__ == '__main__': + main() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..4d2f6d4 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,16 @@ +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count \ No newline at end of file diff --git a/utils/augmentations.py b/utils/augmentations.py new file mode 100644 index 0000000..0e4134a --- /dev/null +++ b/utils/augmentations.py @@ -0,0 +1,425 @@ + +""" Agumentation code for SSD network + +Original author: Ellis Brown, Max deGroot for VOC dataset +https://github.com/amdegroot/ssd.pytorch + +""" + +import torch +from torchvision import transforms +import cv2 +import numpy as np +import types +from numpy import random + + +def intersect(box_a, box_b): + max_xy = np.minimum(box_a[:, 2:], box_b[2:]) + min_xy = np.maximum(box_a[:, :2], box_b[:2]) + inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf) + return inter[:, 0] * inter[:, 1] + + +def jaccard_numpy(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: Multiple bounding boxes, Shape: [num_boxes,4] + box_b: Single bounding box, Shape: [4] + Return: + jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2]-box_a[:, 0]) * + (box_a[:, 3]-box_a[:, 1])) # [A,B] + area_b = ((box_b[2]-box_b[0]) * + (box_b[3]-box_b[1])) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +class Compose(object): + """Composes several augmentations together. + Args: + transforms (List[Transform]): list of transforms to compose. + Example: + >>> augmentations.Compose([ + >>> transforms.CenterCrop(10), + >>> transforms.ToTensor(), + >>> ]) + """ + + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, img, boxes=None, labels=None): + for t in self.transforms: + img, boxes, labels = t(img, boxes, labels) + return img, boxes, labels + + +class Lambda(object): + """Applies a lambda as a transform.""" + + def __init__(self, lambd): + assert isinstance(lambd, types.LambdaType) + self.lambd = lambd + + def __call__(self, img, boxes=None, labels=None): + return self.lambd(img, boxes, labels) + + +class ConvertFromInts(object): + def __call__(self, image, boxes=None, labels=None): + return image.astype(np.float32), boxes, labels + + +class SubtractMeans(object): + def __init__(self, mean): + self.mean = np.array(mean, dtype=np.float32) + + def __call__(self, image, boxes=None, labels=None): + image = image.astype(np.float32) + image -= self.mean + return image.astype(np.float32), boxes, labels + + +class ToAbsoluteCoords(object): + def __call__(self, image, boxes=None, labels=None): + height, width, channels = image.shape + boxes[:, 0] *= width + boxes[:, 2] *= width + boxes[:, 1] *= height + boxes[:, 3] *= height + + return image, boxes, labels + + +class ToPercentCoords(object): + def __call__(self, image, boxes=None, labels=None): + height, width, channels = image.shape + boxes[:, 0] /= width + boxes[:, 2] /= width + boxes[:, 1] /= height + boxes[:, 3] /= height + + return image, boxes, labels + + +class Resize(object): + def __init__(self, size=300): + self.size = size + + def __call__(self, image, boxes=None, labels=None): + image = cv2.resize(image, (self.size, + self.size)) + return image, boxes, labels + + +class RandomSaturation(object): + def __init__(self, lower=0.5, upper=1.5): + self.lower = lower + self.upper = upper + assert self.upper >= self.lower, "contrast upper must be >= lower." + assert self.lower >= 0, "contrast lower must be non-negative." + + def __call__(self, image, boxes=None, labels=None): + if random.randint(2): + image[:, :, 1] *= random.uniform(self.lower, self.upper) + + return image, boxes, labels + + +class RandomHue(object): + def __init__(self, delta=18.0): + assert delta >= 0.0 and delta <= 360.0 + self.delta = delta + + def __call__(self, image, boxes=None, labels=None): + if random.randint(2): + image[:, :, 0] += random.uniform(-self.delta, self.delta) + image[:, :, 0][image[:, :, 0] > 360.0] -= 360.0 + image[:, :, 0][image[:, :, 0] < 0.0] += 360.0 + return image, boxes, labels + + +class RandomLightingNoise(object): + def __init__(self): + self.perms = ((0, 1, 2), (0, 2, 1), + (1, 0, 2), (1, 2, 0), + (2, 0, 1), (2, 1, 0)) + + def __call__(self, image, boxes=None, labels=None): + if random.randint(2): + swap = self.perms[random.randint(len(self.perms))] + shuffle = SwapChannels(swap) # shuffle channels + image = shuffle(image) + return image, boxes, labels + + +class ConvertColor(object): + def __init__(self, current='BGR', transform='HSV'): + self.transform = transform + self.current = current + + def __call__(self, image, boxes=None, labels=None): + if self.current == 'BGR' and self.transform == 'HSV': + image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + elif self.current == 'HSV' and self.transform == 'BGR': + image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) + else: + raise NotImplementedError + return image, boxes, labels + + +class RandomContrast(object): + def __init__(self, lower=0.5, upper=1.5): + self.lower = lower + self.upper = upper + assert self.upper >= self.lower, "contrast upper must be >= lower." + assert self.lower >= 0, "contrast lower must be non-negative." + + # expects float image + def __call__(self, image, boxes=None, labels=None): + if random.randint(2): + alpha = random.uniform(self.lower, self.upper) + image *= alpha + return image, boxes, labels + + +class RandomBrightness(object): + def __init__(self, delta=32): + assert delta >= 0.0 + assert delta <= 255.0 + self.delta = delta + + def __call__(self, image, boxes=None, labels=None): + if random.randint(2): + delta = random.uniform(-self.delta, self.delta) + image += delta + return image, boxes, labels + + +class ToCV2Image(object): + def __call__(self, tensor, boxes=None, labels=None): + return tensor.cpu().numpy().astype(np.float32).transpose((1, 2, 0)), boxes, labels + + +class ToTensor(object): + def __call__(self, cvimage, boxes=None, labels=None): + return torch.from_numpy(cvimage.astype(np.float32)).permute(2, 0, 1), boxes, labels + + +class RandomSampleCrop(object): + """Crop + Arguments: + img (Image): the image being input during training + boxes (Tensor): the original bounding boxes in pt form + labels (Tensor): the class labels for each bbox + mode (float tuple): the min and max jaccard overlaps + Return: + (img, boxes, classes) + img (Image): the cropped image + boxes (Tensor): the adjusted bounding boxes in pt form + labels (Tensor): the class labels for each bbox + """ + def __init__(self): + self.sample_options = ( + # using entire original input image + None, + # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9 + (0.1, None), + (0.3, None), + (0.7, None), + (0.9, None), + # randomly sample a patch + (None, None), + ) + + def __call__(self, image, boxes=None, labels=None): + height, width, _ = image.shape + while True: + # randomly choose a mode + mode = random.choice(self.sample_options) + if mode is None: + return image, boxes, labels + + min_iou, max_iou = mode + if min_iou is None: + min_iou = float('-inf') + if max_iou is None: + max_iou = float('inf') + + # max trails (50) + for _ in range(50): + current_image = image + + w = random.uniform(0.3 * width, width) + h = random.uniform(0.3 * height, height) + + # aspect ratio constraint b/t .5 & 2 + if h / w < 0.5 or h / w > 2: + continue + + left = random.uniform(width - w) + top = random.uniform(height - h) + + # convert to integer rect x1,y1,x2,y2 + rect = np.array([int(left), int(top), int(left+w), int(top+h)]) + + # calculate IoU (jaccard overlap) b/t the cropped and gt boxes + overlap = jaccard_numpy(boxes, rect) + + # is min and max overlap constraint satisfied? if not try again + if overlap.min() < min_iou and max_iou < overlap.max(): + continue + + # cut the crop from the image + current_image = current_image[rect[1]:rect[3], rect[0]:rect[2], + :] + + # keep overlap with gt box IF center in sampled patch + centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0 + + # mask in all gt boxes that above and to the left of centers + m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1]) + + # mask in all gt boxes that under and to the right of centers + m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1]) + + # mask in that both m1 and m2 are true + mask = m1 * m2 + + # have any valid boxes? try again if not + if not mask.any(): + continue + + # take only matching gt boxes + current_boxes = boxes[mask, :].copy() + + # take only matching gt labels + current_labels = labels[mask] + + # should we use the box left and top corner or the crop's + current_boxes[:, :2] = np.maximum(current_boxes[:, :2], + rect[:2]) + # adjust to crop (by substracting crop's left,top) + current_boxes[:, :2] -= rect[:2] + + current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:], + rect[2:]) + # adjust to crop (by substracting crop's left,top) + current_boxes[:, 2:] -= rect[:2] + + return current_image, current_boxes, current_labels + + +class Expand(object): + def __init__(self, mean): + self.mean = mean + + def __call__(self, image, boxes, labels): + if random.randint(2): + return image, boxes, labels + + height, width, depth = image.shape + ratio = random.uniform(1, 4) + left = random.uniform(0, width*ratio - width) + top = random.uniform(0, height*ratio - height) + + expand_image = np.zeros( + (int(height*ratio), int(width*ratio), depth), + dtype=image.dtype) + expand_image[:, :, :] = self.mean + expand_image[int(top):int(top + height), + int(left):int(left + width)] = image + image = expand_image + + boxes = boxes.copy() + boxes[:, :2] += (int(left), int(top)) + boxes[:, 2:] += (int(left), int(top)) + + return image, boxes, labels + + +class RandomMirror(object): + def __call__(self, image, boxes, classes): + _, width, _ = image.shape + if random.randint(2): + image = image[:, ::-1] + boxes = boxes.copy() + boxes[:, 0::2] = width - boxes[:, 2::-2] + return image, boxes, classes + + +class SwapChannels(object): + """Transforms a tensorized image by swapping the channels in the order + specified in the swap tuple. + Args: + swaps (int triple): final order of channels + eg: (2, 1, 0) + """ + + def __init__(self, swaps): + self.swaps = swaps + + def __call__(self, image): + """ + Args: + image (Tensor): image tensor to be transformed + Return: + a tensor with channels swapped according to swap + """ + # if torch.is_tensor(image): + # image = image.data.cpu().numpy() + # else: + # image = np.array(image) + image = image[:, :, self.swaps] + return image + + +class PhotometricDistort(object): + def __init__(self): + self.pd = [ + RandomContrast(), + ConvertColor(transform='HSV'), + RandomSaturation(), + RandomHue(), + ConvertColor(current='HSV', transform='BGR'), + RandomContrast() + ] + self.rand_brightness = RandomBrightness() + self.rand_light_noise = RandomLightingNoise() + + def __call__(self, image, boxes, labels): + im = image.copy() + im, boxes, labels = self.rand_brightness(im, boxes, labels) + if random.randint(2): + distort = Compose(self.pd[:-1]) + else: + distort = Compose(self.pd[1:]) + im, boxes, labels = distort(im, boxes, labels) + return self.rand_light_noise(im, boxes, labels) + + +class SSDAugmentation(object): + def __init__(self, size=300, mean=(104, 117, 123)): + self.mean = mean + self.size = size + self.augment = Compose([ + ConvertFromInts(), + ToAbsoluteCoords(), + PhotometricDistort(), + Expand(self.mean), + RandomSampleCrop(), + RandomMirror(), + ToPercentCoords(), + Resize(self.size), + SubtractMeans(self.mean) + ]) + + def __call__(self, img, boxes, labels): + return self.augment(img, boxes, labels) diff --git a/utils/evaluation.py b/utils/evaluation.py new file mode 100644 index 0000000..1bd81f6 --- /dev/null +++ b/utils/evaluation.py @@ -0,0 +1,155 @@ + +""" Evaluation code based on VOC protocol + +Original author: Ellis Brown, Max deGroot for VOC dataset +https://github.com/amdegroot/ssd.pytorch + +Updated by Gurkirt Singh for ucf101-24 dataset + +""" + +import os +import numpy as np + +def voc_ap(rec, prec, use_07_metric=False): + """ ap = voc_ap(rec, prec, [use_07_metric]) + Compute VOC AP given precision and recall. + If use_07_metric is true, uses the + VOC 07 11 point method (default:False). + """ + # print('voc_ap() - use_07_metric:=' + str(use_07_metric)) + if use_07_metric: + # 11 point metric + ap = 0. + for t in np.arange(0., 1.1, 0.1): + if np.sum(rec >= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11. + else: + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.], rec, [1.])) + mpre = np.concatenate(([0.], prec, [0.])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap + + +def get_gt_of_cls(gt_boxes, cls): + cls_gt_boxes = [] + for i in range(len(gt_boxes)): + if gt_boxes[i,-1] == cls: + cls_gt_boxes.append(gt_boxes[i, :-1]) + return np.asarray(cls_gt_boxes) + + +def compute_iou(cls_gt_boxes, box): + ious = np.zeros(cls_gt_boxes.shape[0]) + + for m in range(ious.shape[0]): + gtbox = cls_gt_boxes[m] + + xmin = max(gtbox[0],box[0]) + ymin = max(gtbox[1], box[1]) + xmax = min(gtbox[2], box[2]) + ymax = min(gtbox[3], box[3]) + iw = np.maximum(xmax - xmin, 0.) + ih = np.maximum(ymax - ymin, 0.) + if iw>0 and ih>0: + intsc = iw*ih + else: + intsc = 0.0 + # print (intsc) + union = (gtbox[2] - gtbox[0]) * (gtbox[3] - gtbox[1]) + (box[2] - box[0]) * (box[3] - box[1]) - intsc + ious[m] = intsc/union + + return ious + +def evaluate_detections(gt_boxes, det_boxes, CLASSES=[], iou_thresh=0.5): + + ap_strs = [] + num_frames = len(gt_boxes) + print('Evaluating for ', num_frames, 'frames') + ap_all = np.zeros(len(CLASSES), dtype=np.float32) + for cls_ind, cls in enumerate(CLASSES): # loop over each class 'cls' + scores = np.zeros(num_frames * 220) + istp = np.zeros(num_frames * 220) + det_count = 0 + num_postives = 0.0 + for nf in range(num_frames): # loop over each frame 'nf' + # if len(gt_boxes[nf])>0 and len(det_boxes[cls_ind][nf]): + frame_det_boxes = np.copy(det_boxes[cls_ind][nf]) # get frame detections for class cls in nf + cls_gt_boxes = get_gt_of_cls(np.copy(gt_boxes[nf]), cls_ind) # get gt boxes for class cls in nf frame + num_postives += cls_gt_boxes.shape[0] + if frame_det_boxes.shape[0]>0: # check if there are dection for class cls in nf frame + argsort_scores = np.argsort(-frame_det_boxes[:,-1]) # sort in descending order + for i, k in enumerate(argsort_scores): # start from best scoring detection of cls to end + box = frame_det_boxes[k, :-1] # detection bounfing box + score = frame_det_boxes[k,-1] # detection score + ispositive = False # set ispostive to false every time + if cls_gt_boxes.shape[0]>0: # we can only find a postive detection + # if there is atleast one gt bounding for class cls is there in frame nf + iou = compute_iou(cls_gt_boxes, box) # compute IOU between remaining gt boxes + # and detection boxes + maxid = np.argmax(iou) # get the max IOU window gt index + if iou[maxid] >= iou_thresh: # check is max IOU is greater than detection threshold + ispositive = True # if yes then this is ture positive detection + cls_gt_boxes = np.delete(cls_gt_boxes, maxid, 0) # remove assigned gt box + scores[det_count] = score # fill score array with score of current detection + if ispositive: + istp[det_count] = 1 # set current detection index (det_count) + # to 1 if it is true postive example + det_count += 1 + if num_postives<1: + num_postives =1 + scores = scores[:det_count] + istp = istp[:det_count] + argsort_scores = np.argsort(-scores) # sort in descending order + istp = istp[argsort_scores] # reorder istp's on score sorting + fp = np.cumsum(istp == 0) # get false positives + tp = np.cumsum(istp == 1) # get true positives + fp = fp.astype(np.float64) + tp = tp.astype(np.float64) + recall = tp / float(num_postives) # compute recall + precision = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) # compute precision + cls_ap = voc_ap(recall, precision) # compute average precision using voc2007 metric + ap_all[cls_ind] = cls_ap + # print(cls_ind,CLASSES[cls_ind], cls_ap) + ap_str = str(CLASSES[cls_ind]) + ' : ' + str(num_postives) + ' : ' + str(det_count) + ' : ' + str(cls_ap) + ap_strs.append(ap_str) + + # print ('mean ap ', np.mean(ap_all)) + return np.mean(ap_all), ap_all, ap_strs + + +def save_detection_framewise(det_boxes, image_ids, iteration): + det_save_dir = '/mnt/mars-beta/gur-workspace/use-ssd-data/UCF101/detections/RGB-01-{:06d}/'.format(iteration) + print('Saving detections to', det_save_dir) + num_images = len(image_ids) + for idx in range(num_images): + img_id = image_ids[idx] + save_path = det_save_dir+img_id[:-5] + if not os.path.isdir(save_path): + os.system('mkdir -p '+save_path) + fid = open(det_save_dir+img_id+'.txt','w') + for cls_ind in range(len(det_boxes)): + frame_det_boxes = det_boxes[cls_ind][idx] + for d in range(len(frame_det_boxes)): + line = str(cls_ind+1) + for k in range(5): + line += ' {:f}'.format(frame_det_boxes[d,k]) + line += '\n' + fid.write(line) + fid.close() +