Skip to content

Commit

Permalink
Merge pull request #339 from Manojkumarmuru/efficientpose
Browse files Browse the repository at this point in the history
Updated unit tests
  • Loading branch information
oarriaga committed May 13, 2024
2 parents 18b0962 + 8b7f67b commit 1d00695
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 13 deletions.
60 changes: 52 additions & 8 deletions docs/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,15 @@
image.calculate_image_center,
image.get_affine_transform,
image.get_scaling_factor,
image.scale_resize
image.scale_resize,
image.compute_resizing_shape,
image.pad_image,
image.equalize_histogram,
image.invert_colors,
image.posterize,
image.solarize,
image.cutout,
image.add_gaussian_noise,
],
},

Expand Down Expand Up @@ -284,7 +292,8 @@
standard.max_pooling_2d,
standard.predict,
standard.predict_with_nones,
standard.weighted_average
standard.weighted_average,
standard.compute_common_row_indices,
],
},

Expand Down Expand Up @@ -346,7 +355,15 @@
{
'page': 'models/pose_estimation.md',
'functions': [
models.HigherHRNet
models.HigherHRNet,
models.EfficientPosePhi0,
models.EfficientPosePhi1,
models.EfficientPosePhi2,
models.EfficientPosePhi3,
models.EfficientPosePhi4,
models.EfficientPosePhi5,
models.EfficientPosePhi6,
models.EfficientPosePhi7,
],
},

Expand All @@ -358,7 +375,10 @@
models.layers.Conv2DNormalization,
models.layers.SubtractScalar,
models.layers.ExpectedValue2D,
models.layers.ExpectedDepth
models.layers.ExpectedDepth,
models.layers.ReduceMean,
models.layers.Sigmoid,
models.layers.Add,
],
},

Expand Down Expand Up @@ -451,7 +471,16 @@
processors.FlipLeftRightImage,
processors.DivideStandardDeviationImage,
processors.ScaledResize,
processors.BufferImages
processors.BufferImages,
processors.PadImage,
processors.EqualizeHistogram,
processors.InvertColors,
processors.Posterize,
processors.Solarize,
processors.SharpenImage,
processors.Cutout,
processors.AddGaussianNoise,

]
},

Expand Down Expand Up @@ -580,7 +609,16 @@
'classes': [
processors.SolvePNP,
processors.SolveChangingObjectPnPRANSAC,
processors.Translation3DFromBoxWidth
processors.Translation3DFromBoxWidth,
processors.MatchPoses,
processors.RotationMatrixToAxisAngle,
processors.ConcatenatePoses,
processors.ConcatenateScale,
processors.AugmentPose6D,
processors.ToPose6D,
processors.BoxesWithOneHotVectorsToPose6D,
processors.BoxesToPose6D,
processors.BoxesWithClassArgToPose6D,
]
},

Expand Down Expand Up @@ -632,7 +670,8 @@
processors.PrintTopics,
processors.FloatToBoolean,
processors.NoneConverter,
processors.AveragePredictions
processors.AveragePredictions,
processors.ComputeCommonRowIndices,
]
},

Expand Down Expand Up @@ -731,7 +770,12 @@
pipelines.HeadPoseKeypointNet2D32,
pipelines.SingleInstancePIX2POSE6D,
pipelines.MultiInstancePIX2POSE6D,
pipelines.MultiInstanceMultiClassPIX2POSE6D
pipelines.MultiInstanceMultiClassPIX2POSE6D,
pipelines.AugmentColor,
pipelines.AugmentEfficientPose,
pipelines.EfficientDetPreprocess,
pipelines.EfficientDetPostprocess,
pipelines.EstimateEfficientPose,
]
},

Expand Down
113 changes: 110 additions & 3 deletions paz/models/classification/xception.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras import Model
from tensorflow.keras.regularizers import l2
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import get_file
from keras import layers


