diff --git a/Project.toml b/Project.toml index 77618dd..79eaf6e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.15.12" +version = "0.15.13" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/ext/ComponentArraysTrackerExt.jl b/ext/ComponentArraysTrackerExt.jl index 55754cb..b0e6841 100644 --- a/ext/ComponentArraysTrackerExt.jl +++ b/ext/ComponentArraysTrackerExt.jl @@ -10,6 +10,8 @@ end Tracker.extract_grad!(ca::ComponentArray) = Tracker.extract_grad!(getdata(ca)) +Tracker.data(ca::ComponentArray) = ComponentArray(Tracker.data(getdata(ca)), getaxes(ca)) + function Base.materialize(bc::Base.Broadcast.Broadcasted{Tracker.TrackedStyle, Nothing, typeof(zero), <:Tuple{<:ComponentVector}}) ca = first(bc.args) diff --git a/test/autodiff_tests.jl b/test/autodiff_tests.jl index b088642..68b371d 100644 --- a/test/autodiff_tests.jl +++ b/test/autodiff_tests.jl @@ -117,3 +117,13 @@ end @test Δ isa AbstractVector{Float64} end + +@testset "Tracker untrack" begin + ps = Tracker.param(ComponentArray(; a = rand(2))) + @test eltype(getdata(ps)) <: Tracker.TrackedReal{Float64} + + ps_data = Tracker.data(ps) + @test !(eltype(getdata(ps_data)) <: Tracker.TrackedReal{Float64}) + @test eltype(getdata(ps_data)) <: Float64 +end +