Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN gradient may be due to weight initialization #22

Open
victorconan opened this issue Nov 17, 2020 · 4 comments
Open

NaN gradient may be due to weight initialization #22

victorconan opened this issue Nov 17, 2020 · 4 comments

Comments

@victorconan
Copy link

Hi Ed,

I saw in your code, the weights are initialized with truncated normal distribution. When I ran it, it seemed in the medical-code-loss part, this produced large values feeding to exp and resulted in inf in the loss and NaN gradients. Also because of such initial weights, the loss in general is pretty high around several hundreds, especially L2 loss is around tens of thousands. Then I changed the weight initialization to be uniform with a small interval [-0.1, 0.1]. That seems to produce reasonable magnitude of loss (under 10). I wonder if you still remember whether you have tried other weight initializations and how they impact the results.

Another question I have is that in the paper, the loss is averaged over T. Is this T visits in the batch or visits per patient? In your code, it seems, your ivec and jvec are generated for the batch. So in the medical-code-loss calculation, it is averaging over all visits in a batch, instead of averaging per patient and then averaging over all patients in a batch?

Thanks!

@mp2893
Copy link
Owner

mp2893 commented Nov 18, 2020

Hi Victor,

Thanks for taking interest in my work.

As for your first question: No, I haven’t tried other initialization strategy. But I think your approach makes sense. Maybe care to contribute to the repo?

For the second question: IIRC (it was a long time ago I wrote this code) ivec and jvec are constructed from the preprocessed patient records so there is no concept of “patient” in the minibatch. There is just a bunch of random visits from the EHR.

Best
Ed

@victorconan
Copy link
Author

Hi Ed! Thanks for the reply! Very appreciate it!

I am transforming your code into TF2 and testing it. I will see if I can contribute to the repo. I am also comparing the results if I implement the code exactly as described in your paper. My data is larger (~2M patients, ~77k medical codes) and it seems to take 2.5 days to train 1 epoch on single CPU...

@mp2893
Copy link
Owner

mp2893 commented Nov 21, 2020

Sounds interesting. Feel free to share any result from your experiments, so that others might gain new knowledge!

@victorconan
Copy link
Author

victorconan commented Jan 14, 2021

I got my 10 epochs of training done. And I found that 80% of the codes are all 0s embeddings (I am taking ReLU(W_emb))...In general the visit loss (~1e-3) is much smaller than the code loss (~10). It seems the co-occurrence loss is dominating the training? and it has difficulty learning for most of the codes.

Also I found transferring the code loss into TF 2 would have some issue when calculating the exponential terms. Taking the exponential of vector product would require the vector to be sparse. Otherwise the value would be very large:

emb_w = tf.maximum(emb_w, 0)
emb_w_transpose = tf.transpose(emb_w)
norms = tf.reduce_sum(tf.math.exp(tf.matmul(emb_w, emb_w_transpose)), axis=1)

i = tf.gather(emb_w_transpose, ivec, axis=1)
j = tf.gather(emb_w_transpose, jvec, axis=1)

numerator = tf.math.exp(tf.reduce_sum(j * i, axis=0))
denominator = tf.gather(norms, ivec)
cost = -tf.math.log(
       numerator / denominator 
        + eps
)
cost = tf.reduce_mean(cost)

So I switch to the below tensorflow function which will prevent inf loss:

norms = tf.matmul(emb_w, emb_w_transpose)

numerator = tf.reduce_sum(j * i, axis=0)
denominator = tf.math.reduce_logsumexp(tf.gather(norms, ivec), axis=1)
cost = - (numerator - denominator)
cost = tf.reduce_mean(cost)

And it's 3 times slower...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants