Skip to content

Commit

Permalink
FFT-based convolutions w/ some great timing tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
cnuernber committed May 31, 2021
1 parent 44c0aa6 commit 1031999
Show file tree
Hide file tree
Showing 5 changed files with 288 additions and 31 deletions.
68 changes: 68 additions & 0 deletions java/tech/v3/datatype/Complex.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package tech.v3.datatype;


public final class Complex
{
public static double mulReal( double ar, double ai, double br, double bi) {
return ar*br - ai*bi;
}
public static double mulImg( double ar, double ai, double br, double bi) {
return ar*bi + ai*br;
}
public static double[] mul(double[] lhs, int lhsOffset, double[] rhs, int rhsOffset,
double[] result, int resOffset, int nElems) {
final int lhsOff = lhsOffset * 2;
final int rhsOff = rhsOffset * 2;
final int resOff = resOffset * 2;
for (int idx = 0; idx < nElems; ++idx ) {
int localIdx = idx*2;
result[localIdx+resOff] = mulReal(lhs[localIdx + lhsOff],
lhs[localIdx + 1 + lhsOff],
rhs[localIdx + rhsOff],
rhs[localIdx + 1 + rhsOff]);
result[localIdx+1+resOff] = mulImg(lhs[localIdx + lhsOff],
lhs[localIdx + 1 + lhsOff],
rhs[localIdx + rhsOff],
rhs[localIdx + 1 + rhsOff]);
}
return result;
}
public static double[] mul(double[] lhs, double[] rhs) {
return mul(lhs, 0, rhs, 0, new double[lhs.length], 0,
lhs.length/2);
}
public static double[] realToComplex(double[] real, int off, double[] complex,
int coff, int nElems) {
final int coffset = coff * 2;
for( int idx = 0; idx < nElems; ++idx ) {
final int localCOff = coffset + idx*2;
complex[localCOff] = real[idx+off];
complex[localCOff+1] = 0.0;
}
return complex;
}
public static double[] realToComplex(double[] real) {
return realToComplex(real, 0, new double[real.length*2], 0, real.length);
}
public static double[] realToComplex(double real, double[] complex, int coff,
int nElems) {
final int coffset = coff * 2;
for( int idx = 0; idx < nElems; ++idx ) {
final int localCOff = coffset + idx*2;
complex[localCOff] = real;
complex[localCOff+1] = 0.0;
}
return complex;
}
public static double[] complexToReal(double[] complex, int coff, double[] real,
int off, int nElems) {
final int coffset = coff * 2;
for(int idx = 0; idx < nElems; ++idx ) {
real[idx+off] = complex[coffset + idx*2];
}
return real;
}
public static double[] complexToReal(double[] complex) {
return complexToReal(complex, 0, new double[complex.length/2], 0, complex.length/2);
}
}
1 change: 1 addition & 0 deletions project.clj
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
[org.xerial.larray/larray-mmap "0.4.1"]
[org.apache.commons/commons-math3 "3.6.1"]
[org.roaringbitmap/RoaringBitmap "0.9.0"]
[com.github.wendykierp/JTransforms "3.1"]
[techascent/tech.resource "5.04"]
[techascent/tech.jna "4.05" :scope "provided"]]
:java-source-paths ["java"]
Expand Down
7 changes: 5 additions & 2 deletions src/tech/v3/datatype/clj_range.clj
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@

(extend-type Range
dtype-proto/PElemwiseDatatype
(elemwise-datatype [rng] (dtype-proto/elemwise-datatype (first rng)))
(elemwise-datatype [rng] (dtype-proto/elemwise-datatype
(if (> (.count rng) 1)
(second rng)
(first rng))))
dtype-proto/PECount
(ecount [rng] (.count rng))
dtype-proto/PClone
Expand All @@ -110,7 +113,7 @@
:float32 :float64}
(dtype-proto/elemwise-datatype rng)))
(->reader [rng]
(let [dtype (dtype-proto/elemwise-datatype (first rng))]
(let [dtype (dtype-proto/elemwise-datatype rng)]
(if (casting/integer-type? dtype)
(let [start (casting/datatype->cast-fn :unknown :int64 (first rng))
step (casting/datatype->cast-fn :unknown :int64
Expand Down
202 changes: 173 additions & 29 deletions src/tech/v3/datatype/convolve.clj
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,135 @@
(:require [tech.v3.datatype.base :as dt-base]
[tech.v3.datatype.array-buffer :as array-buffer]
[tech.v3.datatype.copy-make-container :as dt-cmc]
[tech.v3.parallel.for :as pfor])
[tech.v3.parallel.for :as pfor]
[primitive-math :as pmath])
(:import [tech.v3.datatype Convolve1D Convolve1D$Mode ArrayHelpers DoubleReader
Convolve1D$EdgeMode]
Convolve1D$EdgeMode Complex DoubleReader]
[java.util.function BiFunction]
[org.apache.commons.math3.distribution NormalDistribution]))
[java.util Arrays]
[org.apache.commons.math3.distribution NormalDistribution]
[org.jtransforms.fft DoubleFFT_1D]))


(defn- next-pow-2
^long [^long val]
(loop [retval 1]
(if (pmath/< retval val)
(recur (bit-shift-left retval 1))
retval)))

(deftype ^:private FFTWindow [^long start ^long end])


(defn- window-len
^long [^FFTWindow win]
(- (.end win) (.start win)))


(defn- window-intersect
^FFTWindow [^FFTWindow lhs ^FFTWindow rhs]
(let [win-off (max (.start lhs) (.start rhs))
win-end (min (.end lhs) (.end rhs))]
(when (pmath/> win-end win-off)
(FFTWindow. win-off win-end))))


(defn ^:no-doc convolve-fft-1d
([signal filter {:keys [mode edge-mode fft-size]
:or {mode :full
edge-mode :zero}}]
(let [filter (dt-cmc/->double-array filter)
filter-len (alength filter)
dec-filter-len (dec filter-len)
signal (dt-cmc/->double-array signal)
signal-len (count signal)
n-result (long (case mode
:full (+ signal-len dec-filter-len)
:same signal-len
:valid (inc (- signal-len filter-len))))
;;Because we pad the signal, we have a virtual coordinate space
;;for the padded signal.
[sig-virt-start sig-virt-end]
(case mode
:full [(- dec-filter-len)
(+ signal-len dec-filter-len)]
:same [(- (quot dec-filter-len 2))
(+ signal-len (quot dec-filter-len 2))]
:valid [0 signal-len])

sig-virt-start (long sig-virt-start)
sig-virt-end (long sig-virt-end)
virt-sig-len (- sig-virt-end sig-virt-start)

[sig-left-edge sig-right-edge]
(case edge-mode
:zero [0.0 0.0]
:clamp [(aget signal 0) (aget signal (dec signal-len))])

sig-left-edge (double sig-left-edge)
sig-right-edge (double sig-right-edge)
fft-win-offset (long
(case mode
:same (quot filter-len 2)
:full 0
:valid (dec filter-len)))
;;Multiply filter by 2 to respect nyquist frequency
fft-default-size (min 2048 (next-pow-2 signal-len))
fft-size (long (or fft-size (max fft-default-size
(next-pow-2 (* 2 filter-len)))))
;;Amount of each fft window is occupied by signal
fft-signal-size fft-size
;;Size of a double array that can be used in complexForward/Inverse.
fft-ary-size (* 2 fft-size)
;;Work on same size transfer for step 1.
result (double-array n-result)
;;Number of windows to execute
n-windows (quot (+ virt-sig-len (dec fft-signal-size)) fft-signal-size)
fft (DoubleFFT_1D. fft-size)
filter-fft (let [fft-input (Complex/realToComplex
filter 0
(double-array fft-ary-size) 0
filter-len)]
(.complexForward fft fft-input)
fft-input)
signal-input (double-array fft-ary-size)
fft-result (double-array fft-ary-size)
lhs-window (FFTWindow. sig-virt-start 0)
center-window (FFTWindow. 0 signal-len)
rhs-window (FFTWindow. signal-len sig-virt-end)]
(dotimes [window-idx n-windows]
(let [win-off (+ (* window-idx fft-signal-size) sig-virt-start)
win-end (min sig-virt-end (+ win-off fft-signal-size))
cur-window (FFTWindow. win-off win-end)
valid-lhs (window-intersect lhs-window cur-window)
lhs-win-len (long (if valid-lhs (window-len valid-lhs) 0))
valid-cen (window-intersect center-window cur-window)
cen-win-len (long (if valid-cen (window-len valid-cen) 0))
valid-rhs (window-intersect rhs-window cur-window)]
;;set complex parts to zero
(Arrays/fill signal-input 0.0)
;;fill left edge constant values
(when valid-lhs
(Complex/realToComplex sig-left-edge signal-input 0 (window-len valid-lhs)))
(when valid-cen
(Complex/realToComplex signal (.start valid-cen) signal-input lhs-win-len
(window-len valid-cen)))
(when valid-rhs
(Complex/realToComplex sig-right-edge signal-input (+ lhs-win-len
cen-win-len)
(window-len valid-rhs)))
(.complexForward fft signal-input)
(Complex/mul signal-input 0 filter-fft 0 fft-result 0 fft-size)
(.complexInverse fft fft-result true)
(dotimes [vidx fft-size]
(let [win-idx (- (+ vidx win-off) fft-win-offset)]
(when (and (>= win-idx 0)
(< win-idx n-result))
(ArrayHelpers/accumPlus result win-idx
(aget fft-result (* 2 vidx))))))))
(array-buffer/array-buffer result)))
([signal filter]
(convolve-fft-1d signal filter nil)))


(defn correlate1d
Expand All @@ -27,36 +151,56 @@
* `:force-serial` - For serial execution of the convolution. Unnecessary except
for profiling and comparison purposes.
* `:stepsize` - Defaults to 1, this steps across the input data in stepsize
increments.
increments. `:fft` alorithm cannot be used if stepsize is not 1.
* `:algorithm` - `:naive`, `:auto`, or `:fft`. Defaults to `:auto` which will choose
either `:naive` or `:fft` depending on input data size.
```"
([data win {:keys [mode finalizer force-serial edge-mode stepsize]
([data win {:keys [mode finalizer force-serial edge-mode stepsize algorithm]
:or {mode :full
edge-mode :zero
stepsize 1}}]
stepsize 1}
:as options}]
(let [data (dt-cmc/->double-array data)
win (dt-cmc/->double-array win)]
(-> (Convolve1D/correlate
(reify BiFunction
(apply [this n-elems applier]
(if force-serial
(.apply ^BiFunction applier 0 n-elems)
(pfor/indexed-map-reduce
n-elems
(fn [start-idx group-len]
(.apply ^BiFunction applier start-idx group-len))
dorun))))
data
win
(int stepsize)
finalizer
(case mode
:full Convolve1D$Mode/Full
:same Convolve1D$Mode/Same
:valid Convolve1D$Mode/Valid)
(case edge-mode
:zero Convolve1D$EdgeMode/Zero
:clamp Convolve1D$EdgeMode/Clamp))
(array-buffer/array-buffer))))
win (dt-cmc/->double-array win)
n-data (* (alength data) (alength win))
algorithm (if (== 1 stepsize)
(or algorithm
(if (< n-data 64000)
:naive
:fft))
:naive)]
(case algorithm
:fft
(let [win-len (alength win)
dec-win-len (dec win-len)]
(convolve-fft-1d data (reify DoubleReader
(lsize [this] win-len)
(readDouble [this idx]
(aget win (- dec-win-len idx))))
(assoc options :mode mode :edge-mode edge-mode)))
:naive
(-> (Convolve1D/correlate
(reify BiFunction
(apply [this n-elems applier]
(if force-serial
(.apply ^BiFunction applier 0 n-elems)
(pfor/indexed-map-reduce
n-elems
(fn [start-idx group-len]
(.apply ^BiFunction applier start-idx group-len))
dorun))))
data
win
(int stepsize)
finalizer
(case mode
:full Convolve1D$Mode/Full
:same Convolve1D$Mode/Same
:valid Convolve1D$Mode/Valid)
(case edge-mode
:zero Convolve1D$EdgeMode/Zero
:clamp Convolve1D$EdgeMode/Clamp))
(array-buffer/array-buffer)))))
([data win]
(correlate1d data win nil)))

Expand Down
41 changes: 41 additions & 0 deletions test/tech/v3/datatype/convolve_test.clj
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
(ns tech.v3.datatype.convolve-test
(:require [tech.v3.datatype :as dtype]
[tech.v3.datatype.convolve :as dt-conv]
[tech.v3.datatype.functional :as dfn]
[clojure.test :refer [deftest is]]))



(deftest basetest
(is (dfn/equals [0.000, 1.000, 2.500, 4.000, 1.500]
(dt-conv/convolve1d [1, 2, 3], [0, 1, 0.5])))
(is (dfn/equals [0.000, 1.000, 2.500, 4.000, 1.500]
(dt-conv/convolve1d [1, 2, 3], [0, 1, 0.5]
{:algorithm :fft})))
(is (dfn/equals [1 2.5 4]
(dt-conv/convolve1d [1, 2, 3], [0, 1, 0.5] {:mode :same})))
(is (dfn/equals [1 2.5 4]
(dt-conv/convolve1d [1, 2, 3], [0, 1, 0.5] {:mode :same
:algorithm :fft})))

(is (dfn/equals [2.5]
(dt-conv/convolve1d [1, 2, 3], [0, 1, 0.5] {:mode :valid})))
(is (dfn/equals [2.5]
(dt-conv/convolve1d [1, 2, 3], [0, 1, 0.5] {:mode :valid
:algorithm :fft})))

(let [src-data (dfn/sin (range 0 20 0.1))
modes [:same :valid :full]
edge-modes [:zero :clamp]
window-sizes (range 5 15)]
(->> (for [mode modes
edge-mode edge-modes
window-size window-sizes]
(is (dfn/equals (dt-conv/convolve1d src-data (range window-size)
{:mode mode :edge-mode edge-mode})
(dt-conv/convolve1d src-data (range window-size)
{:mode mode :edge-mode edge-mode
:algorithm :fft}))
(format "Algorithm mismatch: mode %s edge-mode %s window-size %d"
mode edge-mode window-size)))
dorun)))

0 comments on commit 1031999

Please sign in to comment.