Permalink
Switch branches/tags
Find file Copy path
Fetching contributors…
Cannot retrieve contributors at this time
220 lines (176 sloc) 7.34 KB
{-# LANGUAGE TypeOperators, PatternGuards, RankNTypes, ScopedTypeVariables, BangPatterns, FlexibleContexts #-}
{-# OPTIONS -fno-warn-incomplete-patterns #-}
-- | Fast computation of Discrete Fourier Transforms using the Cooley-Tuckey algorithm.
-- Time complexity is O(n log n) in the size of the input.
--
-- This uses a naive divide-and-conquer algorithm, the absolute performance is about
-- 50x slower than FFTW in estimate mode.
--
module Data.Array.Repa.Algorithms.FFT
( Mode(..)
, isPowerOfTwo
, fft3dP
, fft2dP
, fft1dP)
where
import Data.Array.Repa.Algorithms.Complex
import Data.Array.Repa as R
import Data.Array.Repa.Eval as R
import Data.Array.Repa.Unsafe as R
import Data.Bits ((.&.))
import Prelude as P
data Mode
= Forward
| Reverse
| Inverse
deriving (Show, Eq)
signOfMode :: Mode -> Double
signOfMode mode
= case mode of
Forward -> (-1)
Reverse -> 1
Inverse -> 1
{-# INLINE signOfMode #-}
-- | Check if an `Int` is a power of two. Assumes `n` is a natural number.
-- The implementation can be found in Henry S. Warren, Jr.'s book
-- Hacker's delight, Chapter 2.
isPowerOfTwo :: Int -> Bool
isPowerOfTwo 0 = True
isPowerOfTwo 1 = False
isPowerOfTwo n = (n .&. (n-1)) == 0
{-# INLINE isPowerOfTwo #-}
-- 3D Transform -----------------------------------------------------------------------------------
-- | Compute the DFT of a 3d array. Array dimensions must be powers of two else `error`.
fft3dP :: (Source r Complex, Monad m)
=> Mode
-> Array r DIM3 Complex
-> m (Array U DIM3 Complex)
fft3dP mode arr
= let _ :. depth :. height :. width = extent arr
!sign = signOfMode mode
!scale = fromIntegral (depth * width * height)
in if not (isPowerOfTwo depth && isPowerOfTwo height && isPowerOfTwo width)
then error $ unlines
[ "Data.Array.Repa.Algorithms.FFT: fft3d"
, " Array dimensions must be powers of two,"
, " but the provided array is "
P.++ show height P.++ "x" P.++ show width P.++ "x" P.++ show depth ]
else arr `deepSeqArray`
case mode of
Forward -> now $ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
Reverse -> now $ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
Inverse -> computeP
$ R.map (/ scale)
$ fftTrans3d sign $ fftTrans3d sign $ fftTrans3d sign arr
{-# INLINE fft3dP #-}
fftTrans3d
:: Source r Complex
=> Double
-> Array r DIM3 Complex
-> Array U DIM3 Complex
fftTrans3d sign arr
= let (sh :. len) = extent arr
in suspendedComputeP $ rotate3d $ fft sign sh len arr
{-# INLINE fftTrans3d #-}
rotate3d
:: Source r Complex
=> Array r DIM3 Complex -> Array D DIM3 Complex
rotate3d arr
= backpermute (sh :. m :. k :. l) f arr
where (sh :. k :. l :. m) = extent arr
f (sh' :. m' :. k' :. l') = sh' :. k' :. l' :. m'
{-# INLINE rotate3d #-}
-- Matrix Transform -------------------------------------------------------------------------------
-- | Compute the DFT of a matrix. Array dimensions must be powers of two else `error`.
fft2dP :: (Source r Complex, Monad m)
=> Mode
-> Array r DIM2 Complex
-> m (Array U DIM2 Complex)
fft2dP mode arr
= let _ :. height :. width = extent arr
sign = signOfMode mode
scale = fromIntegral (width * height)
in if not (isPowerOfTwo height && isPowerOfTwo width)
then error $ unlines
[ "Data.Array.Repa.Algorithms.FFT: fft2d"
, " Array dimensions must be powers of two,"
, " but the provided array is " P.++ show height P.++ "x" P.++ show width ]
else arr `deepSeqArray`
case mode of
Forward -> now $ fftTrans2d sign $ fftTrans2d sign arr
Reverse -> now $ fftTrans2d sign $ fftTrans2d sign arr
Inverse -> computeP $ R.map (/ scale) $ fftTrans2d sign $ fftTrans2d sign arr
{-# INLINE fft2dP #-}
fftTrans2d
:: Source r Complex
=> Double
-> Array r DIM2 Complex
-> Array U DIM2 Complex
fftTrans2d sign arr
= let (sh :. len) = extent arr
in suspendedComputeP $ transpose $ fft sign sh len arr
{-# INLINE fftTrans2d #-}
-- Vector Transform -------------------------------------------------------------------------------
-- | Compute the DFT of a vector. Array dimensions must be powers of two else `error`.
fft1dP :: (Source r Complex, Monad m)
=> Mode
-> Array r DIM1 Complex
-> m (Array U DIM1 Complex)
fft1dP mode arr
= let _ :. len = extent arr
sign = signOfMode mode
scale = fromIntegral len
in if not $ isPowerOfTwo len
then error $ unlines
[ "Data.Array.Repa.Algorithms.FFT: fft1d"
, " Array dimensions must be powers of two, "
, " but the provided array is " P.++ show len ]
else arr `deepSeqArray`
case mode of
Forward -> now $ fftTrans1d sign arr
Reverse -> now $ fftTrans1d sign arr
Inverse -> computeP $ R.map (/ scale) $ fftTrans1d sign arr
{-# INLINE fft1dP #-}
fftTrans1d
:: Source r Complex
=> Double
-> Array r DIM1 Complex
-> Array U DIM1 Complex
fftTrans1d sign arr
= let (sh :. len) = extent arr
in fft sign sh len arr
{-# INLINE fftTrans1d #-}
-- Rank Generalised Worker ------------------------------------------------------------------------
fft :: (Shape sh, Source r Complex)
=> Double -> sh -> Int
-> Array r (sh :. Int) Complex
-> Array U (sh :. Int) Complex
fft !sign !sh !lenVec !vec
= go lenVec 0 1
where go !len !offset !stride
| len == 2
= suspendedComputeP $ fromFunction (sh :. 2) swivel
| otherwise
= combine len
(go (len `div` 2) offset (stride * 2))
(go (len `div` 2) (offset + stride) (stride * 2))
where swivel (sh' :. ix)
= case ix of
0 -> (vec `unsafeIndex` (sh' :. offset)) + (vec `unsafeIndex` (sh' :. (offset + stride)))
1 -> (vec `unsafeIndex` (sh' :. offset)) - (vec `unsafeIndex` (sh' :. (offset + stride)))
{-# INLINE combine #-}
combine !len' evens odds
= evens `deepSeqArray` odds `deepSeqArray`
let odds' = unsafeTraverse odds id (\get ix@(_ :. k) -> twiddle sign k len' * get ix)
in suspendedComputeP $ (evens +^ odds') R.++ (evens -^ odds')
{-# INLINE fft #-}
-- Compute a twiddle factor.
twiddle :: Double
-> Int -- index
-> Int -- length
-> Complex
twiddle sign k' n'
= (cos (2 * pi * k / n), sign * sin (2 * pi * k / n))
where k = fromIntegral k'
n = fromIntegral n'
{-# INLINE twiddle #-}