Skip to content

Commit

Permalink
Merge pull request #50 from tasostefas/skeleton_based_action_recognition
Browse files Browse the repository at this point in the history
Skeleton based action recognition
  • Loading branch information
negarhdr committed Sep 22, 2021
2 parents 0e5933e + a57db62 commit 4ea8a86
Show file tree
Hide file tree
Showing 32 changed files with 5,299 additions and 1 deletion.
1 change: 1 addition & 0 deletions .github/workflows/tests_sources.yml
Expand Up @@ -64,6 +64,7 @@ jobs:
- perception/object_detection_3d
- perception/pose_estimation
- perception/speech_recognition
- perception/skeleton_based_action_recognition
# - perception/object_tracking_3d
runs-on: ${{ matrix.os }}
steps:
Expand Down
Empty file added __init__.py
Empty file.
37 changes: 37 additions & 0 deletions docs/reference/engine-data.md
Expand Up @@ -164,3 +164,40 @@ The [PointCloudWithCalibration](#class_engine.data.PointCloudWithCalibration) cl
#### numpy()
Return a [NumPy](https://numpy.org)-compatible representation of data.
Given that *data* argument is already internally stored in [NumPy](https://numpy.org)-compatible format, this method is equivalent to `data()`.


### class engine.data.SkeletonSequence
Bases: `engine.data.Data`

A class used for representing a sequence of body skeletons in a video.

The [SkeletonSequence](#class_engine.data.SkeletonSequence) class has the following public methods:
#### SkeletonSequence(data=None)
Construct a new [SkeletonSequence](#class_engine.data.SkeletonSequence) object based on *data*.
*data* is expected to be a 5-D array that can be casted into a 5-D [NumPy](https://numpy.org) array.
The array's dimensions are defined as follows:

`N, C, T, V, M = array.shape()`,

- `N` is the number of samples,
- `C` is the number of channels for each of the body joints
- `T` is the number of skeletons in each sequence
- `V` is the number of body joints in each skeleton
- `M` is the number of persons (or skeletons) in each frame.

Accordingly, an array of size `[10, 3, 300, 18, 2]` contains `10` samples
each containing a sequence of `300` skeletons while each skeleton has `2` persons each of which has `18` joints
and each body joint has `3` channels.

#### data()
Return *data* argument.
Return type is float32 5-D [NumPy](https://numpy.org) array.

#### data(data)
Set the internal *data* argument.
*data* is expected to be a 5-D array that can be casted into a 5-D [NumPy](https://numpy.org) array, where the
dimensions can be organized as e.g. (num_samples, channels, frames, joints, persons).

#### numpy()
Return a [NumPy](https://numpy.org)-compatible representation of data.
Given that *data* argument is already internally stored in [NumPy](https://numpy.org)-compatible format, this method is equivalent to `data()`.
1 change: 1 addition & 0 deletions docs/reference/index.md
Expand Up @@ -33,4 +33,5 @@ Neither the copyright holder nor any applicable licensor will be liable for any
- [object_tracking_2d_fair_mot Module](object-tracking-2d-fair-mot.md)
- [object_tracking_3d_ab3dmot Module](object-tracking-3d-ab3dmot.md)
- [multilinear_compressive_learning Module](multilinear_compressive_learning.md)
- [skeleton_based_action_recognition](skeleton_based_action_recognition.md)
- [ROSBridge Package](rosbridge.md)
840 changes: 840 additions & 0 deletions docs/reference/skeleton_based_action_recognition.md

Large diffs are not rendered by default.

66 changes: 66 additions & 0 deletions src/opendr/engine/data.py
Expand Up @@ -472,3 +472,69 @@ def __str__(self):
:rtype: str
"""
return "Points: " + str(self.data) + "\nCalib:" + str(self.calib)


class SkeletonSequence(Data):
"""
A class used for representing a sequence of body skeletons in a video.
This class provides abstract methods for:
- returning a NumPy compatible representation of data (numpy())
"""

def __init__(self, data=None):
super().__init__(data)

if data is not None:
self.data = data

@property
def data(self):
"""
Getter of data. SkeletonSequence class returns a float32 5-D NumPy array.
:return: the actual data held by the object
:rtype: A float32 5-D NumPy array
"""
if self._data is None:
raise ValueError("SkeletonSequence is empty")

return self._data

@data.setter
def data(self, data):
"""
Setter for data.
:param: data to be used for creating a skeleton sequence
"""
# Convert input data to a NumPy array
# Note that will also fail for non-numeric data (which is expected)
data = np.asarray(data, dtype=np.float32)

# Check if the supplied vector is 5D, e.g. (num_samples, channels, frames, joints, persons)
if len(data.shape) != 5:
raise ValueError(
"Only 5-D arrays are supported by SkeletonSequence. Please supply a data object that can be casted "
"into a 5-D NumPy array.")

self._data = data

def numpy(self):
"""
Returns a NumPy-compatible representation of data.
:return: a NumPy-compatible representation of data
:rtype: numpy.ndarray
"""
# Since this class stores the data as NumPy arrays, we can directly return the data
return self.data

def __str__(self):
"""
Returns a human-friendly string-based representation of the data.
:return: a human-friendly string-based representation of the data
:rtype: str
"""
return str(self.data)
57 changes: 57 additions & 0 deletions src/opendr/perception/skeleton_based_action_recognition/README.md
@@ -0,0 +1,57 @@
# Skeleton-based Human Action Recognition
Python implementation of baseline method, ST-GCN [[1]](#1), and the proposed methods
TA-GCN [[2]](#2), ST-BLN [[3]](#3) and PST-GCN [[4]](#4) for Skeleton-based Human
Action Recognition.
The ST-GCN, TA-GCN and ST-BLN methods can be run and evaluated using spatio_temporal_gcn_learner by specifying the model name.
The PST-GCN method can be run and evaluated using progressive_spatio_temporal_gcn_learner.

This implementation is adapted from the [OpenMMLAB toolbox](
https://github.com/open-mmlab/mmskeleton/tree/b4c076baa9e02e69b5876c49fa7c509866d902c7).

## Datasets
### NTU-RGB+D-60
The NTU-RGB+D [[5]](#5) is the largest indoor-captured action recognition dataset which contains different data modalities
including the $3$D skeletons captured by Kinect-v2 camera. It contains 56,000 action clips from $60$ different action
classes and each action clip is captured by 3 cameras with 3 different views, and provides two different benchmarks,
cross-view (CV) and cross-subject (CS).
In this dataset, the number of joints in each skeleton is 25 and each sample has a sequence of 300 skeletons with 3
different channels each.
### Kinetics-400
The Kinetics-Skeleton [[6]](#6) dataset is a widely used action recognition dataset which contains the skeleton data of
300,000 video clips of 400 different actions collected from YouTube. In this dataset each skeleton in a sequence has 18
joints which are estimated by the OpenPose toolbox [[7]](#7) and each joint is featured by its 2D coordinates and
confidence score. We used the preprocessed data provided by [[1]](#1) and it can be downloaded from [here](
https://drive.google.com/file/d/103NOL9YYZSW1hLoWmYnv5Fs8mK-Ij7qb/view).

## References

<a id="1">[1]</a>
[Yan, S., Xiong, Y., & Lin, D. (2018, April). Spatial temporal graph convolutional networks for skeleton-based action
recognition. In Proceedings of the AAAI conference on artificial intelligence (Vol. 32, No. 1).](
https://arxiv.org/abs/1609.02907)

<a id="2">[2]</a>
[Heidari, N., & Iosifidis, A. (2020). Temporal Attention-Augmented Graph Convolutional Network for Efficient Skeleton-
Based Human Action Recognition. arXiv preprint arXiv: 2010.12221.](https://arxiv.org/abs/2010.12221)

<a id="3">[3]</a>
[Heidari, N., & Iosifidis, A. (2020). On the spatial attention in Spatio-Temporal Graph Convolutional Networks for
skeleton-based human action recognition. arXiv preprint arXiv: 2011.03833.](https://arxiv.org/abs/2011.03833)

<a id="4">[4]</a>
[Heidari, N., & Iosifidis, A. (2020). Progressive Spatio-Temporal Graph Convolutional Network for Skeleton-Based Human
Action Recognition. arXiv preprint arXiv:2011.05668.](https://arxiv.org/pdf/2011.05668.pdf)

<a id="5">[5]</a>
[Shahroudy, A., Liu, J., Ng, T. T., & Wang, G. (2016). Ntu rgb+ d: A large scale dataset for 3d human activity analysis.
In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 1010-1019).](
https://openaccess.thecvf.com/content_cvpr_2016/html/Shahroudy_NTU_RGBD_A_CVPR_2016_paper.html)

<a id="6">[6]</a>
[Kay, W., Carreira, J., Simonyan, K., Zhang, B., Hillier, C., Vijayanarasimhan, S., ... & Zisserman, A. (2017).
The kinetics human action video dataset. arXiv preprint arXiv:1705.06950.](https://arxiv.org/pdf/1705.06950.pdf)

<a id="7">[7]</a>
[Cao, Z., Simon, T., Wei, S. E., & Sheikh, Y. (2017). Realtime multi-person 2d pose estimation using part affinity
fields. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 7291-7299).](
https://openaccess.thecvf.com/content_cvpr_2017/html/Cao_Realtime_Multi-Person_2D_CVPR_2017_paper.html)
Empty file.
@@ -0,0 +1,201 @@
"""
Modified based on: https://github.com/open-mmlab/mmskeleton
"""

import numpy as np
import pickle
from torch.utils.data import Dataset
import random
from tqdm import tqdm


def auto_pading(data_numpy, size, random_pad=False):
C, T, V, M = data_numpy.shape
if T < size:
begin = random.randint(0, size - T) if random_pad else 0
data_numpy_paded = np.zeros((C, size, V, M))
data_numpy_paded[:, begin:begin + T, :, :] = data_numpy
return data_numpy_paded
else:
return data_numpy


def random_choose(data_numpy, size, auto_pad=True):
C, T, V, M = data_numpy.shape
if T == size:
return data_numpy
elif T < size:
if auto_pad:
return auto_pading(data_numpy, size, random_pad=True)
else:
return data_numpy
else:
begin = random.randint(0, T - size)
return data_numpy[:, begin:begin + size, :, :]


def random_move(data_numpy,
angle_candidate=[-10., -5., 0., 5., 10.],
scale_candidate=[0.9, 1.0, 1.1],
transform_candidate=[-0.2, -0.1, 0.0, 0.1, 0.2],
move_time_candidate=[1]):
# input: C,T,V,M
C, T, V, M = data_numpy.shape
move_time = random.choice(move_time_candidate)
node = np.arange(0, T, T * 1.0 / move_time).round().astype(int)
node = np.append(node, T)
num_node = len(node)

A = np.random.choice(angle_candidate, num_node)
S = np.random.choice(scale_candidate, num_node)
T_x = np.random.choice(transform_candidate, num_node)
T_y = np.random.choice(transform_candidate, num_node)

a = np.zeros(T)
s = np.zeros(T)
t_x = np.zeros(T)
t_y = np.zeros(T)

# linspace
for i in range(num_node - 1):
a[node[i]:node[i + 1]] = np.linspace(
A[i], A[i + 1], node[i + 1] - node[i]) * np.pi / 180
s[node[i]:node[i + 1]] = np.linspace(S[i], S[i + 1],
node[i + 1] - node[i])
t_x[node[i]:node[i + 1]] = np.linspace(T_x[i], T_x[i + 1],
node[i + 1] - node[i])
t_y[node[i]:node[i + 1]] = np.linspace(T_y[i], T_y[i + 1],
node[i + 1] - node[i])

theta = np.array([[np.cos(a) * s, -np.sin(a) * s],
[np.sin(a) * s, np.cos(a) * s]])

# perform transformation
for i_frame in range(T):
xy = data_numpy[0:2, i_frame, :, :]
new_xy = np.dot(theta[:, :, i_frame], xy.reshape(2, -1))
new_xy[0] += t_x[i_frame]
new_xy[1] += t_y[i_frame]
data_numpy[0:2, i_frame, :, :] = new_xy.reshape(2, V, M)

return data_numpy


def random_shift(data_numpy):
C, T, V, M = data_numpy.shape
data_shift = np.zeros(data_numpy.shape)
valid_frame = (data_numpy != 0).sum(axis=3).sum(axis=2).sum(axis=0) > 0
begin = valid_frame.argmax()
end = len(valid_frame) - valid_frame[::-1].argmax()

size = end - begin
bias = random.randint(0, T - size)
data_shift[:, bias:bias + size, :, :] = data_numpy[:, begin:end, :, :]

return data_shift


class Feeder(Dataset):
def __init__(self, data_path, label_path,
random_choose=False, random_shift=False, random_move=False,
window_size=-1, normalization=False, use_mmap=True, skeleton_data_type='joint', data_name='nturgbd'):
"""
:param data_path:
:param label_path:
:param random_choose: If true, randomly choose a portion of the input sequence
:param random_shift: If true, randomly pad zeros at the begining or end of sequence
:param random_move:
:param window_size: The length of the output sequence
:param normalization: If true, normalize input sequence
:param use_mmap: If true, use mmap mode to load data, which can save the running memory
:param skeleton_data_type: The type of features that is needed to be generated.
:param data_name: The name of dataset
"""

self.data_path = data_path
self.label_path = label_path
self.random_choose = random_choose
self.random_shift = random_shift
self.random_move = random_move
self.window_size = window_size
self.normalization = normalization
self.use_mmap = use_mmap
self.skeleton_data_type = skeleton_data_type
self.data_name = data_name
self.load_data()
if normalization:
self.get_mean_map()

def load_data(self):
# data: N C V T M
try:
with open(self.label_path) as f:
self.sample_name, self.label = pickle.load(f)
except Exception:
# for pickle file from python2
with open(self.label_path, 'rb') as f:
self.sample_name, self.label = pickle.load(f, encoding='latin1')

if self.skeleton_data_type not in ['joint', 'bone', 'motion']:
raise ValueError('skeleton_data_type should be a str named: joint or bone or motion')
# load joint data
if self.use_mmap:
self.data = np.load(self.data_path, mmap_mode='r')
else:
self.data = np.load(self.data_path)

# if we need bone or motion data instead of joints
if self.data_name == 'nturgbd':
joint_pairs = ((0, 1), (1, 20), (2, 20), (3, 2), (4, 20), (5, 4), (6, 5), (7, 6), (8, 20), (9, 8),
(10, 9), (11, 10), (12, 0), (13, 12), (14, 13), (15, 14), (16, 0), (17, 16), (18, 17),
(19, 18), (21, 22), (20, 20), (22, 7), (23, 24), (24, 11))
elif self.data_name == 'kinetics':
joint_pairs = ((0, 0), (1, 0), (2, 1), (3, 2), (4, 3), (5, 1), (6, 5), (7, 6), (8, 2), (9, 8), (10, 9),
(11, 5), (12, 11), (13, 12), (14, 0), (15, 0), (16, 14), (17, 15))
N, C, T, V, M = self.data.shape
if self.skeleton_data_type == 'bone':
bones = np.zeros((N, C, T, V, M))
for v1, v2 in tqdm(joint_pairs):
bones[:, :, :, v1, :] = self.data[:, :, :, v1, :] - self.data[:, :, :, v2, :]
self.data = bones
elif self.skeleton_data_type == 'motion':
motion = np.zeros((N, C, T, V, M))
for t in tqdm(range(T - 1)):
motion[:, :, t, :, :] = self.data[:, :, t + 1, :, :] - self.data[:, :, t, :, :]
motion[:, :, T - 1, :, :] = 0
self.data = motion

def get_mean_map(self):
data = self.data
N, C, T, V, M = data.shape
self.mean_map = data.mean(axis=2, keepdims=True).mean(axis=4, keepdims=True).mean(axis=0)
self.std_map = data.transpose((0, 2, 4, 1, 3)).reshape((N * T * M, C * V)).std(axis=0).reshape((C, 1, V, 1))

def __len__(self):
return len(self.label)

def __iter__(self):
return self

def __getitem__(self, index):
data_numpy = self.data[index]
label = self.label[index]
data_numpy = np.array(data_numpy)

if self.normalization:
data_numpy = (data_numpy - self.mean_map) / self.std_map
if self.random_shift:
data_numpy = random_shift(data_numpy)
if self.random_choose:
data_numpy = random_choose(data_numpy, self.window_size)
elif self.window_size > 0:
data_numpy = auto_pading(data_numpy, self.window_size)
if self.random_move:
data_numpy = random_move(data_numpy)

return data_numpy, label, index

def top_k(self, score, top_k):
rank = score.argsort()
hit_top_k = [l in rank[i, -top_k:] for i, l in enumerate(self.label)]
return sum(hit_top_k) * 1.0 / len(hit_top_k)

0 comments on commit 4ea8a86

Please sign in to comment.