Skip to content

Commit

Permalink
Merge pull request #23 from aarontrowbridge/dev-aaron
Browse files Browse the repository at this point in the history
Dev aaron
  • Loading branch information
aarontrowbridge committed Jul 14, 2023
2 parents 1ec6604 + b813b18 commit 57f6e60
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 11 deletions.
17 changes: 15 additions & 2 deletions src/methods_named_trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@ using ..StructNamedTrajectory
using ..StructKnotPoint


function StructKnotPoint.KnotPoint(
Z::NamedTrajectory,
t::Int
)
@assert 1 t Z.T
timestep = timesteps(Z)[t]
return KnotPoint(t, Z.data[:, t], timestep, Z.components, Z.names, Z.control_names)
end





"""
copy(::NamedTrajectory)
Expand Down Expand Up @@ -284,14 +297,14 @@ function Base.:*(traj::NamedTrajectory, α::Float64)
end

function Base.:+(traj1::NamedTrajectory, traj2::NamedTrajectory)
@assert sort([traj1.names...]) == sort([traj2.names...])
@assert traj1.names == traj2.names
@assert traj1.dim == traj2.dim
@assert traj1.T == traj2.T
return NamedTrajectory(traj1.datavec + traj2.datavec, traj1)
end

function Base.:-(traj1::NamedTrajectory, traj2::NamedTrajectory)
@assert sort([traj1.names...]) == sort([traj2.names...])
@assert traj1.names == traj2.names
@assert traj1.dim == traj2.dim
@assert traj1.T == traj2.T
return NamedTrajectory(traj1.datavec - traj2.datavec, traj1)
Expand Down
10 changes: 1 addition & 9 deletions src/struct_knot_point.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,12 @@ using ..StructNamedTrajectory
struct KnotPoint
t::Int
data::AbstractVector{Float64}
timestep::Float64
components::NamedTuple{
cnames, <:Tuple{Vararg{AbstractVector{Int}}}
} where cnames
names::Tuple{Vararg{Symbol}}
control_names::Tuple{Vararg{Symbol}}
end

function KnotPoint(
Z::NamedTrajectory,
t::Int
)
@assert 1 t Z.T
data = view(Z.data, :, t)
return KnotPoint(t, data, Z.components, Z.names, Z.control_names)
end

end
1 change: 1 addition & 0 deletions src/struct_named_trajectory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -351,4 +351,5 @@ function NamedTrajectory(

end


end
106 changes: 106 additions & 0 deletions test/test_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,110 @@ test: methods_named_trajectories.jl
@test vec(free_time_traj.b) vec(data)
@test name free_time_traj.names
@test name free_time_traj.control_names

# testing removing control component

name = :a

# case: fixed time

fixed_time_traj = remove_component(fixed_time_traj, name)
@test name fixed_time_traj.names
@test name fixed_time_traj.control_names

# case: free time

free_time_traj = remove_component(free_time_traj, name)
@test name free_time_traj.names
@test name free_time_traj.control_names

# testing removing state components

names = [:z, :y]

# case: fixed time

fixed_time_traj = remove_components(fixed_time_traj, names)
@test all(name fixed_time_traj.names for name in names)

# case: free time

free_time_traj = remove_components(free_time_traj, names)
@test all(name free_time_traj.names for name in names)

# testing updating traj data

name = :x
data = rand(3, T)

# case: fixed time

update!(fixed_time_traj, name, data)
@test fixed_time_traj.x == data

# case: free time

update!(free_time_traj, name, data)
@test free_time_traj.x == data

# testing returning times

# case: free time

@test times(free_time_traj) [0.0, cumsum(vec(free_time_traj.Δt))[1:end-1]...]

# case: fixed time

@test times(fixed_time_traj) 0.1 .* [0:T-1...]


# test get size

@test size(fixed_time_traj) == (dim = sum(fixed_time_traj.dims[fixed_time_traj.names]), T = T)
@test size(free_time_traj) == (dim = sum(free_time_traj.dims[free_time_traj.names]), T = T)


# ---------------------------------------------------------
# knot point methods
# ---------------------------------------------------------


# test getindex
# ---------------------------------------------------------
# freetime
@test free_time_traj[1] isa KnotPoint
@test free_time_traj[1].x == free_time_traj.x[:, 1]
@test free_time_traj[end] isa KnotPoint
@test free_time_traj[end].x == free_time_traj.x[:, end]
@test free_time_traj[:x] == free_time_traj.x
@test free_time_traj.timestep isa Symbol

# fixed time
@test fixed_time_traj[1] isa KnotPoint
@test fixed_time_traj[1].x == fixed_time_traj.x[:, 1]
@test fixed_time_traj[end] isa KnotPoint
@test fixed_time_traj[end].x == fixed_time_traj.x[:, end]
@test fixed_time_traj[:x] == fixed_time_traj.x
@test fixed_time_traj.timestep isa Float64



# ---------------------------------------------------------
# algebraic methods
# ---------------------------------------------------------

free_time_traj2 = copy(free_time_traj)
fixed_time_traj2 = copy(fixed_time_traj)

@test (free_time_traj + free_time_traj2).x == free_time_traj.x + free_time_traj2.x
@test (fixed_time_traj + fixed_time_traj2).x == fixed_time_traj.x + fixed_time_traj2.x

@test (free_time_traj - free_time_traj2).x == free_time_traj.x - free_time_traj2.x
@test (fixed_time_traj - fixed_time_traj2).x == fixed_time_traj.x - fixed_time_traj2.x

@test (2.0 * free_time_traj).x == (free_time_traj * 2.0).x == free_time_traj.x * 2.0
@test (2.0 * fixed_time_traj).x == (fixed_time_traj * 2.0).x == fixed_time_traj.x * 2.0



end

0 comments on commit 57f6e60

Please sign in to comment.