Skip to content

Commit

Permalink
Merge pull request #3 from Mr8ND/master
Browse files Browse the repository at this point in the history
Adding Alexnet examples and LICENSE files.
  • Loading branch information
Mr8ND committed Jun 26, 2019
2 parents 2c75dc0 + d55fbc9 commit 3a6e440
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
DeepCDE: Neural Networks Conditional Density Estimation
===

Conditional density estimation via deep neural network using CDE loss.
Tensorflow and PyTorch code Conditional density estimation via deep neural network using CDE loss.

Authors: Taylor Pospisil, Nic Dalmasso
85 changes: 85 additions & 0 deletions model_examples/alexnet_pytorch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@

### Alexnet - Pytorch
Pytorch implementation of AlexNet for the task of estimating the probability distribution of correct orientations of an image.
The input to the model consists of (277, 277) colored images with 3 channels (i.e. color bands).
The target is a continuous variable $y \in (0,2\pi)$ for image orientation.

Please note that the attached implementations do not include code for generating training and testing sets.
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from deepcde.deepcde_pytorch import cde_layer, cde_loss, cde_predict
from deepcde.utils import box_transform
from deepcde.bases.cosine import CosineBasis

# PRE-PROCESSING #############################################################################
# Create basis (in this case 31 cosine basis)
n_basis = 31
basis = CosineBasis(n_basis)

# ... Creation of training and testing set ...

# Evaluate the y_train over the basis
y_train = box_transform(y_train, 0, 2*math.pi) # transform to a variable between 0 and 1
y_basis = basis.evaluate(y_train) # evaluate basis
y_basis = y_basis.astype(np.float32)

# ALEXNET DEFINITION #########################################################################
# `basis_size` is the number of basis (in this case 31).
# `marginal_beta` is the initial value for the bias of the cde layer, if available
class AlexNetCDE(nn.Module):
def __init__(self, basis_size, marginal_beta=None):
super(AlexNetCDE, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2)
self.conv2 = nn.Conv2d(64, 192, kernel_size=5, padding=2)
self.conv3 = nn.Conv2d(192, 384, kernel_size=3, padding=1)
self.conv4 = nn.Conv2d(384, 256, kernel_size=3, padding=1)
self.conv5 = nn.Conv2d(256, 256, kernel_size=3, padding=1)

self.fc1 = nn.Linear(256 * 6 * 6, 4096)
self.fc2 = nn.Linear(4096, 4096)
self.cde = cde_layer(4096, basis_size - 1)
if marginal_beta:
self.cde.bias.data = torch.from_numpy(marginal_beta[1:]).type(torch.FloatTensor)

def forward(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), kernel_size=3, stride=2)
x = F.max_pool2d(F.relu(self.conv2(x)), kernel_size=3, stride=2)
x = F.relu(self.conv3(x))
x = F.relu(self.conv4(x))
x = F.max_pool2d(F.relu(self.conv5(x)), kernel_size=3, stride=2)
x = F.dropout(x.view(x.size(0), 256 * 6 * 6), training=self.training)
x = F.dropout(F.relu(self.fc1(x)), training=self.training)
x = F.dropout(F.relu(self.fc2(x)), training=self.training)
beta = self.cde(x)
return beta

# Definition of model and loss function (examples)
model = AlexNetCDE(basis_size=n_basis)
loss = cde_loss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.5)

# TRAINING #################################################################################
# ... Creation of `training_set_loader` and `testing_set_loader` object ...
for epoch in range(n_epochs):
model.train()
for batch_idx, (x_batch, y_basis_batch) in enumerate(training_set_loader):
x_batch, y_basis_batch = x_batch.to(device), y_basis_batch.to(device)
optimizer.zero_grad()
beta_batch = model(x_batch)
loss = loss(beta_batch, y_basis_batch)
loss.backward()
optimizer.step()

# ... Evaluation of testing set ...

# PREDICTION ##############################################################################
# ... Selection of `x_test` to get conditional density estimate of ...
y_grid = np.linspace(0, 1, 1000) # Creating a grid over the density range
beta_prediction = model(x_test)
cdes = cde_predict(beta_prediction, 0, 1, y_grid, basis, n_basis)
predicted_cdes = cdes * 2 * math.pi # Re-normalize
```
143 changes: 143 additions & 0 deletions model_examples/alexnet_tf.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
### Alexnet - Tensorflow
Tensorflow implementation of AlexNet for the task of estimating the probability distribution of correct orientations of an image.
The input to the model consists of (277, 277) colored images with 3 channels (i.e. color bands).
The target is a continuous variable $y \in (0,2\pi)$ for image orientation.

Please note that the attached implementations do not include code for generating training and testing sets.


```python
import numpy as np
import tensorflow as tf
import math
from deepcde.deepcde_tensorflow import cde_loss, cde_predict
from deepcde.utils import box_transform
from deepcde.bases.cosine import CosineBasis

