Skip to content

Tracing static parameters #8

@cscherrer

Description

@cscherrer

Lowered code sometimes contains :($(Expr(:static_parameter, 1))), which seems to confuse Umlaut. The compiler's use of :static_parameter is new to me, so I don't yet have any suggestions for the right way to handle this.

There are probably simpler-still examples where this problem comes up, but I'm seeing it in tracing a Tilde.jl model. So this works fine:

julia> m = @model begin
           x ~ Normal()
       end;

julia> r = rand(m())
(x = -0.8876255812527766,)

julia> logdensityof(m(), r)
-1.3128781194518375

But tracing this last call with Umlaut errors with

julia> trace(logdensityof, m(), r)

ERROR: MethodError: no method matching getproperty(::NamedTuple{(:x,), Tuple{Float64}}, ::Expr)
Closest candidates are:
  getproperty(::Any, ::Symbol) at ~/julia/julia-1.8.0-beta1/share/julia/base/Base.jl:38
  getproperty(::Any, ::Symbol, ::Symbol) at ~/julia/julia-1.8.0-beta1/share/julia/base/Base.jl:50
Stacktrace:
  [1] mkcall(::Function, ::Variable, ::Vararg{Any}; val::Missing, kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/git/Umlaut.jl/src/tape.jl:196
  [2] mkcall(::Function, ::Variable, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/tape.jl:180
  [3] record_primitive!(::Tape{Umlaut.BaseCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:137
  [4] record_or_recurse!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:194
  [5] trace!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Core.CodeInfo, ::Variable, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:212
  [6] record_or_recurse!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Variable, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:196
  [7] trace!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:212
  [8] record_or_recurse!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:196
  [9] trace!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:212
 [10] record_or_recurse!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:196
 [11] trace!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:212
 [12] record_or_recurse!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:196
 [13] trace!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Core.CodeInfo, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:212
 [14] record_or_recurse!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Function, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:196
 [15] trace!(::Umlaut.Tracer{Umlaut.BaseCtx}, ::Core.CodeInfo, ::Variable, ::Vararg{Any})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:212
 [16] trace(::Function, ::Tilde.ModelClosure{Model{NamedTuple{()}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{40}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(), Tuple{}}}, ::Vararg{Any}; ctx::Umlaut.BaseCtx, fargtypes::Nothing, deprecated_kws::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:344
 [17] trace(::Function, ::Tilde.ModelClosure{Model{NamedTuple{()}, GeneralizedGenerated.NGG.TypeLevel{Expr, "Buf{40}()"}, GeneralizedGenerated.NGG.TypeLevel{Module, "Buf{17}()"}}, NamedTuple{(), Tuple{}}}, ::NamedTuple{(:x,), Tuple{Float64}})
    @ Umlaut ~/git/Umlaut.jl/src/trace.jl:334
 [18] top-level scope
    @ REPL[91]:1

If I change mkcall to

function mkcall(fn, args...; val=missing, kwargs...)
    kwargs = NamedTuple(kwargs)
    if !isempty(kwargs)
        args = (kwargs, fn, args...)
        fn = Core.kwfunc(fn)
    end
    fargs = (fn, args...)
    calculable = all(
        a -> !isa(a, Variable) ||                      # not variable
        (a._op !== nothing && a._op.val !== missing),  # bound variable
        fargs
    )
    if val === missing && calculable
        fargs_ = map_vars(v -> v._op.val, fargs)
        fn_, args_ = fargs_[1], fargs_[2:end]
        @show fn_
        @show args_
        val_ = fn_(args_...)
    else
        val_ = val
    end
    return Call(0, val_, fn, [args...])
end

I get that when the error is thrown, we have

fn_ = getproperty
args_ = ((x = -0.8876255812527766,), :($(Expr(:static_parameter, 1))))

This makes me think Umlaut seems some extra code to handle :($(Expr(:static_parameter, 1))) as a special case.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions