Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seems to be #265ing #22

Open
oxinabox opened this issue Sep 28, 2018 · 13 comments
Open

Seems to be #265ing #22

oxinabox opened this issue Sep 28, 2018 · 13 comments

Comments

@oxinabox
Copy link
Member

oxinabox commented Sep 28, 2018

Maybe there is a catch that is not getting updated when a function is redefined?

Code:

using Zygote;
g(x) = 2x;
@show g'(10);
g(x) = 3x;
@show g'(10);

Expected output is 2, then 3
Actual output is 2, then 2 again.

Demo:

julia> using Zygote;

julia> g(x) = 2x;

julia> @show g'(10);
(g')(10) = 2

julia> g(x) = 3x;

julia> @show g'(10);
(g')(10) = 2
@MikeInnes
Copy link
Member

Yes, this is mentioned in the readme; you can call Zygote.refresh() to avoid it.

@oxinabox
Copy link
Member Author

oxinabox commented Sep 28, 2018

Ah, I missed that.
Worth leaving this open, to track against it?

I feel like hooking into the julia #265 stuff,
should be doable?

@jrevels
Copy link

jrevels commented Sep 28, 2018

I feel like hooking into the julia #265 stuff,
should be doable?

ref JuliaLang/julia#27073

IIRC we hit a roadblock with making the world-age mechanism hookable when we realized that the compiler would need to backprop the new world age bounds in a way it didn't have to before. So it's going to take some redesign of the world-age mechanism to make that feasible.

Also note that there are still some compiler bugs that make it quite possible to trigger miscompiles that manifest as 265-style issues (ref JuliaLabs/Cassette.jl#6, JuliaLang/julia#28595). It's possible that fixing those bugs will fix some of the more blatant 265-style problems without requiring any world age redesign.

@oxinabox
Copy link
Member Author

oxinabox commented Sep 30, 2018

How about this as a work around.

While we don't want to call Refesh() when it is not going to change the answer, since it triggers an expensive recompile.
We do want to do if it would change the answer.

So we track the world-ages of functions,
that we have operated on.

min_world(f) = maximum(mm.min_world for mm in methods(f).ms)

const _tracked_worldages = Dict{Function,Int}()

"""
Refereshes Zygote, if it is required to correctly work with `f`
This avoids expensive recomplilation if it is not required.
Returns true if a refresh was done.
"""
function Zygote.refresh(f)
	cur_world = min_world(f)
	old_world = get!(_tracked_worldages, f, cur_world)
	if cur_world > old_world
		_tracked_worldages[f] = cur_world
		Zygote.refresh() # trigger refresh
		true
	else
		false
	end
end

Then inside each of the calls to functions affected by this (just gradients(f)?)
insert a call to refresh(f).

Demonstration

for proof of concept,
just doing this for my own function:

function der(f::Function)
	function (x)
		Zygote.refresh(f)
		derivative(f, x)
	end
end

demo

julia> h(x)  = 2x
h (generic function with 1 method)

julia> dh = der(h)
#7 (generic function with 1 method)

julia> dh(1)
2

julia> dh(1)
2

julia> h(x) = 10x
h (generic function with 1 method)

julia> dh(1)
2

julia> g(x) = 2x
g (generic function with 1 method)

julia> h(x) = 2g(x)
h (generic function with 1 method)

julia> dh(1)
4

julia> g(x) = 10x
g (generic function with 1 method)

julia> dh(1)
20

@oxinabox
Copy link
Member Author

oxinabox commented Oct 1, 2018

Hmmm, that is not quiet as reliable as it seemed.
it will catch redefinitions of h I think always,
but it will only sometimes catch redefinitions of g

@oxinabox
Copy link
Member Author

oxinabox commented Oct 1, 2018

The following is more reliable, but slower.
It goes and recursively search's the code looking for functions being called.

I think there is a way to speed it up,
especially if we actually only need to worry functions that are calling functions we have called gradient on before.
Then we can short circuit it in i_min_world(ff::Function).
Also for at least the outer call, we have type information, so don't need to check all methods.

function min_world(f)
	visitted = Set{Function}()

	########################################
	# Inner Dispatches
	function i_min_world(ff::Function)
		ff in visitted && return 0 
		# we are maxing over all mins, so returning 0 is fine we'll have the true value already in
		push!(visitted, ff)
		
		meths=methods(ff).ms
		isempty(meths) && return 0
		maximum(i_min_world.(meths))
	end
	
	function i_min_world(mm::Method)
		if isdefined(mm, :generator)
			# I don't know how to deal with generated functions
			# Can't search all methods since there are an infinite number of them
			mm.min_world
		else
			max(mm.min_world, i_min_world(Base.uncompressed_ast(mm)))
		end
	end
	
	i_min_world(u_ast::Core.CodeInfo) = maximum(i_min_world.(u_ast.code))

	i_min_world(expr::Expr) = expr.head == :call ? maximum(i_min_world.(expr.args)) : 0 
	
	function i_min_world(gr::GlobalRef)
		func = try 
			eval(gr)
			# This will occasionally error,
			# but never for the case we care about AFAIK
			# which is when the result is a function
		catch err
			err isa UndefVarError || rethrow()
			return 0
		end
		
		i_min_world(func)
	end
	
	i_min_world(x::Any) = 0
	
	################################################
	i_min_world(f)
end

@oxinabox
Copy link
Member Author

oxinabox commented Nov 7, 2018

thoughts, on the code I posted before?
I think it can be made sufficiently smart to solve this without a large cost overhead

@MikeInnes
Copy link
Member

So the idea is that we call refresh() at the entry point of gradient? And then that can call eval/invokelatest on _forward to get the latest definition.

It seems hard to do this in general. I could see it working in the case that everything is fully well-typed, but what about if the functions that f calls are not known at compile time? We either make that significantly more expensive (call refresh again at the boundary) or just push this issue into more complex code (which might actually make it more surprising when it comes up). I could see something like this working, but it would be a relatively invasive change.

@oxinabox
Copy link
Member Author

oxinabox commented Nov 7, 2018

So the idea is that we call refresh() at the entry point of gradient?

Yes, we call a version of refresh before the existing gradient, where this version only refreshes if required according to world_ages.

And calculating world ages has to be done recursively, since only functions directly changed have the min_world field updated (I am pretty sure you understand that already, but just for anyone else reading this in the future. Like future-me.)

I could see it working in the case that everything is fully well-typed, but what about if the functions that f calls are not known at compile time?

I assume you mean _if the methods called are not known at compile time.
Since the functions always are (excluding maybe some kinda eval hackery that I'm not sure is actually possible.).

Yes, if the method called is not known at compile time, the world-age needs to be (recursively) checked for all possibilities. And that set can be lowered somewhat by various things like argument counting (doesn't work for splatting), and knowing some of the argument types, even if not all of them.

The code in #22 (comment)
always checks all methods, it does none of the cutting down of things just mentioned.
and so is quiet slow.

But yes the other way would be to call refresh at the boundry, I think that makes sense.
There is a timing trade-off between how long it takes to just do a refresh that is not required, vs how long it takes to recurse the AST of all methods to show that it is not required.

A heuristic to make this much faster would be a max nesting depth for how deep to check for changes in called methods.
Only checking for modifications in the function being calls world age is really fast.
(No recursion, as in min_world(f) = maximum(mm.min_world for mm in methods(f).ms))

@MikeInnes
Copy link
Member

MikeInnes commented Nov 7, 2018

I assume you mean _if the methods called are not known at compile time. Since the functions always are ...

Slightly contrived, but for example:

function foo(fs)
  f = pop!(fs)
  f(1, 2)
end

foo(Any[+])

@oxinabox
Copy link
Member Author

oxinabox commented Nov 7, 2018

Ah, yes, I see what you mean.
If we can't work out the type of f there is nothing we can do at compile time.

@willow-ahrens
Copy link

Would it make sense to call the generators with Core._apply_pure() as is done in https://github.com/NHDaly/StagedFunctions.jl/blob/master/src/StagedFunctions.jl#L116? You don't need to invalidate later like StagedFunctions.jl does, but it would at least free things up so that if you define relevant rules before calling the generator you wouldn't need to refresh(). This is the approach I'm taking with finch-tensor/Finch.jl#176 and it seems to work okay for me so far.

@ToucheSir
Copy link
Member

There are some changes to world age handling being discussed in FluxML/IRTools.jl#109, but I lack the know-how to say how they fit into your proposal.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants