From bb0a7c4f4d7c490da3b3b0fb787ae13fffef89c4 Mon Sep 17 00:00:00 2001 From: Aaron Trowbridge Date: Mon, 15 Jan 2024 21:37:14 -0500 Subject: [PATCH] fixes removing components that are controls --- src/methods_named_trajectory.jl | 22 ++++++++++++++++++++-- src/struct_named_trajectory.jl | 2 +- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/methods_named_trajectory.jl b/src/methods_named_trajectory.jl index 96b73c1..848fad9 100644 --- a/src/methods_named_trajectory.jl +++ b/src/methods_named_trajectory.jl @@ -183,11 +183,16 @@ end Remove a component from the trajectory. """ -function remove_component(traj::NamedTrajectory, name::Symbol) +function remove_component(traj::NamedTrajectory, name::Symbol; new_control=nothing) @assert name ∈ traj.names comps = NamedTuple([ (key => data) for (key, data) ∈ pairs(components(traj)) if key != name ]) + if name ∈ traj.control_names + @assert !isnothing(new_control) + traj.control_names = filter!(n -> n != name, traj.control_names) + traj.control_names = (traj.control_names..., new_control) + end return NamedTrajectory(comps, traj) end @@ -196,11 +201,24 @@ end Remove a set of components from the trajectory. """ -function remove_components(traj::NamedTrajectory, names::Vector{Symbol}) +function remove_components( + traj::NamedTrajectory, + names::Vector{Symbol}; + new_control_names=nothing +) @assert all([name ∈ traj.names for name ∈ names]) comps = NamedTuple([ (key => data) for (key, data) ∈ pairs(components(traj)) if !(key ∈ names) ]) + if any([name ∈ traj.control_names for name ∈ names]) + @assert !isnothing(new_control_names) + if new_control_names isa Symbol + new_control_names = (new_control_names,) + end + traj.control_names = Tuple(filter!(n -> n ∉ names, [traj.control_names...])) + traj.control_names = (traj.control_names..., new_control_names...) + end + return NamedTrajectory(comps, traj) end diff --git a/src/struct_named_trajectory.jl b/src/struct_named_trajectory.jl index 08175c0..0aea5d6 100644 --- a/src/struct_named_trajectory.jl +++ b/src/struct_named_trajectory.jl @@ -41,7 +41,7 @@ function NamedTrajectory( @assert !isempty(controls) @assert !isnothing(timestep) @assert timestep isa Symbol && timestep ∈ keys(comp_data) || - timestep isa Real + timestep isa Real "timestep $(timestep)::$(typeof(timestep)) must be a symbol or real" @assert all([k ∈ keys(comp_data) for k ∈ controls]) @assert all([k ∈ keys(comp_data) for k ∈ keys(initial)])