# PRE-PROCESSING #############################################################################
# Create basis (in this case 31 cosine basis)
n_basis = 31
basis = CosineBasis(n_basis)

# ... Creation of training and testing set ...

# Evaluate the y_train over the basis
y_train = box_transform(y_train, 0, 2*math.pi) # transform to a variable between 0 and 1
y_basis = basis.evaluate(y_train) # evaluate basis
y_basis = y_basis.astype(np.float32)

# Features and basis are inserted in a dictionary of this form
features = {
'x': tf.FixedLenFeature([20 * 5 * 1], tf.float32),
'y': tf.FixedLenFeature([1], tf.float32),
'y_basis': tf.FixedLenFeature([n_basis - 1], tf.float32)
}
# ... Generation of TensorFlow training and testing data ...

# ALEXNET DEFINITION #########################################################################
weight_sd = 0.01 # sd parameter for initialization of weights
marginal_beta = None # Initialization parameter for bias in CDE layer

def model_function(features, y_basis, mode,
weight_sd=weight_sd, marginal_beta=marginal_beta):
# Input Layer
input_layer = tf.reshape(features, shape=[-1, 277, 277, 3])

# Convolutional Layer 1
conv1 = tf.layers.conv2d(
strides=[1, 4, 4, 1],
inputs=input_layer,
filters=64,
kernel_size=[1, 11, 11, 1],
padding="same",
activation=tf.nn.relu)
pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[3, 3], strides=2)

# Convolutional Layer 2
conv2 = tf.layers.conv2d(
inputs=pool1,
filters=192,
kernel_size=[5, 5],
strides=[1, 1, 1, 1],
padding="same",
activation=tf.nn.relu)
pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[3, 3], strides=2)

# Convolutional Layers 3, 4 and 5
conv3 = tf.layers.conv2d(
inputs=pool2,
filters=384,
kernel_size=[3, 3],
strides=[1, 1, 1, 1],
padding="same",
activation=tf.nn.relu)

conv4 = tf.layers.conv2d(
inputs=conv3,
filters=256,
kernel_size=[3, 3],
strides=[1, 1, 1, 1],
padding="same",
activation=tf.nn.relu)

conv5 = tf.layers.conv2d(
inputs=conv4,
filters=256,
kernel_size=[3, 3],
strides=[1, 1, 1, 1],
padding="same",
activation=tf.nn.relu)
pool5 = tf.layers.max_pooling2d(inputs=conv5, pool_size=[3, 3], strides=2)
pool5_flat = tf.reshape(pool5, [-1, 256 * 6 * 6])

# Dense Layers
dense_1 = tf.layers.dense(inputs=pool5_flat, units=4096, activation=tf.nn.relu)
dropout_1 = tf.layers.dropout(inputs=dense_1, rate=0.5,
training=mode == tf.estimator.ModeKeys.TRAIN)

dense_2 = tf.layers.dense(inputs=dropout_1, units=4096, activation=tf.nn.relu)
dropout_2 = tf.layers.dropout(inputs=dense_2, rate=0.5,
training=mode == tf.estimator.ModeKeys.TRAIN)

# CDE Layer
beta = cde_layer(dropout_2, weight_sd, marginal_beta)

# Loss Computation
loss = cde_loss(beta, y_basis)
metrics = {
'cde_loss': tf.metrics.mean(cde_loss(beta, y_basis))
}

# Training and Evaluation Steps
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode,
loss=loss,
eval_metric_ops=metrics)

elif mode == tf.estimator.ModeKeys.TRAIN:
# Get train operator, using Adam for instance
train_op = tf.train.AdamOptimizer().minimize(
loss, global_step=tf.train.get_or_create_global_step())
return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

else:
raise NotImplementedError('Unknown mode {}'.format(mode))

# TRAINING #################################################################################
cfg = tf.estimator.RunConfig(save_checkpoints_secs=120)
estimator = tf.estimator.Estimator(model_fn, model_dir, cfg)
# ... Creation of `train_set` and `test_set` objects ...
# ... Inclusion of all extra parameters like learning rate, momentum, etc. ...
tf.estimator.train_and_evaluate(estimator, train_set, test_set)


# PREDICTION ##############################################################################
# ... Selection of `x_test` to get conditional density estimate of ...
y_grid = np.linspace(0, 1, 1000)
beta = tf.placeholder(tf.float32, [1, n_basis-1])
with tf.Session() as sess:
# ... Resume model from checkpoint ...
cdes = cde_predict(sess, beta, 0, 1, y_grid, basis, n_basis,
input_dict={'features': x_test})
cdes = cdes * 2 * math.pi # Re-normalize
```

0 comments on commit 3a6e440

Please sign in to comment.