Skip to content
This repository has been archived by the owner on Jun 26, 2021. It is now read-only.

Commit

Permalink
Merge pull request #92 from justusschock/fix_tf_resnet18
Browse files Browse the repository at this point in the history
Fix tf resnet18
  • Loading branch information
ORippler committed May 6, 2019
2 parents 25702e0 + 5b55552 commit 9a133c0
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 49 deletions.
170 changes: 124 additions & 46 deletions delira/models/classification/ResNet18.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,68 +6,146 @@
relu = tf.keras.layers.ReLU
gap2d = tf.keras.layers.GlobalAveragePooling2D
batchnorm2d = tf.keras.layers.BatchNormalization
add = tf.keras.layers.Add


def get_image_format_and_axis():
"""
helper function to read out keras image_format and convert to axis
dimension
Returns
-------
str
image data format (either "channels_first" or "channels_last")
int
integer corresponding to the channel_axis (either 1 or -1)
"""
image_format = tf.keras.backend.image_data_format()
if image_format == "channels_first":
return image_format, 1
elif image_format == "channels_last":
return image_format, -1
else:
raise RuntimeError(
"Image format unknown, got: {}".format(image_format)
)


class ResBlock(tf.keras.Model):
def __init__(self, filters_in: int, filters: int,
strides: tuple, kernel_size: int, bias=False):
super(ResBlock, self).__init__()

_, _axis = get_image_format_and_axis()

self.identity = None
if filters_in != filters:
self.identity = conv2d(
filters=filters, strides=strides[0],
kernel_size=1, padding='same', use_bias=bias)
self.bnorm_identity = batchnorm2d(axis=_axis)

self.conv_1 = conv2d(
filters=filters, strides=strides[0],
kernel_size=kernel_size,
padding='same', use_bias=bias)
self.batchnorm_1 = batchnorm2d(axis=_axis)

self.conv_2 = conv2d(
filters=filters, strides=strides[1],
kernel_size=kernel_size,
padding='same', use_bias=bias)
self.batchnorm_2 = batchnorm2d(axis=_axis)

self.relu = relu()
self.add = add()

def call(self, inputs, training=None):

if self.identity:
identity = self.identity(inputs)
identity = self.bnorm_identity(identity, training=training)
else:
identity = inputs

x = self.conv_1(inputs)
x = self.batchnorm_1(x, training=training)
x = self.relu(x)
x = self.conv_2(x)
x = self.batchnorm_2(x, training=training)
x = self.add([x, identity])
x = self.relu(x)

return x


class ResNet18(tf.keras.Model):
def __init__(self, num_classes=None):
def __init__(self, num_classes=None, bias=False):
super(ResNet18, self).__init__()

_image_format, _axis = get_image_format_and_axis()

self.conv1 = conv2d(filters=64, strides=2, kernel_size=7,
padding='same')
self.batchnorm1 = batchnorm2d(axis=1)
padding='same', use_bias=bias)
self.batchnorm1 = batchnorm2d(axis=_axis)
self.relu = relu()
self.pool1 = maxpool2d(pool_size=3, strides=2)

self.conv2_1 = conv2d(filters=64, strides=1, kernel_size=3,
padding='same')
self.conv2_2 = conv2d(filters=64, strides=1, kernel_size=3,
padding='same')

self.conv3_1 = conv2d(filters=128, strides=2, kernel_size=3,
padding='same')
self.conv3_2 = conv2d(filters=128, strides=1, kernel_size=3,
padding='same')

self.conv4_1 = conv2d(filters=256, strides=2, kernel_size=3,
padding='same')
self.conv4_2 = conv2d(filters=256, strides=1, kernel_size=3,
padding='same')

self.conv5_1 = conv2d(filters=512, strides=2, kernel_size=3,
padding='same')
self.conv5_2 = conv2d(filters=512, strides=1, kernel_size=3,
padding='same')

self.gap = gap2d()
self.dense1 = dense(1000)
self.dense2 = dense(num_classes)
self.relu = relu()
self.block_2_1 = ResBlock(filters_in=64, filters=64,
strides=(1, 1), kernel_size=3,
bias=bias)

self.block_2_2 = ResBlock(filters_in=64, filters=64,
strides=(1, 1), kernel_size=3,
bias=bias)

self.block_3_1 = ResBlock(filters_in=64, filters=128,
strides=(2, 1), kernel_size=3,
bias=bias)

self.block_3_2 = ResBlock(filters_in=128, filters=128,
strides=(1, 1), kernel_size=3,
bias=bias)

self.block_4_1 = ResBlock(filters_in=128, filters=256,
strides=(2, 1), kernel_size=3,
bias=bias)

self.block_4_2 = ResBlock(filters_in=256, filters=256,
strides=(1, 1), kernel_size=3,
bias=bias)

self.block_5_1 = ResBlock(filters_in=256, filters=512,
strides=(2, 1), kernel_size=3,
bias=bias)

self.block_5_2 = ResBlock(filters_in=512, filters=512,
strides=(1, 1), kernel_size=3,
bias=bias)
self.dense = dense(num_classes)
self.gap = gap2d(data_format=_image_format)

def call(self, inputs, training=None):

x = self.conv1(inputs)
x = self.batchnorm1(x)
x = self.batchnorm1(x, training=training)
x = self.relu(x)
x = self.pool1(x)

x = self.conv2_1(x)
x = self.relu(x)
x = self.conv2_2(x)
x = self.relu(x)
x = self.block_2_1(x, training=training)
x = self.block_2_2(x, training=training)

x = self.conv3_1(x)
x = self.relu(x)
x = self.conv3_2(x)
x = self.relu(x)
x = self.block_3_1(x, training=training)
x = self.block_3_2(x, training=training)

x = self.conv4_1(x)
x = self.relu(x)
x = self.conv4_2(x)
x = self.relu(x)
x = self.block_4_1(x, training=training)
x = self.block_4_2(x, training=training)

x = self.conv5_1(x)
x = self.relu(x)
x = self.conv5_2(x)
x = self.relu(x)
x = self.block_5_1(x, training=training)
x = self.block_5_2(x, training=training)

x = self.gap(x)
x = self.dense1(x)
x = self.dense(x)

x = self.dense2(x)
return x
4 changes: 1 addition & 3 deletions delira/models/classification/classification_network_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from delira.models.abstract_network import AbstractTfNetwork
from delira.models.classification.ResNet18 import ResNet18

tf.keras.backend.set_image_data_format('channels_first')


logger = logging.getLogger(__name__)


Expand All @@ -32,6 +29,7 @@ def __init__(self, in_channels: int, n_outputs: int, **kwargs):
n_outputs : int
number of outputs (usually same as number of classes)
"""
tf.keras.backend.set_image_data_format('channels_first')
# register params by passing them as kwargs to parent class __init__
super().__init__(in_channels=in_channels,
n_outputs=n_outputs,
Expand Down

0 comments on commit 9a133c0

Please sign in to comment.