-
Notifications
You must be signed in to change notification settings - Fork 1
/
fft3dc2r.go
65 lines (56 loc) · 1.7 KB
/
fft3dc2r.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
package cuda
import (
"fmt"
"github.com/kuchkin/mumax3-gneb/cuda/cu"
"github.com/kuchkin/mumax3-gneb/cuda/cufft"
"github.com/kuchkin/mumax3-gneb/data"
"github.com/kuchkin/mumax3-gneb/timer"
)
// 3D single-precission real-to-complex FFT plan.
type fft3DC2RPlan struct {
fftplan
size [3]int
}
// 3D single-precission real-to-complex FFT plan.
func newFFT3DC2R(Nx, Ny, Nz int) fft3DC2RPlan {
handle := cufft.Plan3d(Nz, Ny, Nx, cufft.C2R) // new xyz swap
handle.SetStream(stream0)
return fft3DC2RPlan{fftplan{handle}, [3]int{Nx, Ny, Nz}}
}
// Execute the FFT plan, asynchronous.
// src and dst are 3D arrays stored 1D arrays.
func (p *fft3DC2RPlan) ExecAsync(src, dst *data.Slice) {
if Synchronous {
Sync()
timer.Start("fft")
}
oksrclen := p.InputLenFloats()
if src.Len() != oksrclen {
panic(fmt.Errorf("fft size mismatch: expecting src len %v, got %v", oksrclen, src.Len()))
}
okdstlen := p.OutputLenFloats()
if dst.Len() != okdstlen {
panic(fmt.Errorf("fft size mismatch: expecting dst len %v, got %v", okdstlen, dst.Len()))
}
p.handle.ExecC2R(cu.DevicePtr(uintptr(src.DevPtr(0))), cu.DevicePtr(uintptr(dst.DevPtr(0))))
if Synchronous {
Sync()
timer.Stop("fft")
}
}
// 3D size of the input array.
func (p *fft3DC2RPlan) InputSizeFloats() (Nx, Ny, Nz int) {
return 2 * (p.size[X]/2 + 1), p.size[Y], p.size[Z]
}
// 3D size of the output array.
func (p *fft3DC2RPlan) OutputSizeFloats() (Nx, Ny, Nz int) {
return p.size[X], p.size[Y], p.size[Z]
}
// Required length of the (1D) input array.
func (p *fft3DC2RPlan) InputLenFloats() int {
return prod3(p.InputSizeFloats())
}
// Required length of the (1D) output array.
func (p *fft3DC2RPlan) OutputLenFloats() int {
return prod3(p.OutputSizeFloats())
}