Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The details of "samples one sub-graph at one training iteration" #12

Closed
kongxz opened this issue May 28, 2019 · 4 comments
Closed

The details of "samples one sub-graph at one training iteration" #12

kongxz opened this issue May 28, 2019 · 4 comments

Comments

@kongxz
Copy link

kongxz commented May 28, 2019

Can you talk about the details of "samples one sub-graph at one training iteration"?

As far as I know, the result of Gumbel Softmax may not be a one hot vector. It may be a vector like [0.96, 0.01, 0.01, 0.01, 0.01].

When you sample one sub-graph at training, do you just drop all the connections with weights 0.01?

Thanks.

@D-X-Y
Copy link
Owner

D-X-Y commented May 28, 2019

Sure, we use the hard mode and thus it is a one-shot vector. Something like this in PyTorch:

y_soft = Gumbel Softmax( ... )
y_hard = one_hot( y_soft )
y_hard = y_hard - y_soft.detach() + y_soft

During the forward, you could use:

cals = []
for i, w in enumerate(y_hard):
  if w.item() == 1:
    cals.append( op[i](x) * w )
  else:
    cals.append( x )
return sum(cals)

@D-X-Y D-X-Y closed this as completed May 28, 2019
@coolKeen
Copy link

Sure, we use the hard mode and thus it is a one-shot vector. Something like this in PyTorch:

y_soft = Gumbel Softmax( ... )
y_hard = one_hot( y_soft )
y_hard = y_hard - y_soft.detach() + y_soft

How do you implement the backward process?

@D-X-Y
Copy link
Owner

D-X-Y commented Jun 18, 2019

If you implement forward in the above style, it can automatically backward in PyTorch.

@brdav
Copy link

brdav commented Nov 4, 2019

Sure, we use the hard mode and thus it is a one-shot vector. Something like this in PyTorch:

y_soft = Gumbel Softmax( ... )
y_hard = one_hot( y_soft )
y_hard = y_hard - y_soft.detach() + y_soft

During the forward, you could use:

cals = []
for i, w in enumerate(y_hard):
  if w.item() == 1:
    cals.append( op[i](x) * w )
  else:
    cals.append( x )
return sum(cals)

Thanks for the code snippet!
Why would you append x for the paths with weight 0? Shouldn't there be no forward propagation?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants