Skip to content

Commit

Permalink
Merge pull request #8 from reyoung/feature/refactorize_framework_proto
Browse files Browse the repository at this point in the history
Catch-up with develop branch
  • Loading branch information
reyoung authored Aug 10, 2017
2 parents c7e8c1a + 7fab7dd commit 5ac3641
Show file tree
Hide file tree
Showing 18 changed files with 356 additions and 14 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ include(simd)
################################ Configurations #######################################
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." ${AVX_FOUND})
option(WITH_MKLML "Compile PaddlePaddle with mklml package." ${AVX_FOUND})
option(WITH_MKLDNN "Compile PaddlePaddle with mkl-dnn support." OFF)
option(WITH_MKLML "Compile PaddlePaddle with mklml package." OFF)
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
option(WITH_TESTING "Compile PaddlePaddle with unit testing" ON)
option(WITH_SWIG_PY "Compile PaddlePaddle with inference api" ON)
Expand Down
1 change: 1 addition & 0 deletions paddle/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,6 @@ cc_library(paddle_pybind SHARED
cross_entropy_op
recurrent_op
uniform_random_op
gaussian_random_op
fill_zeros_like_op)
endif(WITH_PYTHON)
2 changes: 1 addition & 1 deletion paddle/framework/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ limitations under the License. */

#pragma once

#include <boost/variant.hpp>
#include <functional>
#include <string>
#include <unordered_map>
Expand All @@ -23,6 +22,7 @@ limitations under the License. */

#include "paddle/framework/framework.pb.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/variant.h"

