Skip to content

Commit

Permalink
Tests passing again.
Browse files Browse the repository at this point in the history
  • Loading branch information
cnuernber committed Oct 13, 2020
1 parent 564bbdb commit 6c24077
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 105 deletions.
34 changes: 12 additions & 22 deletions src/tech/v3/datatype/argops.clj
Original file line number Diff line number Diff line change
Expand Up @@ -119,15 +119,14 @@


(defmacro ^:private impl-arglast-every-loop
[datatype init-value n-elems rdr pred]
[datatype n-elems rdr pred]
(let [{:keys [read-fn pred-fn]} (compare-compile-time-family datatype)]
(when-not (and read-fn pred-fn)
(throw (Exception. (format "Compile failure: %s, :read-fn %s, :pred-fn %s"
datatype read-fn pred-fn))))
`(loop [idx# 0
`(loop [idx# 1
max-idx# 0
max-value# (casting/datatype->cast-fn
:unknown ~datatype ~init-value)]
max-value# (~read-fn ~rdr 0)]
(if (== ~n-elems idx#)
max-idx#
(let [cur-val# (~read-fn ~rdr idx#)
Expand All @@ -141,37 +140,28 @@
(defn arglast-every
"Return the last index where (pred (rdr idx) (rdr (dec idx))) was true by
comparing every value and keeping track of the last index where pred was true."
[rdr pred-op init-val-map]
[rdr pred-op]
(let [pred (->binary-predicate pred-op)
op-space (casting/simple-operation-space
(dtype-base/elemwise-datatype rdr))
rdr (dtype-base/->reader rdr op-space)
n-elems (.lsize rdr)]
(case op-space
:int64 (impl-arglast-every-loop :int64 (init-val-map :int64)
n-elems rdr pred)
:float64 (impl-arglast-every-loop :float64 (init-val-map :float64)
n-elems rdr pred)
(impl-arglast-every-loop :object (init-val-map :object)
n-elems rdr pred))))
:int64 (impl-arglast-every-loop :int64 n-elems rdr pred)
:float64 (impl-arglast-every-loop :float64 n-elems rdr pred)
(impl-arglast-every-loop :object n-elems rdr pred))))


(defn argmax
"Return the index of the max item in the reader."
^long [rdr]
(arglast-every rdr :>
{:int64 Long/MIN_VALUE
:float64 (- Double/MAX_VALUE)
:object nil}))
(arglast-every rdr :>))


(defn argmin
"Return the index of the min item in the reader."
^long [rdr]
(arglast-every rdr :<
{:int64 Long/MAX_VALUE
:float64 Double/MAX_VALUE
:object nil}))
(arglast-every rdr :<))

(defmacro impl-index-of
[datatype comp-value n-elems pred rdr]
Expand Down Expand Up @@ -459,10 +449,10 @@
See arggroup for Options."
(^Map [rdr options partition-fn]
(if (= identity partition-fn)
(arggroup options rdr)
(arggroup options (unary-op/reader (->unary-operator partition-fn) rdr))))
(arggroup rdr options)
(arggroup (unary-op/reader (->unary-operator partition-fn) rdr) options)))
(^Map [rdr partition-fn]
(arggroup-by partition-fn nil rdr)))
(arggroup-by rdr nil partition-fn)))


(defn- do-argpartition-by
Expand Down
104 changes: 41 additions & 63 deletions src/tech/v3/datatype/binary_pred.clj
Original file line number Diff line number Diff line change
Expand Up @@ -60,45 +60,53 @@
(== 0 (.compare item lhs rhs)))
dtype-proto/POperator
(op-name [this] opname)))
(instance? java.util.function.BiPredicate item)
(let [^java.util.function.BiPredicate item item]
(reify
BinaryPredicates$ObjectBinaryPredicate
(binaryObject [this lhs rhs]
(.test item lhs rhs))
dtype-proto/POperator
(op-name [this] opname)))
(instance? IFn item) (ifn->binary-predicate item opname)))
(^BinaryPredicate [item] (->predicate item :_unnamed)))


