Skip to content

Commit

Permalink
Better compute tensor (#3)
Browse files Browse the repository at this point in the history
* better compute tensors is tough.

* Hard work but a lot faster and closer to optimal.

* finished up faster compute tensors.
  • Loading branch information
cnuernber committed Nov 14, 2020
1 parent ac70778 commit acf5703
Show file tree
Hide file tree
Showing 8 changed files with 559 additions and 144 deletions.
45 changes: 45 additions & 0 deletions java/tech/v3/datatype/DoubleTensorReader.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package tech.v3.datatype;


import clojure.lang.Keyword;


public interface DoubleTensorReader extends NDBuffer {
default Object elemwiseDatatype() { return Keyword.intern(null, "float64"); }
default boolean ndReadBoolean(long idx) {
return ndReadDouble(idx) != 0.0;
}
default boolean ndReadBoolean(long y, long x) {
return ndReadDouble(y,x) != 0.0;
}
default boolean ndReadBoolean(long y, long x, long c) {
return ndReadDouble(y,x,c) != 0.0;
}
default long ndReadLong(long idx) {
return (long)ndReadDouble(idx);
}
default long ndReadLong(long row, long col) {
return (long)ndReadDouble(row,col);
}
default long ndReadLong(long height, long width, long chan) {
return (long)ndReadDouble(height,width,chan);
}
default Object ndReadObject(long c) {
if (1 != rank()) {
throw new RuntimeException("Tensor is not rank 1");
}
return ndReadDouble(c);
}
default Object ndReadObject(long y, long x) {
if (2 != rank()) {
throw new RuntimeException("Tensor is not rank 2");
}
return ndReadDouble(y,x);
}
default Object ndReadObject(long y, long x, long c) {
if (3 != rank()) {
throw new RuntimeException("Tensor is not rank 3");
}
return ndReadDouble(y,x,c);
}
}
50 changes: 50 additions & 0 deletions java/tech/v3/datatype/LongTensorReader.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package tech.v3.datatype;


import clojure.lang.Keyword;


public interface LongTensorReader extends NDBuffer {
default Object elemwiseDatatype() { return Keyword.intern(null, "int64"); }
default boolean ndReadBoolean(long idx) {
return ndReadLong(idx) != 0;
}
default boolean ndReadBoolean(long y, long x) {
return ndReadLong(y,x) != 0;
}
default boolean ndReadBoolean(long y, long x, long c) {
return ndReadLong(y,x,c) != 0;
}
default double ndReadDouble(long idx) {
return (double)ndReadLong(idx);
}
default double ndReadDouble(long row, long col) {
return (double)ndReadLong(row,col);
}
default double ndReadDouble(long height, long width, long chan) {
return (double)ndReadLong(height,width,chan);
}

//These overloads are dangerous as the ndReadObject methods are
//expected to return slices if the tensor is of greater rank
//than the nd method implies. This is why NDBuffers aren't
//Buffers.
default Object ndReadObject(long idx) {
if (1 != rank()) {
throw new RuntimeException("Tensor is not rank 1");
}
return ndReadLong(idx);
}
default Object ndReadObject(long y, long x) {
if (2 != rank()) {
throw new RuntimeException("Tensor is not rank 2");
}
return ndReadLong(y,x);
}
default Object ndReadObject(long y, long x, long c) {
if (3 != rank()) {
throw new RuntimeException("Tensor is not rank 3");
}
return ndReadLong(y,x,c);
}
}
12 changes: 6 additions & 6 deletions java/tech/v3/datatype/NDBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ public interface NDBuffer extends DatatypeBase, Iterable, IFn,
void ndWriteLong(long idx, long value);
void ndWriteLong(long row, long col, long value);
void ndWriteLong(long height, long width, long chan, long value);
double ndReadDouble(double idx);
double ndReadDouble(double row, double col);
double ndReadDouble(double height, double width, double chan);
void ndWriteDouble(double idx, double value);
void ndWriteDouble(double row, double col, double value);
void ndWriteDouble(double height, double width, double chan, double value);
double ndReadDouble(long idx);
double ndReadDouble(long row, long col);
double ndReadDouble(long height, long width, long chan);
void ndWriteDouble(long idx, double value);
void ndWriteDouble(long row, long col, double value);
void ndWriteDouble(long height, long width, long chan, double value);

// Object read methods can return slices or values.
Object ndReadObject(long idx);
Expand Down
36 changes: 36 additions & 0 deletions java/tech/v3/datatype/ObjectTensorReader.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package tech.v3.datatype;


import clojure.lang.Keyword;


public interface ObjectTensorReader extends NDBuffer {
default Object elemwiseDatatype() { return Keyword.intern(null, "float64"); }
default boolean ndReadBoolean(long idx) {
return BooleanConversions.from(ndReadObject(idx));
}
default boolean ndReadBoolean(long y, long x) {
return BooleanConversions.from(ndReadObject(y,x));
}
default boolean ndReadBoolean(long y, long x, long c) {
return BooleanConversions.from(ndReadObject(y,x,c));
}
default long ndReadLong(long idx) {
return NumericConversions.longCast(ndReadObject(idx));
}
default long ndReadLong(long row, long col) {
return NumericConversions.longCast(ndReadObject(row,col));
}
default long ndReadLong(long height, long width, long chan) {
return NumericConversions.longCast(ndReadObject(height,width,chan));
}
default double ndReadDouble(long idx) {
return NumericConversions.doubleCast(ndReadObject(idx));
}
default double ndReadDouble(long row, long col) {
return NumericConversions.doubleCast(ndReadObject(row,col));
}
default double ndReadDouble(long height, long width, long chan) {
return NumericConversions.doubleCast(ndReadObject(height,width,chan));
}
}
7 changes: 3 additions & 4 deletions src/tech/v3/datatype/copy.clj
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
[tech.v3.datatype.protocols :as dtype-proto])
(:import [sun.misc Unsafe]
[tech.v3.datatype.native_buffer NativeBuffer]
[tech.v3.datatype.array_buffer ArrayBuffer]))
[tech.v3.datatype.array_buffer ArrayBuffer]
[tech.v3.datatype NDBuffer]))


(set! *warn-on-reflection* true)
Expand All @@ -16,16 +17,14 @@

(defn generic-copy!
[src dst]
(let [src-dtype (dtype-base/elemwise-datatype src)
dst-dtype (dtype-base/elemwise-datatype dst)
(let [dst-dtype (dtype-base/elemwise-datatype dst)
op-space (casting/simple-operation-space dst-dtype)
src (dtype-base/->reader src dst-dtype)
dst (dtype-base/->writer dst)
n-elems (.lsize src)]
(when-not (== n-elems (.lsize dst))
(throw (Exception. (format "src,dst ecount mismatch: %d-%d"
n-elems (.lsize dst)))))

(case op-space
:boolean
(parallel-for/parallel-for
Expand Down

0 comments on commit acf5703

Please sign in to comment.