Skip to content

Commit

Permalink
initial import from development repo
Browse files Browse the repository at this point in the history
  • Loading branch information
enalisnick committed May 20, 2016
1 parent be45ed2 commit e2839cd
Show file tree
Hide file tree
Showing 21 changed files with 1,716 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Expand Up @@ -60,3 +60,6 @@ target/

#Ipython Notebook
.ipynb_checkpoints

*~
*.pyc
Empty file added models/__init__.py
Empty file.
Empty file added models/neural_net/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions models/neural_net/activation_fns.py
@@ -0,0 +1,23 @@
import theano.tensor as T

def ReLU(x):
y = T.maximum(0.0, x)
return(y)

def Sigmoid(x):
y = T.nnet.sigmoid(x)
return(y)

def Softplus(x):
y = T.nnet.softplus(x)
return(y)

def Softmax(x):
y = T.nnet.softmax(x)
return(y)

def Identity(x):
return(x)

def Beta_fn(a, b):
return T.exp(T.gammaln(a) + T.gammaln(b) - T.gammaln(a+b))
33 changes: 33 additions & 0 deletions models/neural_net/layers.py
@@ -0,0 +1,33 @@
import numpy as np
import theano
import theano.tensor as T

class HiddenLayer(object):
def __init__(self, rng, input, n_in, n_out, activation, W=None, b=None):
self.input = input
self.activation = activation

if W is None:
W_values = np.asarray(0.01 * rng.standard_normal(size=(n_in, n_out)), dtype=theano.config.floatX)
W = theano.shared(value=W_values, name='W')
if b is None:
b_values = np.zeros((n_out,), dtype=theano.config.floatX)
b = theano.shared(value=b_values, name='b')
self.W = W
self.b = b

self.output = self.activation( T.dot(self.input, self.W) + self.b )

# parameters of the model
self.params = [self.W, self.b]


class ResidualHiddenLayer(HiddenLayer):
def __init__(self, rng, input, n_in, n_out, activation, W=None, b=None):

super(ResidualHiddenLayer, self).__init__(rng=rng, input=input, n_in=n_in, n_out=n_out, W=W, b=b, activation=activation)

# F(h_l-1) + h_l-1
self.output += self.input


34 changes: 34 additions & 0 deletions models/neural_net/loss_fns.py
@@ -0,0 +1,34 @@
import theano.tensor as T

def calc_binaryVal_negative_log_likelihood(data, probabilities, axis_to_sum=1):
if axis_to_sum != 1:
# addresses the case where we marginalize
data = T.extra_ops.repeat(T.shape_padaxis(data, axis=1), repeats = probabilities.shape[1], axis=1)
return - T.sum(data * T.log(probabilities) + (1 - data) * T.log(1 - probabilities), axis=axis_to_sum)

def calc_categoricalVal_negative_log_likelihood(data, probabilities, axis_to_sum=1):
if axis_to_sum != 1:
# addresses the case where we marginalize
data = T.extra_ops.repeat(T.shape_padaxis(data, axis=1), repeats = probabilities.shape[1], axis=1)
return - T.sum(data * T.log(probabilities), axis=axis_to_sum)

def calc_realVal_negative_log_likelihood(data, recon, axis_to_sum=1):
if axis_to_sum != 1:
# addresses the case where we marginalize
data = T.extra_ops.repeat(T.shape_padaxis(data, axis=1), repeats = recon.shape[1], axis=1)
return .5 * T.sum( (data - recon)**2, axis=axis_to_sum )

def calc_poissonVal_negative_log_likelihood(data, recon, axis_to_sum=1):
if axis_to_sum != 1:
# addresses the case where we marginalize
data = T.extra_ops.repeat(T.shape_padaxis(data, axis=1), repeats = recon.shape[1], axis=1)
return T.sum( T.exp(recon) - data * recon, axis=axis_to_sum )

def calc_cat_entropy(probabilities):
return - T.sum(probabilities * T.log(probabilities), axis=1)

def calc_cat_kl_divergence(p1, p2):
return -T.sum(p1 * T.log(p2), axis=1) - calc_cat_entropy(p1)

def calc_prediction_errors(class_idxs, pred_idxs):
return T.sum(T.neq(class_idxs, pred_idxs))
32 changes: 32 additions & 0 deletions models/neural_net/noise_fns.py
@@ -0,0 +1,32 @@
import numpy as np
import theano
import theano.tensor as T

def no_noise(input):
# needed because dnn pseudo ensemble code assumes each input / hidden layer gets noise
return input

def dropout_noise(rng, input, p=0.5):

srng = theano.tensor.shared_randomstreams.RandomStreams(rng.randint(999999))

# Bernoulli(1-p) multiplicative noise
mask = T.cast(srng.binomial(n=1, p=1-p, size=input.shape), theano.config.floatX)
return mask * input


def beta_noise(rng, input):

srng = theano.tensor.shared_randomstreams.RandomStreams(rng.randint(999999))

# Beta(.5,.5) multiplicative noise
mask = T.cast(T.sin( (np.pi / 2.0) * srng.uniform(size=input.shape, low=0.0, high=1.0) )**2, theano.config.floatX)
return mask * input

def poisson_noise(rng, input, lam=0.5):

srng = theano.tensor.shared_randomstreams.RandomStreams(rng.randint(999999))

# Poisson noise
mask = T.cast(srng.poisson(lam=lam, size=input.shape), theano.config.floatX)
return mask * input
84 changes: 84 additions & 0 deletions models/ss_Gauss_DGM.py
@@ -0,0 +1,84 @@
import numpy as np
import theano
import theano.tensor as T

from variational_coders.encoders import Gauss_Encoder_w_Labels
from variational_coders.decoders import Supervised_Decoder, Marginalized_Decoder

### Gaussian Semi-Supervised DGM ###
class SS_Gaussian_DGM(object):
def __init__(self, rng, sup_input, un_sup_input, labels,
sup_batch_size, un_sup_batch_size,
layer_sizes, layer_types, activations,
label_size, latent_size, out_activation, label_fn): # architecture specs

# check lists are correct sizes
assert len(layer_types) == len(layer_sizes) - 1
assert len(activations) == len(layer_sizes) - 1
assert label_size > 1 # labels need to be one-hot encoded!

# Set up the NN that parametrizes the encoder
layer_specs = zip(layer_types, layer_sizes, layer_sizes[1:])
self.encoding_layers = []
next_sup_layer_input = sup_input
next_un_sup_layer_input = un_sup_input
activation_counter = 0
for layer_type, n_in, n_out in layer_specs:
next_sup_layer = layer_type(rng=rng, input=next_sup_layer_input, activation=activations[activation_counter], n_in=n_in, n_out=n_out)
next_sup_layer_input = next_sup_layer.output
self.encoding_layers.append(next_sup_layer)
next_un_sup_layer = layer_type(rng=rng, input=next_un_sup_layer_input, activation=activations[activation_counter], n_in=n_in, n_out=n_out, W=next_sup_layer.W, b=next_sup_layer.b)
next_un_sup_layer_input = next_un_sup_layer.output
activation_counter += 1

# init encoders -- one supervised, one un_supervised
self.supervised_encoder = Gauss_Encoder_w_Labels(rng, input=next_sup_layer_input, batch_size=sup_batch_size, in_size=layer_sizes[-1], label_size=label_size, latent_size=latent_size, label_fn=label_fn)
self.un_supervised_encoder = Gauss_Encoder_w_Labels(rng, input=next_un_sup_layer_input, batch_size=un_sup_batch_size, in_size=layer_sizes[-1], label_size=label_size, latent_size=latent_size,
label_fn=label_fn,
W_y = self.supervised_encoder.W_y, b_y = self.supervised_encoder.b_y,
W_mu = self.supervised_encoder.W_mu, W_sigma = self.supervised_encoder.W_sigma)

# init decoders -- one supervised, one up_supervised
self.supervised_decoder = Supervised_Decoder(rng, input=self.supervised_encoder.latent_vars, labels=labels, latent_size=latent_size,
label_size=label_size, out_size=layer_sizes[-1], activation=activations[-1])
self.un_supervised_decoder = Marginalized_Decoder(rng, input=self.un_supervised_encoder.latent_vars, batch_size=un_sup_batch_size, latent_size=latent_size, label_size=label_size,
out_size=layer_sizes[-1], activation=activations[-1],
W_z=self.supervised_decoder.W_z, W_y=self.supervised_decoder.W_y, b=self.supervised_decoder.b)

# setup the NN that parametrizes the decoder (generative model)
layer_specs = zip(reversed(layer_types), reversed(layer_sizes), reversed(layer_sizes[:-1]))
self.decoding_layers = []
# add output activation as first activation. last act. taken care of by the decoder
activations = [out_activation] + activations[:-1]
activation_counter = len(activations)-1
next_sup_layer_input = self.supervised_decoder.output
next_un_sup_layer_input = self.un_supervised_decoder.output
for layer_type, n_in, n_out in layer_specs:
# supervised decoding layers
next_sup_layer = layer_type(rng=rng, input=next_sup_layer_input, activation=activations[activation_counter], n_in=n_in, n_out=n_out)
next_sup_layer_input = next_sup_layer.output
self.decoding_layers.append(next_sup_layer)
# un supervised decoding layers
next_un_sup_layer = layer_type(rng=rng, input=next_un_sup_layer_input, activation=activations[activation_counter], n_in=n_in, n_out=n_out, W=next_sup_layer.W, b=next_sup_layer.b)
next_un_sup_layer_input = next_un_sup_layer.output
activation_counter -= 1

# Grab all the parameters--only need to get one half since params are tied
self.params = [p for layer in self.encoding_layers for p in layer.params] + self.supervised_encoder.params + self.supervised_decoder.params + [p for layer in self.decoding_layers for p in layer.params]

# Grab the posterior params
self.sup_post_mu = self.supervised_encoder.mu
self.sup_post_log_sigma = self.supervised_encoder.log_sigma
self.un_sup_post_mu = self.un_supervised_encoder.mu
self.un_sup_post_log_sigma = self.un_supervised_encoder.log_sigma

# grab the kl-divergence functions
self.calc_sup_kl_divergence = self.supervised_encoder.calc_kl_divergence
self.calc_un_sup_kl_divergence = self.un_supervised_encoder.calc_kl_divergence

# Grab the reconstructions and predictions
self.x_recon_sup = next_sup_layer_input
self.x_recon_un_sup = next_un_sup_layer_input
self.y_probs_sup = self.supervised_encoder.y_probs
self.y_probs_un_sup = self.un_supervised_encoder.y_probs
self.y_preds_sup = T.argmax(self.y_probs_sup, axis=1)
85 changes: 85 additions & 0 deletions models/ss_StickBreaking_DGM.py
@@ -0,0 +1,85 @@
import numpy as np
import theano
import theano.tensor as T

from variational_coders.encoders import StickBreaking_Encoder_w_Labels
from variational_coders.decoders import Supervised_Decoder, Marginalized_Decoder

### Gaussian Semi-Supervised DGM ###
class SS_StickBreaking_DGM(object):
def __init__(self, rng, sup_input, un_sup_input, labels,
sup_batch_size, un_sup_batch_size,
layer_sizes, layer_types, activations,
label_size, latent_size, out_activation, label_fn): # architecture specs

# check lists are correct sizes
assert len(layer_types) == len(layer_sizes) - 1
assert len(activations) == len(layer_sizes) - 1
assert label_size > 1 # labels need to be one-hot encoded!

# Set up the NN that parametrizes the encoder
layer_specs = zip(layer_types, layer_sizes, layer_sizes[1:])
self.encoding_layers = []
next_sup_layer_input = sup_input
next_un_sup_layer_input = un_sup_input
activation_counter = 0
for layer_type, n_in, n_out in layer_specs:
next_sup_layer = layer_type(rng=rng, input=next_sup_layer_input, activation=activations[activation_counter], n_in=n_in, n_out=n_out)
next_sup_layer_input = next_sup_layer.output
self.encoding_layers.append(next_sup_layer)
next_un_sup_layer = layer_type(rng=rng, input=next_un_sup_layer_input, activation=activations[activation_counter], n_in=n_in, n_out=n_out, W=next_sup_layer.W, b=next_sup_layer.b)
next_un_sup_layer_input = next_un_sup_layer.output
activation_counter += 1

# init encoders -- one supervised, one un_supervised
self.supervised_encoder = StickBreaking_Encoder_w_Labels(rng, input=next_sup_layer_input, batch_size=sup_batch_size,
in_size=layer_sizes[-1], label_size=label_size, latent_size=latent_size, label_fn=label_fn)
self.un_supervised_encoder = StickBreaking_Encoder_w_Labels(rng, input=next_un_sup_layer_input, batch_size=un_sup_batch_size,
in_size=layer_sizes[-1], label_size=label_size, latent_size=latent_size, label_fn=label_fn,
W_y = self.supervised_encoder.W_y, b_y = self.supervised_encoder.b_y,
W_a = self.supervised_encoder.W_a, W_b = self.supervised_encoder.W_b)

# init decoders -- one supervised, one up_supervised
self.supervised_decoder = Supervised_Decoder(rng, input=self.supervised_encoder.latent_vars, labels=labels, latent_size=latent_size,
label_size=label_size, out_size=layer_sizes[-1], activation=activations[-1])
self.un_supervised_decoder = Marginalized_Decoder(rng, input=self.un_supervised_encoder.latent_vars, batch_size=un_sup_batch_size, latent_size=latent_size, label_size=label_size,
out_size=layer_sizes[-1], activation=activations[-1],
W_z=self.supervised_decoder.W_z, W_y=self.supervised_decoder.W_y, b=self.supervised_decoder.b)