namespace paddle {
namespace framework {
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/ddim.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ limitations under the License. */

#pragma once

#include <boost/variant.hpp>
#include <initializer_list>
#include <stdexcept>
#include <vector>
#include "paddle/framework/dim.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/variant.h"
#include "unsupported/Eigen/CXX11/Tensor"

namespace paddle {
Expand Down
2 changes: 1 addition & 1 deletion paddle/framework/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ limitations under the License. */
#pragma once

#include <algorithm>
#include <boost/variant.hpp>
#include <string>
#include <unordered_map>
#include <vector>
Expand All @@ -26,6 +25,7 @@ limitations under the License. */
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/platform/variant.h"
#include "paddle/utils/Error.h"

namespace paddle {
Expand Down
2 changes: 2 additions & 0 deletions paddle/framework/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ USE_OP(softmax);
USE_OP(rowwise_add);
USE_OP(fill_zeros_like);
USE_OP_WITHOUT_KERNEL(recurrent_op);
USE_OP(gaussian_random);
USE_OP(uniform_random);

namespace paddle {
namespace framework {

Expand Down
4 changes: 2 additions & 2 deletions paddle/framework/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,11 @@ class Tensor {
inline const DDim& dims() const;

/*! Resize the dimensions of the memory block. */
inline void Resize(const DDim& dims);
inline Tensor& Resize(const DDim& dims);

/*! The internal of two tensors share the same memory block. */
template <typename T>
inline void ShareDataWith(const Tensor& src);
inline Tensor& ShareDataWith(const Tensor& src);

/**
* @brief Copy the content of external tensor to a new place.
Expand Down
16 changes: 11 additions & 5 deletions paddle/framework/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@ template <typename T>
inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_GE(holder_->size(), product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.");
PADDLE_ENFORCE_GE(
holder_->size(), product(dims_) * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.\n"
"or maybe the required data-type mismatches the data already stored.");
}

template <typename T>
Expand Down Expand Up @@ -78,9 +80,10 @@ inline T* Tensor::mutable_data(platform::Place place) {
}

template <typename T>
inline void Tensor::ShareDataWith(const Tensor& src) {
inline Tensor& Tensor::ShareDataWith(const Tensor& src) {
src.check_memory_size<T>();
*this = src;
return *this;
}

template <typename T>
Expand Down Expand Up @@ -136,7 +139,10 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
return dst;
}

inline void Tensor::Resize(const DDim& dims) { dims_ = dims; }
inline Tensor& Tensor::Resize(const DDim& dims) {
dims_ = dims;
return *this;
}

inline const DDim& Tensor::dims() const { return dims_; }

Expand Down
7 changes: 6 additions & 1 deletion paddle/gserver/tests/test_KmaxSeqScore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@ TEST(Layer, kmaxSeqScoreLayer) {
MatrixPtr inValue =
Matrix::create(subSeqStartPosition.back(), 1, false, false);

std::vector<bool> mode = {false};
#ifndef PADDLE_ONLY_CPU
mode.push_back(true);
#endif

for (auto hasSubseq : {false, true}) {
vector<vector<int>> groundTruth;
inValue->randomizeUniform();
Expand All @@ -104,7 +109,7 @@ TEST(Layer, kmaxSeqScoreLayer) {
hasSubseq ? subSeqStartPosition : seqStartPosition,
beamSize);

for (auto useGpu : {false, true}) {
for (auto useGpu : mode) {
TestConfig config;
config.layerConfig.set_type("kmax_seq_score");
config.layerConfig.set_beam_size(beamSize);
Expand Down
3 changes: 3 additions & 0 deletions paddle/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ function(op_library TARGET)
endif()
endfunction()

cc_test(gather_test SRCS gather_test.cc DEPS tensor)

cc_library(net_op SRCS net_op.cc DEPS op_registry)
cc_test(net_op_test SRCS net_op_test.cc DEPS net_op)

Expand All @@ -53,6 +55,7 @@ op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc)

op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu)
op_library(softmax_op SRCS softmax_op.cc softmax_op.cu)
op_library(gaussian_random_op SRCS gaussian_random_op.cc gaussian_random_op.cu)
op_library(cross_entropy_op SRCS cross_entropy_op.cc cross_entropy_op.cu)
op_library(fill_zeros_like_op SRCS fill_zeros_like_op.cc fill_zeros_like_op.cu)

Expand Down
73 changes: 73 additions & 0 deletions paddle/operators/gather.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <memory.h>
#include <cstring>

#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"

namespace paddle {
namespace operators {

// Implementation of CPU copy
template <typename T>
void CPUGather(const T* params, const int* indices, const int slice_size,
const int index_size, T* output) {
const size_t slice_bytes = slice_size * sizeof(T);

for (size_t i = 0; i < index_size; ++i) {
int index_ = indices[i];
memcpy(output + i * slice_size, params + index_ * slice_size, slice_bytes);
}
}

// Implementation of GPU copy:
template <typename T>
void GPUGather(const T* src, const int* index, const int slice_size,
const int index_size, T* output);

/**
* Return a new tensor from source tensor, gathered according to index
* input[src]: type-T source Tensor
* input[index]: type-int index Tensor (1-D)
* return: output tensor
*/
template <typename T>
void Gather(const platform::Place& place, const paddle::framework::Tensor* src,
const paddle::framework::Tensor* index,
paddle::framework::Tensor* output) {
// check index of shape 1-D
PADDLE_ENFORCE(index->dims().size() == 1);
int index_size = index->dims()[0];

auto src_dims = src->dims();
paddle::framework::DDim output_dims(src_dims);
output_dims[0] = index_size;

// slice size
int slice_size = 1;
for (size_t i = 1; i < src_dims.size(); ++i) slice_size *= src_dims[i];

// Gathering
if (platform::is_cpu_place(place)) {
CPUGather<T>(src->data<T>(), index->data<int>(), slice_size, index_size,
output->data<T>());
}
}

} // namespace operators
} // namespace paddle
48 changes: 48 additions & 0 deletions paddle/operators/gather_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/operators/gather.h"
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/place.h"

#include <gtest/gtest.h>
#include <iostream>
#include <string>

TEST(Gather, GatherData) {
using namespace paddle::framework;
using namespace paddle::platform;
using namespace paddle::operators;

Tensor* src = new Tensor();
Tensor* index = new Tensor();
Tensor* output = new Tensor();

int* p_src = nullptr;
int* p_index = nullptr;
p_src = src->mutable_data<int>(make_ddim({3, 4}), CPUPlace());
p_index = index->mutable_data<int>(make_ddim({2}), CPUPlace());

for (size_t i = 0; i < 12; ++i) p_src[i] = i;
p_index[0] = 1;
p_index[1] = 0;

int* p_output = output->mutable_data<int>(make_ddim({2, 4}), CPUPlace());

Gather<int>(CPUPlace(), src, index, output);

for (size_t i = 0; i < 4; ++i) EXPECT_EQ(p_output[i], i + 4);
for (size_t i = 4; i < 8; ++i) EXPECT_EQ(p_output[i], i - 4);
}
82 changes: 82 additions & 0 deletions paddle/operators/gaussian_random_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <random>
#include "paddle/framework/op_registry.h"

namespace paddle {
namespace operators {

template <typename T>
class GaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op_.GetAttr<float>("mean");
float std = context.op_.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>(0);
T* data = tensor->mutable_data<T>(context.GetPlace());

// TODO(dzh): attribute does not support unsigned int.
// And we need a global random seed configuration.
int seed = context.op_.GetAttr<int>("seed");
if (seed == 0) {
seed = std::random_device()();
}
std::mt19937 g(seed);
std::normal_distribution<T> distribution(mean, std);
ssize_t size = framework::product(tensor->dims());
for (int i = 0; i < size; ++i) {
data[i] = distribution(g);
}
}
};

class GaussianRandomOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>(0);
auto dims = GetAttr<std::vector<int>>("dims");
PADDLE_ENFORCE(dims.size() > 0UL,
"dims can be one int or array. dims must be set.");
tensor->Resize(framework::make_ddim(dims));
}
};

class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GaussianRandomOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "output matrix of random op");
AddComment(R"DOC(
GaussianRandom operator.
Use to initialize tensor with gaussian random generator.
)DOC");

AddAttr<std::vector<int>>("dims", "The dimension of random tensor.");
AddAttr<float>("mean", "mean value of random.").SetDefault(.0f);
AddAttr<float>("std", "minimum value of random value.").SetDefault(1.0f);
AddAttr<int>("seed",
"Random seed of generator."
"0 means use system wide seed")
.SetDefault(0);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP(gaussian_random, ops::GaussianRandomOp, ops::GaussianRandomOpMaker);
REGISTER_OP_CPU_KERNEL(gaussian_random, ops::GaussianRandomKernel<float>);
Loading

0 comments on commit 5ac3641

Please sign in to comment.