Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from invenia/ox/create
Create
- Loading branch information
Showing
12 changed files
with
554 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,14 @@ | ||
name = "NamedDims" | ||
uuid = "356022a1-0364-5f58-8944-0da4b18d706f" | ||
authors = ["Lyndon White <lyndon.white@invenialabs.co.uk>"] | ||
authors = ["Invenia Technical Computing Corporation"] | ||
version = "0.1.0" | ||
|
||
[deps] | ||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" | ||
|
||
[extras] | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" | ||
|
||
[targets] | ||
test = ["Test"] | ||
test = ["Test", "SparseArrays"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,12 @@ | ||
module NamedDims | ||
using Base: @propagate_inbounds | ||
using Statistics | ||
|
||
greet() = print("Hello World!") | ||
export NamedDimsArray, dim | ||
|
||
include("name_core.jl") | ||
include("wrapper_array.jl") | ||
include("functions.jl") | ||
include("functions_dims.jl") | ||
|
||
end # module |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# This file is for functions that just need simple standard overloading. | ||
|
||
## Helpers: | ||
|
||
function nameddimsarray_result(original_nda, reduced_data, reduction_dims) | ||
L = names(original_nda) | ||
return NamedDimsArray{L}(reduced_data) | ||
end | ||
|
||
# if reducing over `:` then results is a scalar | ||
function nameddimsarray_result(original_nda, reduced_data, reduction_dims::Colon) | ||
return reduced_data | ||
end | ||
|
||
|
||
################################################################################### | ||
# Overloads | ||
|
||
# 1 Arg | ||
for (mod, funs) in ( | ||
(:Base, ( | ||
:sum, :prod, :count, :maximum, :minimum, :extrema, :cumsum, :cumprod, | ||
:sort, :sort!) | ||
), | ||
(:Statistics, (:mean, :std, :var, :median, :cov, :cor)), | ||
) | ||
for fun in funs | ||
@eval function $mod.$fun(a::NamedDimsArray; dims=:, kwargs...) | ||
numerical_dims = dim(a, dims) | ||
data = $mod.$fun(parent(a); dims=numerical_dims, kwargs...) | ||
return nameddimsarray_result(a, data, numerical_dims) | ||
end | ||
end | ||
end | ||
|
||
# 1 arg before | ||
for (mod, funs) in ( | ||
(:Base, (:mapslices,)), | ||
) | ||
for fun in funs | ||
@eval function $mod.$fun(f, a::NamedDimsArray; dims=:, kwargs...) | ||
numerical_dims = dim(a, dims) | ||
data = $mod.$fun(f, parent(a); dims=numerical_dims, kwargs...) | ||
return nameddimsarray_result(a, data, numerical_dims) | ||
end | ||
end | ||
end | ||
|
||
# 2 arg before | ||
for (mod, funs) in ( | ||
(:Base, (:mapreduce,)), | ||
) | ||
for fun in funs | ||
@eval function $mod.$fun(f1, f2, a::NamedDimsArray; dims=:, kwargs...) | ||
numerical_dims = dim(a, dims) | ||
data = $mod.$fun(f1, f2, parent(a); dims=numerical_dims, kwargs...) | ||
return nameddimsarray_result(a, data, numerical_dims) | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# This file is for functions that explictly mess with the dimensions of a NameDimsArray | ||
|
||
function Base.dropdims(nda::NamedDimsArray; dims) | ||
numerical_dims = dim(nda, dims) | ||
data = dropdims(parent(nda); dims=numerical_dims) | ||
L = remaining_dimnames_after_dropping(names(nda), numerical_dims) | ||
return NamedDimsArray{L}(data) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
""" | ||
dim(dimnames, [name]) | ||
For `dimnames` being a tuple of names (symbols) for the dimensions. | ||
If called with just the tuple, | ||
returns a named tuple, with each name mapped to a dimension. | ||
e.g `dim((:a, :b)) == (a=1, b=2)`. | ||
If the second `name` argument is given, them the dimension corresponding to that `name`, | ||
is returned. | ||
e.g. `dim((:a, :b), :b) == 2` | ||
If that `name` is not found then `0` is returned. | ||
""" | ||
function dim(dimnames::Tuple) | ||
# Note: This code is runnable at compile time if input is a constant | ||
# If modified, make sure to recheck that it still can run at compile time | ||
# e.g. via `@code_llvm (()->dim((:a, :b)))()` which should be very short | ||
ndims = length(dimnames) | ||
return NamedTuple{dimnames, NTuple{ndims, Int}}(1:ndims) | ||
end | ||
|
||
function dim(dimnames::Tuple, name::Symbol) | ||
# Note: This code is runnable at compile time if inputs are constants | ||
# If modified, make sure to recheck that it still can run at compile time | ||
# e.g. via `@code_llvm (()->dim((:a, :b), :a))()` which should just say `return 1` | ||
this_namemap = NamedTuple{(name,), Tuple{Int}}((0,)) # 0 is default we will overwrite | ||
full_namemap = dim(dimnames) | ||
return first(merge(this_namemap, full_namemap)) | ||
end | ||
|
||
function dim(dimnames::Tuple, names) | ||
# This handles things like `(:x, :y)` or `[:x, :y]` | ||
# or via the fallbacks `(1,2)`, or `1:5` | ||
return map(name->dim(dimnames, name), names) | ||
end | ||
|
||
function dim(dimnames::Tuple, d::Union{Integer, Colon}) | ||
# This is the fallback that allows `NamedDimsArray`'s to be have dimensions | ||
# referred to by number. This is required to allow functions on `AbstractArray`s | ||
# and that use function like `sum(xs; dims=2)` to continue to work without changes | ||
# `:` is the default for most methods that take `dims` | ||
return d | ||
end | ||
|
||
|
||
""" | ||
default_inds(dimnames::Tuple) | ||
This is the default value for all indexing expressions using the given dimnames. | ||
Which is to say: take a full slice on everything | ||
""" | ||
function default_inds(dimnames::NTuple{N}) where N | ||
# Note: This code is runnable at compile time if input is a constant | ||
# If modified, make sure to recheck that it still can run at compile time | ||
values = ntuple(_->Colon(), N) | ||
return NamedTuple{dimnames, NTuple{N, Colon}}(values) | ||
end | ||
|
||
|
||
""" | ||
order_named_inds(dimnames::Tuple; named_inds...) | ||
Returns the values of the `named_inds`, sorted as per the order they appear in `dimnames`, | ||
with any missing dimnames, having there value set to `:`. | ||
An error is thrown if any dimnames are given in `named_inds` that do not occur in `dimnames`. | ||
""" | ||
function order_named_inds(dimnames::Tuple; named_inds...) | ||
# Note: This code is runnable at compile time if input is a constant | ||
# If modified, make sure to recheck that it still can run at compile time | ||
|
||
slice_everything = default_inds(dimnames) | ||
full_named_inds = merge(slice_everything, named_inds) | ||
if length(full_named_inds) != length(dimnames) | ||
throw(DimensionMismatch("Expected $(dimnames), got $(keys(named_inds))")) | ||
end | ||
inds = Tuple(full_named_inds) | ||
return inds | ||
end | ||
|
||
""" | ||
remaining_dimnames_from_indexing(dimnames::Tuple, inds...) | ||
Given a tuple of dimension names | ||
and a set of index expressesion e.g `1, :, 1:3, [true, false]`, | ||
determine which are not dropped. | ||
Dimensions indexed with scalars are dropped | ||
""" | ||
@generated function remaining_dimnames_from_indexing(dimnames::Tuple, inds) | ||
# Note: This allocates once, and it shouldn't have to | ||
# See: #@btime (()->determine_remaining_dim((:a, :b, :c), (:,390,:)))() | ||
# this is because returning tuple of symbols allocates. | ||
# See: https://discourse.julialang.org/t/zero-allocation-tuple-subsetting/23122/8 | ||
# In general this allocation should be optimised out anyway, when not benchmarking | ||
# just this. | ||
ind_types = inds.parameters | ||
kept_dims = findall(keep_dim_ind_type, ind_types) | ||
keep_names = [:(getfield(dimnames, $ii)) for ii in kept_dims] | ||
return Expr(:tuple, keep_names...) | ||
end | ||
keep_dim_ind_type(::Type{<:Integer}) = false | ||
keep_dim_ind_type(::Any) = true | ||
|
||
|
||
""" | ||
remaining_dimnames_after_dropping(dimnames::Tuple, dropped_dims) | ||
Given a tuple of dimension names, and either a collection of dimensions, | ||
or a single dimension, expressed as a number, | ||
Returns the dimension names with those dimensions dropped. | ||
""" | ||
function remaining_dimnames_after_dropping(dimnames::Tuple, dropped_dim::Integer) | ||
return remaining_dimnames_after_dropping(dimnames, (dropped_dim,)) | ||
end | ||
|
||
function remaining_dimnames_after_dropping(dimnames::Tuple, dropped_dims) | ||
# Note: This allocates once, and it shouldn't have to. Reason is same as for | ||
# remaining_dimnames_from_indexing. I.e. returning tuple of symbols allocates. | ||
# see `@btime remaining_dims_names_from_reducing((:a,:b, :c, :d, :e), (1,2,)) | ||
|
||
|
||
anti_names = identity_namedtuple(map(x->dimnames[x], dropped_dims)) | ||
full_names = identity_namedtuple(dimnames) | ||
|
||
# Now we construct a new named tuple, with all the names we want to remove at the start | ||
combined_names = merge(anti_names, full_names) | ||
n_skip = length(anti_names) | ||
ntuple(length(full_names) - n_skip) do ii | ||
combined_names[ii + n_skip] # Skip over the ones we left as the start | ||
end | ||
end | ||
|
||
function identity_namedtuple(tup::NTuple{N, Symbol}) where N | ||
return NamedTuple{tup, typeof(tup)}(tup) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
""" | ||
NamedDimsArray{L,T,N,A}(data) | ||
A `NamedDimsArray` is a wrapper array type, that provides a view onto the | ||
orignal array, which can have its dimensions refer to name rather than by | ||
position. | ||
For example: | ||
``` | ||
xs = NamedDimsArray{(:features, :observations)}(data); | ||
n_obs = size(xs, :observations) | ||
feature_totals = sum(xs; dims=:observations) | ||
first_obs_vector = xs[observations=1] | ||
x = x[observations=15, features=2] # 2nd feature in 15th observation. | ||
``` | ||
`NamedDimsArray`s are normally a (near) zero-cost abstraction. | ||
They generally resolve most dimension name related operations at compile | ||
time. | ||
""" | ||
struct NamedDimsArray{L, T, N, A<:AbstractArray{T,N}} <: AbstractArray{T,N} | ||
# `L` is for labels, it should be an `NTuple{N, Symbol}` | ||
data::A | ||
end | ||
|
||
|
||
function NamedDimsArray{L}(orig::AbstractArray{T,N}) where {L, T, N} | ||
if !(L isa NTuple{N, Symbol}) | ||
throw(ArgumentError( | ||
"A $N dimentional array, needs a $N-tuple of dimension names. Got: $L" | ||
)) | ||
end | ||
return NamedDimsArray{L, T, N, typeof(orig)}(orig) | ||
end | ||
function NamedDimsArray(orig::AbstractArray{T,N}, names::NTuple{N, Symbol}) where {T, N} | ||
return NamedDimsArray{names}(orig) | ||
end | ||
|
||
parent_type(::Type{<:NamedDimsArray{L,T,N,A}}) where {L,T,N,A} = A | ||
Base.parent(x::NamedDimsArray) = x.data | ||
|
||
|
||
""" | ||
names(A) | ||
Returns a tuple of containing the names of all the dimensions of the array `A`. | ||
""" | ||
names(::Type{<:NamedDimsArray{L}}) where L = L | ||
names(x::T) where T<:NamedDimsArray = names(T) | ||
|
||
|
||
dim(a::NamedDimsArray{L}, name) where L = dim(L, name) | ||
|
||
|
||
|
||
############################# | ||
# AbstractArray Interface | ||
# https://docs.julialang.org/en/v1/manual/interfaces/index.html#man-interface-array-1 | ||
|
||
## Minimal | ||
Base.size(a::NamedDimsArray) = size(parent(a)) | ||
Base.size(a::NamedDimsArray, d) = size(parent(a), dim(a, d)) | ||
|
||
|
||
## optional | ||
Base.IndexStyle(::Type{A}) where A<:NamedDimsArray = Base.IndexStyle(parent_type(A)) | ||
|
||
Base.length(a::NamedDimsArray) = length(parent(a)) | ||
|
||
Base.axes(a::NamedDimsArray) = axes(parent(a)) | ||
Base.axes(a::NamedDimsArray, d) = axes(parent(a), dim(a, d)) | ||
|
||
|
||
function Base.similar(a::NamedDimsArray{L}, args::Type...) where L | ||
return NamedDimsArray{L}(similar(parent(a), args...)) | ||
end | ||
|
||
|
||
############################### | ||
# kwargs indexing | ||
|
||
""" | ||
order_named_inds(A, named_inds...) | ||
Returns the indices that have the names and values given by `named_inds` | ||
sorted into the order expected for the dimension of the array `A`. | ||
If any dimensions of `A` are not present in the named_inds, | ||
then they are given the value `:`, for slicing | ||
For example: | ||
``` | ||
A = NamedDimArray(rand(4,4), (:x,, :y)) | ||
order_named_inds(A; y=10, x=13) == (13,10) | ||
order_named_inds(A; x=2, y=1:3) == (2, 1:3) | ||
order_named_inds(A; y=5) == (:, 5) | ||
``` | ||
This provides the core indexed lookup for `getindex` and `setindex` on the Array `A` | ||
""" | ||
order_named_inds(A::AbstractArray; named_inds...) = order_named_inds(names(A); named_inds...) | ||
|
||
################### | ||
# getindex / view / dotview | ||
# Note that `dotview` is undocumented but needed for making `a[x=2] .= 3` work | ||
|
||
for f in (:getindex, :view, :dotview) | ||
@eval begin | ||
@propagate_inbounds function Base.$f(A::NamedDimsArray; named_inds...) | ||
inds = order_named_inds(A; named_inds...) | ||
return Base.$f(A, inds...) | ||
end | ||
|
||
@propagate_inbounds function Base.$f(a::NamedDimsArray, inds::Vararg{<:Integer}) | ||
# Easy scalar case, will just return the element | ||
return Base.$f(parent(a), inds...) | ||
end | ||
|
||
@propagate_inbounds function Base.$f(a::NamedDimsArray, inds...) | ||
# Some nonscalar case, will return an array, so need to give that names. | ||
data = Base.$f(parent(a), inds...) | ||
L = remaining_dimnames_from_indexing(names(a), inds) | ||
return NamedDimsArray{L}(data) | ||
end | ||
end | ||
end | ||
|
||
############################################ | ||
# setindex! | ||
@propagate_inbounds function Base.setindex!(a::NamedDimsArray, value; named_inds...) | ||
inds = order_named_inds(a; named_inds...) | ||
return setindex!(a, value, inds...) | ||
end | ||
|
||
@propagate_inbounds function Base.setindex!(a::NamedDimsArray, value, inds...) | ||
return setindex!(parent(a), value, inds...) | ||
end |
Oops, something went wrong.