# setup the NN that parametrizes the decoder (generative model)
layer_specs = zip(reversed(layer_types), reversed(layer_sizes), reversed(layer_sizes[:-1]))
self.decoding_layers = []
# add output activation as first activation. last act. taken care of by the decoder
activations = [out_activation] + activations[:-1]
activation_counter = len(activations)-1
next_sup_layer_input = self.supervised_decoder.output
next_un_sup_layer_input = self.un_supervised_decoder.output
for layer_type, n_in, n_out in layer_specs:
# supervised decoding layers
next_sup_layer = layer_type(rng=rng, input=next_sup_layer_input, activation=activations[activation_counter], n_in=n_in, n_out=n_out)
next_sup_layer_input = next_sup_layer.output
self.decoding_layers.append(next_sup_layer)
# un supervised decoding layers
next_un_sup_layer = layer_type(rng=rng, input=next_un_sup_layer_input, activation=activations[activation_counter], n_in=n_in, n_out=n_out, W=next_sup_layer.W, b=next_sup_layer.b)
next_un_sup_layer_input = next_un_sup_layer.output
activation_counter -= 1

# Grab all the parameters--only need to get one half since params are tied
self.params = [p for layer in self.encoding_layers for p in layer.params] + self.supervised_encoder.params + self.supervised_decoder.params + [p for layer in self.decoding_layers for p in layer.params]

# Grab the posterior params
self.sup_post_a = self.supervised_encoder.a
self.sup_post_b = self.supervised_encoder.b
self.un_sup_post_a = self.un_supervised_encoder.a
self.un_sup_post_b = self.un_supervised_encoder.b

# grab the kl-divergence functions
self.calc_sup_kl_divergence = self.supervised_encoder.calc_kl_divergence
self.calc_un_sup_kl_divergence = self.un_supervised_encoder.calc_kl_divergence

# Grab the reconstructions and predictions
self.x_recon_sup = next_sup_layer_input
self.x_recon_un_sup = next_un_sup_layer_input
self.y_probs_sup = self.supervised_encoder.y_probs
self.y_probs_un_sup = self.un_supervised_encoder.y_probs
self.y_preds_sup = T.argmax(self.y_probs_sup, axis=1)
Empty file.
61 changes: 61 additions & 0 deletions models/variational_coders/decoders.py
@@ -0,0 +1,61 @@
import numpy as np
import theano
import theano.tensor as T

### Regular Decoder
class Decoder(object):
def __init__(self, rng, input, latent_size, out_size, activation, W_z = None, b = None):
self.input = input
self.activation = activation

# setup the params
if W_z is None:
W_values = np.asarray(0.01 * rng.standard_normal(size=(latent_size, out_size)), dtype=theano.config.floatX)
W_z = theano.shared(value=W_values, name='W_hid_z')
if b is None:
b_values = np.zeros((out_size,), dtype=theano.config.floatX)
b = theano.shared(value=b_values, name='b')
self.W_z = W_z
self.b = b

self.pre_act_out = T.dot(self.input, self.W_z) + self.b
self.output = self.activation(self.pre_act_out)

# gather parameters
self.params = [self.W_z, self.b]

### Supervised Decoder
class Supervised_Decoder(Decoder):
def __init__(self, rng, input, labels, latent_size, label_size, out_size, activation, W_z = None, W_y = None, b = None):
self.labels = labels

# init parent class
super(Supervised_Decoder, self).__init__(rng=rng, input=input, latent_size=latent_size, out_size=out_size, activation=activation, W_z=W_z, b=b)

# setup the params
if W_y is None:
W_values = np.asarray(0.01 * rng.standard_normal(size=(label_size, out_size)), dtype=theano.config.floatX)
W_y = theano.shared(value=W_values, name='W_y')
self.W_y = W_y

self.output = self.activation( self.pre_act_out + T.dot(self.labels, self.W_y) )

# gather parameters
self.params += [self.W_y]

### Marginalized Decoder (for semi-supervised model)
class Marginalized_Decoder(Decoder):
def __init__(self, rng, input, batch_size, latent_size, label_size, out_size, activation, W_z, W_y, b):

# init parent class
super(Marginalized_Decoder, self).__init__(rng=rng, input=input, latent_size=latent_size, out_size=out_size, activation=activation, W_z=W_z, b=b)

# setup the params
self.W_y = W_y

# compute marginalized outputs
labels_tensor = T.extra_ops.repeat( T.shape_padaxis(T.eye(n=label_size, m=label_size), axis=0), repeats=batch_size, axis=0)
self.output = self.activation(T.extra_ops.repeat(T.shape_padaxis(T.dot(self.input, self.W_z), axis=1), repeats=label_size, axis=1) + T.dot(labels_tensor, self.W_y) + self.b)

# no params here since we'll grab them from the supervised decoder

0 comments on commit e2839cd

Please sign in to comment.