Skip to content

Commit

Permalink
Added simple linear regression functionality.
Browse files Browse the repository at this point in the history
  • Loading branch information
cnuernber committed Jul 5, 2021
1 parent 678b099 commit bf8d175
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 5 deletions.
42 changes: 41 additions & 1 deletion src/tech/v3/datatype/functional.clj
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@
UnaryPredicate BinaryPredicate
PrimitiveList ArrayHelpers]
[org.roaringbitmap RoaringBitmap]
[java.util List])
[java.util List]
[org.apache.commons.math3.stat.regression SimpleRegression])
(:refer-clojure :exclude [+ - / *
<= < >= >
identity
Expand Down Expand Up @@ -462,3 +463,42 @@
(cumop options (binary-op/builtin-ops :tech.numerics/*) data))
([data]
(cumprod nil data)))


(defn linear-regressor
"Create a simple linear regressor. Returns a function that given a (double) 'x'
predicts a (double) 'y'. The function has metadata that contains the regressor and
some regressor info, notably slope and intercept.
Example:
```clojure
tech.v3.datatype.functional> (def regressor (linear-regressor [1 2 3] [4 5 6]))
#'tech.v3.datatype.functional/regressor
tech.v3.datatype.functional> (regressor 1)
4.0
tech.v3.datatype.functional> (regressor 2)
5.0
tech.v3.datatype.functional> (meta regressor)
{:regressor
#object[org.apache.commons.math3.stat.regression.SimpleRegression 0x52091e82 \"org.apache.commons.math3.stat.regression.SimpleRegression@52091e82\"],
:intercept 3.0,
:slope 1.0,
:mean-squared-error 0.0}
```"
[x y]
(let [reg (SimpleRegression.)
x (dtype-base/->reader x :float64)
y (dtype-base/->reader y :float64)]
(errors/when-not-errorf
(== (.lsize x) (.lsize y))
"x length (%d) doesn't match y length (%d)"
(.lsize x) (.lsize y))
(dotimes [idx (.size x)]
(.addData reg (.readDouble x idx) (.readDouble y idx)))
(with-meta
#(.predict reg (double %))
{:regressor reg
:intercept (.getIntercept reg)
:slope (.getSlope reg)
:mean-squared-error (.getMeanSquareError reg)})))
6 changes: 3 additions & 3 deletions src/tech/v3/datatype/gradient.clj
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,18 @@ user> (dt-grad/diff1d (dt-grad/diff1d [1 2 4 7 0]))
(.readObject data (Math/round (* multiple idx))))))))
(^Buffer [data n-elems window-fn]
(let [data-size (dt-base/ecount data)
new-n (min n-elems data-size)
n-elems (min n-elems data-size)
window-size (/ (double data-size) (double n-elems))]
(->
(if (casting/numeric-type? (dt-base/elemwise-datatype data))
(reify DoubleReader
(lsize [rdr] new-n)
(lsize [rdr] n-elems)
(readDouble [rdr idx]
(let [start-idx (Math/round (* idx window-size))
end-idx (min data-size (+ start-idx (Math/round window-size)))]
(double (window-fn (dt-base/sub-buffer data start-idx (- end-idx start-idx)))))))
(reify ObjectReader
(lsize [rdr] new-n)
(lsize [rdr] n-elems)
(readObject [rdr idx]
(let [start-idx (Math/round (* idx window-size))
end-idx (min data-size (+ start-idx (Math/round window-size)))]
Expand Down
2 changes: 1 addition & 1 deletion src/tech/v3/datatype/wavelet.clj
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,6 @@
'[tech.viz.pyplot :as pyplot])

(-> (pyplot/figure {:figsize [6 4.5]})
(pyplot/plot (range 100) (ricker 100 4))
(pyplot/plot (range (* 5 240)) (ricker (* 5 240) (/ (* 5 240) 5)))
(pyplot/show))
)

0 comments on commit bf8d175

Please sign in to comment.