(defn reader
^Buffer [pred lhs-rdr rhs-rdr]
(let [pred (->predicate pred)
lhs-rdr (dtype-base/->reader lhs-rdr)
rhs-rdr (dtype-base/->reader rhs-rdr)
op-dtype (casting/widest-datatype (.elemwiseDatatype lhs-rdr)
(.elemwiseDatatype rhs-rdr))]
op-dtype (casting/simple-operation-space
(dtype-base/elemwise-datatype lhs-rdr)
(dtype-base/elemwise-datatype rhs-rdr))
lhs-rdr (dtype-base/->reader lhs-rdr op-dtype)
rhs-rdr (dtype-base/->reader rhs-rdr op-dtype)]
(when-not (== (.lsize lhs-rdr)
(.lsize rhs-rdr))
(errors/throwf "lhs size (%d), rhs size (%d) mismatch"
(.lsize lhs-rdr)
(.lsize rhs-rdr)))
(cond
(= :boolean op-dtype)
(case op-dtype
:boolean
(reify BooleanReader
(lsize [rdr] (.lsize lhs-rdr))
(readBoolean [rdr idx]
(.binaryBoolean pred
(.readBoolean lhs-rdr idx)
(.readBoolean rhs-rdr idx))))
(casting/integer-type? op-dtype)
:int64
(reify BooleanReader
(lsize [rdr] (.lsize lhs-rdr))
(readBoolean [rdr idx]
(.binaryLong pred
(.readLong lhs-rdr idx)
(.readLong rhs-rdr idx))))
(casting/float-type? op-dtype)
:float64
(reify BooleanReader
(lsize [rdr] (.lsize lhs-rdr))
(readBoolean [rdr idx]
(.binaryDouble pred
(.readDouble lhs-rdr idx)
(.readDouble rhs-rdr idx))))
:else
(reify ObjectReader
(lsize [rdr] (.lsize lhs-rdr))
(readObject [rdr idx]
Expand Down Expand Up @@ -180,57 +188,27 @@
dtype-proto/POperator
(op-name [this] :eq))

:> (make-numeric-binary-predicate :> (pmath/> x y)
(cond
(and (instance? Instant x)
(instance? Instant y))
(.isAfter ^Instant x ^Instant y)
(and (instance? LocalDate x)
(instance? LocalDate y))
(.isAfter ^LocalDate x ^LocalDate y)
(and (instance? ZonedDateTime x)
(instance? ZonedDateTime y))
(.isAfter ^ZonedDateTime x ^ZonedDateTime y)
:else
(Numbers/gt x y)))
:>= (make-numeric-binary-predicate :>= (pmath/>= x y)
(or (.equals ^Object x y)
(cond
(and (instance? Instant x)
(instance? Instant y))
(.isAfter ^Instant x ^Instant y)
(and (instance? LocalDate x)
(instance? LocalDate y))
(.isAfter ^LocalDate x ^LocalDate y)
(and (instance? ZonedDateTime x)
(instance? ZonedDateTime y))
(.isAfter ^ZonedDateTime x ^ZonedDateTime y)
:else
(Numbers/gt x y))))
:< (make-numeric-binary-predicate :< (pmath/< x y)
(cond
(and (instance? Instant x)
(instance? Instant y))
(.isBefore ^Instant x ^Instant y)
(and (instance? LocalDate x)
(instance? LocalDate y))
(.isBefore ^LocalDate x ^LocalDate y)
(and (instance? ZonedDateTime x)
(instance? ZonedDateTime y))
(.isBefore ^ZonedDateTime x ^ZonedDateTime y)
:else
(Numbers/lt x y)))
:<= (make-numeric-binary-predicate :<= (pmath/<= x y)
(or (.equals ^Object x y)
(cond
(and (instance? Instant x)
(instance? Instant y))
(.isBefore ^Instant x ^Instant y)
(and (instance? LocalDate x)
(instance? LocalDate y))
(.isBefore ^LocalDate x ^LocalDate y)
(and (instance? ZonedDateTime x)
(instance? ZonedDateTime y))
(.isBefore ^ZonedDateTime x ^ZonedDateTime y)
:else
(Numbers/lte x y))))})
:> (make-numeric-binary-predicate
:> (pmath/> x y)
(let [comp-val (long (if (instance? Comparable x)
(.compareTo ^Comparable x y)
(compare x y)))]
(pmath/> comp-val 0)))
:>= (make-numeric-binary-predicate
:>= (pmath/>= x y)
(let [comp-val (long (if (instance? Comparable x)
(.compareTo ^Comparable x y)
(compare x y)))]
(pmath/>= comp-val 0)))
:< (make-numeric-binary-predicate
:< (pmath/< x y)
(let [comp-val (long (if (instance? Comparable x)
(.compareTo ^Comparable x y)
(compare x y)))]
(pmath/< comp-val 0)))
:<= (make-numeric-binary-predicate
:<= (pmath/<= x y)
(let [comp-val (long (if (instance? Comparable x)
(.compareTo ^Comparable x y)
(compare x y)))]
(pmath/<= comp-val 0)))})
14 changes: 7 additions & 7 deletions src/tech/v3/datatype/unary_pred.clj
Original file line number Diff line number Diff line change
Expand Up @@ -76,25 +76,25 @@
(defn reader
^Buffer [pred src-rdr]
(let [pred (->predicate pred)
src-rdr (dtype-base/->reader src-rdr)
src-dtype (dtype-base/elemwise-datatype src-rdr)]
(cond
(= :boolean src-dtype)
op-space (casting/simple-operation-space
(dtype-base/elemwise-datatype src-rdr))
src-rdr (dtype-base/->reader src-rdr op-space)]
(case op-space
:boolean
(reify BooleanReader
(lsize [rdr] (.lsize src-rdr))
(readBoolean [rdr idx]
(.unaryBoolean pred (.readBoolean src-rdr idx))))
(casting/integer-type? src-dtype)
:int64
(reify BooleanReader
(lsize [rdr] (.lsize src-rdr))
(readBoolean [rdr idx]
(.unaryLong pred (.readLong src-rdr idx))))
(casting/float-type? src-dtype)
:float64
(reify BooleanReader
(lsize [rdr] (.lsize src-rdr))
(readBoolean [rdr idx]
(.unaryDouble pred (.readDouble src-rdr idx))))
:else
(reify BooleanReader
(lsize [rdr] (.lsize src-rdr))
(readBoolean [rdr idx]
Expand Down
22 changes: 11 additions & 11 deletions test/tech/v3/datatype_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@
(let [test-floats (float-array [0 ##NaN 2 4 ##NaN])
nan-floats (float-array (repeat 5 ##NaN))]
(is (= [1 4]
(-> (argops/argfilter #(= 0 (Float/compare % Float/NaN)) test-floats)
(-> (argops/argfilter test-floats #(= 0 (Float/compare % Float/NaN)))
vec)))
(is (= [false true false false true]
(dfn/nan? test-floats)))
Expand All @@ -248,9 +248,9 @@
(dfn/eq test-floats Float/NaN)))

(is (= [1 4]
(->> (dfn/eq test-floats nan-floats)
(argops/argfilter identity)
vec)))))
(-> (dfn/eq test-floats nan-floats)
(argops/argfilter identity)
vec)))))


(deftest round-and-friends
Expand Down Expand Up @@ -331,7 +331,7 @@
;;if it operates as :object datatype. This is the default.
(let [{truevals true
falsevals false}
(argops/arggroup-by even? (range 20))]
(argops/arggroup-by (range 20) even?)]
(is (= (set truevals)
(set (filter even? (range 20)))))
(is (= (set falsevals)
Expand All @@ -342,7 +342,7 @@
;;operates and returns values in long space.
(let [{truevals true
falsevals false}
(argops/arggroup-by even? {:storage-datatype :int64} (range 20))]
(argops/arggroup-by (range 20) {:storage-datatype :int64} even?)]
(is (= (set truevals)
(set (filter even? (range 20)))))
(is (= (set falsevals)
Expand All @@ -354,7 +354,7 @@
;;if it operates as :object datatype. This is the default.
(let [{truevals true
falsevals false}
(argops/arggroup-by even? {:storage-datatype :int32} (range 20))]
(argops/arggroup-by (range 20) {:storage-datatype :int32} even?)]
;;Default group by is ordered
(is (= truevals
(filter even? (range 20))))
Expand All @@ -364,7 +364,7 @@

(let [{truevals true
falsevals false}
(argops/arggroup-by even? {:storage-datatype :int64} (range 20))]
(argops/arggroup-by (range 20) {:storage-datatype :int64} even?)]
(is (= truevals
(filter even? (range 20))))
(is (= falsevals
Expand All @@ -376,7 +376,7 @@
;;if it operates as :object datatype. This is the default.
(let [{truevals true
falsevals false}
(argops/arggroup-by even? {:storage-datatype :bitmap} (range 20))]
(argops/arggroup-by (range 20) {:storage-datatype :bitmap} even?)]
(is (= (vec truevals)
(vec (filter even? (range 20)))))
(is (= (vec falsevals)
Expand All @@ -390,7 +390,7 @@
[1 (range 5 10)]
[2 (range 10 15)]
[3 (range 15 20)]]
(vec (argops/argpartition-by #(quot (long %) 5) (range 20))))))
(vec (argops/argpartition-by (range 20) #(quot (long %) 5))))))


(deftest typed-buffer-destructure
Expand Down Expand Up @@ -463,7 +463,7 @@
(deftest argsort-generic
(let [data (dtype/make-container :java-array :int16 (shuffle
(range 10)))
indexes (argops/argsort > data)
indexes (argops/argsort data >)
new-data (dtype/indexed-buffer indexes data)]
(is (= (vec (reverse (range 10)))
(vec new-data)))))
Expand Down
4 changes: 2 additions & 2 deletions test/tech/v3/tensor/integration_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@
(dotimes [iter (count test-indexes)]
(.ndWriteLong writer iter 3 255))
(is (dfn/equals (sort test-indexes)
(argops/argfilter #(not= 0 %)
(dtt/select test-tens :all :all 3))))))
(argops/argfilter (dtt/select test-tens :all :all 3)
#(not= 0 %))))))


(deftest normal-tensor-select
Expand Down

0 comments on commit 6c24077

Please sign in to comment.