Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Difficulties with nn.sparse_softmax_cross_entropy_with_logits #124

Closed
oxinabox opened this issue Jan 31, 2017 · 6 comments
Closed

Difficulties with nn.sparse_softmax_cross_entropy_with_logits #124

oxinabox opened this issue Jan 31, 2017 · 6 comments

Comments

@oxinabox
Copy link
Collaborator

oxinabox commented Jan 31, 2017

I have been having trouble working out how to use nn.sparse_softmax_cross_entropy_with_logits.
See this StackOverflow question

I always seem to get Tensorflow error: Status: Incompatible shapes:
when I run an optimizer over it.

I feel like it would benefit from a test proving it works (and giving an example),
and maybe from some docs if it is different from the python one.

Here is a MWE (smaller than on SO):

using TensorFlow
using Distributions
sess = Session(Graph())
input  = constant([0. 2. 2.; 1. 2. 3.])

variable_scope("myscope", initializer=Normal(0, .1)) do
    global W = get_variable("weights3", [3, 3], Float64)
end

logits = input*W
labels = constant([1, 2])
costs = nn.sparse_softmax_cross_entropy_with_logits(logits, labels)
loss = reduce_mean(-costs)
optimizer = train.minimize(train.AdamOptimizer(0.1), loss)

run(sess, initialize_all_variables())
run(sess, [costs, optimizer])

Gives me:

Tensorflow error: Status: Incompatible shapes: [1,2] vs. [2,3]
	 [[Node: gradients/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_2/Const/SparseSoftmaxCrossEntropyWithLogits_3/SparseSoftmaxCrossEntropyWithLogits_20_grad/mul = Mul[T=DT_DOUBLE, _class=[], _device="/job:localhost/replica:0/task:0/cpu:0"](gradients/SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_2/Const/SparseSoftmaxCrossEntropyWithLogits_3/SparseSoftmaxCrossEntropyWithLogits_20_grad/ExpandDims, SparseSoftmaxCrossEntropyWithLogits/SparseSoftmaxCrossEntropyWithLogits_2/Const/SparseSoftmaxCrossEntropyWithLogits_3/SparseSoftmaxCrossEntropyWithLogits_20:1)]]


 in check_status(::TensorFlow.Status) at /home/ubuntu/.julia/v0.5/TensorFlow/src/core.jl:101
 in run(::TensorFlow.Session, ::Array{TensorFlow.Port,1}, ::Array{Any,1}, ::Array{TensorFlow.Port,1}, ::Array{Ptr{Void},1}) at /home/ubuntu/.julia/v0.5/TensorFlow/src/run.jl:96
 in run(::TensorFlow.Session, ::Array{TensorFlow.Tensor,1}, ::Dict{Any,Any}) at /home/ubuntu/.julia/v0.5/TensorFlow/src/run.jl:143
 in run(::TensorFlow.Session, ::Array{TensorFlow.Tensor,1}) at /home/ubuntu/.julia/v0.5/TensorFlow/src/run.jl:169
@malmaud
Copy link
Owner

malmaud commented Jan 31, 2017

I will take a look soonish

@staticfloat
Copy link
Contributor

staticfloat commented Feb 1, 2017

This is expected, is it not? The docs for nn.sparse_softmax_cross_entropy_with_logits state that the labels and predicted labels need to have the same shape. Take a look at the shape of logits and the shape of labels.

Nope, I'm dead wrong, there's no strict relationship defined between the number of valid label values and the number of output nodes in the net. I was confused between the docs of sparse_softmax_cross_entropy_with_logits and softmax_cross_entropy_with_logits.

@staticfloat
Copy link
Contributor

If you alter the shape of your weights from [3, 3] to [3, 2], it works.

@malmaud
Copy link
Owner

malmaud commented Feb 2, 2017

I checked; there is a genuine issue with taking gradients of SparseSoftmaxCrossEntryWithLogits operations. Not sure of the cause yet though.

@malmaud
Copy link
Owner

malmaud commented Feb 2, 2017

This is fixed by JuliaIO/ProtoBuf.jl#87

@malmaud
Copy link
Owner

malmaud commented Feb 4, 2017

Closed by JuliaLang/METADATA.jl#7811 (comment)

@malmaud malmaud closed this as completed Feb 4, 2017
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants