Skip to content

Commit

Permalink
Fixes #35
Browse files Browse the repository at this point in the history
  • Loading branch information
cnuernber committed Aug 18, 2021
1 parent f37d548 commit a0fa166
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 4 deletions.
83 changes: 79 additions & 4 deletions src/tech/v3/datatype/base.clj
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
[tech.v3.datatype.errors :as errors]
[tech.v3.datatype.io-sub-buffer :as io-sub-buf]
[tech.v3.datatype.casting :as casting]
[tech.v3.parallel.for :as parallel-for])
[tech.v3.parallel.for :as parallel-for]
[tech.v3.datatype.argtypes :as argtypes])
(:import [tech.v3.datatype Buffer BinaryBuffer
ObjectBuffer ElemwiseDatatype ObjectReader NDBuffer
LongReader DoubleReader]
Expand Down Expand Up @@ -297,13 +298,79 @@
(defn get-value
"Get a value from an object via conversion to a reader."
[item idx]
((->reader item) idx))
(cond
(number? item)
((->reader item) idx)
(sequential? idx)
(dtype-proto/select item idx)
:else
(throw (Exception. "Unrecognized idx type in get-value!"))))


(defn set-value!
"Set a value on an object via conversion to a writer."
"Set a value on an object via conversion to a writer. set-value can also take a tuple
in which case a select operation will be done on the tensor and value applied to the
result. See also [[set-constant!]].
Example:
```clojure
tech.v3.tensor.integration-test> (def test-tens (dtt/->tensor (->> (range 27)
(partition 3)
(partition 3))
:datatype :float64))
#'tech.v3.tensor.integration-test/test-tens
tech.v3.tensor.integration-test> test-tens
#tech.v3.tensor<float64>[3 3 3]
[[[0.000 1.000 2.000]
[3.000 4.000 5.000]
[6.000 7.000 8.000]]
[[9.000 10.00 11.00]
[12.00 13.00 14.00]
[15.00 16.00 17.00]]
[[18.00 19.00 20.00]
[21.00 22.00 23.00]
[24.00 25.00 26.00]]]
tech.v3.tensor.integration-test> (dtype/set-value! (dtype/clone test-tens) [:all :all (range 2)] 0)
#tech.v3.tensor<float64>[3 3 3]
[[[0.000 0.000 2.000]
[0.000 0.000 5.000]
[0.000 0.000 8.000]]
[[0.000 0.000 11.00]
[0.000 0.000 14.00]
[0.000 0.000 17.00]]
[[0.000 0.000 20.00]
[0.000 0.000 23.00]
[0.000 0.000 26.00]]]
tech.v3.tensor.integration-test> (def sv-tens (dtt/reshape (double-array [1 2 3]) [3 1]))
#'tech.v3.tensor.integration-test/sv-tens
tech.v3.tensor.integration-test> (dtype/set-value! (dtype/clone test-tens) [:all :all (range 2)] sv-tens)
#tech.v3.tensor<float64>[3 3 3]
[[[1.000 1.000 2.000]
[2.000 2.000 5.000]
[3.000 3.000 8.000]]
[[1.000 1.000 11.00]
[2.000 2.000 14.00]
[3.000 3.000 17.00]]
[[1.000 1.000 20.00]
[2.000 2.000 23.00]
[3.000 3.000 26.00]]]
```"
[item idx value]
((->writer item) idx value))
(cond
(number? idx)
((->writer item) idx value)
(sequential? idx)
(let [sub-tens (dtype-proto/select item idx)]
(if (= :scalar (argtypes/arg-type value))
(dtype-proto/set-constant! sub-tens 0 (ecount sub-tens) value)
(let [dst-shape (shape sub-tens)
value (dtype-proto/broadcast value dst-shape)]
(dtype-proto/mset! sub-tens nil value))))
:else
(throw (Exception. "Unrecognized idx type in set-value!")))
item)


(defn- random-access->io
Expand Down Expand Up @@ -941,6 +1008,14 @@ user> (dtt/transpose tensor [1 2 0])
"Set value(s) on an ND object. If fewer indexes are provided than dimension then a
tensor assignment is done and value is expected to be the same shape as the subrect
of the tensor as indexed by the provided dimensions. Returns t."
([t value]
(check-ns 'tech.v3.tensor)
(if (= :scalar (argtypes/arg-type value))
(dtype-proto/set-constant! t 0 (ecount t) value)
(let [t-shp (shape t)
value (dtype-proto/broadcast value t-shp)]
(dtype-proto/mset! t nil value)))
t)
([t x value]
(check-ns 'tech.v3.tensor)
(if (instance? NDBuffer t)
Expand Down
32 changes: 32 additions & 0 deletions test/tech/v3/tensor/integration_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,35 @@
into-array)]
(println "nested array of tensor shape" (dtype/shape d10000))
(time (dtt/->tensor d10000))))


(deftest set-value-mset-constant
(let [test-tens (dtt/->tensor (->> (range 27)
(partition 3)
(partition 3))
:datatype :float64)
sv-tens (dtt/reshape (double-array [1 2 3]) [3 1])
tt1 (dtype/set-value! (dtype/clone test-tens) [:all :all (range 2)] 0)
;;broadcast test
tt2 (dtype/set-value! (dtype/clone test-tens) [:all :all (range 2)] sv-tens)]

(is (= (dtt/->jvm tt1)
[[[0.000 0.000 2.000]
[0.000 0.000 5.000]
[0.000 0.000 8.000]]
[[0.000 0.000 11.00]
[0.000 0.000 14.00]
[0.000 0.000 17.00]]
[[0.000 0.000 20.00]
[0.000 0.000 23.00]
[0.000 0.000 26.00]]]))
(is (= (dtt/->jvm tt2)
[[[1.000 1.000 2.000]
[2.000 2.000 5.000]
[3.000 3.000 8.000]]
[[1.000 1.000 11.00]
[2.000 2.000 14.00]
[3.000 3.000 17.00]]
[[1.000 1.000 20.00]
[2.000 2.000 23.00]
[3.000 3.000 26.00]]]))))

0 comments on commit a0fa166

Please sign in to comment.