URL = 'https://github.com/oarriaga/altamira-data/releases/download/v0.6/'
Expand Down Expand Up @@ -84,6 +84,112 @@ def build_xception(
return model


def build_minixception(input_shape, num_classes, l2_reg=0.01):
"""Function for instantiating an Mini-Xception model.
# Arguments
input_shape: List corresponding to the input shape
of the model.
num_classes: Integer.
l2_reg. Float. L2 regularization used
in the convolutional kernels.
# Returns
Tensorflow-Keras model.
"""

regularization = l2(l2_reg)

# base
img_input = Input(input_shape)
x = Conv2D(5, (3, 3), strides=(1, 1), kernel_regularizer=regularization,
use_bias=False)(img_input)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = Conv2D(8, (3, 3), strides=(1, 1), kernel_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)

# module 1
residual = Conv2D(16, (1, 1), strides=(2, 2),
padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)

x = SeparableConv2D(16, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(16, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)

x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = layers.add([x, residual])

# module 2
residual = Conv2D(32, (1, 1), strides=(2, 2),
padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)

x = SeparableConv2D(32, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(32, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)

x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = layers.add([x, residual])

# module 3
residual = Conv2D(64, (1, 1), strides=(2, 2),
padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)

x = SeparableConv2D(64, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(64, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)

x = MaxPooling2D((3, 3), strides=(2, 2), padding='same')(x)
x = layers.add([x, residual])

# module 4
residual = Conv2D(128, (1, 1), strides=(1, 1),
padding='same', use_bias=False)(x)
residual = BatchNormalization()(residual)

x = SeparableConv2D(128, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)
x = Activation('relu')(x)
x = SeparableConv2D(128, (3, 3), padding='same',
depthwise_regularizer=regularization,
use_bias=False)(x)
x = BatchNormalization()(x)

x = layers.add([x, residual])

x = Conv2D(num_classes, (3, 3), padding='same')(x)
x = GlobalAveragePooling2D()(x)
output = Activation('softmax', name='predictions')(x)

model = Model(img_input, output)
return model


def MiniXception(input_shape, num_classes, weights=None):
"""Build MiniXception (see references).
Expand All @@ -101,9 +207,10 @@ def MiniXception(input_shape, num_classes, weights=None):
Gender Classification](https://arxiv.org/abs/1710.07557)
"""
if weights == 'FER':
filename = 'fer2013_mini_XCEPTION.119-0.65.hdf5'
filename = 'fer2013_mini_XCEPTION.hdf5'
path = get_file(filename, URL + filename, cache_subdir='paz/models')
model = load_model(path)
model = build_minixception(input_shape, num_classes)
model.load_weights(path)
else:
stem_kernels = [32, 64]
block_data = [128, 128, 256, 256, 512, 512, 1024]
Expand Down
13 changes: 13 additions & 0 deletions paz/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ def compute_output_shape(self, input_shape):


class ReduceMean(Layer):
"""Wraps tensorflow's `reduce_mean` function into a keras layer.
# Arguments
axes: List of integers. Axes along which mean is to be calculated.
keepdims: Bool, whether to presere the dimension or not.
"""

def __init__(self, axes=[1, 2], keepdims=True):
self.axes = axes
self.keepdims = keepdims
Expand All @@ -130,6 +137,9 @@ def call(self, x):


class Sigmoid(Layer):
"""Wraps tensorflow's `sigmoid` function into a keras layer.
"""

def __init__(self):
super(Sigmoid, self).__init__()

Expand All @@ -138,6 +148,9 @@ def call(self, x):


class Add(Layer):
"""Wraps tensorflow's `add` function into a keras layer.
"""

def __init__(self):
super(Add, self).__init__()

Expand Down
3 changes: 3 additions & 0 deletions paz/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@
from .pose import PIX2YCBTools6D
from .pose import AugmentColor
from .pose import AugmentEfficientPose
from .pose import EfficientPosePreprocess
from .pose import EfficientPosePostprocess
from .pose import EstimateEfficientPose

from .masks import RGBMaskToImagePoints2D
from .masks import RGBMaskToObjectPoints3D
Expand Down
1 change: 0 additions & 1 deletion tests/paz/pipelines/classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def labeled_emotion():
return 'happy'


@pytest.mark.skip()
def test_MiniXceptionFER(image_with_face, labeled_emotion, labeled_scores):
classifier = MiniXceptionFER()
inferences = classifier(image_with_face)
Expand Down
1 change: 0 additions & 1 deletion tests/paz/pipelines/detection_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,6 @@ def test_HaarCascadeFrontalFace(image_with_faces, boxes_HaarCascadeFace):
assert_inferences(detector, image_with_faces, boxes_HaarCascadeFace)


@pytest.mark.skip()
def test_DetectMiniXceptionFER(image_with_faces, boxes_MiniXceptionFER):
cv2.ocl.setUseOpenCL(False)
cv2.setNumThreads(1)
Expand Down

0 comments on commit 1d00695

Please sign in to comment.