Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC] Track tracing state with thread local state #57

Closed
wants to merge 19 commits into from

Conversation

ezyang
Copy link
Owner

@ezyang ezyang commented Aug 8, 2017

A bit of hacking I did on the flight to India.

The basic premise of this patch is that we use thread-local state to track what our tracing state is. The convention goes like this:

  • For tracing forward passes from Python, the user manually toggles tracing on/off, similar to what we had with a global tracing state variable
  • For tracing backward passes from autograd, every Function in the autograd graph is associated with a tracing state, which specify that when this Function is executed (in a single-threaded manner, natch!) this is where the traced operators should go. The saved_trace_state associated with Function itself is populated using the thread-local tracing variable.

This is in contrast to the current mechanism, which tracks state by variables.

Why do I think using TLS (or some equivalent mechanism) is better here?

  1. Tracking state by variables has too many degrees of freedom. You get multiple variables as inputs: what do you do if they're from different traces? In contrast, there isn't any way to get it wrong with a Function saved trace.
  2. It has the correct lifetime tracking. With variable tracing, we had to have weak pointers to the trace, and carefully ensure the trace stayed live as long as any variables are involved. With this new scheme, liveness is tied to the autograd closures, which makes sense: as long as grad_fn is retained, it's possible for the user to call backwards (and trace into our old trace.) If you clear it, no more traces!
  3. I conjecture this scheme works with stochastic, whereas the variable scheme does not. The intuition here is that stochastic backwards don't get any variables as input. So how do they know what to trace into? Certainly not by the input variables...
  4. Maybe we don't need to insert dummy nodes anymore?? I'm not sure about this one, but I got rid of the code that overwrites the input variables on entry to a trace and things still seemed to work

All the tests pass, but there are lots of slightly perturbed tests which still fail.

Some technical problems had to be solved along the way:

  • As Adam pointed out, you can't just use Variable* naively as the variable key in the variable-to-node map, because variables can pop in and out of existence via the SavedVariable mechanism. I ended up implementing a simple (and inefficient) unique numbering scheme using UniqueVariable, which is just a heap allocated dummy object whose pointer serves as the unique. As it has no fields, it can be safely stored in a SavedVariable.

Squishy bits:

  • I'm really not sure if the hook logic is right in the new world order. Have to think about it carefully.

There are also a bunch of misc cleanups which should get merged no matter what we decide on the PoC.

  • I added torch.jit.trace_fn and cleaned up the testsuite to use it
  • I made alexnet out-of-place for now
  • Added a utility function to check if tracing is enabled; also a utility to disable tracing unconditionally, although this one is more of a hack to work around a bug where tracing state isn't reset upon exception
  • Fixed the DCE UB
  • Made lint failure print out the IR that failed lint
  • Reworked how we maintain stages in the IR. Previously there was a graph-wide variable that controlled where all nodes were inserted; this made it really easy to violate stage invariants in optimization passes. Now the stage is inferred as much as possible, with only parameters having their stage being controlled by this variable. There may still be room for improvement here
  • Renamed enter/exit to forward_enter/forward_exit, and made them callable only from Python. Not sure about this one.

Also: don't look at the commits. One big diff is good.

apaszke and others added 19 commits August 6, 2017 13:36
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
@ezyang ezyang mentioned this pull request Aug 15, 2017
@apaszke apaszke closed this Aug 21, 2017
@ezyang ezyang deleted the pr/jit-function-trace-poc branch September 7, 2017 20:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants