Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create unet_lstm.py #207

Merged
merged 21 commits into from Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
75048ca
Create unet_lstm.py
Aakanksha-Rana Jan 19, 2022
d8c1f9b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2022
7df74ec
Update unet_lstm.py
Aakanksha-Rana Jan 20, 2022
80f7817
Merge branch 'master' into 3D_LSTM_Transfomers
Aakanksha-Rana Jan 24, 2022
dea32f6
Merge branch 'master' into 3D_LSTM_Transfomers
satra Feb 28, 2022
2b3a217
Merge branch 'master' into 3D_LSTM_Transfomers
Aakanksha-Rana Mar 9, 2022
dd5310c
Merge branch 'neuronets:master' into 3D_LSTM_Transfomers
Aakanksha-Rana Mar 13, 2022
49396b7
Merge branch 'master' into 3D_LSTM_Transfomers
Aakanksha-Rana May 10, 2022
aba42ce
Merge branch 'master' into 3D_LSTM_Transfomers
satra May 11, 2022
194ea7c
Update unet_lstm.py
Aakanksha-Rana May 11, 2022
f1cf8bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2022
aff4172
updated unet_lstm test
Aakanksha-Rana May 11, 2022
1cd8082
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 11, 2022
449ba7f
Update unet_lstm.py
Aakanksha-Rana May 11, 2022
96b1e44
docstrings
Aakanksha-Rana May 16, 2022
f0a46bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 16, 2022
afeace7
Create max_pool4d.py
Aakanksha-Rana May 16, 2022
2d683b7
Update max_pool4d.py
Aakanksha-Rana May 16, 2022
497095e
Merge branch 'master' into 3D_LSTM_Transfomers
Aakanksha-Rana Jul 25, 2022
295d33f
Merge branch 'master' into 3D_LSTM_Transfomers
Hoda1394 Aug 3, 2022
9d6bc88
Merge branch 'master' into 3D_LSTM_Transfomers
satra Aug 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 11 additions & 0 deletions nobrainer/models/tests/models_test.py
Expand Up @@ -13,6 +13,7 @@
from ..meshnet import meshnet
from ..progressivegan import progressivegan
from ..unet import unet
from ..unet_lstm import unet_lstm
from ..vnet import vnet
from ..vox2vox import Vox_ensembler, vox_gan

Expand Down Expand Up @@ -208,6 +209,16 @@ def test_bayesian_vnet():
)


def test_unet_lstm():
input_shape = (1, 32, 32, 32, 32)
n_classes = 1
x = 10 * np.random.random(input_shape)
y = 10 * np.random.random(input_shape)
model = unet_lstm(input_shape=(32, 32, 32, 32, 1), n_classes=1)
actual_output = model.predict(x)
assert actual_output.shape == y.shape[:-1] + (n_classes,)


def test_vox2vox():
input_shape = (1, 32, 32, 32, 1)
n_classes = 1
Expand Down
313 changes: 313 additions & 0 deletions nobrainer/models/unet_lstm.py
@@ -0,0 +1,313 @@
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.regularizers import l2


def unet_lstm(
n_classes=1,
input_shape=(32, 32, 32, 8, 1),
filters=8,
activation="tanh",
reg_val=1e-08,
drop_val=0.0,
drop_val_recur=0.0,
name="unet_lstm",
):
"""unet_lstm - A model for the spatial and temporal evolution of 3D fields."""

batch_norm = False
concat_axis = -1

inputs = layers.Input(shape=(input_shape))

x_layer = layers.ConvLSTM3D(
filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(inputs)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(x_layer)

conv1 = layers.ConvLSTM3D(
filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(conv1)

# x_layer = layers.MaxPooling4D(pool_size=(1, 2, 2 , 2))(conv1) ToDo
x_layer = layers.ConvLSTM3D(
2 * filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(x_layer)

conv2 = layers.ConvLSTM3D(
2 * filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(conv2)

# x_layer = layers.MaxPooling4D(pool_size=(1, 2, 2, 2))(conv2) ToDo
x_layer = layers.ConvLSTM3D(
4 * filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(x_layer)

conv3 = layers.ConvLSTM3D(
4 * filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(conv3)

# x_layer = layers.MaxPooling4D(pool_size=(1, 2, 2, 2))(conv3)
x_layer = layers.ConvLSTM3D(
8 * filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(x_layer)

conv4 = layers.ConvLSTM3D(
8 * filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(conv4)

# x_layer = layers.MaxPooling4D(pool_size=(1, 2, 2 , 2))(conv4) ToDo
x_layer = layers.ConvLSTM3D(
16 * filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(x_layer)

conv5 = layers.ConvLSTM3D(
16 * filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(conv5)

# x_layer = layers.UpSampling4D(size=(1, 2, 2, 2))(conv5) ToDo
x_layer = layers.concatenate([x_layer, conv4], axis=concat_axis)

x_layer = layers.ConvLSTM3D(
8 * filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(x_layer)

x_layer = layers.ConvLSTM3D(
8 * filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(x_layer)

# x_layer = layers.UpSampling4D(size=(1, 2, 2, 2))(x_layer) ToDo
x_layer = layers.concatenate([x_layer, conv3], axis=concat_axis)

x_layer = layers.ConvLSTM3D(
4 * filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(x_layer)

x_layer = layers.ConvLSTM3D(
4 * filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(x_layer)

# x_layer = layers.UpSampling4D(size=(1, 2, 2, 2))(x_layer) ToDo
x_layer = layers.concatenate([x_layer, conv2], axis=concat_axis)

x_layer = layers.ConvLSTM3D(
2 * filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(x_layer)

x_layer = layers.ConvLSTM3D(
2 * filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(x_layer)

# x_layer = layers.UpSampling4D(size=(1, 2, 2))(x_layer) ToDo
x_layer = layers.concatenate([x_layer, conv1], axis=concat_axis)

x_layer = layers.ConvLSTM3D(
filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(x_layer)

x_layer = layers.ConvLSTM3D(
filters,
3,
activation=activation,
padding="same",
kernel_regularizer=l2(reg_val),
recurrent_regularizer=l2(reg_val),
bias_regularizer=l2(reg_val),
dropout=drop_val,
recurrent_dropout=drop_val_recur,
return_sequences=True,
)(x_layer)
if batch_norm:
x_layer = layers.BatchNormalization(axis=concat_axis)(x_layer)

outputs = layers.ConvLSTM3D(
n_classes, 1, activation="linear", padding="same", return_sequences=False
)(x_layer)

return tf.keras.Model(inputs=inputs, outputs=outputs)