Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merging of ComponentArrays #69

Open
scheidan opened this issue Mar 18, 2021 · 4 comments
Open

Merging of ComponentArrays #69

scheidan opened this issue Mar 18, 2021 · 4 comments

Comments

@scheidan
Copy link
Contributor

Thanks for the really useful package! It has helped a lot to clean up our model code.

One functionality I missed is to merge component arrays. With tuples we can do:

t1 = (a=1, b=2, c=3)
t2 = (a=111, d=444)
merge(t1, t2)            # (a = 111, b = 2, c = 3, d = 444)

I can simulate this behavior, but I think my implementation is not very optimal:

function merge(a::T, b::T) where T <: ComponentVector
    ComponentVector(merge(NamedTuple(a), NamedTuple(b)))
end
function merge(a::ComponentVector, b::NamedTuple) 
    ComponentVector(merge(NamedTuple(a), b))
end
function merge(a::NamedTuple, b::ComponentVector) 
    ComponentVector(merge(a, NamedTuple(b)))
end

ca1 = ComponentVector(a=1, b=2, c=3)
ca2 = ComponentVector(b=22, d=44)

merge(ca1, ca2)     # ComponentVector{Int64}(a = 1, b = 22, c = 3, d = 44)
# also useful to add a parameter
merge(ca1, (;new=222)) # ComponentVector{Int64}(a = 1, b = 2, c = 3, new = 222)

# it works with ForwardDiff but with with Zygote
f(x) = sum(merge2(ca1, x))
f(ca2)
Zygote.gradient(f, ca2)

A good use case would be optimizing some parameters while keeping others fix:

foo(ca) = ca.a + ca.b + ca.c + ca.d
ca_fix = ComponentVector(a=1, b=2)
# optimize only parameter 'c' and 'd'
optim(ca_opt -> foo(merge(ca_fix, ca_opt))
      ...
    )
@jonniedie
Copy link
Owner

There is sorta a way to handle this already by passing in an existing ComponentArray to the constructor with keyword arguments for the fields you want to merge. But I think a merge method is probably a better. Here is the current way that sort of thing is done:

julia> ca1 = ComponentVector(a=1, b=2, c=3);

julia> ComponentArray(ca1; new=222, a=20)
ComponentVector{Int64}(a = 20, b = 2, c = 3, new = 222)

One of the problems with doing it this way is there is no easy way to splat new fields from another ComponentArray. merge would fix that. It's a little tricky to get a performant version of this. The speed of construction from a NamedTuple has been an open issue for a while now and this is a similar problem to that. This one should be a little easier to tackle, though.

@scheidan
Copy link
Contributor Author

Thanks a lot, that's good to know!
Should I make a PR to add this to the quick start section?

Does the problem with splatting of new fields you mentioned relate to this:

ca = ComponentVector(a=(a1=1, a2=2), b=(b1=33, b2=44), c=555)
ComponentArray(ca; c = 5, new=222, a=(a2=33, a1=99))  # doesn't work

@jonniedie
Copy link
Owner

Yes, I PR would be much appreciated.

Yeah, that's part of it. Having a merge(x::NamedTuple, y::CompnentArray) and the reverse should fix that and other splatting issues. For example, you should be able to splat a ComponentArray after a semicolon in function arguments and have the fields splat out as if you were doing it with a NamedTuple. merge should fix that, I think.

@scheidan
Copy link
Contributor Author

Just for reference: PropDicts.jl implements this kind of merge for dicts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants