In [1]:
! pip freeze

actionlib==1.12.0
angles==1.9.12
attrs==19.3.0
backcall==0.1.0
bleach==1.5.0
bondpy==1.8.3
camera-calibration==1.14.0
camera-calibration-parsers==1.11.13
catkin==0.7.20
controller-manager==0.16.0
controller-manager-msgs==0.16.0
cv-bridge==1.13.0
decorator==4.4.2
defusedxml==0.6.0
diagnostic-analysis==1.9.3
diagnostic-common-diagnostics==1.9.3
diagnostic-updater==1.9.3
dynamic-reconfigure==1.6.0
entrypoints==0.3
gazebo-plugins==2.8.6
gazebo-ros==2.8.6
gencpp==0.6.2
geneus==2.2.6
genlisp==0.4.16
genmsg==0.5.12
gennodejs==2.0.1
genpy==0.6.9
html5lib==0.9999999
image-geometry==1.13.0
importlib-metadata==1.5.0
interactive-markers==1.11.4
ipykernel==5.1.4
ipython==7.13.0
ipython-genutils==0.2.0
ipywidgets==7.5.1
jedi==0.16.0
Jinja2==2.11.1
joint-state-publisher==1.12.14
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==6.0.0
jupyter-console==6.1.0
jupyter-core==4.6.3
kdl-parser-py==1.13.1
Keras==2.0.8
laser-geometry==1.6.4
Markdown==3.2.1
MarkupS

# Load torch models

In [2]:
load_torch = False

In [3]:
if load_torch:
    import torch

    from models.cls_model import ClassifierNet
    from models.dla import get_pose_net
    from utils.transforms import decode_results

    model = get_pose_net(34, heads={'hm': 1, 'wh': 2}, head_conv=-1)
    state_dict = torch.load("../models/last512_map039350.pth")
    model.load_state_dict(state_dict)
    model.cpu()
    model.eval()


    clsnet = ClassifierNet()
    state_dict = torch.load("../models/best_cls.pth")
    clsnet.load_state_dict(state_dict)
    clsnet=clsnet.cpu()
    clsnet.eval()


# Keras approach

In [4]:
from keras.layers import Conv2D, BatchNormalization, Input
from keras.activations import relu
from keras.models import Sequential, Model
from keras.layers.merge import add
from keras.optimizers import SGD

from keras import backend as K
from keras.layers import Layer
    
import numpy as np

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [5]:
class ReLU(Layer):
    def __init__(self, **kwargs):
        super(ReLU, self).__init__(**kwargs)

    def build(self, input_shape):
        super(ReLU, self).build(input_shape)

    def call(self, x):
        return K.relu(x)

    def compute_output_shape(self, input_shape):
        return input_shape
    
class HSwish(Layer):
    def __init__(self, **kwargs):
        super(HSwish, self).__init__(**kwargs)

    def build(self, input_shape):
        super(HSwish, self).build(input_shape)

    def call(self, x):
        six = K.ones_like(x)*6
        return x* K.minimum(K.relu(x+3),six)/6

    def compute_output_shape(self, input_shape):
        return input_shape

In [6]:
class DlaKeras():
    def __init__(self, input_shape = (512,512,3),  
                 levels=[1, 1, 1, 2, 2, 1],
                 planes=[16, 32, 64, 128, 256, 512],
                 data_format = "NHWC", activation_function="relu"):
        self.norm_axis = 1 if data_format =="NCHW" else -1
        self.input = Input(shape = input_shape)
        self.activation_function = activation_function
        
        self.base_layer = self._make_simple_block(planes[0], kernel_size=7)
        self.level0 = self._make_conv_level(planes[0], levels[0], stride=1)
        self.level1 = self._make_conv_level(planes[1], levels[1], stride=2)
        
        
    def _make_simple_block(self, planes, kernel_size=3, stride = 1):
        
        conv = Conv2D(filters=planes, kernel_size=kernel_size, padding='same', strides=(stride,stride), use_bias=False)
        bn = BatchNormalization(axis=self.norm_axis)
        if self.activation_function == "hswish":
            activation = HSwish()
        else:
            activation = ReLU()
        
        layers = [conv,bn,activation]
        
        return layers
    
    def _make_conv_level(self, planes, levels, stride = 1):
        layers = []
        for i in range(levels):
            layers.extend(self._make_simple_block(planes, stride = stride))
        return layers
            
    def _build_list(self, x, layers):
        for l in layers:
            x = l(x)
        return x
        
    def build(self):
        inputs = self.input

        x = self._build_list(inputs, self.base_layer)
        x = self._build_list(x, self.level0)
        x = self._build_list(x, self.level1)
        
        return Model(inputs = inputs, outputs=x)

# Testine inference

In [7]:
# Test output
m = DlaKeras(activation_function="hswish").build()
optimizer = SGD(lr=0.1)
m.compile(optimizer=optimizer, loss='mse')

dummy = np.zeros((1,512,512,3), dtype=float)
dummy_shape=m.predict(dummy).shape
print(dummy_shape)
print(m.summary())

def dummy_train(model:Model,n=50, epochs=2):
    rand = np.random.rand(1,512,512,1)
    dummy_inputs = np.repeat(np.repeat(rand, n, axis=0),3,axis=-1)
    dummy_outputs = np.zeros((n,dummy_shape[1],dummy_shape[2],dummy_shape[3]))
    model.fit(dummy_inputs, dummy_outputs, batch_size=4, epochs=epochs)
    
print("Training:")
dummy_train(m)

(1, 256, 256, 32)
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 512, 512, 3)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 512, 512, 16)      2352      
_________________________________________________________________
batch_normalization_1 (Batch (None, 512, 512, 16)      64        
_________________________________________________________________
h_swish_1 (HSwish)           (None, 512, 512, 16)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 512, 512, 16)      2304      
_________________________________________________________________
batch_normalization_2 (Batch (None, 512, 512, 16)      64        
_________________________________________________________________
h_swish_2 (HSwish)           (None, 512, 512, 16)      0  