Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SELU values for a truncated normal distribution #10

Closed
carlthome opened this issue Mar 23, 2018 · 2 comments
Closed

SELU values for a truncated normal distribution #10

carlthome opened this issue Mar 23, 2018 · 2 comments

Comments

@carlthome
Copy link

carlthome commented Mar 23, 2018

SNNs/selu.py

Line 31 in f992b22

initializer = layers.variance_scaling_initializer(factor=1.0, mode='FAN_IN')
and many other examples (e.g. Keras) do an additional trick where samples are resampled if they're not within two standard deviations of the mean. I'm curious how much of an effect this truncation has on the fix points derivation? Are they analytically identical for a normal distribution and a truncated normal distribution?

I read in the paper that "Uniform and truncated Gaussian distributions with these moments led to
networks with similar behavior." but this feels unsatisfactory to me. Maybe a small discrepancy becomes really problematic for deeper networks? This aligns with my experience that it's still beneficial to have batchnorm/layernorm with SELU.

@agrinh
Copy link

agrinh commented Mar 23, 2018

@carlthome I don't really trust the truncation either. Was looking at how truncated normal weights might actually change the distribution with a modified version of SNNs/getSELUparameters. These are the numbers I get for truncated / not-truncated normally distributed weights. They seem quite different.

Normal weights

mean/var should be at: 0 / 1
Input data mean/var:   0.000050974286 / 0.999992370605
After selu:            0.021500919014 / 1.023049235344
After dropout mean/var 0.038502901793 / 1.021905064583

Truncated normal weights

mean/var should be at: 0 / 1
Input data mean/var:   -0.000023538434 / 1.000025033951
After selu:            -0.012259763665 / 0.838797271252
After dropout mean/var -0.007134077605 / 0.899475276470

Code

Here's the modified cell of SNNs/getSELUparameters I used:

import tensorflow as tf
import numpy as np

from __future__ import absolute_import, division, print_function
import numbers
from tensorflow.contrib import layers
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.layers import utils

in_data = tf.random_normal([10000, 50000], mean=myFixedPointMean, stddev=np.sqrt(myFixedPointVar))

# Truncated normal weights
weights = tf.truncated_normal([50000, 1], mean=0., stddev=1 / np.sqrt(50000))

# Normal weights
#weights = tf.random_normal([50000, 1], mean=0., stddev=1 / np.sqrt(50000))

x = tf.matmul(in_data, weights)
w = selu(x)
y = dropout_selu(w,0.2,training=True)
init = tf.global_variables_initializer()
                
gpu_options = tf.GPUOptions(allow_growth=True)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
    sess.run(init)
    in_data, x, w, y = sess.run([in_data, x, w, y])
    print("mean/var should be at:", myFixedPointMean, "/", myFixedPointVar)
    print("Input data mean/var:  ", "{:.12f}".format(np.mean(in_data)), "/", "{:.12f}".format(np.var(in_data)))    
    print("After selu:           ", "{:.12f}".format(np.mean(w)), "/", "{:.12f}".format(np.var(w)))
    print("After dropout mean/var", "{:.12f}".format(np.mean(y)), "/", "{:.12f}".format(np.var(y)))

@gklambauer
Copy link
Member

gklambauer commented Mar 24, 2018

It's simply because the variance is incorrect. The parameters of the function to generate truncated normals uses the parameter stddev before truncation, but it actually should be sqrt(1/n) after truncation.

My solution was to solve the expressions for the variance of the truncated Gaussian for the variance of the non-truncated Gaussian.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants