diff --git a/keras_model.cc b/keras_model.cc index 24d5563..8ee5b3a 100644 --- a/keras_model.cc +++ b/keras_model.cc @@ -63,6 +63,15 @@ bool KerasLayerActivation::LoadLayer(std::ifstream* file) break; case kSoftPlus: activation_type_ = kSoftPlus; + break; + case kHardSigmoid: + activation_type_ = kHardSigmoid; + break; + case kSigmoid: + activation_type_ = kSigmoid; + break; + case kTanh: + activation_type_ = kTanh; break; default: KASSERT(false, "Unsupported activation type %d", activation); @@ -97,6 +106,38 @@ bool KerasLayerActivation::Apply(Tensor* in, Tensor* out) out->data_[i] = std::log(1.0 + std::exp(out->data_[i])); } break; + case kHardSigmoid: + for (size_t i = 0; i < out->data_.size(); i++) + { + float x = (out->data_[i] * 0.2) + 0.5; + + if ( x <= 0 ) + out->data_[i] = 0.0; + else if ( x >= 1 ) + out->data_[i] = 1.0; + else + out->data_[i] = x; + } + break; + case kSigmoid: + for (size_t i = 0; i < out->data_.size(); i++) + { + float & x = out->data_[i]; + + if ( x >= 0 ) + out->data_[i] = 1.0/(1.0 + std::exp(-x)); + else { + float z = std::exp(x); + out->data_[i] = z/(1.0+z); + } + } + break; + case kTanh: + for (size_t i = 0; i < out->data_.size(); i++) + { + out->data_[i] = std::tanh(out->data_[i]); + } + break; default: break; } @@ -357,6 +398,251 @@ bool KerasLayerMaxPooling2d::Apply(Tensor* in, Tensor* out) } +bool KerasLayerLSTM::LoadLayer(std::ifstream* file) +{ + KASSERT(file, "Invalid file stream"); + + unsigned int wi_rows = 0; + KASSERT(ReadUnsignedInt(file, &wi_rows), "Expected Wi rows"); + KASSERT(wi_rows > 0, "Invalid Wi # rows"); + + unsigned int wi_cols = 0; + KASSERT(ReadUnsignedInt(file, &wi_cols), "Expected Wi cols"); + KASSERT(wi_cols > 0, "Invalid Wi shape"); + + unsigned int ui_rows = 0; + KASSERT(ReadUnsignedInt(file, &ui_rows), "Expected Ui rows"); + KASSERT(ui_rows > 0, "Invalid Ui # rows"); + + unsigned int ui_cols = 0; + KASSERT(ReadUnsignedInt(file, &ui_cols), "Expected Ui cols"); + KASSERT(ui_cols > 0, "Invalid Ui shape"); + + unsigned int bi_shape = 0; + KASSERT(ReadUnsignedInt(file, &bi_shape), "Expected bi shape"); + KASSERT(bi_shape > 0, "Invalid bi shape"); + + unsigned int wf_rows = 0; + KASSERT(ReadUnsignedInt(file, &wf_rows), "Expected Wf rows"); + KASSERT(wf_rows > 0, "Invalid Wf # rows"); + + unsigned int wf_cols = 0; + KASSERT(ReadUnsignedInt(file, &wf_cols), "Expected Wf cols"); + KASSERT(wf_cols > 0, "Invalid Wf shape"); + + unsigned int uf_rows = 0; + KASSERT(ReadUnsignedInt(file, &uf_rows), "Expected Uf rows"); + KASSERT(uf_rows > 0, "Invalid Uf # rows"); + + unsigned int uf_cols = 0; + KASSERT(ReadUnsignedInt(file, &uf_cols), "Expected Uf cols"); + KASSERT(uf_cols > 0, "Invalid Uf shape"); + + unsigned int bf_shape = 0; + KASSERT(ReadUnsignedInt(file, &bf_shape), "Expected bf shape"); + KASSERT(bf_shape > 0, "Invalid bf shape"); + + unsigned int wc_rows = 0; + KASSERT(ReadUnsignedInt(file, &wc_rows), "Expected Wc rows"); + KASSERT(wc_rows > 0, "Invalid Wc # rows"); + + unsigned int wc_cols = 0; + KASSERT(ReadUnsignedInt(file, &wc_cols), "Expected Wc cols"); + KASSERT(wc_cols > 0, "Invalid Wc shape"); + + unsigned int uc_rows = 0; + KASSERT(ReadUnsignedInt(file, &uc_rows), "Expected Uc rows"); + KASSERT(uc_rows > 0, "Invalid Uc # rows"); + + unsigned int uc_cols = 0; + KASSERT(ReadUnsignedInt(file, &uc_cols), "Expected Uc cols"); + KASSERT(uc_cols > 0, "Invalid Uc shape"); + + unsigned int bc_shape = 0; + KASSERT(ReadUnsignedInt(file, &bc_shape), "Expected bc shape"); + KASSERT(bc_shape > 0, "Invalid bc shape"); + + unsigned int wo_rows = 0; + KASSERT(ReadUnsignedInt(file, &wo_rows), "Expected Wo rows"); + KASSERT(wo_rows > 0, "Invalid Wo # rows"); + + unsigned int wo_cols = 0; + KASSERT(ReadUnsignedInt(file, &wo_cols), "Expected Wo cols"); + KASSERT(wo_cols > 0, "Invalid Wo shape"); + + unsigned int uo_rows = 0; + KASSERT(ReadUnsignedInt(file, &uo_rows), "Expected Uo rows"); + KASSERT(uo_rows > 0, "Invalid Uo # rows"); + + unsigned int uo_cols = 0; + KASSERT(ReadUnsignedInt(file, &uo_cols), "Expected Uo cols"); + KASSERT(uo_cols > 0, "Invalid Uo shape"); + + unsigned int bo_shape = 0; + KASSERT(ReadUnsignedInt(file, &bo_shape), "Expected bo shape"); + KASSERT(bo_shape > 0, "Invalid bo shape"); + + + + /* Load Input Weights and Biases */ + Wi_.Resize(wi_rows, wi_cols); + KASSERT(ReadFloats(file, Wi_.data_.data(), wi_rows * wi_cols), "Expected Wi weights"); + + Ui_.Resize(ui_rows, ui_cols); + KASSERT(ReadFloats(file, Ui_.data_.data(), ui_rows * ui_cols), "Expected Ui weights"); + + bi_.Resize(1, bi_shape); + KASSERT(ReadFloats(file, bi_.data_.data(), bi_shape), "Expected bi biases"); + + + /* Load Forget Weights and Biases */ + Wf_.Resize(wf_rows, wf_cols); + KASSERT(ReadFloats(file, Wf_.data_.data(), wf_rows * wf_cols), "Expected Wf weights"); + + Uf_.Resize(uf_rows, uf_cols); + KASSERT(ReadFloats(file, Uf_.data_.data(), uf_rows * uf_cols), "Expected Uf weights"); + + bf_.Resize(1, bf_shape); + KASSERT(ReadFloats(file, bf_.data_.data(), bf_shape), "Expected bf biases"); + + + /* Load State Weights and Biases */ + Wc_.Resize(wc_rows, wc_cols); + KASSERT(ReadFloats(file, Wc_.data_.data(), wc_rows * wc_cols), "Expected Wc weights"); + + Uc_.Resize(uc_rows, uc_cols); + KASSERT(ReadFloats(file, Uc_.data_.data(), uc_rows * uc_cols), "Expected Uc weights"); + + bc_.Resize(1, bc_shape); + KASSERT(ReadFloats(file, bc_.data_.data(), bc_shape), "Expected bc biases"); + + + /* Load Output Weights and Biases */ + Wo_.Resize(wo_rows, wo_cols); + KASSERT(ReadFloats(file, Wo_.data_.data(), wo_rows * wo_cols), "Expected Wo weights"); + + Uo_.Resize(uo_rows, uo_cols); + KASSERT(ReadFloats(file, Uo_.data_.data(), uo_rows * uo_cols), "Expected Uo weights"); + + bo_.Resize(1, bo_shape); + KASSERT(ReadFloats(file, bo_.data_.data(), bo_shape), "Expected bo biases"); + + + KASSERT(innerActivation_.LoadLayer(file), "Failed to load inner activation"); + KASSERT(activation_.LoadLayer(file), "Failed to load activation"); + + unsigned int return_sequences = 0; + KASSERT(ReadUnsignedInt(file, &return_sequences), "Expected return_sequences param"); + returnSequences = return_sequences; + + return true; +} + +bool KerasLayerLSTM::Apply(Tensor* in, Tensor* out) +{ + /*lets assume bo always keeps the output shape and we will always recive one single sample */ + int outputDim = bo_.dims_[1]; + Tensor ht_1 = Tensor(1, outputDim); + Tensor ct_1 = Tensor(1, outputDim); + + K::fill(&ht_1, 0.0); + K::fill(&ct_1, 0.0); + + int steps = in->dims_[0]; + + Tensor outputs, lastOutput; + + if ( returnSequences ){ + outputs.dims_ = {steps, outputDim}; + outputs.data_.reserve(steps*outputDim); + } + + for ( int s = 0; s < steps; s++ ){ + Tensor x = K::select(in, s); + +// bool success = + KASSERT(step(&x, &lastOutput, &ht_1, &ct_1), "Failed to execute step"); + + if ( returnSequences ){ + outputs.data_.insert(outputs.data_.end(), lastOutput.data_.begin(), lastOutput.data_.end()); + } + } + + if (returnSequences) + *out = outputs; + else + *out = lastOutput; + + return true; +} + +bool KerasLayerEmbedding::LoadLayer(std::ifstream* file) +{ + KASSERT(file, "Invalid file stream"); + + unsigned int weights_rows = 0; + KASSERT(ReadUnsignedInt(file, &weights_rows), "Expected weight rows"); + KASSERT(weights_rows > 0, "Invalid weights # rows"); + + unsigned int weights_cols = 0; + KASSERT(ReadUnsignedInt(file, &weights_cols), "Expected weight cols"); + KASSERT(weights_cols > 0, "Invalid weights shape"); + + weights_.Resize(weights_rows, weights_cols); + KASSERT(ReadFloats(file, weights_.data_.data(), weights_rows * weights_cols), "Expected weights"); + + return true; +} + +bool KerasLayerEmbedding::Apply(Tensor* in, Tensor* out) +{ + int output_rows = in->dims_[1]; + int output_cols = weights_.dims_[1]; + out->dims_ = {output_rows, output_cols}; + out->data_.reserve(output_rows*output_cols); + + std::for_each(in->data_.begin(), in->data_.end(), [=](float i){ + std::vector::const_iterator first = this->weights_.data_.begin() + (i*output_cols); + std::vector::const_iterator last = this->weights_.data_.begin() + (i+1)*output_cols; + + out->data_.insert(out->data_.end(), first, last); + }); + + + return true; +} + + +bool KerasLayerLSTM::step(Tensor* x, Tensor* out, Tensor* ht_1, Tensor* ct_1) +{ + Tensor xi = K::add(K::dot(*x, Wi_), bi_); + Tensor xf = K::add(K::dot(*x, Wf_), bf_); + Tensor xc = K::add(K::dot(*x, Wc_), bc_); + Tensor xo = K::add(K::dot(*x, Wo_), bo_); + + Tensor i_ = K::add(xi, K::dot(*ht_1, Ui_)); + Tensor f_ = K::add(xf, K::dot(*ht_1, Uf_)); + Tensor c_ = K::add(xc, K::dot(*ht_1, Uc_)); + Tensor o_ = K::add(xo, K::dot(*ht_1, Uo_)); + + + Tensor i, f, cc, o; + + KASSERT(innerActivation_.Apply(&i_, &i), "Failed to apply inner activation on i"); + KASSERT(innerActivation_.Apply(&f_, &f), "Failed to apply inner activation on f"); + KASSERT(activation_.Apply(&c_, &cc), "Failed to apply activation on c_"); + KASSERT(innerActivation_.Apply(&o_, &o), "Failed to apply inner activation on o"); + + *ct_1 = K::add(K::mult(f, *ct_1), K::mult(i, cc)); + + + KASSERT(activation_.Apply(ct_1, &cc), "Failed to apply activation on c"); + *out = *ht_1 = K::mult(o, cc); + + return true; +} + + bool KerasModel::LoadModel(const std::string& filename) { std::ifstream file(filename.c_str(), std::ios::binary); @@ -392,6 +678,12 @@ bool KerasModel::LoadModel(const std::string& filename) case kMaxPooling2D: layer = new KerasLayerMaxPooling2d(); break; + case kLSTM: + layer = new KerasLayerLSTM(); + break; + case kEmbedding: + layer = new KerasLayerEmbedding(); + break; default: break; } diff --git a/keras_model.h b/keras_model.h index 6efc1c8..34ecf0d 100644 --- a/keras_model.h +++ b/keras_model.h @@ -11,6 +11,8 @@ #include #include #include +#include +#include #define KASSERT(x, ...) \ if (!(x)) { \ @@ -111,6 +113,15 @@ class Tensor { return data_[dims_[1] * i + j]; } + + inline float operator()(int i, int j) const + { + KDEBUG(dims_.size() == 2, "Invalid indexing for tensor"); + KDEBUG(i < dims_[0] && i >= 0, "Invalid i: %d (max %d)", i, dims_[0]); + KDEBUG(j < dims_[1] && j >= 0, "Invalid j: %d (max %d)", j, dims_[1]); + + return data_[dims_[1] * i + j]; + } inline float& operator()(int i, int j, int k) { @@ -196,11 +207,82 @@ class Tensor { printf(")\n"); } - std::vector dims_; std::vector data_; }; +namespace K { + inline void fill(Tensor* tensor, float value) { + std::fill(tensor->data_.begin(), tensor->data_.end(), value); + } + + inline Tensor unpack(const Tensor * tensor, int row) { + KASSERT(tensor->dims_.size() >= 2, "Invalid tensor"); + std::vector pack_dims = std::vector(tensor->dims_.begin()+1, tensor->dims_.end()); + int pack_size = std::accumulate(pack_dims.begin(), pack_dims.end(), 0); + + std::vector::const_iterator first = tensor->data_.begin() + (row*pack_size); + std::vector::const_iterator last = tensor->data_.begin() + (row+1)*pack_size; + + Tensor x = Tensor(); + x.dims_ = pack_dims; + x.data_ = std::vector(first, last); + + return x; + } + + inline Tensor select(const Tensor * tensor, int row) { + Tensor x = unpack(tensor, row); + x.dims_.insert(x.dims_.begin(), 1); + + return x; + } + + inline Tensor dot(const Tensor & a, const Tensor & b) { + KDEBUG(a.dims_.size() == 2, "Invalid tensor"); + KDEBUG(b.dims_.size() == 2, "Invalid tensor"); + KASSERT(a.dims_[1] == b.dims_[0], "Cannot multiply with different inner dimensions"); + + Tensor tmp(a.dims_[0], b.dims_[1]); + K::fill(&tmp, 0.0); + + for ( int i = 0; i < a.dims_[0]; i++ ){ + for ( int j = 0; j < b.dims_[1]; j++) { + for ( int k = 0; k < a.dims_[1]; k++ ) { + tmp(i, j) += a(i,k) * b(k,j); + } + } + } + + return tmp; + } + + + inline Tensor add(const Tensor & a, const Tensor & b) { + KASSERT(a.dims_== b.dims_, "Cannot add elements with different dimensions"); + + Tensor result; + result.dims_ = a.dims_; + result.data_.reserve(a.data_.size()); + + std::transform(a.data_.begin(), a.data_.end(), b.data_.begin(), std::back_inserter(result.data_), [](float x, float y){return x+y;}); + + return result; + } + + inline Tensor mult(const Tensor & a, const Tensor & b) { + KASSERT(a.dims_== b.dims_, "Cannot multipy elements with different dimensions"); + + Tensor result; + result.dims_ = a.dims_; + result.data_.reserve(a.data_.size()); + + std::transform(a.data_.begin(), a.data_.end(), b.data_.begin(), std::back_inserter(result.data_), [](float x, float y){return x*y;}); + + return result; + } +} + class KerasLayer { public: @@ -219,7 +301,10 @@ class KerasLayerActivation : public KerasLayer { { kLinear = 1, kRelu = 2, - kSoftPlus = 3 + kSoftPlus = 3, + kSigmoid = 4, + kTanh = 5, + kHardSigmoid = 6 }; KerasLayerActivation() @@ -323,6 +408,52 @@ class KerasLayerMaxPooling2d : public KerasLayer { unsigned int pool_size_k_; }; +class KerasLayerLSTM : public KerasLayer { +public: + KerasLayerLSTM() {} + + virtual ~KerasLayerLSTM() {} + + virtual bool LoadLayer(std::ifstream* file); + + virtual bool Apply(Tensor* in, Tensor* out); + +private: + bool step(Tensor * x, Tensor *out, Tensor * ht_1, Tensor * ct_1); + + Tensor Wi_; + Tensor Ui_; + Tensor bi_; + Tensor Wf_; + Tensor Uf_; + Tensor bf_; + Tensor Wc_; + Tensor Uc_; + Tensor bc_; + Tensor Wo_; + Tensor Uo_; + Tensor bo_; + + KerasLayerActivation innerActivation_; + KerasLayerActivation activation_; + bool returnSequences; +}; + +class KerasLayerEmbedding : public KerasLayer { +public: + KerasLayerEmbedding() {} + + virtual ~KerasLayerEmbedding() {} + + virtual bool LoadLayer(std::ifstream* file); + + virtual bool Apply(Tensor* in, Tensor* out); + +private: + + Tensor weights_; +}; + class KerasModel { public: @@ -333,13 +464,15 @@ class KerasModel { kFlatten = 3, kElu = 4, kActivation = 5, - kMaxPooling2D = 6 + kMaxPooling2D = 6, + kLSTM = 7, + kEmbedding = 8 }; KerasModel() {} - ~KerasModel() + virtual ~KerasModel() { for (unsigned int i = 0; i < layers_.size(); i++) { @@ -347,9 +480,9 @@ class KerasModel { } } - bool LoadModel(const std::string& filename); + virtual bool LoadModel(const std::string& filename); - bool Apply(Tensor* in, Tensor* out); + virtual bool Apply(Tensor* in, Tensor* out); private: std::vector layers_; diff --git a/keras_model_test.cc b/keras_model_test.cc index 684c809..2b85283 100644 --- a/keras_model_test.cc +++ b/keras_model_test.cc @@ -15,11 +15,18 @@ #include "test_elu_10.h" #include "test_relu_10.h" #include "test_dense_relu_10.h" +#include "test_dense_tanh_10.h" #include "test_conv_softplus_2x2.h" +#include "test_conv_hard_sigmoid_2x2.h" +#include "test_conv_sigmoid_2x2.h" #include "test_maxpool2d_1x1.h" #include "test_maxpool2d_2x2.h" #include "test_maxpool2d_3x2x2.h" #include "test_maxpool2d_3x3x3.h" +#include "test_lstm_simple_7x20.h" +#include "test_lstm_simple_stacked20x9.h" +#include "test_lstm_stacked150x83.h" +#include "test_embedding64.h" #include "test_benchmark.h" bool tensor_test() @@ -88,6 +95,42 @@ bool tensor_test() } } } + + { + Tensor a(2, 2); + Tensor b(2, 2); + + a.data_ = {1.0, 2.0, 3.0, 5.0}; + b.data_ = {2.0, 5.0, 4.0, 1.0}; + + Tensor result = K::add(a, b); + + KASSERT( result.data_ == std::vector({3.0, 7.0, 7.0, 6.0}), "Vector add failed" ); + } + + { + Tensor a(1, 2); + Tensor b(2, 1); + + a.data_ = {1.0, 2.0}; + b.data_ = {2.0, 5.0}; + + Tensor result = K::dot(a, b); + + KASSERT( result.data_ == std::vector({12.0}), "Vector mult failed" ); + } + + { + Tensor a(2, 1); + Tensor b(1, 2); + + a.data_ = {1.0, 2.0}; + b.data_ = {2.0, 5.0}; + + Tensor result = K::dot(a, b); + + KASSERT( result.data_ == std::vector({2.0, 5.0, 4.0, 10.0}), "Vector mult failed" ); + } return true; } @@ -133,7 +176,16 @@ int main() if (!test_dense_relu_10(&load_time, &apply_time)) return 1; - if (!test_conv_softplus_2x2(&load_time, &apply_time)) + if (!test_dense_tanh_10(&load_time, &apply_time)) + return 1; + + if (!test_conv_softplus_2x2(&load_time, &apply_time)) + return 1; + + if (!test_conv_hard_sigmoid_2x2(&load_time, &apply_time)) + return 1; + + if (!test_conv_sigmoid_2x2(&load_time, &apply_time)) return 1; if (!test_maxpool2d_1x1(&load_time, &apply_time)) @@ -147,6 +199,18 @@ int main() if (!test_maxpool2d_3x3x3(&load_time, &apply_time)) return 1; + + if (!test_lstm_simple_7x20(&load_time, &apply_time)) + return 1; + + if (!test_lstm_simple_stacked20x9(&load_time, &apply_time)) + return 1; + + if (!test_lstm_stacked150x83(&load_time, &apply_time)) + return 1; + + if (!test_embedding64(&load_time, &apply_time)) + return 1; // Run benchmark 5 times and report duration. double total_load_time = 0.0; diff --git a/kerasify.py b/kerasify.py index 30d039c..f70f4ce 100644 --- a/kerasify.py +++ b/kerasify.py @@ -7,10 +7,15 @@ LAYER_ELU = 4 LAYER_ACTIVATION = 5 LAYER_MAXPOOLING2D = 6 +LAYER_LSTM = 7 +LAYER_EMBEDDING = 8 ACTIVATION_LINEAR = 1 ACTIVATION_RELU = 2 ACTIVATION_SOFTPLUS = 3 +ACTIVATION_SIGMOID = 4 +ACTIVATION_TANH = 5 +ACTIVATION_HARD_SIGMOID = 6 def write_floats(file, floats): ''' @@ -37,13 +42,20 @@ def write_activation(activation): f.write(struct.pack('I', ACTIVATION_RELU)) elif activation == 'softplus': f.write(struct.pack('I', ACTIVATION_SOFTPLUS)) + elif activation == 'tanh': + f.write(struct.pack('I', ACTIVATION_TANH)) + elif activation == 'sigmoid': + f.write(struct.pack('I', ACTIVATION_SIGMOID)) + elif activation == 'hard_sigmoid': + f.write(struct.pack('I', ACTIVATION_HARD_SIGMOID)) else: assert False, "Unsupported activation type: %s" % activation - num_layers = len(model.layers) + model_layers = [l for l in model.layers if type(l).__name__ not in ['Dropout']] + num_layers = len(model_layers) f.write(struct.pack('I', num_layers)) - for layer in model.layers: + for layer in model_layers: layer_type = type(layer).__name__ if layer_type == 'Dense': @@ -111,7 +123,94 @@ def write_activation(activation): f.write(struct.pack('I', LAYER_MAXPOOLING2D)) f.write(struct.pack('I', pool_size[0])) f.write(struct.pack('I', pool_size[1])) - + + elif layer_type == 'LSTM': + inner_activation = layer.get_config()['inner_activation'] + activation = layer.get_config()['activation'] + return_sequences = int(layer.get_config()['return_sequences']) + + weights = layer.get_weights() + W_i = weights[0] + U_i = weights[1] + b_i = weights[2] + + W_c = weights[3] + U_c = weights[4] + b_c = weights[5] + + W_f = weights[6] + U_f = weights[7] + b_f = weights[8] + + W_o = weights[9] + U_o = weights[10] + b_o = weights[11] + + f.write(struct.pack('I', LAYER_LSTM)) + f.write(struct.pack('I', W_i.shape[0])) + f.write(struct.pack('I', W_i.shape[1])) + f.write(struct.pack('I', U_i.shape[0])) + f.write(struct.pack('I', U_i.shape[1])) + f.write(struct.pack('I', b_i.shape[0])) + + f.write(struct.pack('I', W_f.shape[0])) + f.write(struct.pack('I', W_f.shape[1])) + f.write(struct.pack('I', U_f.shape[0])) + f.write(struct.pack('I', U_f.shape[1])) + f.write(struct.pack('I', b_f.shape[0])) + + f.write(struct.pack('I', W_c.shape[0])) + f.write(struct.pack('I', W_c.shape[1])) + f.write(struct.pack('I', U_c.shape[0])) + f.write(struct.pack('I', U_c.shape[1])) + f.write(struct.pack('I', b_c.shape[0])) + + f.write(struct.pack('I', W_o.shape[0])) + f.write(struct.pack('I', W_o.shape[1])) + f.write(struct.pack('I', U_o.shape[0])) + f.write(struct.pack('I', U_o.shape[1])) + f.write(struct.pack('I', b_o.shape[0])) + + W_i = W_i.flatten() + U_i = U_i.flatten() + b_i = b_i.flatten() + W_f = W_f.flatten() + U_f = U_f.flatten() + b_f = b_f.flatten() + W_c = W_c.flatten() + U_c = U_c.flatten() + b_c = b_c.flatten() + W_o = W_o.flatten() + U_o = U_o.flatten() + b_o = b_o.flatten() + + write_floats(f, W_i) + write_floats(f, U_i) + write_floats(f, b_i) + write_floats(f, W_f) + write_floats(f, U_f) + write_floats(f, b_f) + write_floats(f, W_c) + write_floats(f, U_c) + write_floats(f, b_c) + write_floats(f, W_o) + write_floats(f, U_o) + write_floats(f, b_o) + + write_activation(inner_activation) + write_activation(activation) + f.write(struct.pack('I', return_sequences)) + + elif layer_type == 'Embedding': + weights = layer.get_weights()[0] + + f.write(struct.pack('I', LAYER_EMBEDDING)) + f.write(struct.pack('I', weights.shape[0])) + f.write(struct.pack('I', weights.shape[1])) + + weights = weights.flatten() + + write_floats(f, weights) else: assert False, "Unsupported layer type: %s" % layer_type diff --git a/make_tests.py b/make_tests.py index 153775e..f7370ac 100644 --- a/make_tests.py +++ b/make_tests.py @@ -2,8 +2,10 @@ import pprint from keras.models import Sequential -from keras.layers import Convolution2D, Dense, Flatten, Activation, MaxPooling2D +from keras.layers import Convolution2D, Dense, Flatten, Activation, MaxPooling2D, Dropout +from keras.layers.recurrent import LSTM from keras.layers.advanced_activations import ELU +from keras.layers.embeddings import Embedding from kerasify import export_model @@ -64,11 +66,11 @@ def c_array(a): ''' def output_testcase(model, test_x, test_y, name, eps): - print "Processing %s" % name + print("Processing %s" % name) model.compile(loss='mean_squared_error', optimizer='adamax') model.fit(test_x, test_y, nb_epoch=1, verbose=False) predict_y = model.predict(test_x).astype('f') - print model.summary() + print(model.summary()) export_model(model, 'test_%s.model' % name) @@ -182,6 +184,16 @@ def output_testcase(model, test_x, test_y, name, eps): output_testcase(model, test_x, test_y, 'dense_relu_10', '1e-6') +''' Dense relu ''' +test_x = np.random.rand(1, 10).astype('f') +test_y = np.random.rand(1, 10).astype('f') +model = Sequential() +model.add(Dense(10, input_dim=10, activation='tanh')) +model.add(Dense(10, input_dim=10, activation='tanh')) +model.add(Dense(10, input_dim=10, activation='tanh')) + +output_testcase(model, test_x, test_y, 'dense_tanh_10', '1e-6') + ''' Conv softplus ''' test_x = np.random.rand(10, 1, 2, 2).astype('f') test_y = np.random.rand(10, 1).astype('f') @@ -193,6 +205,27 @@ def output_testcase(model, test_x, test_y, name, eps): output_testcase(model, test_x, test_y, 'conv_softplus_2x2', '1e-6') +''' Conv hardsigmoid ''' +test_x = np.random.rand(10, 1, 2, 2).astype('f') +test_y = np.random.rand(10, 1).astype('f') +model = Sequential() +model.add(Convolution2D(1, 2, 2, input_shape=(1, 2, 2), activation='hard_sigmoid')) +model.add(Flatten()) +model.add(Dense(1)) + +output_testcase(model, test_x, test_y, 'conv_hard_sigmoid_2x2', '1e-6') + +''' Conv sigmoid ''' +test_x = np.random.rand(10, 1, 2, 2).astype('f') +test_y = np.random.rand(10, 1).astype('f') +model = Sequential() +model.add(Convolution2D(1, 2, 2, input_shape=(1, 2, 2), activation='sigmoid')) +model.add(Flatten()) +model.add(Dense(1)) + +output_testcase(model, test_x, test_y, 'conv_sigmoid_2x2', '1e-6') + + ''' Maxpooling2D 1x1''' test_x = np.random.rand(10, 1, 10, 10).astype('f') test_y = np.random.rand(10, 1).astype('f') @@ -233,6 +266,48 @@ def output_testcase(model, test_x, test_y, name, eps): output_testcase(model, test_x, test_y, 'maxpool2d_3x3x3', '1e-6') +''' LSTM simple 7x20 ''' +test_x = np.random.rand(10, 7, 20).astype('f') +test_y = np.random.rand(10, 3).astype('f') +model = Sequential() +model.add(LSTM(3, return_sequences=False, input_shape=(7, 20))) + +output_testcase(model, test_x, test_y, 'lstm_simple_7x20', '1e-6') + + +''' LSTM simple stacked 20x9 ''' +test_x = np.random.rand(10, 20, 9).astype('f') +test_y = np.random.rand(10, 1).astype('f') +model = Sequential() +model.add(LSTM(32, return_sequences=False, input_shape=(20, 9))) +model.add(Dense(3, input_dim=32, activation='tanh')) +model.add(Dense(1)) + +output_testcase(model, test_x, test_y, 'lstm_simple_stacked20x9', '1e-6') + +''' LSTM stacked 150x83 ''' +test_x = np.random.rand(10, 150, 83).astype('f') +test_y = np.random.rand(10, 1).astype('f') +model = Sequential() +model.add(LSTM(32, return_sequences=True, input_shape=(150, 83))) +model.add(LSTM(32, return_sequences=False)) +model.add(Dense(1, activation='sigmoid')) + +output_testcase(model, test_x, test_y, 'lstm_stacked150x83', '1e-6') + + +''' Embedding 64 ''' +np.random.seed(10) +test_x = np.random.randint(100, size=(32, 10)).astype('f') +test_y = np.random.rand(32, 20).astype('f') +model = Sequential() +model.add(Embedding(100, 64, input_length=10)) +model.add(Flatten()) +#model.add(Dropout(0.5)) +model.add(Dense(20, activation='sigmoid')) + +output_testcase(model, test_x, test_y, 'embedding64', '1e-6') + ''' Benchmark ''' test_x = np.random.rand(1, 3, 128, 128).astype('f')