/
solver.go
146 lines (128 loc) · 3.38 KB
/
solver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package pf
import (
"encoding/json"
"github.com/davidkleiven/gopf/pfutil"
"github.com/vbauerster/mpb"
"github.com/vbauerster/mpb/decor"
)
// SolverCB is function type that can be added to the solver it is executed after each
// iteration
type SolverCB func(s *Solver, epoch int)
// TimeStepper is a generic interface for a the time stepper types
type TimeStepper interface {
Step(m *Model)
SetFilter(filter ModalFilter)
GetTime() float64
}
// FourierTransform is a type used to represent fourier transforms
type FourierTransform interface {
FFT(data []complex128) []complex128
IFFT(data []complex128) []complex128
Freq(i int) []float64
}
// Solver is a type used to solve phase field equations
type Solver struct {
Model *Model
Dt float64
FT FourierTransform
Stepper TimeStepper
Callbacks []SolverCB
Monitors []Monitor
StartEpoch int
}
// NewSolver initializes a new solver
func NewSolver(m *Model, domainSize []int, dt float64) *Solver {
var solver Solver
m.Init()
solver.Model = m
solver.Dt = dt
solver.Callbacks = []SolverCB{}
solver.Monitors = []Monitor{}
solver.FT = pfutil.NewFFTW(domainSize)
solver.Stepper = &Euler{
Dt: solver.Dt,
FT: solver.FT,
}
// Sanity check for fields
N := pfutil.ProdInt(domainSize)
for _, f := range m.Fields {
if len(f.Data) != N {
panic("solver: Inconsistent domain size and number of grid points")
}
}
return &solver
}
// AddCallback appends a new callback function to the solver
func (s *Solver) AddCallback(cb SolverCB) {
s.Callbacks = append(s.Callbacks, cb)
}
// Propagate evolves the equation a fixed number of steps
func (s *Solver) Propagate(nsteps int) {
for i := 0; i < nsteps; i++ {
s.Stepper.Step(s.Model)
t := s.Stepper.GetTime()
for j := range s.Model.ImplicitTerms {
s.Model.ImplicitTerms[j].OnStepFinished(t, s.Model.Bricks)
}
for j := range s.Model.ExplicitTerms {
s.Model.ExplicitTerms[j].OnStepFinished(t, s.Model.Bricks)
}
for j := range s.Model.MixedTerms {
s.Model.MixedTerms[j].OnStepFinished(t, s.Model.Bricks)
}
}
}
// SetStepper updates the stepper method based on a string.
// name has to be one of ["euler", "rk4"]
func (s *Solver) SetStepper(name string) {
switch name {
case "euler":
s.Stepper = &Euler{
Dt: s.Dt,
FT: s.FT,
}
case "rk4":
s.Stepper = &RK4{
Dt: s.Dt,
FT: s.FT,
}
default:
panic("Unknown stepper scheme")
}
}
// Solve solves the equation
func (s *Solver) Solve(nepochs int, nsteps int) {
// Initialize a progress bar
progressContainer := mpb.New(mpb.WithWidth(64))
name := "ETA: "
progressBar := progressContainer.AddBar(int64(nepochs),
mpb.PrependDecorators(
decor.Name(name, decor.WC{W: len(name) + 1, C: decor.DidentRight}),
decor.OnComplete(decor.AverageETA(decor.ET_STYLE_GO, decor.WC{W: 4}), "done"),
),
mpb.AppendDecorators(decor.Percentage()),
)
for i := 0; i < nepochs; i++ {
s.Propagate(nsteps)
for _, cb := range s.Callbacks {
cb(s, i+s.StartEpoch)
}
// Update monitors
for i := range s.Monitors {
s.Monitors[i].Add(s.Model.Bricks)
}
progressBar.Increment()
}
}
// AddMonitor adds a new monitor to the solver
func (s *Solver) AddMonitor(m Monitor) {
s.Monitors = append(s.Monitors, m)
}
// JSONifyMonitors return a JSON representation of all the monitors
func (s *Solver) JSONifyMonitors() []byte {
res, err := json.Marshal(s.Monitors)
if err != nil {
panic(err)
}
return res
}