Skip to content

Commit

Permalink
fixes removing components that are controls
Browse files Browse the repository at this point in the history
  • Loading branch information
aarontrowbridge committed Jan 16, 2024
1 parent e95bb74 commit bb0a7c4
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
22 changes: 20 additions & 2 deletions src/methods_named_trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,16 @@ end
Remove a component from the trajectory.
"""
function remove_component(traj::NamedTrajectory, name::Symbol)
function remove_component(traj::NamedTrajectory, name::Symbol; new_control=nothing)
@assert name traj.names
comps = NamedTuple([
(key => data) for (key, data) pairs(components(traj)) if key != name
])
if name traj.control_names
@assert !isnothing(new_control)
traj.control_names = filter!(n -> n != name, traj.control_names)
traj.control_names = (traj.control_names..., new_control)
end
return NamedTrajectory(comps, traj)
end

Expand All @@ -196,11 +201,24 @@ end
Remove a set of components from the trajectory.
"""
function remove_components(traj::NamedTrajectory, names::Vector{Symbol})
function remove_components(
traj::NamedTrajectory,
names::Vector{Symbol};
new_control_names=nothing
)
@assert all([name traj.names for name names])
comps = NamedTuple([
(key => data) for (key, data) pairs(components(traj)) if !(key names)
])
if any([name traj.control_names for name names])
@assert !isnothing(new_control_names)
if new_control_names isa Symbol
new_control_names = (new_control_names,)
end
traj.control_names = Tuple(filter!(n -> n names, [traj.control_names...]))
traj.control_names = (traj.control_names..., new_control_names...)
end

return NamedTrajectory(comps, traj)
end

Expand Down
2 changes: 1 addition & 1 deletion src/struct_named_trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function NamedTrajectory(
@assert !isempty(controls)
@assert !isnothing(timestep)
@assert timestep isa Symbol && timestep keys(comp_data) ||
timestep isa Real
timestep isa Real "timestep $(timestep)::$(typeof(timestep)) must be a symbol or real"

@assert all([k keys(comp_data) for k controls])
@assert all([k keys(comp_data) for k keys(initial)])
Expand Down

0 comments on commit bb0a7c4

Please sign in to comment.