Skip to content

Commit

Permalink
Merge pull request #149 from marty1885/apichange
Browse files Browse the repository at this point in the history
misc
  • Loading branch information
marty1885 committed Jun 29, 2020
2 parents 323efc9 + 06834d1 commit db858e5
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 18 deletions.
15 changes: 8 additions & 7 deletions Etaler/Algorithms/Anomaly.hpp
Expand Up @@ -5,15 +5,16 @@
namespace et
{

static float anomaly(const Tensor& pred, const Tensor& real)
inline float anomaly(const Tensor& pred, const Tensor& real)
{
et_assert(real.dtype() == DType::Bool);
et_assert(pred.dtype() == DType::Bool);
et_assert(real.shape() == pred.shape());
checkProperties(real.pimpl(), DType::Bool);
checkProperties(pred.pimpl(), DType::Bool);
et_check(real.shape() == pred.shape()
, "The 1st and 2nd arguments should have to same shape");

Tensor should_predict = sum(real);
Tensor not_predicted = sum(!pred && real).cast(DType::Float);
return (not_predicted/should_predict).toHost<float>()[0];
int should_predict = sum(real).item<int>();
int not_predicted = sum((!pred) && real).item<int>();
return float(not_predicted)/should_predict;
}

}
5 changes: 5 additions & 0 deletions Etaler/Algorithms/TemporalMemory.cpp
Expand Up @@ -44,13 +44,18 @@ void TemporalMemory::loadState(const StateDict& states)
input_shape_ = std::any_cast<Shape>(states.at("input_shape"));
connections_ = std::any_cast<Tensor>(states.at("connections"));
permanences_ = std::any_cast<Tensor>(states.at("permanences"));

// Sort the synapse in case the synapses are not pre-sorted.
// Presorting is a requirment for the GPU but not the CPU
sortSynapse(connections_, permanences_);
}

TemporalMemory TemporalMemory::to(Backend* b) const
{
TemporalMemory tm = *this;
tm.connections_ = connections_.to(b);
tm.permanences_ = permanences_.to(b);
sortSynapse(tm.connections_, tm.permanences_);

return tm;
}
4 changes: 2 additions & 2 deletions Etaler/Core/Error.cpp
Expand Up @@ -19,7 +19,7 @@ ETALER_EXPORT bool et::getEnableTraceOnException()
return g_enable_trace_on_exception;
}

std::string et::genStackTrace(size_t skip)
ETALER_EXPORT std::string et::genStackTrace(size_t skip)
{
#ifndef BACKWARD_SYSTEM_UNKNOWN
std::stringstream ss;
Expand Down Expand Up @@ -48,4 +48,4 @@ ETALER_EXPORT EtError::EtError(const std::string &msg)
{
if(getEnableTraceOnException())
msg_ += "\n"+genStackTrace(1); // Skip the EtError ctor
}
}
8 changes: 4 additions & 4 deletions Etaler/Core/Error.hpp
Expand Up @@ -12,9 +12,9 @@
namespace et
{

std::string genStackTrace(size_t skip = 0);
ETALER_EXPORT std::string genStackTrace(size_t skip = 0);

class EtError : public std::exception
class ETLAER_EXPORT EtError : public std::exception
{
public:
explicit EtError(const std::string &msg);
Expand All @@ -24,8 +24,8 @@ class EtError : public std::exception
std::string msg_;
};

void enableTraceOnException(bool enable);
bool getEnableTraceOnException();
ETALER_EXPORT void enableTraceOnException(bool enable);
ETALER_EXPORT bool getEnableTraceOnException();

}

Expand Down
4 changes: 2 additions & 2 deletions Etaler/Core/Serialize.cpp
Expand Up @@ -5,14 +5,14 @@
#include "Etaler/Core/Tensor.hpp"
#include "TypeHelpers.hpp"

using namespace et;

#include <cereal/cereal.hpp>
#include <cereal/types/vector.hpp>
#include <cereal/types/string.hpp>
#include <cereal/archives/json.hpp>
#include <cereal/archives/portable_binary.hpp>

using namespace et;

namespace cereal
{

Expand Down
7 changes: 4 additions & 3 deletions Etaler/Core/Tensor.cpp
Expand Up @@ -198,6 +198,7 @@ Tensor Tensor::view(const IndexList& rgs) const
// Compute the new shape and stride. Most of the code here exists to check for out-of-bounds access
offset.reserve(dimentions());
result_shape.reserve(dimentions());
result_stride.reserve(dimentions());
for(size_t i=0;i<dimentions();i++) { std::visit([&](auto index_range) { // <- make the code neater
const auto& r = index_range;
const intmax_t dim_size = shape()[i];
Expand Down Expand Up @@ -329,16 +330,16 @@ Tensor et::cat(const svector<Tensor>& tensors, intmax_t dim)
}

if(base_dtype != t.dtype())
throw EtError("DType mismatch when concatenate.");
throw EtError("Cannot concat tensors of different types.");

if(base_backend != t.backend())
throw EtError("Backend mismatch when concatenate.");
throw EtError("Cannot concat tensors on different backends.");

auto shape = t.shape();
assert((intmax_t)shape.size() > dim);
shape[dim] = base_shape[dim];
if(shape != base_shape)
throw EtError("Tensors must have the same shape along all axises besides the concatenating axis.");
throw EtError("Tensors must have the same shape along all dimensions besides the concatenating dimension.");
}

Shape res_shape = base_shape;
Expand Down
25 changes: 25 additions & 0 deletions tests/common_tests.cpp
Expand Up @@ -7,6 +7,7 @@
#include <Etaler/Encoders/GridCell2d.hpp>
#include <Etaler/Core/Serialize.hpp>
#include <Etaler/Algorithms/SDRClassifer.hpp>
#include <Etaler/Algorithms/Anomaly.hpp>

#include <numeric>

Expand Down Expand Up @@ -1019,6 +1020,30 @@ TEST_CASE("Type system")
}
}

TEST_CASE("Anomaly")
{
Tensor real = zeros({256}, DType::Bool);
real[{range(30, 40)}] = true;
SECTION("All zeros") {
Tensor pred = zeros({256}, DType::Bool);
CHECK(anomaly(pred, real) == 1);
}

SECTION("Totally correct") {
CHECK(anomaly(real, real) == 0);
}

SECTION("Totally wrong") {
CHECK(anomaly(!real, real) == 1);
}

SECTION("Partially correct") {
Tensor pred = zeros({256}, DType::Bool);
pred[{range(30, 35)}] = true;
CHECK(anomaly(pred, real) == 0.5);
}
}

// TODO: Should I count this as an integration test?
// This test checks all components of Tensor works together properly
TEST_CASE("Complex Tensor operations")
Expand Down

0 comments on commit db858e5

Please sign in to comment.