diff --git a/.gitmodules b/.gitmodules index 2f2130f..7c12992 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "Etaler/3rdparty/pcg-cpp"] path = Etaler/3rdparty/pcg-cpp url = https://github.com/imneme/pcg-cpp +[submodule "Etaler/3rdparty/half_precision"] + path = Etaler/3rdparty/half_precision + url = https://github.com/marty1885/half_precision diff --git a/CMakeLists.txt b/CMakeLists.txt index 3e2c247..4ab87da 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,6 @@ cmake_minimum_required(VERSION 3.0) +project(Etaler) + SET(BUILD_SHARED_LIBS ON) diff --git a/Etaler/3rdparty/half_precision b/Etaler/3rdparty/half_precision new file mode 160000 index 0000000..82cdf3a --- /dev/null +++ b/Etaler/3rdparty/half_precision @@ -0,0 +1 @@ +Subproject commit 82cdf3a4cf414f2c862232fa948e0208a760af6c diff --git a/Etaler/Algorithms/SpatialPooler.hpp b/Etaler/Algorithms/SpatialPooler.hpp index aee57f8..b96d0d8 100644 --- a/Etaler/Algorithms/SpatialPooler.hpp +++ b/Etaler/Algorithms/SpatialPooler.hpp @@ -71,7 +71,7 @@ struct ETALER_EXPORT SpatialPooler { return to(connections_.backend()); } -protected: +//protected: float permanence_inc_ = 0.1; float permanence_dec_ = 0.1; float connected_permanence_ = 0.21; diff --git a/Etaler/Backends/CPUBackend.cpp b/Etaler/Backends/CPUBackend.cpp index 4f9bab8..bf954e4 100644 --- a/Etaler/Backends/CPUBackend.cpp +++ b/Etaler/Backends/CPUBackend.cpp @@ -1,6 +1,7 @@ #include "CPUBackend.hpp" #include "Etaler/Core/Views.hpp" #include "Etaler/Core/Random.hpp" +#include "Etaler/Core/TypeList.hpp" #include #include @@ -54,31 +55,43 @@ void* CPUBuffer::data() const return std::visit([](const auto& v){return (void*)v;}, storage_); } -std::shared_ptr CPUBackend::cellActivity(const TensorImpl* x, const TensorImpl* connections, const TensorImpl* permeances - , float connected_permeance, size_t active_threshold, bool has_unconnected_synapse) +template , typename Func = void> +inline void dispatch(DType dtype, Func f) +{ + static_assert(std::is_same_v == false); //void is just a dummy value + if constexpr(std::is_same_v == false) { + using T = typename TypeList::head; + if(typeToDType() == dtype) { + f(T()); + return; + } + dispatch(dtype, f); + } + else + throw EtError("Cannot dispatch such dtype: " + to_ctype_string(dtype)); +} + +namespace et::detail +{ +template +static std::shared_ptr cellActivity(const TensorImpl* x, const TensorImpl* connections, const TensorImpl* permeances + , float connected_permeance, size_t active_threshold, bool has_unconnected_synapse, CPUBackend* backend) { //Checks the input are sane - et_assert(x->backend() == this); - et_assert(connections->backend() == this); - et_assert(permeances->backend() == this); - et_assert(x->iscontiguous()); - et_assert(connections->iscontiguous()); - et_assert(permeances->iscontiguous()); - - et_assert(x->dtype() == DType::Bool); - et_assert(connections->dtype() == DType::Int32); - et_assert(permeances->dtype() == DType::Float); + requireProperties(x, backend, DType::Bool, IsContingous()); + requireProperties(connections, backend, DType::Int32, IsContingous()); + requireProperties(permeances, backend, typeToDType(), IsContingous()); et_assert(connections->shape() == permeances->shape()); et_assert(connections->dimentions() >= 2); Shape s = connections->shape(); s.pop_back(); - auto y = createTensor(s, DType::Int32); + auto y = backend->createTensor(s, DType::Int32); const bool* input = (const bool*)x->data(); const int32_t* synapses = (const int32_t*)connections->data(); - const float* synapse_strengths = (float*)permeances->data(); + const PermType* synapse_strengths = (PermType*)permeances->data(); int32_t* result = (int32_t*)y->data(); size_t max_connections_per_cell = connections->shape().back(); @@ -101,7 +114,7 @@ std::shared_ptr CPUBackend::cellActivity(const TensorImpl* x, const if(input[target] == false) continue; - float strength = synapse_strengths[index]; + PermType strength = synapse_strengths[index]; if(strength > connected_permeance) sum += 1; } @@ -115,29 +128,22 @@ std::shared_ptr CPUBackend::cellActivity(const TensorImpl* x, const return y; } -void CPUBackend::learnCorrilation(const TensorImpl* x, const TensorImpl* learn, const TensorImpl* connections, TensorImpl* permeances - , float perm_inc, float perm_dec, bool has_unconnected_synapse) +template +void learnCorrilation(const TensorImpl* x, const TensorImpl* learn, const TensorImpl* connections, TensorImpl* permeances + , float perm_inc, float perm_dec, bool has_unconnected_synapse, CPUBackend* backend) { - et_assert(x->backend() == this); - et_assert(connections->backend() == this); - et_assert(permeances->backend() == this); - et_assert(learn->backend() == this); - et_assert(x->iscontiguous()); - et_assert(learn->iscontiguous()); - et_assert(connections->iscontiguous()); - et_assert(permeances->iscontiguous()); + requireProperties(x, backend, DType::Bool, IsContingous()); + requireProperties(learn, backend, DType::Bool, IsContingous()); + requireProperties(connections, backend, DType::Int32, IsContingous()); + requireProperties(permeances, backend, typeToDType(), IsContingous()); et_assert(connections->shape() == permeances->shape()); et_assert(x->shape() == learn->shape()); - et_assert(x->dtype() == DType::Bool); - et_assert(learn->dtype() == DType::Bool); - et_assert(connections->dtype() == DType::Int32); - et_assert(permeances->dtype() == DType::Float); const bool* input = (const bool*)x->data(); const bool* learning = (const bool*)learn->data(); const int32_t* synapses = (const int32_t*)connections->data(); - float* synapse_strengths = (float*)permeances->data(); + PermType* synapse_strengths = (PermType*)permeances->data(); size_t max_connections_per_cell = connections->shape().back(); size_t num_cells = connections->size()/max_connections_per_cell; @@ -153,23 +159,183 @@ void CPUBackend::learnCorrilation(const TensorImpl* x, const TensorImpl* learn, break; ASSERT((size_t)connection < num_cells); - float& perm = synapse_strengths[idx]; + PermType& perm = synapse_strengths[idx]; if(input[connection] == true) perm += perm_inc; else perm -= perm_dec; - perm = std::max(std::min(perm, 1.f), 0.f); + perm = std::max(std::min(perm, PermType(1)), PermType(0)); } }); } -std::shared_ptr CPUBackend::globalInhibition(const TensorImpl* x, float fraction) +template +void sortSynapse(TensorImpl* connections, TensorImpl* permeances, CPUBackend* backend) +{ + requireProperties(connections, backend, DType::Int32, IsContingous()); + requireProperties(permeances, backend, typeToDType(), IsContingous()); + et_assert(connections->shape() == permeances->shape()); + + size_t max_synapse_per_cell = connections->shape().back(); + size_t num_cells = connections->size()/max_synapse_per_cell; + + uint32_t* conns = (uint32_t*)connections->data(); //HACK: -1s should be at the end of the arrays. + PermType* perms = (PermType*)permeances->data(); + + tbb::parallel_for(size_t(0), num_cells, [&](size_t i) { + size_t start_index = i*max_synapse_per_cell; + size_t end_index = (i+1)*max_synapse_per_cell; + + std::vector sort_indices(max_synapse_per_cell); + std::iota(sort_indices.begin(), sort_indices.end(), 0); + std::sort(sort_indices.begin(), sort_indices.end(), + [&](size_t i, size_t j)->bool { + return conns[i+start_index] < conns[j+start_index]; + }); + apply_permutation_in_place(conns+start_index, conns+end_index, sort_indices); + apply_permutation_in_place(perms+start_index, perms+end_index, sort_indices); + }); +} + +template +void growSynapses(const TensorImpl* x, const TensorImpl* y, TensorImpl* connections + , TensorImpl* permeances, float initial_perm, CPUBackend* backend) +{ + requireProperties(x, backend, DType::Bool, IsContingous()); + requireProperties(y, backend, DType::Bool, IsContingous()); + requireProperties(connections, backend, DType::Int32, IsContingous()); + requireProperties(permeances, backend, typeToDType(), IsContingous()); + + et_assert(connections->shape() == permeances->shape()); + Shape s = connections->shape(); + s.pop_back(); + et_assert(s == y->shape()); + + size_t max_synapses_per_cell = connections->shape().back(); + size_t input_cell_count = x->size(); + + const bool* in = (const bool*) x->data(); + const bool* out = (const bool*) y->data(); + int32_t* conns = (int32_t*)connections->data(); + PermType* perms = (PermType*)permeances->data(); + + std::vector on_bits; + on_bits.reserve(input_cell_count*0.1); + for(size_t i=0;ishape().back()); + tbb::parallel_for(tbb::blocked_range(size_t(0), y->size(), block_size), [&](const auto& r) { + for(size_t i=r.begin();i!=r.end();i++) { + if(out[i] == 0) + continue; + + uint32_t* synapses = (uint32_t*)conns+i*max_synapses_per_cell; + PermType* strengths = perms+i*max_synapses_per_cell; + uint32_t* end = synapses+max_synapses_per_cell; + + if(synapses[max_synapses_per_cell-1] != uint32_t(-1)) //If there is no space for new synapse. Ignore + continue; + + uint32_t* it = std::lower_bound(synapses, end, uint32_t(-1)); + size_t used_space = it - synapses; + + size_t write_idx = it - synapses; + size_t read_idx = 0; + + for(size_t j=0;write_idx!=max_synapses_per_cell && j < on_bits.size();j++) { + bool connected = false; + for(;read_idx on_bits[j]) + break; + } + + if(connected == false) { + synapses[write_idx] = on_bits[j]; + strengths[write_idx] = initial_perm; + write_idx++; + } + } + + std::vector sort_indices(write_idx); + std::iota(sort_indices.begin(), sort_indices.begin()+write_idx, 0); + std::sort(sort_indices.begin(), sort_indices.begin()+write_idx, + [&](size_t i, size_t j)->bool { + return ((uint32_t*)synapses)[i] < ((uint32_t*)synapses)[j]; + }); + apply_permutation_in_place(synapses, synapses+write_idx, sort_indices); + apply_permutation_in_place(strengths, strengths+write_idx, sort_indices); + } + }); +} + +template +void decaySynapses(TensorImpl* connections, TensorImpl* permeances, float threshold, CPUBackend* backend) +{ + requireProperties(connections, backend, DType::Int32, IsContingous()); + requireProperties(permeances, backend, typeToDType(), IsContingous()); + et_assert(connections->shape() == permeances->shape()); + + PermType* perms = (PermType*)permeances->data(); + uint32_t* conns = (uint32_t*)connections->data(); + + size_t max_synapses_per_cell = connections->shape().back(); + size_t input_cell_count = connections->size()/max_synapses_per_cell; + + tbb::parallel_for(size_t(0), input_cell_count, [&](size_t i) { + uint32_t* synapses = (uint32_t*)conns+i*max_synapses_per_cell; + PermType* strengths = perms+i*max_synapses_per_cell; + uint32_t* end = synapses+max_synapses_per_cell; + + uint32_t* it = std::lower_bound(synapses, end, uint32_t(-1)); + size_t used_space = it - synapses; + + for(size_t j=0;j sort_indices(used_space); + std::iota(sort_indices.begin(), sort_indices.begin()+used_space, 0); + std::sort(sort_indices.begin(), sort_indices.begin()+used_space, + [&](size_t i, size_t j)->bool { + return ((uint32_t*)synapses)[i] < ((uint32_t*)synapses)[j]; + }); + apply_permutation_in_place(synapses, synapses+used_space, sort_indices); + apply_permutation_in_place(strengths, strengths+used_space, sort_indices); + }); +} + +} + +std::shared_ptr CPUBackend::cellActivity(const TensorImpl* x, const TensorImpl* connections, const TensorImpl* permeances + , float connected_permeance, size_t active_threshold, bool has_unconnected_synapse) { - et_assert(x->backend() == this); - et_assert(x->iscontiguous()); + std::shared_ptr res; + dispatch>(permeances->dtype(), [&](auto v){ + res = detail::cellActivity(x, connections, permeances, connected_permeance, active_threshold, has_unconnected_synapse, this); + }); + return res; +} - et_assert(x->dtype() == DType::Int32); +void CPUBackend::learnCorrilation(const TensorImpl* x, const TensorImpl* learn, const TensorImpl* connections, TensorImpl* permeances + , float perm_inc, float perm_dec, bool has_unconnected_synapse) +{ + dispatch>(permeances->dtype(), [&](auto v){ + detail::learnCorrilation(x, learn, connections, permeances, perm_inc, perm_dec, has_unconnected_synapse, this); + }); +} + +std::shared_ptr CPUBackend::globalInhibition(const TensorImpl* x, float fraction) +{ + requireProperties(x, this, DType::Int32, IsContingous()); auto y = createTensor(x->shape(), DType::Bool); @@ -204,6 +370,8 @@ static Ret run(const CPUBuffer& t, Op op) return op((int32_t*)t.data()); else if(t.dtype() == DType::Float) return op((float*)t.data()); + else if(t.dtype() == DType::Half) + return op((half*)t.data()); else throw EtError("Cannot cast"); } @@ -216,8 +384,7 @@ static std::vector castData(const From* ptr, size_t n) std::shared_ptr CPUBackend::cast(const TensorImpl* x, DType toType) { - et_assert(x->backend() == this); - et_assert(x->iscontiguous()); + requireProperties(x, this, IsContingous()); const CPUBuffer* p = dynamic_cast(x->buffer().get()); const CPUBuffer& t = *p; return run>(t, [&x, toType, this](const auto* ptr){ @@ -234,6 +401,10 @@ std::shared_ptr CPUBackend::cast(const TensorImpl* x, DType toType) auto casted_data = castData(ptr, x->size()); return createTensor(x->shape(), toType, casted_data.data()); } + else if(toType == DType::Half){ + auto casted_data = castData(ptr, x->size()); + return createTensor(x->shape(), toType, casted_data.data()); + } else throw EtError("Cannot cast"); }); @@ -241,57 +412,27 @@ std::shared_ptr CPUBackend::cast(const TensorImpl* x, DType toType) void CPUBackend::copyToHost(const TensorImpl* t, void* ptr) { - et_assert(points_to(t->buffer().get())); - et_assert(t->iscontiguous()); + requireProperties(t, this, IsContingous()); memcpy(ptr, t->data(), t->size()*dtypeToSize(t->dtype())); } std::shared_ptr CPUBackend::copy(const TensorImpl* x) { - et_assert(x->backend() == this); - et_assert(x->iscontiguous()); + requireProperties(x, this, IsContingous()); return createTensor(x->shape(), x->dtype(), x->data()); } void CPUBackend::sortSynapse(TensorImpl* connections, TensorImpl* permeances) { - et_assert(connections->shape() == permeances->shape()); - et_assert(connections->backend() == this); - et_assert(permeances->backend() == this); - et_assert(connections->dtype() == DType::Int32); - et_assert(permeances->dtype() == DType::Float); - et_assert(connections->iscontiguous()); - et_assert(permeances->iscontiguous()); - - size_t max_synapse_per_cell = connections->shape().back(); - size_t num_cells = connections->size()/max_synapse_per_cell; - - uint32_t* conns = (uint32_t*)connections->data(); //HACK: -1s should be at the end of the arrays. - float* perms = (float*)permeances->data(); - - tbb::parallel_for(size_t(0), num_cells, [&](size_t i) { - size_t start_index = i*max_synapse_per_cell; - size_t end_index = (i+1)*max_synapse_per_cell; - - std::vector sort_indices(max_synapse_per_cell); - std::iota(sort_indices.begin(), sort_indices.end(), 0); - std::sort(sort_indices.begin(), sort_indices.end(), - [&](size_t i, size_t j)->bool { - return conns[i+start_index] < conns[j+start_index]; - }); - apply_permutation_in_place(conns+start_index, conns+end_index, sort_indices); - apply_permutation_in_place(perms+start_index, perms+end_index, sort_indices); + dispatch>(permeances->dtype(), [&](auto v) { + detail::sortSynapse(connections, permeances, this); }); } std::shared_ptr CPUBackend::burst(const TensorImpl* x, const TensorImpl* s) { - et_assert(x->backend() == this); - et_assert(s->backend() == this); - et_assert(x->dtype() == DType::Bool); - et_assert(s->dtype() == DType::Bool); - et_assert(x->iscontiguous()); - et_assert(s->iscontiguous()); + requireProperties(x, this, DType::Bool, IsContingous()); + requireProperties(s, this, DType::Bool, IsContingous()); Shape shape = s->shape(); shape.pop_back(); @@ -319,9 +460,7 @@ std::shared_ptr CPUBackend::burst(const TensorImpl* x, const TensorI std::shared_ptr CPUBackend::reverseBurst(const TensorImpl* x) { - et_assert(x->backend() == this); - et_assert(x->dtype() == DType::Bool); - et_assert(x->iscontiguous()); + requireProperties(x, this, DType::Bool, IsContingous()); size_t cells_per_column = x->shape().back(); size_t num_columns = x->size()/cells_per_column; @@ -348,86 +487,8 @@ std::shared_ptr CPUBackend::reverseBurst(const TensorImpl* x) void CPUBackend::growSynapses(const TensorImpl* x, const TensorImpl* y, TensorImpl* connections , TensorImpl* permeances, float initial_perm) { - et_assert(x->backend() == this); - et_assert(y->backend() == this); - et_assert(connections->backend() == this); - et_assert(permeances->backend() == this); - et_assert(x->iscontiguous()); - et_assert(y->iscontiguous()); - et_assert(connections->iscontiguous()); - et_assert(permeances->iscontiguous()); - - et_assert(x->dtype() == DType::Bool); - et_assert(y->dtype() == DType::Bool); - et_assert(connections->dtype() == DType::Int32); - et_assert(permeances->dtype() == DType::Float); - - et_assert(connections->shape() == permeances->shape()); - Shape s = connections->shape(); - s.pop_back(); - et_assert(s == y->shape()); - - size_t max_synapses_per_cell = connections->shape().back(); - size_t input_cell_count = x->size(); - - const bool* in = (const bool*) x->data(); - const bool* out = (const bool*) y->data(); - int32_t* conns = (int32_t*)connections->data(); - float* perms = (float*)permeances->data(); - - std::vector on_bits; - on_bits.reserve(input_cell_count*0.1); - for(size_t i=0;ishape().back()); - tbb::parallel_for(tbb::blocked_range(size_t(0), y->size(), block_size), [&](const auto& r) { - for(size_t i=r.begin();i!=r.end();i++) { - if(out[i] == 0) - continue; - - uint32_t* synapses = (uint32_t*)conns+i*max_synapses_per_cell; - float* strengths = perms+i*max_synapses_per_cell; - uint32_t* end = synapses+max_synapses_per_cell; - - if(synapses[max_synapses_per_cell-1] != uint32_t(-1)) //If there is no space for new synapse. Ignore - continue; - - uint32_t* it = std::lower_bound(synapses, end, uint32_t(-1)); - size_t used_space = it - synapses; - - size_t write_idx = it - synapses; - size_t read_idx = 0; - - for(size_t j=0;write_idx!=max_synapses_per_cell && j < on_bits.size();j++) { - bool connected = false; - for(;read_idx on_bits[j]) - break; - } - - if(connected == false) { - synapses[write_idx] = on_bits[j]; - strengths[write_idx] = initial_perm; - write_idx++; - } - } - - std::vector sort_indices(write_idx); - std::iota(sort_indices.begin(), sort_indices.begin()+write_idx, 0); - std::sort(sort_indices.begin(), sort_indices.begin()+write_idx, - [&](size_t i, size_t j)->bool { - return ((uint32_t*)synapses)[i] < ((uint32_t*)synapses)[j]; - }); - apply_permutation_in_place(synapses, synapses+write_idx, sort_indices); - apply_permutation_in_place(strengths, strengths+write_idx, sort_indices); - } + dispatch>(permeances->dtype(), [&](auto v) { + detail::growSynapses(x, y, connections, permeances, initial_perm, this); }); } @@ -440,19 +501,6 @@ const T* getPtrToValue(size_t parent_idx, const TensorImpl* t) return ((const T*)t->data())+offset; } -template -void dispatch(DType dtype, Func f) -{ - if(dtype == DType::Int32) - f(int32_t{}); - else if(dtype == DType::Float) - f(float{}); - else if(dtype == DType::Bool) - f(bool{}); - else - throw EtError("Cannot realize such dtype"); -} - template void write(T2* ptr, T1 v) { @@ -470,7 +518,9 @@ static std::shared_ptr uniaryOp(const TensorImpl* src, Op op) auto res = op(*ptr); using ResType = decltype(res); //We don't have support to double percition now. Cast it to float - using StoreType = typename std::conditional::value, float, ResType>::type; + using StoreType = typename std::conditional_t, bool + , typename std::conditional_t, half + , typename std::conditional_t, float, ResType>>>; if(i == 0) dest = src->backend()->createTensor(src->shape(), typeToDType()); @@ -515,7 +565,7 @@ static std::shared_ptr binaryOp(const TensorImpl* src, const TensorI std::shared_ptr CPUBackend::realize(const TensorImpl* x) { - et_assert(x->backend() == this); + requireProperties(x, this); et_assert(x->data() != nullptr); auto res = createTensor(x->shape(), x->dtype()); @@ -532,8 +582,8 @@ std::shared_ptr CPUBackend::realize(const TensorImpl* x) void CPUBackend::assign(TensorImpl* dest, const TensorImpl* src) { - et_assert(dest->backend() == this); - et_assert(src->backend() == this); + requireProperties(dest, this); + requireProperties(src, this); if(dest->shape() != src->shape()) throw EtError("Shape mismatch in tensor assignment. Shape " @@ -556,9 +606,8 @@ void CPUBackend::assign(TensorImpl* dest, const TensorImpl* src) std::shared_ptr CPUBackend::sum(const TensorImpl* x, size_t chunk_size, DType dtype) { - et_assert(x->backend() == this); + requireProperties(x, this, IsContingous()); et_assert(x->size() % chunk_size == 0); - et_assert(x->iscontiguous()); DType result_dtype = dtype; @@ -567,6 +616,8 @@ std::shared_ptr CPUBackend::sum(const TensorImpl* x, size_t chunk_si DType dtype = x->dtype(); if(dtype == DType::Bool || dtype == DType::Int32) return DType::Int32; + else if(dtype == DType::Half) + return DType::Half; else return DType::Float; }(); @@ -593,41 +644,8 @@ std::shared_ptr CPUBackend::sum(const TensorImpl* x, size_t chunk_si void CPUBackend::decaySynapses(TensorImpl* connections, TensorImpl* permeances, float threshold) { - et_assert(connections->shape() == permeances->shape()); - et_assert(connections->backend() == this); - et_assert(permeances->backend() == this); - et_assert(connections->dtype() == DType::Int32); - et_assert(permeances->dtype() == DType::Float); - et_assert(permeances->iscontiguous()); - et_assert(connections->iscontiguous()); - - float* perms = (float*)permeances->data(); - uint32_t* conns = (uint32_t*)connections->data(); - - size_t max_synapses_per_cell = connections->shape().back(); - size_t input_cell_count = connections->size()/max_synapses_per_cell; - - tbb::parallel_for(size_t(0), input_cell_count, [&](size_t i) { - uint32_t* synapses = (uint32_t*)conns+i*max_synapses_per_cell; - float* strengths = perms+i*max_synapses_per_cell; - uint32_t* end = synapses+max_synapses_per_cell; - - uint32_t* it = std::lower_bound(synapses, end, uint32_t(-1)); - size_t used_space = it - synapses; - - for(size_t j=0;j sort_indices(used_space); - std::iota(sort_indices.begin(), sort_indices.begin()+used_space, 0); - std::sort(sort_indices.begin(), sort_indices.begin()+used_space, - [&](size_t i, size_t j)->bool { - return ((uint32_t*)synapses)[i] < ((uint32_t*)synapses)[j]; - }); - apply_permutation_in_place(synapses, synapses+used_space, sort_indices); - apply_permutation_in_place(strengths, strengths+used_space, sort_indices); + dispatch>(permeances->dtype(), [&](auto v) { + detail::decaySynapses(connections, permeances, threshold, this); }); } diff --git a/Etaler/Backends/CPUBackend.hpp b/Etaler/Backends/CPUBackend.hpp index 7777270..d5dd946 100644 --- a/Etaler/Backends/CPUBackend.hpp +++ b/Etaler/Backends/CPUBackend.hpp @@ -23,6 +23,8 @@ struct CPUBuffer : public BufferImpl storage_ = new int32_t[shape.volume()]; else if(dtype == DType::Float) storage_ = new float[shape.volume()]; + else if(dtype == DType::Half) + storage_ = new half[shape.volume()]; else std::cerr << "Critical Warning: CPUBuffer Initialize failed. Unknown DType" << std::endl; } @@ -42,7 +44,7 @@ struct CPUBuffer : public BufferImpl virtual void* data() const override; protected: - std::variant storage_; + std::variant storage_; }; struct CPUBackend : public Backend diff --git a/Etaler/Backends/OpenCLBackend.cpp b/Etaler/Backends/OpenCLBackend.cpp index 56d9787..54e2dbe 100644 --- a/Etaler/Backends/OpenCLBackend.cpp +++ b/Etaler/Backends/OpenCLBackend.cpp @@ -104,6 +104,8 @@ void OpenCLBackend::init(cl::Context context, cl::Platform platform, cl::Device std::string device_name = device_.getInfo(); et_assert(isExtentionSupported("cl_khr_local_int32_base_atomics"), "cl_khr_local_int32_base_atomics is not supported by " + device_name); et_assert(isExtentionSupported("cl_khr_local_int32_extended_atomics"), "cl_khr_local_int32_extended_atomics is not supported by " + device_name); + + have_fp16_ = isExtentionSupported("cl_khr_fp16"); } std::shared_ptr OpenCLBackend::createTensor(const Shape& shape, DType dtype, const void* data) @@ -128,6 +130,8 @@ std::shared_ptr OpenCLBackend::createTensor(const Shape& shape, DTyp std::shared_ptr OpenCLBackend::createTensor(const Shape& shape, DType dtype, cl::Buffer buf) { + if(dtype == DType::Half && have_fp16_ == false) + throw EtError("Creating half(fp16) tensor but device have no fp16 capablity."); auto ptr = std::shared_ptr(new OpenCLBuffer(shape, dtype, buf, shared_from_this()), [this](OpenCLBuffer* ptr){releaseTensor(ptr);}); return std::make_shared(ptr, shape, shapeToStride(shape)); } @@ -161,6 +165,7 @@ std::string OpenCLBackend::deviceInfo() const res += "Local memory size: " + std::to_string(localMemorySize()/1024) + " KB\n"; res += "Local memory type: " + local_type[localMemoryType()] + "\n"; res += "Prefered work group size: " + std::to_string(kernel_manager_.kernel("__etaler_dummy__").getWorkGroupInfo(device_)) + "\n"; + res += "Half percision: " + std::string(isExtentionSupported("cl_khr_fp16") ? "Yes" : "No"); return res; } @@ -219,17 +224,17 @@ void KernelManager::compileKernel(const std::vector& srcs, const st } void KernelManager::compileFromFile(const std::string& path, const std::string& program_name, const std::vector& kernel_names - , bool force_override, const std::string& flags) + , bool force_override, const std::string& flags, const std::string& prepend) { - compileFromFile(std::vector{path}, program_name, kernel_names, force_override, flags); + compileFromFile(std::vector{path}, program_name, kernel_names, force_override, flags, prepend); } void KernelManager::compileFromFile(const std::vector& paths, const std::string& program_name, const std::vector& kernel_names - , bool force_override, const std::string& flags) + , bool force_override, const std::string& flags, const std::string& prepend) { std::vector sources; for(const auto& path : paths) - sources.emplace_back(readKernel(path)); + sources.emplace_back(prepend + (prepend!=""?"\n":"") + readKernel(path)); compileKernel(sources, program_name, kernel_names, force_override, flags); } @@ -259,16 +264,9 @@ void KernelManager::addSearchPath(const std::string& path) std::shared_ptr OpenCLBackend::cellActivity(const TensorImpl* x, const TensorImpl* connections, const TensorImpl* permeances, float connected_permeance, size_t active_threshold, bool has_unconnected_synapse) { - et_assert(x->backend() == this); - et_assert(connections->backend() == this); - et_assert(permeances->backend() == this); - et_assert(x->iscontiguous()); - et_assert(connections->iscontiguous()); - et_assert(permeances->iscontiguous()); - - et_assert(x->dtype() == DType::Bool); - et_assert(connections->dtype() == DType::Int32); - et_assert(permeances->dtype() == DType::Float); + requireProperties(x, this, DType::Bool, IsContingous()); + requireProperties(connections, this, DType::Int32, IsContingous()); + requireProperties(permeances, this, IsDType{DType::Float, DType::Half}, IsContingous()); et_assert(connections->shape() == permeances->shape()); et_assert(connections->dimentions() >= 2); @@ -276,16 +274,18 @@ std::shared_ptr OpenCLBackend::cellActivity(const TensorImpl* x, con s.pop_back(); auto y = createTensor(s, DType::Int32); - auto args = "-DINPUT_SIZE="+str(x->size())+" -DMAX_SYNAPSE_PER_CELL="+str(connections->shape().back())+" -DNO_UNUSED_SYNAPSE=" + str(!has_unconnected_synapse); - auto hash = hash_string(args); + auto args = "-DINPUT_SIZE="+str(x->size())+" -DMAX_SYNAPSE_PER_CELL="+str(connections->shape().back())+" -DNO_UNUSED_SYNAPSE=" + str(!has_unconnected_synapse) + + " -DPERM_TYPE="+to_ctype_string(permeances->dtype()); + auto prepend = (permeances->dtype()==DType::Half?"#pragma OPENCL EXTENSION cl_khr_fp16 : enable":""); + auto hash = hash_string(args + prepend); auto program_name = "overlapScore"+hash; if(x->size() < localMemorySize() && localMemoryType() == CL_LOCAL) - kernel_manager_.compileFromFile("cellActivity.cl", program_name, {"cellActivity"}, false, args); + kernel_manager_.compileFromFile("cellActivity.cl", program_name, {"cellActivity"}, false, args, prepend); else if(x->size() < localMemorySize()*8-8 && localMemoryType() == CL_LOCAL) - kernel_manager_.compileFromFile("cellActivity_compressed_local.cl", program_name, {"cellActivity"}, false, args); + kernel_manager_.compileFromFile("cellActivity_compressed_local.cl", program_name, {"cellActivity"}, false, args, prepend); else - kernel_manager_.compileFromFile("cellActivity_global.cl", program_name, {"cellActivity"}, false, args); + kernel_manager_.compileFromFile("cellActivity_global.cl", program_name, {"cellActivity"}, false, args, prepend); cl::Kernel k = kernel_manager_.kernel(program_name, "cellActivity"); k.setArg(0, std::static_pointer_cast(x->buffer())->buffer()); @@ -307,10 +307,7 @@ std::shared_ptr OpenCLBackend::cellActivity(const TensorImpl* x, con std::shared_ptr OpenCLBackend::globalInhibition(const TensorImpl* x, float fraction) { - et_assert(x->backend() == this); - et_assert(x->iscontiguous()); - - et_assert(x->dtype() == DType::Int32); + requireProperties(x, this, DType::Int32, IsContingous()); auto y = createTensor(x->shape(), DType::Bool); @@ -342,9 +339,9 @@ std::shared_ptr OpenCLBackend::globalInhibition(const TensorImpl* x, std::shared_ptr OpenCLBackend::cast(const TensorImpl* x, DType toType) { - et_assert(x->backend() == this); - et_assert(x->iscontiguous()); - auto args = "-DInType="+to_ctype_string(x->dtype())+" -DOutType="+to_ctype_string(toType); + requireProperties(x, this, IsContingous()); + auto args = "-DInType="+to_ctype_string(x->dtype())+" -DOutType="+to_ctype_string(toType) + + (x->dtype() == DType::Half || toType == DType::Half ? " -DHalfSupport" : ""); auto hash = hash_string(args); auto program_name = "cast"+hash; kernel_manager_.compileFromFile("cast.cl", program_name, {"cast"}, false, args); @@ -368,8 +365,7 @@ void OpenCLBackend::sync() const std::shared_ptr OpenCLBackend::copy(const TensorImpl* x) { - et_assert(x->backend() == this); - et_assert(x->iscontiguous()); + requireProperties(x, this, IsContingous()); size_t buf_size = x->size()*dtypeToSize(x->dtype()); cl::Buffer buf = allocBuffer(buf_size); const cl::Buffer& src = std::static_pointer_cast(x->buffer())->buffer(); @@ -383,29 +379,26 @@ std::shared_ptr OpenCLBackend::copy(const TensorImpl* x) void OpenCLBackend::learnCorrilation(const TensorImpl* x, const TensorImpl* learn, const TensorImpl* connections, TensorImpl* permeances, float perm_inc, float perm_dec, bool has_unconnected_synapse) { - et_assert(x->backend() == this); - et_assert(connections->backend() == this); - et_assert(permeances->backend() == this); - et_assert(learn->backend() == this); + requireProperties(x, this, DType::Bool, IsContingous()); + requireProperties(learn, this, DType::Bool, IsContingous()); + requireProperties(connections, this, DType::Int32, IsContingous()); + requireProperties(permeances, this, IsDType{DType::Float, DType::Half}, IsContingous()); et_assert(connections->shape() == permeances->shape()); et_assert(x->shape() == learn->shape()); - et_assert(x->dtype() == DType::Bool); - et_assert(learn->dtype() == DType::Bool); - et_assert(connections->dtype() == DType::Int32); - et_assert(permeances->dtype() == DType::Float); auto args = "-DINPUT_SIZE="+str(x->size())+" -DMAX_SYNAPSE_PER_CELL="+str(connections->shape().back())+" -DNO_UNUSED_SYNAPSE="+str(!has_unconnected_synapse) - +" -DOUTPUT_SIZE="+str(learn->size()); - auto hash = hash_string(args); + +" -DOUTPUT_SIZE="+str(learn->size()) + " -DPERM_TYPE="+to_ctype_string(permeances->dtype()); + auto prepend = (permeances->dtype()==DType::Half?"#pragma OPENCL EXTENSION cl_khr_fp16 : enable":""); + auto hash = hash_string(args+prepend); auto program_name = "learnCorrilation"+hash; if(x->size() < localMemorySize() && localMemoryType() == CL_LOCAL) - kernel_manager_.compileFromFile("learnCorrilation.cl", program_name, {"learnCorrilation"}, false, args); + kernel_manager_.compileFromFile("learnCorrilation.cl", program_name, {"learnCorrilation"}, false, args, prepend); else if(x->size() < localMemorySize()*8-8 && localMemoryType() == CL_LOCAL) - kernel_manager_.compileFromFile("learnCorrilation_compressed_local.cl", program_name, {"learnCorrilation"}, false, args); + kernel_manager_.compileFromFile("learnCorrilation_compressed_local.cl", program_name, {"learnCorrilation"}, false, args, prepend); else - kernel_manager_.compileFromFile("learnCorrilation_global.cl", program_name, {"learnCorrilation"}, false, args); + kernel_manager_.compileFromFile("learnCorrilation_global.cl", program_name, {"learnCorrilation"}, false, args, prepend); cl::Kernel k = kernel_manager_.kernel(program_name, "learnCorrilation"); k.setArg(0, std::static_pointer_cast(x->buffer())->buffer()); @@ -425,17 +418,14 @@ void OpenCLBackend::learnCorrilation(const TensorImpl* x, const TensorImpl* lear void OpenCLBackend::sortSynapse(TensorImpl* connections, TensorImpl* permeances) { + requireProperties(connections, this, DType::Int32, IsContingous()); + requireProperties(permeances, this, IsDType{DType::Float, DType::Int32}, IsContingous()); et_assert(connections->shape() == permeances->shape()); - et_assert(connections->backend() == this); - et_assert(permeances->backend() == this); - et_assert(connections->dtype() == DType::Int32); - et_assert(permeances->dtype() == DType::Float); - et_assert(connections->iscontiguous()); - et_assert(permeances->iscontiguous()); - - auto args = "-DMAX_SYNAPSE_PER_CELL="+str(connections->shape().back()); - auto program_name = "sortSynapse"+hash_string(args); - kernel_manager_.compileFromFile("sort.cl", program_name, {"sortSynapse"}, false, args); + + auto args = "-DMAX_SYNAPSE_PER_CELL="+str(connections->shape().back()) + " -DPERM_TYPE="+to_ctype_string(permeances->dtype()); + auto prepend = (permeances->dtype()==DType::Half?"#pragma OPENCL EXTENSION cl_khr_fp16 : enable":""); + auto program_name = "sortSynapse"+hash_string(args+prepend); + kernel_manager_.compileFromFile("sort.cl", program_name, {"sortSynapse"}, false, args, prepend); cl::Kernel k = kernel_manager_.kernel(program_name, "sortSynapse"); int num_cells = connections->size()/connections->shape().back(); @@ -457,12 +447,8 @@ void OpenCLBackend::sortSynapse(TensorImpl* connections, TensorImpl* permeances) std::shared_ptr OpenCLBackend::burst(const TensorImpl* x, const TensorImpl* s) { - et_assert(x->backend() == this); - et_assert(s->backend() == this); - et_assert(x->dtype() == DType::Bool); - et_assert(s->dtype() == DType::Bool); - et_assert(x->iscontiguous()); - et_assert(s->iscontiguous()); + requireProperties(x, this, DType::Bool, IsContingous()); + requireProperties(s, this, DType::Bool, IsContingous()); Shape shape = s->shape(); shape.pop_back(); @@ -489,9 +475,7 @@ std::shared_ptr OpenCLBackend::burst(const TensorImpl* x, const Tens std::shared_ptr OpenCLBackend::reverseBurst(const TensorImpl* x) { - et_assert(x->backend() == this); - et_assert(x->dtype() == DType::Bool); - et_assert(x->iscontiguous()); + requireProperties(x, this, DType::Bool, IsContingous()); size_t cells_per_column = x->shape().back(); size_t num_columns = x->size()/cells_per_column; @@ -519,21 +503,11 @@ std::shared_ptr OpenCLBackend::reverseBurst(const TensorImpl* x) void OpenCLBackend::growSynapses(const TensorImpl* x, const TensorImpl* y, TensorImpl* connections , TensorImpl* permeances, float initial_perm) { - et_assert(x->backend() == this); - et_assert(y->backend() == this); - et_assert(connections->backend() == this); - et_assert(permeances->backend() == this); - et_assert(x->iscontiguous()); - et_assert(y->iscontiguous()); - et_assert(connections->iscontiguous()); - et_assert(permeances->iscontiguous()); - - et_assert(x->dtype() == DType::Bool); - et_assert(y->dtype() == DType::Bool); - et_assert(connections->dtype() == DType::Int32); - et_assert(permeances->dtype() == DType::Float); - - et_assert(x->shape() == y->shape()); + requireProperties(x, this, DType::Bool, IsContingous()); + requireProperties(y, this, DType::Bool, IsContingous()); + requireProperties(connections, this, DType::Int32, IsContingous()); + requireProperties(permeances, this, IsDType{DType::Float, DType::Int32}, IsContingous()); + et_assert(connections->shape() == permeances->shape()); Shape s = connections->shape(); s.pop_back(); @@ -542,9 +516,11 @@ void OpenCLBackend::growSynapses(const TensorImpl* x, const TensorImpl* y, Tenso size_t max_synapses_per_cell = connections->shape().back(); size_t input_cell_count = x->size(); - auto args = "-DNUM_CELLS="+str(y->size())+" -DNUM_INPUT_BITS="+str(x->size())+" -DMAX_SYNAPSE_PER_CELL="+str(max_synapses_per_cell); - auto program_name = "growSynapses"+hash_string(args); - kernel_manager_.compileFromFile("growSynapses.cl", program_name, {"growSynapses"}, false, args); + auto args = "-DNUM_CELLS="+str(y->size())+" -DNUM_INPUT_BITS="+str(x->size())+" -DMAX_SYNAPSE_PER_CELL="+str(max_synapses_per_cell) + +" -DPERM_TYPE="+to_ctype_string(permeances->dtype()); + auto prepend = (permeances->dtype()==DType::Half?"#pragma OPENCL EXTENSION cl_khr_fp16 : enable":""); + auto program_name = "growSynapses"+hash_string(args+prepend); + kernel_manager_.compileFromFile("growSynapses.cl", program_name, {"growSynapses"}, false, args, prepend); cl::Kernel k = kernel_manager_.kernel(program_name, "growSynapses"); size_t local_size = 32; @@ -578,9 +554,7 @@ void OpenCLBackend::growSynapses(const TensorImpl* x, const TensorImpl* y, Tenso std::optional OpenCLBackend::toSparse(const TensorImpl* x) { - et_assert(x->backend() == this); - et_assert(x->dtype() == DType::Bool); - et_assert(x->iscontiguous()); + requireProperties(x, this, DType::Bool, IsContingous()); auto args = "-DINPUT_SIZE="+str(x->size()); auto program_name = "toSparse"+hash_string(args); @@ -701,7 +675,7 @@ kernel void copy(global Type* restrict x, global Type* restrict y) std::shared_ptr OpenCLBackend::realize(const TensorImpl* x) { - et_assert(x->backend() == this); + requireProperties(x, this); if(x->iscontiguous() == true) return copy(x); @@ -730,8 +704,8 @@ std::shared_ptr OpenCLBackend::realize(const TensorImpl* x) void OpenCLBackend::assign(TensorImpl* dest, const TensorImpl* src) { - et_assert(dest->backend() == this); - et_assert(src->backend() == this); + requireProperties(dest, this); + requireProperties(src, this); if(dest->shape() != src->shape()) throw EtError("Shape mismatch in tensor assignment. Shape " @@ -760,9 +734,8 @@ void OpenCLBackend::assign(TensorImpl* dest, const TensorImpl* src) std::shared_ptr OpenCLBackend::sum(const TensorImpl* x, size_t chunk_size, DType dtype) { - et_assert(x->backend() == this); + requireProperties(x, this, IsContingous()); et_assert(x->size() % chunk_size == 0); - et_assert(x->iscontiguous()); DType result_dtype = dtype; if(dtype == DType::Unknown) { @@ -770,6 +743,8 @@ std::shared_ptr OpenCLBackend::sum(const TensorImpl* x, size_t chunk DType dtype = x->dtype(); if(dtype == DType::Bool || dtype == DType::Int32) return DType::Int32; + else if(dtype == DType::Half) + return DType::Half; else return DType::Float; }(); @@ -778,10 +753,13 @@ std::shared_ptr OpenCLBackend::sum(const TensorImpl* x, size_t chunk DType intermid_type = [](DType in, DType out) { if(in == DType::Float) return DType::Float; + if(out == DType::Half) + return DType::Half; return DType::Int32; }(x->dtype(), result_dtype); - std::string args = "-DInType=" + to_ctype_string(x->dtype()) + " -DOutType=" + to_ctype_string(result_dtype) + " -DIntermidType=" + to_ctype_string(intermid_type); + std::string args = "-DInType=" + to_ctype_string(x->dtype()) + " -DOutType=" + to_ctype_string(result_dtype) + " -DIntermidType=" + to_ctype_string(intermid_type) + + (intermid_type==DType::Half? " -DIntermidIsHalf" : ""); std::string program_name = "sum" + hash_string(args); kernel_manager_.compileFromFile("sum.cl", program_name, {"sum"}, false, args); @@ -804,20 +782,17 @@ std::shared_ptr OpenCLBackend::sum(const TensorImpl* x, size_t chunk void OpenCLBackend::decaySynapses(TensorImpl* connections, TensorImpl* permeances, float threshold) { + requireProperties(connections, this, DType::Int32, IsContingous()); + requireProperties(permeances, this, IsDType{DType::Float, DType::Half}, IsContingous()); et_assert(connections->shape() == permeances->shape()); - et_assert(connections->backend() == this); - et_assert(permeances->backend() == this); - et_assert(connections->dtype() == DType::Int32); - et_assert(permeances->dtype() == DType::Float); - et_assert(connections->iscontiguous()); - et_assert(permeances->iscontiguous()); size_t max_synapses_per_cell = connections->shape().back(); size_t input_cell_count = connections->size()/max_synapses_per_cell; - auto args = "-DNUM_CELLS="+str(input_cell_count) + " -DMAX_SYNAPSE_PER_CELL="+str(max_synapses_per_cell); - std::string program_name = "sum" + hash_string(args); - kernel_manager_.compileFromFile("decaySynapses.cl", program_name, {"decaySynapses"}, false, args); + auto args = "-DNUM_CELLS="+str(input_cell_count) + " -DMAX_SYNAPSE_PER_CELL="+str(max_synapses_per_cell) + " -DPERM_TYPE="+to_ctype_string(permeances->dtype()); + auto prepend = (permeances->dtype()==DType::Half?"#pragma OPENCL EXTENSION cl_khr_fp16 : enable":""); + std::string program_name = "sum" + hash_string(args+prepend); + kernel_manager_.compileFromFile("decaySynapses.cl", program_name, {"decaySynapses"}, false, args, prepend); cl::Kernel k = kernel_manager_.kernel(program_name, "decaySynapses"); @@ -848,7 +823,11 @@ kernel void op(global T0* restrict x, global ResType* restrict y) )"; - std::string res = f + "\n" + jitStridedView(x, 0) + "\n" + kernel; + std::string extention_decl; + if(x->dtype() == DType::Half) + extention_decl = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable"; + + std::string res = extention_decl + "\n" + f + "\n" + jitStridedView(x, 0) + "\n" + kernel; replaceAll(res, "$SIZE", std::to_string(x->size())); return res; } @@ -869,14 +848,18 @@ kernel void op(global T0* restrict x1, global T1* restrict x2, global ResType* r } )"; - std::string res = f + "\n" + jitStridedView(x1, 0) + "\n" + jitStridedView(x2, 1) + "\n" + kernel; + std::string extention_decl; + if(x1->dtype() == DType::Half || x2->dtype() == DType::Half) + extention_decl = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable"; + + std::string res = extention_decl + "\n" + f + "\n" + jitStridedView(x1, 0) + "\n" + jitStridedView(x2, 1) + "\n" + kernel; replaceAll(res, "$SIZE", std::to_string(x1->size())); return res; } std::shared_ptr OpenCLBackend::applyUnaryOp(const TensorImpl* x, std::string f, DType resType) { - et_assert(x->backend() == this); + requireProperties(x, this); std::string args = "-DT0="+to_ctype_string(x->dtype())+" -DResType="+to_ctype_string(resType); std::string program_name = f+hash_string(args)+std::to_string(x->offset())+to_string(x->shape())+to_string(x->stride()); @@ -902,8 +885,8 @@ std::shared_ptr OpenCLBackend::applyUnaryOp(const TensorImpl* x, std std::shared_ptr OpenCLBackend::applyBinaryOp(const TensorImpl* x1, const TensorImpl* x2, std::string f, DType resType) { - et_assert(x1->backend() == this); - et_assert(x2->backend() == this); + requireProperties(x1, this); + requireProperties(x2, this); et_assert(x1->shape() == x2->shape()); auto to_str = [](auto x){ @@ -935,22 +918,26 @@ std::shared_ptr OpenCLBackend::applyBinaryOp(const TensorImpl* x1, c std::shared_ptr OpenCLBackend::exp(const TensorImpl* x) { - return applyUnaryOp(x, "#define f(x) (exp((float)x))", DType::Float); + DType result_type = x->dtype() == DType::Half ? DType::Half : DType::Float; + return applyUnaryOp(x, "#define f(x) (exp((float)x))", result_type); } std::shared_ptr OpenCLBackend::negate(const TensorImpl* x) { - return applyUnaryOp(x, "#define f(x) (-x)", x->dtype()==DType::Float? DType::Float : DType::Int32); + DType result_type = x->dtype() == DType::Bool ? DType::Int32 : x->dtype(); + return applyUnaryOp(x, "#define f(x) (-x)", result_type); } std::shared_ptr OpenCLBackend::inverse(const TensorImpl* x) { - return applyUnaryOp(x, "#define f(x) (1.0f/(float)x)", DType::Float); + DType result_type = x->dtype() == DType::Half ? DType::Half : DType::Float; + return applyUnaryOp(x, "#define f(x) (1.0f/(float)x)", result_type); } std::shared_ptr OpenCLBackend::log(const TensorImpl* x) { - return applyUnaryOp(x, "#define f(x) (log((float)x))", DType::Float); + DType result_type = x->dtype() == DType::Half ? DType::Half : DType::Float; + return applyUnaryOp(x, "#define f(x) (log((float)x))", result_type); } std::shared_ptr OpenCLBackend::logical_not(const TensorImpl* x) @@ -962,6 +949,8 @@ static DType solveBinaryOpDType(DType t1, DType t2) { if(t1 == DType::Float || t2 == DType::Float) return DType::Float; + else if(t1 == DType::Half || t2 == DType::Half) + return DType::Half; return DType::Int32; } diff --git a/Etaler/Backends/OpenCLBackend.hpp b/Etaler/Backends/OpenCLBackend.hpp index ca64106..85f7c35 100644 --- a/Etaler/Backends/OpenCLBackend.hpp +++ b/Etaler/Backends/OpenCLBackend.hpp @@ -52,9 +52,9 @@ struct KernelManager void compileKernel(const std::vector& srcs, const std::string& program_name, const std::vector& kernel_names , bool force_override=false, const std::string& flags=""); void compileFromFile(const std::string& paths, const std::string& program_name, const std::vector& kernel_names - , bool force_override=false, const std::string& flags=""); + , bool force_override=false, const std::string& flags="", const std::string& prepend=""); void compileFromFile(const std::vector& paths, const std::string& program_name, const std::vector& kernel_names - , bool force_override=false, const std::string& flags=""); + , bool force_override=false, const std::string& flags="", const std::string& prepend=""); inline bool exists(const std::string& program_name, const std::string& kernel_name) { auto it = apps_.find(program_name); @@ -138,7 +138,7 @@ struct ETALER_EXPORT OpenCLBackend : public Backend inline cl::Context context() {return context_;} - inline bool isExtentionSupported(std::string ext) + inline bool isExtentionSupported(std::string ext) const { return (std::find(supported_extentions_.begin(), supported_extentions_.end(), ext) != supported_extentions_.end()); @@ -187,6 +187,7 @@ struct ETALER_EXPORT OpenCLBackend : public Backend cl_uint num_compute_units_; std::vector supported_extentions_; + bool have_fp16_ = false; }; } diff --git a/Etaler/CMakeLists.txt b/Etaler/CMakeLists.txt index 24d01a4..6d6f2ea 100644 --- a/Etaler/CMakeLists.txt +++ b/Etaler/CMakeLists.txt @@ -15,6 +15,7 @@ generate_export_header(Etaler EXPORT_FILE_NAME Etaler_export.h) target_include_directories(Etaler PRIVATE 3rdparty/pcg-cpp/include) target_include_directories(Etaler PRIVATE 3rdparty/cereal/include) +target_include_directories(Etaler PRIVATE 3rdparty/half_precision) target_include_directories(Etaler PRIVATE 3rdparty) if(${TBB_FOUND}) diff --git a/Etaler/Core/DType.hpp b/Etaler/Core/DType.hpp index a201924..1d05445 100644 --- a/Etaler/Core/DType.hpp +++ b/Etaler/Core/DType.hpp @@ -4,6 +4,8 @@ #include #include +#include "Half.hpp" + namespace et { @@ -13,6 +15,7 @@ enum class DType Bool = 0, Int32, Float, + Half, }; template @@ -24,11 +27,13 @@ constexpr inline DType typeToDType() return DType::Int32; else if constexpr(std::is_same::value || std::is_same::value) return DType::Bool; + else if constexpr(std::is_same::value) + return DType::Half; else return DType::Unknown; } -inline size_t dtypeToSize(DType dtype) +inline constexpr size_t dtypeToSize(DType dtype) { if(dtype == DType::Bool) return sizeof(bool); @@ -36,6 +41,8 @@ inline size_t dtypeToSize(DType dtype) return sizeof(int32_t); else if(dtype == DType::Float) return sizeof(float); + else if(dtype == DType::Half) + return sizeof(float16); return std::numeric_limits::max(); } @@ -47,6 +54,8 @@ inline std::string to_ctype_string(DType dtype) return "int"; else if(dtype == DType::Float) return "float"; + else if(dtype == DType::Half) + return "half"; return "Unknown"; } diff --git a/Etaler/Core/DefaultBackend.hpp b/Etaler/Core/DefaultBackend.hpp index 73ac070..0a21ad5 100644 --- a/Etaler/Core/DefaultBackend.hpp +++ b/Etaler/Core/DefaultBackend.hpp @@ -8,7 +8,7 @@ namespace et { extern ETALER_EXPORT Backend* g_default_backend; -extern std::shared_ptr g_default_backend_hold; +extern ETALER_EXPORT std::shared_ptr g_default_backend_hold; inline void setDefaultBackend(Backend* backend) {g_default_backend = backend;} inline void setDefaultBackend(std::shared_ptr backend) {g_default_backend_hold = backend; g_default_backend = backend.get();} diff --git a/Etaler/Core/Half.hpp b/Etaler/Core/Half.hpp new file mode 100644 index 0000000..8404891 --- /dev/null +++ b/Etaler/Core/Half.hpp @@ -0,0 +1,13 @@ +#pragma once + +#include +#include + +namespace et +{ + +using half = half_precision::half; + +using float16 = half; + +} \ No newline at end of file diff --git a/Etaler/Core/Serialize.cpp b/Etaler/Core/Serialize.cpp index b0b04e3..b920128 100644 --- a/Etaler/Core/Serialize.cpp +++ b/Etaler/Core/Serialize.cpp @@ -30,6 +30,15 @@ void load(Archive & archive , Shape & s) s = Shape(vec.begin(), vec.end()); } +template +void serialize(Archive & archive, + half & m) +{ + archive(m.storage_); +} + + + template void save(Archive & archive, Tensor const & t) { @@ -40,6 +49,8 @@ void save(Archive & archive, Tensor const & t) return "float"; if(t.dtype() == DType::Int32) return "int32"; + if(t.dtype() == DType::Half) + return "half"; throw EtError("Cannot handle such dtype()"); }(); @@ -59,6 +70,10 @@ void save(Archive & archive, Tensor const & t) std::vector arr = t.toHost(); archive(make_nvp("data", arr)); } + else if(t.dtype() == DType::Half) { + std::vector arr = t.toHost(); + archive(make_nvp("data", arr)); + } } template @@ -85,6 +100,11 @@ void load(Archive & archive, Tensor & t) archive(make_nvp("data", d)); t = Tensor(s, d.data()); } + else if(dtype == "half") { + std::vector d(s.volume()); + archive(make_nvp("data", d)); + t = Tensor(s, d.data()); + } } template @@ -137,6 +157,8 @@ void save(Archive & archive ,StateDict const & item) types.push_back("std::vector"); else if(v.type() == typeid(std::vector)) types.push_back("std::vector"); + else if(v.type() == typeid(std::vector)) + types.push_back("std::vector"); else throw EtError("Cannot save (mangled name:) type " + std::string(v.type().name()) + ", key " + k); } @@ -163,6 +185,8 @@ void save(Archive & archive ,StateDict const & item) archive(std::any_cast>(v)); else if(v.type() == typeid(std::vector)) archive(std::any_cast>(v)); + else if(v.type() == typeid(std::vector)) + archive(std::any_cast>(v)); else throw EtError("Cannot save type " + std::string(typeid(decltype(v)).name()) + ", key " + k); @@ -211,6 +235,8 @@ void load(Archive & archive ,StateDict & item) read_archive>(archive, item, key); else if(type == "std::vector") read_archive>(archive, item, key); + else if(type == "std::vector") + read_archive>(archive, item, key); else throw EtError("Cannot serealize type " + type); diff --git a/Etaler/Core/Tensor.cpp b/Etaler/Core/Tensor.cpp index 1346707..eae8a52 100644 --- a/Etaler/Core/Tensor.cpp +++ b/Etaler/Core/Tensor.cpp @@ -114,6 +114,8 @@ static void printNDArray(std::ostream& os, const void* ptr, Shape shape, DType d prettyPrintTensor(os, (int32_t*)ptr, shape, 0, shape.size(), 0, truncate); else if(dtype == DType::Bool) prettyPrintTensor(os, (bool*)ptr, shape, 0, shape.size(), 0, truncate); + else if(dtype == DType::Half) + prettyPrintTensor(os, (half*)ptr, shape, 0, shape.size(), 0, truncate); else throw EtError("Printing tensor of this type is not supported."); } @@ -220,6 +222,8 @@ Tensor et::zeros(const Shape& shape, DType dtype, Backend* backend) return constant(shape, 0, backend); else if(dtype == DType::Float) return constant(shape, 0, backend); + else if(dtype == DType::Half) + return constant(shape, half(0), backend); else throw EtError("Cannot creatr a tensor of zeros of type " + to_ctype_string(dtype)); } @@ -232,6 +236,8 @@ Tensor et::ones(const Shape& shape, DType dtype, Backend* backend) return constant(shape, 1, backend); else if(dtype == DType::Float) return constant(shape, 1, backend); + else if(dtype == DType::Half) + return constant(shape, half(1), backend); else throw EtError("Cannot creatr a tensor of ones of type " + to_ctype_string(dtype)); } diff --git a/Etaler/Core/TensorImpl.hpp b/Etaler/Core/TensorImpl.hpp index 5b34853..62ff353 100644 --- a/Etaler/Core/TensorImpl.hpp +++ b/Etaler/Core/TensorImpl.hpp @@ -1,8 +1,10 @@ #pragma once +#include #include "Shape.hpp" #include "DType.hpp" #include "Backend.hpp" +#include "TypeHelpers.hpp" #include @@ -52,4 +54,69 @@ struct TensorImpl : public std::enable_shared_from_this size_t offset_; }; +struct IsContingous {}; + +template +struct IsDType +{ + Storage types; +}; + +template + IsDType(_Tp, _Up...) + -> IsDType && ...), _Tp>, + 1 + sizeof...(_Up)>>; + + +template +bool checkProperty(const TensorImpl* x, const T& value) +{ + if constexpr(std::is_base_of_v>>) + return x->backend() == value; + else if constexpr(std::is_same_v) + return x->dtype() == value; + else if constexpr(std::is_same_v) + return x->iscontiguous(); + else if constexpr(is_specialization>, IsDType>::value) + return (std::find(value.types.begin(), value.types.end(), x->dtype()) != value.types.end()); + else + et_assert(false, "a non-supported value is passed into checkProperty"); + return false; +} + +template +void requireProperty(const TensorImpl* x, const T value, const std::string& line, const std::string& v_name) +{ + if(checkProperty(x, value) == true) + return; + + //Otherwise assertion failed + const std::string msg = line + " Tensor property requirment not match. Expecting " + v_name; + if constexpr(std::is_base_of_v>>) + throw EtError(msg + ".backend() == " + value->name()); + else if constexpr(std::is_same_v) + throw EtError(msg + ".dtype() == " + to_ctype_string(value)); + else if constexpr(std::is_same_v) + throw EtError(msg + ".iscontiguous() == true"); + else if constexpr(is_specialization>, IsDType>::value) { + throw EtError(msg + ".dtype() is in {" + std::accumulate(value.types.begin(), value.types.end(), std::string() + , [](auto v, auto a){return v + to_ctype_string(a) + ", ";})); + } } + +template +bool checkProperties(const TensorImpl* x, Args... args) +{ + return (checkProperty(x, args) && ...); +} + +template +void requirePropertiesInternal(const TensorImpl* x, const std::string& line, const std::string& v_name, Args... args) +{ + (requireProperty(x, args, line, v_name), ...); +} + +} + +#define requireProperties(x, ...) (requirePropertiesInternal(x, std::string(__FILE__)+":"+std::to_string(__LINE__)\ + +":"+std::string(__func__)+"():", #x, __VA_ARGS__)) diff --git a/Etaler/Core/TypeHelpers.hpp b/Etaler/Core/TypeHelpers.hpp index f74c7f2..1103448 100644 --- a/Etaler/Core/TypeHelpers.hpp +++ b/Etaler/Core/TypeHelpers.hpp @@ -15,9 +15,20 @@ return r;\ }() + + namespace et { std::string ETALER_EXPORT demangle(const char* name); +template class Ref> +struct is_specialization : std::false_type {}; + +template class Ref, typename... Args> +struct is_specialization, Ref>: std::true_type {}; + +template class Ref, typename... Args> +const bool is_specialization_v = is_specialization, Ref>::value; + } \ No newline at end of file diff --git a/Etaler/Core/TypeList.hpp b/Etaler/Core/TypeList.hpp new file mode 100644 index 0000000..57d9f75 --- /dev/null +++ b/Etaler/Core/TypeList.hpp @@ -0,0 +1,30 @@ +#include + +namespace et +{ +//Using STL coding style in this file +struct null_t {}; + +template +struct type_list_node +{ + using head = T; + using tail = U; +}; + +template struct type_list; + +// Case: Normal recursion. Consume one type per call. +template +struct type_list { + using type = type_list_node::type>; +}; + +// Case: Recursion abort, because the list of types ran empty +template <> +struct type_list<> { using type = null_t; }; + +template +using type_list_t = typename type_list::type; + +} \ No newline at end of file diff --git a/examples/spbench.cpp b/examples/spbench.cpp index 61331c3..30b2a76 100644 --- a/examples/spbench.cpp +++ b/examples/spbench.cpp @@ -12,6 +12,7 @@ using namespace et; float benchmarkSpatialPooler(const Shape& out_shape, const std::vector& x, size_t num_epoch) { SpatialPooler sp(x[0].shape(), out_shape); + sp.permanences_ = sp.permanences_.cast(DType::Half); //To make the OpenCL backen ptr-compile the kernels Tensor t = zeros(x[0].shape(), DType::Bool); diff --git a/kernels/cast.cl b/kernels/cast.cl index 912149d..6105f35 100644 --- a/kernels/cast.cl +++ b/kernels/cast.cl @@ -6,6 +6,10 @@ #error OutType not defined #endif +#ifdef HalfSupport + #pragma OPENCL EXTENSION cl_khr_fp16 : enable +#endif + //InType: Input Type //OutType: OutputType //global_size: arbitrary diff --git a/kernels/cellActivity.cl b/kernels/cellActivity.cl index 3a865eb..e77ebe6 100644 --- a/kernels/cellActivity.cl +++ b/kernels/cellActivity.cl @@ -6,6 +6,10 @@ #error "MAX_SYNAPSE_PER_CELL not defined" #endif +#ifndef PERM_TYPE + #error "PERM_TYPE not defined" +#endif + #ifndef NO_UNUSED_SYNAPSE #define NO_UNUSED_SYNAPSE false #endif @@ -15,7 +19,7 @@ //INPUT_SIZE: The size of input SDR, must be smaller then CL_DEVICE_LOCAL_MEMORY_SIZE //NO_UNUSED_SYNAPSE: If there are unised synapses. Useful for sparial pooler, accelerates ~30% kernel void cellActivity(global bool* restrict x, global int* restrict synapses - , global float* restrict permeances, global int* restrict y + , global PERM_TYPE* restrict permeances, global int* restrict y , float connected_perm, int active_threshold, int output_size) { //Load input state into local memory for faster access diff --git a/kernels/cellActivity_compressed_local.cl b/kernels/cellActivity_compressed_local.cl index 7042a24..f869647 100644 --- a/kernels/cellActivity_compressed_local.cl +++ b/kernels/cellActivity_compressed_local.cl @@ -6,6 +6,10 @@ #error "MAX_SYNAPSE_PER_CELL not defined" #endif +#ifndef PERM_TYPE + #error "PERM_TYPE not defined" +#endif + #ifndef NO_UNUSED_SYNAPSE #define NO_UNUSED_SYNAPSE false #endif @@ -20,7 +24,7 @@ int up_round(int v, int mul) //INPUT_SIZE: The size of input SDR, must be smaller then CL_DEVICE_LOCAL_MEMORY_SIZE*8-8 //NO_UNUSED_SYNAPSE: If there are unised synapses. Useful for sparial pooler, accelerates ~30% kernel void cellActivity(global bool* restrict x, global int* restrict synapses - , global float* restrict permeances, global int* restrict y + , global PERM_TYPE* restrict permeances, global int* restrict y , float connected_perm, int active_threshold, int output_size) { //Load input state into local memory for faster access diff --git a/kernels/cellActivity_global.cl b/kernels/cellActivity_global.cl index d45cb4d..d7b11e6 100644 --- a/kernels/cellActivity_global.cl +++ b/kernels/cellActivity_global.cl @@ -6,6 +6,10 @@ #error "MAX_SYNAPSE_PER_CELL not defined" #endif +#ifndef PERM_TYPE + #error "PERM_TYPE not defined" +#endif + #ifndef NO_UNUSED_SYNAPSE #define NO_UNUSED_SYNAPSE false #endif @@ -15,7 +19,7 @@ //INPUT_SIZE: The size of input SDR, must be smaller then CL_DEVICE_LOCAL_MEMORY_SIZE //NO_UNUSED_SYNAPSE: If there are unised synapses. Useful for sparial pooler, accelerates ~30% kernel void cellActivity(global bool* restrict x, global int* restrict synapses - , global float* restrict permeances, global int* restrict y + , global PERM_TYPE* restrict permeances, global int* restrict y , float connected_perm, int active_threshold, int output_size) { int global_size = get_global_size(0); diff --git a/kernels/decaySynapses.cl b/kernels/decaySynapses.cl index 6304c92..be9240e 100644 --- a/kernels/decaySynapses.cl +++ b/kernels/decaySynapses.cl @@ -6,8 +6,12 @@ #error "MAX_SYNAPSE_PER_CELL is not defined" #endif +#ifndef PERM_TYPE + #error "PERM_TYPE not defined" +#endif + -kernel void decaySynapses(global int* restrict connections, global float* restrict permeances, float threshold) +kernel void decaySynapses(global int* restrict connections, global PERM_TYPE* restrict permeances, float threshold) { int global_size = get_global_size(0); int global_id = get_global_id(0); diff --git a/kernels/growSynapses.cl b/kernels/growSynapses.cl index 113db48..27ef15d 100644 --- a/kernels/growSynapses.cl +++ b/kernels/growSynapses.cl @@ -10,6 +10,10 @@ #error "MAX_SYNAPSE_PER_CELL is not defined" #endif +#ifndef PERM_TYPE + #error "PERM_TYPE not defined" +#endif + //NOTE: The old version of this kernel might perform better with more cells //global_size: Arbitrary @@ -20,7 +24,7 @@ //x: The input SDR **IN SPARSE FORMAT** //aux: temporary buffer for storage, must be size of NUM_INPUT_BITS*global_size[0] kernel void growSynapses(global int* restrict x, global bool* restrict y, global int* restrict connections - , global float* restrict permeances, float initial_perm, int num_input_on_bits, global bool* restrict aux) + , global PERM_TYPE* restrict permeances, float initial_perm, int num_input_on_bits, global bool* restrict aux) { int global_size = get_global_size(0); int global_id = get_global_id(0); diff --git a/kernels/learnCorrilation.cl b/kernels/learnCorrilation.cl index e7b3b13..4c01251 100644 --- a/kernels/learnCorrilation.cl +++ b/kernels/learnCorrilation.cl @@ -10,6 +10,10 @@ #error "OUTPUT_SIZE not defined" #endif +#ifndef PERM_TYPE + #error "PERM_TYPE not defined" +#endif + #ifndef NO_UNUSED_SYNAPSE #define NO_UNUSED_SYNAPSE false #endif @@ -19,7 +23,7 @@ //INPUT_SIZE: The size of input SDR, must be smaller then CL_DEVICE_LOCAL_MEMORY_SIZE //NO_UNUSED_SYNAPSE: If there are unised synapses. Useful for sparial pooler, accelerates ~30% kernel void learnCorrilation(global bool* restrict x, global bool* restrict y - , global int* restrict synapses, global float* restrict permeances + , global int* restrict synapses, global PERM_TYPE* restrict permeances , float permeance_inc, float permeance_dec) { local char xl[INPUT_SIZE]; diff --git a/kernels/learnCorrilation_compressed_local.cl b/kernels/learnCorrilation_compressed_local.cl index b1b5943..a8cfb5e 100644 --- a/kernels/learnCorrilation_compressed_local.cl +++ b/kernels/learnCorrilation_compressed_local.cl @@ -10,6 +10,10 @@ #error "OUTPUT_SIZE not defined" #endif +#ifndef PERM_TYPE + #error "PERM_TYPE not defined" +#endif + #ifndef NO_UNUSED_SYNAPSE #define NO_UNUSED_SYNAPSE false #endif @@ -19,7 +23,7 @@ //INPUT_SIZE: The size of input SDR, must be smaller then CL_DEVICE_LOCAL_MEMORY_SIZE*8-8 //NO_UNUSED_SYNAPSE: If there are unised synapses. Useful for sparial pooler, accelerates ~30% kernel void learnCorrilation(global bool* restrict x, global bool* restrict y - , global int* restrict synapses, global float* restrict permeances + , global int* restrict synapses, global PERM_TYPE* restrict permeances , float permeance_inc, float permeance_dec) { local char xl[INPUT_SIZE/8+1]; diff --git a/kernels/learnCorrilation_global.cl b/kernels/learnCorrilation_global.cl index c08d300..dd2086c 100644 --- a/kernels/learnCorrilation_global.cl +++ b/kernels/learnCorrilation_global.cl @@ -10,6 +10,10 @@ #error "OUTPUT_SIZE not defined" #endif +#ifndef PERM_TYPE + #error "PERM_TYPE not defined" +#endif + #ifndef NO_UNUSED_SYNAPSE #define NO_UNUSED_SYNAPSE false #endif @@ -19,7 +23,7 @@ //INPUT_SIZE: The size of input SDR, must be smaller then CL_DEVICE_LOCAL_MEMORY_SIZE //NO_UNUSED_SYNAPSE: If there are unised synapses. Useful for sparial pooler, accelerates ~30% kernel void learnCorrilation(global bool* restrict x, global bool* restrict y - , global int* restrict synapses, global float* restrict permeances + , global int* restrict synapses, global PERM_TYPE* restrict permeances , float permeance_inc, float permeance_dec) { int global_size = get_global_size(0); diff --git a/kernels/sort.cl b/kernels/sort.cl index 03638de..7b23a7d 100644 --- a/kernels/sort.cl +++ b/kernels/sort.cl @@ -3,16 +3,20 @@ #error "MAX_SYNAPSE_PER_CELL not defined" #endif -void merge(global unsigned int* restrict a1, global float* restrict a2, int l, int m, int r - , global unsigned int* restrict aux_buffer1, global float* restrict aux_buffer2) +#ifndef PERM_TYPE + #error "PERM_TYPE not defined" +#endif + +void merge(global unsigned int* restrict a1, global PERM_TYPE* restrict a2, int l, int m, int r + , global unsigned int* restrict aux_buffer1, global PERM_TYPE* restrict aux_buffer2) { int n1 = m - l + 1; int n2 = r - m; global unsigned int* restrict L = aux_buffer1; global unsigned int* restrict R = aux_buffer1+n1; - global float* restrict L2 = aux_buffer2; - global float* restrict R2 = aux_buffer2+n1; + global PERM_TYPE* restrict L2 = aux_buffer2; + global PERM_TYPE* restrict R2 = aux_buffer2+n1; for (int i=0;i data(16); for(size_t i=0;i() == DType::Int32); + STATIC_REQUIRE(typeToDType() == DType::Float); + STATIC_REQUIRE(typeToDType() == DType::Bool); + STATIC_REQUIRE(typeToDType() == DType::Half); + } + + SECTION("type of Tensor operatoins") { + bool support_fp16 = [&](){ + try {ones({1}, DType::Half);} + catch(const EtError&) {return false;} + return true; + }(); + + SECTION("exp") { + CHECK(exp(ones({1}, DType::Bool)).dtype() == DType::Float); + CHECK(exp(ones({1}, DType::Int32)).dtype() == DType::Float); + CHECK(exp(ones({1}, DType::Float)).dtype() == DType::Float); + if(support_fp16) + CHECK(exp(ones({1}, DType::Half)).dtype() == DType::Half); + } + + SECTION("negation") { + CHECK((-ones({1}, DType::Bool)).dtype() == DType::Int32); + CHECK((-ones({1}, DType::Int32)).dtype() == DType::Int32); + CHECK((-ones({1}, DType::Float)).dtype() == DType::Float); + if(support_fp16) + CHECK((-ones({1}, DType::Half)).dtype() == DType::Half); + } + + SECTION("inverse") { + CHECK(inverse(ones({1}, DType::Bool)).dtype() == DType::Float); + CHECK(inverse(ones({1}, DType::Int32)).dtype() == DType::Float); + CHECK(inverse(ones({1}, DType::Float)).dtype() == DType::Float); + if(support_fp16) + CHECK(inverse(ones({1}, DType::Half)).dtype() == DType::Half); + } + + SECTION("log") { + CHECK(log(ones({1}, DType::Bool)).dtype() == DType::Float); + CHECK(log(ones({1}, DType::Int32)).dtype() == DType::Float); + CHECK(log(ones({1}, DType::Float)).dtype() == DType::Float); + if(support_fp16) + CHECK(log(ones({1}, DType::Half)).dtype() == DType::Half); + } + + SECTION("logical_not") { + CHECK(logical_not(ones({1}, DType::Bool)).dtype() == DType::Bool); + CHECK(logical_not(ones({1}, DType::Int32)).dtype() == DType::Bool); + CHECK(logical_not(ones({1}, DType::Float)).dtype() == DType::Bool); + if(support_fp16) + CHECK(logical_not(ones({1}, DType::Half)).dtype() == DType::Bool); + } + + SECTION("sum") { + CHECK(ones({1}, DType::Bool).sum().dtype() == DType::Int32); + CHECK(ones({1}, DType::Int32).sum().dtype() == DType::Int32); + CHECK(ones({1}, DType::Float).sum().dtype() == DType::Float); + if(support_fp16) + CHECK(ones({1}, DType::Half).sum().dtype() == DType::Half); + } + + auto solve_binary_op_type = [](DType self, DType other)->DType { + // Implement a C++ like type promotion rule + if(other == DType::Float || self == DType::Float) + return DType::Float; + else if(other == DType::Half || self == DType::Half) + return DType::Half; + return DType::Int32; //Even bool is promoted to int in operation + }; + + std::vector types = {DType::Int32, DType::Bool, DType::Float, DType::Half}; + SECTION("add") { + for(auto t1 : types) { + for(auto t2 : types) { + if((t1 == DType::Half || t2 == DType::Half) && support_fp16 == false) + continue; + Tensor t = ones({1}, t1); + Tensor q = ones({1}, t2); + CHECK((t+q).dtype() == solve_binary_op_type(t1, t2)); + } + } + } + + SECTION("subtract") { + for(auto t1 : types) { + for(auto t2 : types) { + if((t1 == DType::Half || t2 == DType::Half) && support_fp16 == false) + continue; + Tensor t = ones({1}, t1); + Tensor q = ones({1}, t2); + CHECK((t-q).dtype() == solve_binary_op_type(t1, t2)); + } + } + } + + SECTION("mul") { + for(auto t1 : types) { + for(auto t2 : types) { + if((t1 == DType::Half || t2 == DType::Half) && support_fp16 == false) + continue; + Tensor t = ones({1}, t1); + Tensor q = ones({1}, t2); + CHECK((t*q).dtype() == solve_binary_op_type(t1, t2)); + } + } + } + + SECTION("div") { + for(auto t1 : types) { + for(auto t2 : types) { + if((t1 == DType::Half || t2 == DType::Half) && support_fp16 == false) + continue; + Tensor t = ones({1}, t1); + Tensor q = ones({1}, t2); + CHECK((t/q).dtype() == solve_binary_op_type(t1, t2)); + } + } + } + } +} + // TEST_CASE("Serealize") // { // using namespace et;