Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mat index bug #249

Merged
merged 22 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/indexing_behavior.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ ComponentVector{Int64}(b = [4, 1], c = (a = 2, b = [6, 30]))
But what if our range doesn't capture a full component? We can see below that using `KeepIndex` on the first five elements returns a `ComponentVector` with those elements but only the `a` and `b` names, since the `c` component wasn't fully captured.
```jldoctest indexing-label-retain
julia> ca[KeepIndex(1:5)]
5-element ComponentVector{Int64} with axis Axis(a = 1, b = 2:3):
5-element ComponentVector{Int64} with axis Axis(a = 1, b = ViewAxis(2:3, Shaped1DAxis((2,)))):
5
4
1
Expand Down
6 changes: 3 additions & 3 deletions docs/src/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ ComponentVector{Int64}(a = 11, b = 2, c = 3, new = 42)
Higher dimensional `ComponentArray`s can be created too, but it's a little messy at the moment. The nice thing for modeling is that dimension expansion through broadcasted operations can create higher-dimensional `ComponentArray`s automatically, so Jacobian cache arrays that are created internally with `false .* x .* x'` will be `ComponentArray`s with proper axes. Check out the [ODE with Jacobian](https://github.com/SciML/ComponentArrays.jl/blob/master/examples/ODE_jac_example.jl) example in the examples folder to see how this looks in practice.
```jldoctest quickstart
julia> x2 = x .* x'
7×7 ComponentMatrix{Float64} with axes Axis(a = 1, b = 2:4, c = ViewAxis(5:7, Axis(a = 1, b = 2:3))) × Axis(a = 1, b = 2:4, c = ViewAxis(5:7, Axis(a = 1, b = 2:3)))
7×7 ComponentMatrix{Float64} with axes Axis(a = 1, b = ViewAxis(2:4, Shaped1DAxis((3,))), c = ViewAxis(5:7, Axis(a = 1, b = ViewAxis(2:3, Shaped1DAxis((2,)))))) × Axis(a = 1, b = ViewAxis(2:4, Shaped1DAxis((3,))), c = ViewAxis(5:7, Axis(a = 1, b = ViewAxis(2:3, Shaped1DAxis((2,))))))
1.0 2.0 1.0 4.0 400.0 1.0 2.0
2.0 4.0 2.0 8.0 800.0 2.0 4.0
1.0 2.0 1.0 4.0 400.0 1.0 2.0
Expand All @@ -54,7 +54,7 @@ julia> x2 = x .* x'
2.0 4.0 2.0 8.0 800.0 2.0 4.0

julia> x2[:c,:c]
3×3 ComponentMatrix{Float64} with axes Axis(a = 1, b = 2:3) × Axis(a = 1, b = 2:3)
3×3 ComponentMatrix{Float64} with axes Axis(a = 1, b = ViewAxis(2:3, Shaped1DAxis((2,)))) × Axis(a = 1, b = ViewAxis(2:3, Shaped1DAxis((2,))))
160000.0 400.0 800.0
400.0 1.0 2.0
800.0 2.0 4.0
Expand All @@ -66,7 +66,7 @@ julia> x2[:a,:c]
ComponentVector{Float64}(a = 400.0, b = [1.0, 2.0])

julia> x2[:b,:c]
3×3 ComponentMatrix{Float64} with axes FlatAxis() × Axis(a = 1, b = 2:3)
3×3 ComponentMatrix{Float64} with axes Shaped1DAxis((3,)) × Axis(a = 1, b = ViewAxis(2:3, Shaped1DAxis((2,))))
800.0 2.0 4.0
400.0 1.0 2.0
1600.0 4.0 8.0
Expand Down
2 changes: 1 addition & 1 deletion src/ComponentArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ export fastindices # Deprecated
include("lazyarray.jl")

include("axis.jl")
export AbstractAxis, Axis, PartitionedAxis, ShapedAxis, ViewAxis, FlatAxis
export AbstractAxis, Axis, PartitionedAxis, ShapedAxis, Shaped1DAxis, ViewAxis, FlatAxis

include("componentarray.jl")
export ComponentArray, ComponentVector, ComponentMatrix, getaxes, getdata, valkeys
Expand Down
22 changes: 18 additions & 4 deletions src/axis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,22 @@ example)
"""
struct ShapedAxis{Shape} <: AbstractAxis{nothing} end
@inline ShapedAxis(Shape) = ShapedAxis{Shape}()
ShapedAxis(::Tuple{<:Int}) = FlatAxis()
# ShapedAxis(::Tuple{<:Int}) = FlatAxis()
Base.length(::ShapedAxis{Shape}) where{Shape} = prod(Shape)

struct Shaped1DAxis{Shape} <: AbstractAxis{nothing} end
ShapedAxis(shape::Tuple{<:Int}) = Shaped1DAxis{shape}()
Shaped1DAxis(shape::Tuple{<:Int}) = Shaped1DAxis{shape}()
Base.length(::Shaped1DAxis{Shape}) where {Shape} = only(Shape)

const Shape = ShapedAxis

unshape(ax) = ax
unshape(ax::ShapedAxis) = Axis(indexmap(ax))
unshape(ax::Shaped1DAxis) = Axis(indexmap(ax))

Base.size(::ShapedAxis{Shape}) where {Shape} = Shape
Base.size(::Shaped1DAxis{Shape}) where {Shape} = Shape



Expand Down Expand Up @@ -133,9 +141,9 @@ Axis(::Number) = NullAxis()
Axis(::NamedTuple{()}) = FlatAxis()
Axis(x) = FlatAxis()

const NotShapedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis} where {IdxMap}
const NotPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, ShapedAxis{Shape}} where {Shape, IdxMap}
const NotShapedOrPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis} where {IdxMap}
const NotShapedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, Shaped1DAxis} where {IdxMap}
const NotPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, NullAxis, ShapedAxis{Shape}, Shaped1DAxis} where {Shape, IdxMap}
const NotShapedOrPartitionedAxis = Union{Axis{IdxMap}, FlatAxis, Shaped1DAxis} where {IdxMap}


Base.merge(axs::Vararg{Axis}) = Axis(merge(indexmap.(axs)...))
Expand All @@ -149,6 +157,10 @@ reindex(i, offset) = i .+ offset
reindex(ax::FlatAxis, _) = ax
reindex(ax::Axis, offset) = Axis(map(x->reindex(x, offset), indexmap(ax)))
reindex(ax::ViewAxis, offset) = ViewAxis(viewindex(ax) .+ offset, indexmap(ax))
function reindex(ax::ViewAxis{OldInds,IdxMap,Ax}, offset) where {OldInds,IdxMap,Ax<:Shaped1DAxis}
NewInds = viewindex(ax) .+ offset
return ViewAxis(NewInds, Ax())
end

# Get AbstractAxis index
@inline Base.getindex(::AbstractAxis, idx) = ComponentIndex(idx)
Expand All @@ -175,6 +187,7 @@ end

_maybe_view_axis(inds, ax::AbstractAxis) = ViewAxis(inds, ax)
_maybe_view_axis(inds, ::NullAxis) = inds[1]
_maybe_view_axis(inds, ax::Union{ShapedAxis,Shaped1DAxis}) = ViewAxis(inds, ax)

struct CombinedAxis{C,A} <: AbstractUnitRange{Int}
component_axis::C
Expand All @@ -188,6 +201,7 @@ _component_axis(ax) = FlatAxis()

_array_axis(ax::CombinedAxis) = ax.array_axis
_array_axis(ax) = ax
_array_axis(ax::Int) = Shaped1DAxis((ax,))

Base.first(ax::CombinedAxis) = first(_array_axis(ax))

Expand Down
3 changes: 2 additions & 1 deletion src/compat/static_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ end

_maybe_SArray(x::SubArray, ::Val{N}, ::FlatAxis) where {N} = SVector{N}(x)
_maybe_SArray(x::Base.ReshapedArray, ::Val, ::ShapedAxis{Sz}) where {Sz} = SArray{Tuple{Sz...}}(x)
_maybe_SArray(x, ::Val, ::Shaped1DAxis{Sz}) where {Sz} = SArray{Tuple{Sz...}}(x)
_maybe_SArray(x, vals...) = x

@generated function static_getproperty(ca::ComponentVector, ::Val{s}) where {s}
Expand Down Expand Up @@ -32,4 +33,4 @@ macro static_unpack(expr)
push!(out.args, :($esc_name = static_getproperty($parent_var_name, $(Val(name)))))
end
return out
end
end
9 changes: 7 additions & 2 deletions src/componentarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ ComponentArray{T}(::UndefInitializer, ax::Axes) where {T,Axes<:Tuple} =

# Entry from data array and AbstractAxis types dispatches to correct shapes and partitions
# then packs up axes into a tuple for inner constructor
ComponentArray(data, ::FlatAxis...) = data
# ComponentArray(data, ::FlatAxis...) = data
ComponentArray(data, ::Union{FlatAxis,Shaped1DAxis}...) = data
ComponentArray(data, ax::NotShapedOrPartitionedAxis...) = ComponentArray(data, ax)
ComponentArray(data, ax::NotPartitionedAxis...) = ComponentArray(maybe_reshape(data, ax...), unshape.(ax)...)
function ComponentArray(data, ax::AbstractAxis...)
Expand Down Expand Up @@ -179,6 +180,10 @@ function make_idx(data, nt::Union{NamedTuple, AbstractDict}, last_val)
)...)
return (data, ViewAxis(last_index(last_val) .+ (1:len), kvs))
end
function make_idx(data, nt::NamedTuple{(), Tuple{}}, last_val)
out = last_index(last_val) .+ (1:length(nt))
return (data, ViewAxis(out, ShapedAxis((length(nt),))))
end
function make_idx(data, pair::Pair, last_val)
data, ax = make_idx(data, pair.second, last_val)
len = recursive_length(data)
Expand Down Expand Up @@ -245,7 +250,7 @@ end
# Reshape ComponentArrays with ShapedAxis axes
maybe_reshape(data, ::NotShapedOrPartitionedAxis...) = data
function maybe_reshape(data, axs::AbstractAxis...)
shapes = filter_by_type(ShapedAxis, axs...) .|> size
shapes = filter_by_type(Union{ShapedAxis,Shaped1DAxis}, axs...) .|> size
shapes = reduce((tup, s) -> (tup..., s...), shapes)
return reshape(data, shapes)
end
Expand Down
4 changes: 3 additions & 1 deletion src/componentindex.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ struct ComponentIndex{Idx, Ax<:AbstractAxis}
ax::Ax
end
ComponentIndex(idx) = ComponentIndex(idx, FlatAxis())
ComponentIndex(idx::CartesianIndex) = ComponentIndex(idx, ShapedAxis((1,)))
ComponentIndex(idx::AbstractArray{<:Integer}) = ComponentIndex(idx, ShapedAxis(size(idx)))
ComponentIndex(idx::Int) = ComponentIndex(idx, NullAxis())
ComponentIndex(vax::ViewAxis{Inds,IdxMap,Ax}) where {Inds,IdxMap,Ax} = ComponentIndex(Inds, vax.ax)

Expand Down Expand Up @@ -44,4 +46,4 @@ function _getindex_keep(ax::AbstractAxis, sym::Symbol)
end
new_ax = reindex(new_ax, -first(idx)+1)
return ComponentIndex(idx, new_ax)
end
end
2 changes: 2 additions & 0 deletions src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Base.show(io::IO, ::PartitionedAxis{PartSz, IdxMap, Ax}) where {PartSz, IdxMap,

Base.show(io::IO, ::ShapedAxis{Shape}) where {Shape} =
print(io, "ShapedAxis($Shape)")
Base.show(io::IO, ::Shaped1DAxis{Shape}) where {Shape} =
print(io, "Shaped1DAxis($Shape)")

Base.show(io::IO, ::MIME"text/plain", ::ViewAxis{Inds, IdxMap, Ax}) where {Inds, IdxMap, Ax} =
print(io, "ViewAxis($Inds, $(Ax()))")
Expand Down
Loading
Loading