Skip to content
This repository was archived by the owner on May 6, 2021. It is now read-only.

Commit 970e442

Browse files
authored
Add functionality; backward integration (#4)
1 parent 4f83242 commit 970e442

File tree

4 files changed

+110
-37
lines changed

4 files changed

+110
-37
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "FymEnvs"
22
uuid = "d6fd7ba0-2ca9-4676-ba23-5bed9e863cfb"
33
authors = ["JinraeKim <kjl950403@gmail.com> and contributors"]
4-
version = "0.2.4"
4+
version = "0.3.0"
55

66
[deps]
77
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"

README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ systems = Dict("sys" => BaseSystem(initial_state=x0, name="3d_sys"))
4747
log_dir = "data/test"
4848
file_name = "fym.h5"
4949
logger = Logger(log_dir=log_dir, file_name=file_name)
50-
env = BaseEnv(max_t=100.00, logger=logger, name="test_env")
50+
env = BaseEnv(max_t=100.00, logger=logger, name="test_env",)
5151
systems!(env, systems) # set systems; required
5252
dyn!(env, set_dyn) # set dynamics; required
5353
step!(env, step) # set step; required
@@ -66,27 +66,27 @@ i = 0
6666
end
6767
close!(env)
6868
data = load(env.logger.path)
69-
@show env
70-
@show size(data["state"]["sys"])
69+
show(env)
70+
show(size(data["state"]["sys"]))
7171
```
7272

7373
Result:
7474

7575
```julia
7676
# time and progressbar
77-
99%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌ | ETA: 0:00:00
78-
0.744015 seconds (3.85 M allocations: 225.629 MiB, 3.30% gc time)
77+
99%|████████████████████████████████████████████████████████████████████████████████████▏| ETA: 0:00:00 3.513103 seconds (11.45 M allocations: 621.034 MiB, 3.28% gc time)
7978
# representation, i.e., show (nested env supported)
8079
name: test_env
80+
max_t: 100.0
81+
dt: 0.01
8182
+---name: 3d_sys
8283
| state: [3.7200760072278747e-44, 7.440152014455749e-44, 1.1160228021683672e-43]
8384
| dot: [-3.7200756925403154e-44, -7.440151385080631e-44, -1.1160227077620995e-43]
8485
| initial_state: [1.0, 2.0, 3.0]
8586
| state_size: (3,)
8687
| flat_index: 1:3
87-
env =
8888
# saved data
89-
size((data["state"])["sys"]) = (10000, 3)
89+
(10000, 3)
9090
```
9191

9292
For more examples, see directory `test`.

src/FymCore.jl

Lines changed: 42 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ export time_over, time
3131

3232

3333
############ Clock ############
34+
"""
35+
Clock
36+
37+
# Arguments
38+
max_t: maximum (minimum) terminal time for forward (backward) integration,
39+
respectively.
40+
"""
3441
mutable struct Clock
3542
t::Float64
3643
dt::Float64
@@ -51,8 +58,8 @@ function init!(clock::Clock, dt, ode_step_len; max_t=10.0)
5158
return clock
5259
end
5360

54-
function reset!(clock::Clock)
55-
clock.t = 0.0
61+
function reset!(clock::Clock; t=0.0)
62+
clock.t = t
5663
return clock
5764
end
5865

@@ -72,16 +79,17 @@ end
7279
"Check if the time is larger than max_t."
7380
function time_over(clock::Clock; t=nothing)
7481
if t == nothing
75-
return time(clock) >= clock.max_t
82+
return sign(clock.dt) * (time(clock) - clock.max_t) >= 0.0
7683
else
77-
return t >= clock.max_t
84+
return sign(clock.dt) * (t - clock.max_t) >= 0.0
7885
end
7986
end
8087

8188
function thist(clock::Clock)
8289
thist = clock.thist .+ time(clock)
8390
if time_over(clock, t=thist[end])
84-
index = findfirst(thist .> clock.max_t)
91+
index = findfirst(sign(clock.dt)*(thist .- clock.max_t) .> 0.0)
92+
# index = findfirst(thist .> clock.max_t)
8593
if index == nothing
8694
return thist
8795
else
@@ -106,21 +114,9 @@ end
106114

107115
function _show(sys::BaseSystem; i=0)
108116
result = []
109-
for symbol in [:name, :state, :dot, :initial_state,
110-
:state_size, :flat_index]
111-
if isdefined(sys, symbol)
112-
value = getproperty(sys, symbol)
113-
else
114-
value = "undef"
115-
end
116-
if symbol == :name
117-
space = "+---"
118-
else
119-
space = "| "
120-
end
121-
push!(result,
122-
_add_space("$(String(symbol)): $value", i, space=space))
123-
end
117+
_stack_property!(result, sys,
118+
[:name, :state, :dot, :initial_state,
119+
:state_size, :flat_index], i=i)
124120
return join(result, "\n")
125121
end
126122

@@ -169,6 +165,7 @@ mutable struct BaseEnv
169165
dyn
170166
step
171167

168+
initial_time
172169
dt::Float64
173170
clock::Clock
174171
progressbar
@@ -191,9 +188,29 @@ function _add_space(string, i; space=" "^4)
191188
return space^i * string
192189
end
193190

191+
function _stack_property!(result, sys, symbol_array; i=0, space="")
192+
for symbol in symbol_array
193+
if isdefined(sys, symbol)
194+
value = getproperty(sys, symbol)
195+
else
196+
value = "undef"
197+
end
198+
if symbol == :name
199+
space = "+---"
200+
else
201+
space = "| "
202+
end
203+
push!(result, _add_space("$(String(symbol)): $value", i, space=space))
204+
# push!(result, _add_space("$(String(symbol)): $value", 0))
205+
end
206+
end
207+
194208
function _show(env::BaseEnv; i=0)
195209
result = []
196-
push!(result, _add_space("name: $(env.name)", i, space="+---"))
210+
_stack_property!(result, env, [:name]) # from env
211+
if i == 0
212+
_stack_property!(result, env.clock, [:max_t, :dt]) # from env.clock
213+
end
197214
for system in _systems(env)
198215
if typeof(system) == BaseSystem
199216
v_str = _show(system, i=i+1)
@@ -213,15 +230,16 @@ end
213230
function init!(env::BaseEnv;
214231
systems=Dict(), dyn=nothing, step=nothing,
215232
# params=Dict(),
216-
dt=0.01, max_t=1.0, ode_step_len=1,
233+
initial_time=0.0, dt=0.01, max_t=1.0,
234+
ode_step_len=1,
217235
logger=nothing, ode_option=Dict(), solver="rk4",
218236
name=nothing,
219237
)
220238
env.name = name
221239
systems!(env, systems)
222240

223241
env.clock = Clock(dt, ode_step_len, max_t=max_t)
224-
env.dt = env.clock.dt
242+
env.initial_time = initial_time
225243

226244
env.logger = logger
227245

@@ -279,7 +297,7 @@ function reset!(env::BaseEnv)
279297
for system in _systems(env)
280298
reset!(system)
281299
end
282-
reset!(env.clock)
300+
reset!(env.clock, t=env.initial_time)
283301
end
284302

285303
function state(env::BaseEnv)

test/runtests.jl

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function test_Fym()
4646
log_dir = "data/test"
4747
file_name = "fym.h5"
4848
logger = Logger(log_dir=log_dir, file_name=file_name)
49-
env = BaseEnv(max_t=100.00, logger=logger, name="test_env")
49+
env = BaseEnv(max_t=100.00, logger=logger, name="test_env",)
5050
systems!(env, systems) # set systems; required
5151
dyn!(env, set_dyn) # set dynamics; required
5252
step!(env, step) # set step; required
@@ -65,8 +65,61 @@ function test_Fym()
6565
end
6666
close!(env)
6767
data = load(env.logger.path)
68-
@show env
69-
@show size(data["state"]["sys"])
68+
show(env)
69+
show(size(data["state"]["sys"]))
70+
end
71+
72+
function test_reverse_time()
73+
print_msg("reverse_time")
74+
function set_dyn(env, t)
75+
# corresponding to `set_dot` of the original fym
76+
# you can use any names in this package
77+
sys = env.systems["sys"]
78+
x = sys.state
79+
A = Matrix(I, 3, 3)
80+
sys.dot = -A * x
81+
end
82+
function step(env)
83+
t = time(env.clock)
84+
sys = env.systems["sys"]
85+
x = sys.state
86+
update!(env)
87+
next_obs = sys.state
88+
reward = zeros(1)
89+
done = time_over(env.clock)
90+
info = Dict(
91+
"time" => t,
92+
"state" => x,
93+
)
94+
return next_obs, reward, done, info
95+
end
96+
97+
x0 = collect(1:3)
98+
systems = Dict("sys" => BaseSystem(initial_state=x0, name="3d_sys"))
99+
log_dir = "data/test"
100+
file_name = "fym.h5"
101+
logger = Logger(log_dir=log_dir, file_name=file_name)
102+
env = BaseEnv(max_t=0.00, dt=-0.01, initial_time=1.0, logger=logger, name="test_env",)
103+
systems!(env, systems) # set systems; required
104+
dyn!(env, set_dyn) # set dynamics; required
105+
step!(env, step) # set step; required
106+
107+
reset!(env) # reset env; required before propagation
108+
obs = observe_flat(env)
109+
i = 0
110+
@time while true
111+
render(env) # not mendatory; would make simulator slow
112+
next_obs, reward, done, info = env.step()
113+
obs = next_obs
114+
i += 1
115+
if done
116+
break
117+
end
118+
end
119+
close!(env)
120+
data = load(env.logger.path)
121+
show(env)
122+
show(size(data["state"]["sys"]))
70123
end
71124

72125
function test_largescale_env()
@@ -110,7 +163,7 @@ function test_largescale_env()
110163
"sys" => BaseSystem(initial_state=x0),
111164
"sys2" => BaseSystem(initial_state=y0),
112165
)
113-
env0 = BaseEnv(systems=systems0, dyn=set_dyn0)
166+
env0 = BaseEnv(systems=systems0, dyn=set_dyn0, name="env0")
114167
systems = Dict(
115168
"sys" => BaseSystem(initial_state=x0),
116169
"sys2" => BaseSystem(initial_state=y0),
@@ -119,7 +172,7 @@ function test_largescale_env()
119172
log_dir = "data/test"
120173
file_name = "largescale.h5"
121174
logger = Logger(log_dir=log_dir, file_name=file_name)
122-
env = BaseEnv(max_t=100.00, logger=logger)
175+
env = BaseEnv(max_t=10.00, logger=logger, name="env")
123176
systems!(env, systems)
124177
dyn!(env, set_dyn)
125178
step!(env, step)
@@ -151,6 +204,7 @@ function test_largescale_env()
151204
savefig(p2, joinpath(log_dir, "state2.pdf"))
152205
savefig(env0p, joinpath(log_dir, "env0state.pdf"))
153206
savefig(env0p2, joinpath(log_dir, "env0state2.pdf"))
207+
show(env)
154208
end
155209

156210
function test_custom_env()
@@ -227,6 +281,7 @@ end
227281

228282
function test_all()
229283
test_Fym()
284+
test_reverse_time()
230285
test_largescale_env()
231286
test_custom_fym()
232287
end

0 commit comments

Comments
 (0)