Skip to content

Commit

Permalink
Merge pull request #138 from marty1885/apichange
Browse files Browse the repository at this point in the history
New batch of changes
  • Loading branch information
marty1885 committed Apr 5, 2020
2 parents 076cdda + eb4b97a commit 7da19eb
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 35 deletions.
4 changes: 4 additions & 0 deletions Etaler/Core/Backend.hpp
Expand Up @@ -20,6 +20,10 @@ struct Backend;
struct ETALER_EXPORT Backend : public std::enable_shared_from_this<Backend>
{
virtual ~Backend() = default;
// Backends are not copyable
Backend() = default;
Backend(const Backend&) = delete;
Backend& operator=(const Backend&) = delete;
virtual std::shared_ptr<TensorImpl> createTensor(const Shape& shape, DType dtype, const void* data = nullptr) {throw notImplemented("createTensor");};

virtual void sync() const {} //Default empty implemention. For async backends
Expand Down
47 changes: 23 additions & 24 deletions Etaler/Core/Tensor.cpp
Expand Up @@ -9,7 +9,7 @@ size_t g_print_threshold = 1000;
size_t g_truncate_size = 3;

template <typename T>
static size_t prettyPrintTensor(std::ostream& os, const T* arr, Shape shape, size_t depth, size_t max_depth, size_t max_length=0, bool truncate=false) noexcept
static void prettyPrintTensor(std::ostream& os, const T* arr, const Shape& shape, size_t depth, size_t max_length=0, bool truncate=false) noexcept
{
// Not using std::to_string because std::to_string(0.f) returns "0.00000"
auto toStr = [](auto val) {
Expand All @@ -21,17 +21,17 @@ static size_t prettyPrintTensor(std::ostream& os, const T* arr, Shape shape, siz
//If at the first dimention
if(depth == 0) {
//Calculatet the max character of printing a single element needs
for(int i=0;i<shape.volume();i++)
for(intmax_t i=0;i<shape.volume();i++)
max_length = std::max(max_length, toStr(arr[i]).size());
}

const std::string truncate_symbol = "....";

//If at the the last dimention, print the content of the tensor
if(shape.size() == 1) {
if(depth+1 == shape.size()) {
os << "{ ";
intmax_t size = shape[0];
intmax_t max_line_content = intmax_t((80-depth*2-4)/(max_length+2));
intmax_t size = shape[depth];
intmax_t max_line_content = intmax_t((80-depth*2-truncate_symbol.size())/(max_length+2));

//Print the full content
if(size <= max_line_content || !truncate) {
Expand Down Expand Up @@ -62,70 +62,69 @@ static size_t prettyPrintTensor(std::ostream& os, const T* arr, Shape shape, siz
}

os << "}";
return 1;
return;
}

// Otherwise (we aren't in the last dimension)
// print the curly braces recursively
intmax_t size = shape[0];
shape.erase(shape.begin());
intmax_t vol = shape.volume();

size_t ret_depth = 0; // TODO: Do we really need this? Should be deterministic?
const intmax_t size = shape[0];
const intmax_t vol = std::accumulate(shape.begin()+depth+1, shape.end(), intmax_t(1), std::multiplies<intmax_t>());
const size_t remain_recursion = shape.size() - depth - 1;
const size_t done_recursion = depth + 1;
os << "{";

if(size < 2*intmax_t(g_truncate_size) || !truncate) {
//The full version
for(intmax_t i=0;i<size;i++) {
//Print the data recursivelly
ret_depth = prettyPrintTensor(os, arr+i*vol, shape, depth+1, max_depth, max_length, truncate);
prettyPrintTensor(os, arr+i*vol, shape, depth+1, max_length, truncate);
if(i != size-1)
os << ", " << std::string(ret_depth, '\n') << (i==size-1 ? std::string("") : std::string(max_depth-ret_depth, ' '));
os << ", " << std::string(remain_recursion, '\n') << (i==size-1 ? std::string("") : std::string(done_recursion, ' '));
}
}
else {
//The first half
for(intmax_t i=0;i<intmax_t(g_truncate_size);i++) {
//Print the data recursivelly
ret_depth = prettyPrintTensor(os, arr+i*vol, shape, depth+1, max_depth, max_length, truncate);
prettyPrintTensor(os, arr+i*vol, shape, depth+1, max_length, truncate);
if(i != size-1)
os << ", " << std::string(ret_depth, '\n') << std::string(max_depth-ret_depth, ' ');
os << ", " << std::string(remain_recursion, '\n') << std::string(done_recursion, ' ');
}

//seperator
os << truncate_symbol << '\n' << std::string(max_depth-ret_depth, ' ');
os << truncate_symbol << '\n' << std::string(done_recursion, ' ');

//The second half
for(intmax_t i=size-intmax_t(g_truncate_size);i<size;i++) {
//Print the data recursivelly
ret_depth = prettyPrintTensor(os, arr+i*vol, shape, depth+1, max_depth, max_length, truncate);
prettyPrintTensor(os, arr+i*vol, shape, depth+1, max_length, truncate);
if(i != size-1)
os << ", " << std::string(ret_depth, '\n') << (i==size-1 ? std::string("") : std::string(max_depth-ret_depth, ' '));
os << ", " << std::string(remain_recursion, '\n') << (i==size-1 ? std::string("") : std::string(done_recursion, ' '));
}
}
os << "}";

return ret_depth+1;//return the current depth from the back
return;
}

static void printTensor(std::ostream& os, const void* ptr, const Shape& shape, DType dtype)
{
bool truncate = size_t(shape.volume()) > g_print_threshold;
if(dtype == DType::Float)
prettyPrintTensor(os, (float*)ptr, shape, 0, shape.size(), 0, truncate);
prettyPrintTensor(os, (float*)ptr, shape, 0, 0, truncate);
else if(dtype == DType::Int32)
prettyPrintTensor(os, (int32_t*)ptr, shape, 0, shape.size(), 0, truncate);
prettyPrintTensor(os, (int32_t*)ptr, shape, 0, 0, truncate);
else if(dtype == DType::Bool)
prettyPrintTensor(os, (bool*)ptr, shape, 0, shape.size(), 0, truncate);
prettyPrintTensor(os, (bool*)ptr, shape, 0, 0, truncate);
else if(dtype == DType::Half)
prettyPrintTensor(os, (half*)ptr, shape, 0, shape.size(), 0, truncate);
prettyPrintTensor(os, (half*)ptr, shape, 0, 0, truncate);
else
throw EtError("Printing tensor of this type is not supported.");
}

std::ostream& et::operator<< (std::ostream& os, const Tensor& t)
{
if(t.has_value() == false) {
if(t.has_value() == false || t.shape().size() == 0) {
os << "{}";
return os;
}
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
@@ -1,4 +1,4 @@
Copyright 2019 Martin Chang
Copyright 2019-2020 Martin Chang

Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:

Expand Down
15 changes: 9 additions & 6 deletions examples/example1.cpp
Expand Up @@ -12,13 +12,16 @@ int main()
{
//Create a SP that takes in 128 input bits and generates 32 bit representation
//setDefaultBackend(std::make_shared<OpenCLBackend>());
SpatialPooler sp({128}, {32});
// SpatialPooler sp({128}, {32});

//Encode the value 0.1 into a 32 bit SDR
Tensor x = encoder::scalar(0.1, 0, 1, 128, 12);
// //Encode the value 0.1 into a 32 bit SDR
// Tensor x = encoder::scalar(0.1, 0, 1, 128, 12);

std::cout << sp.compute(x) << std::endl;
// std::cout << sp.compute(x) << std::endl;

auto state = sp.states();
sp.loadState(state);
// auto state = sp.states();
// sp.loadState(state);

Tensor t = ones({4, 4, 2});
std::cout << t << std::endl;
}
8 changes: 4 additions & 4 deletions tests/common_tests.cpp
Expand Up @@ -106,7 +106,7 @@ TEST_CASE("Testing Tensor", "[Tensor]")
SECTION("Create Tensor from vector") {
std::vector<int> v = {1, 2, 3, 4};
Tensor t = Tensor(v);
CHECK(t.size() == intmax_t(v.size()));
CHECK(t.size() == v.size());
CHECK(t.dtype() == DType::Int);
}

Expand Down Expand Up @@ -395,7 +395,7 @@ TEST_CASE("Testing Tensor", "[Tensor]")
num_iteration += 1;
}
CHECK(num_iteration == t.shape()[0]);
CHECK(t.sum().item<int>() == 42*t.size());
CHECK(t.sum().item<int>() == int(42*t.size()));
}

SECTION("swapping Tensor") {
Expand Down Expand Up @@ -438,7 +438,7 @@ TEST_CASE("Testing Encoders", "[Encoder]")
CHECK(t.size() == 32);
REQUIRE(t.dtype() == DType::Bool);
auto v = t.toHost<uint8_t>();
CHECK(std::accumulate(v.begin(), v.end(), 0) == num_on_bits);
CHECK(std::accumulate(v.begin(), v.end(), size_t(0)) == num_on_bits);
}

SECTION("Category Encoder") {
Expand All @@ -449,7 +449,7 @@ TEST_CASE("Testing Encoders", "[Encoder]")
CHECK(t.size() == num_categories*bits_per_category);
REQUIRE(t.dtype() == DType::Bool);
auto v = t.toHost<uint8_t>();
CHECK(std::accumulate(v.begin(), v.end(), 0) == bits_per_category);
CHECK(std::accumulate(v.begin(), v.end(), size_t(0)) == bits_per_category);

Tensor q = encoder::category(1, num_categories, bits_per_category);
auto u = q.toHost<uint8_t>();
Expand Down

0 comments on commit 7da19eb

Please sign in to comment.