From 554a9c03373680af84586762f68ebd32b6d34abe Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 20 May 2024 14:18:57 -0400 Subject: [PATCH] `Tracker.data` overload (#260) --- Project.toml | 2 +- ext/ComponentArraysTrackerExt.jl | 2 ++ test/autodiff_tests.jl | 10 ++++++++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 77618dd0..79eaf6e5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.15.12" +version = "0.15.13" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/ext/ComponentArraysTrackerExt.jl b/ext/ComponentArraysTrackerExt.jl index 55754cb2..b0e68413 100644 --- a/ext/ComponentArraysTrackerExt.jl +++ b/ext/ComponentArraysTrackerExt.jl @@ -10,6 +10,8 @@ end Tracker.extract_grad!(ca::ComponentArray) = Tracker.extract_grad!(getdata(ca)) +Tracker.data(ca::ComponentArray) = ComponentArray(Tracker.data(getdata(ca)), getaxes(ca)) + function Base.materialize(bc::Base.Broadcast.Broadcasted{Tracker.TrackedStyle, Nothing, typeof(zero), <:Tuple{<:ComponentVector}}) ca = first(bc.args) diff --git a/test/autodiff_tests.jl b/test/autodiff_tests.jl index b088642c..68b371de 100644 --- a/test/autodiff_tests.jl +++ b/test/autodiff_tests.jl @@ -117,3 +117,13 @@ end @test Δ isa AbstractVector{Float64} end + +@testset "Tracker untrack" begin + ps = Tracker.param(ComponentArray(; a = rand(2))) + @test eltype(getdata(ps)) <: Tracker.TrackedReal{Float64} + + ps_data = Tracker.data(ps) + @test !(eltype(getdata(ps_data)) <: Tracker.TrackedReal{Float64}) + @test eltype(getdata(ps_data)) <: Float64 +end +