In [1]:
import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as layers

In [5]:
class DropPath(tf.keras.layers.Layer):
    # borrowed from https://github.com/rishigami/Swin-Transformer-TF/blob/main/swintransformer/model.py
    def __init__(self, drop_prob=None):
        super().__init__()
        self.drop_prob = drop_prob

    def call(self, x, training=None):
        return drop_path(x, self.drop_prob, training)

class ConvNeXt_Block(layers.Layer):
    r""" ConvNeXt Block.
    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
    """

    def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
        super().__init__()
        self.dwconv = layers.DepthwiseConv2D(kernel_size=7, padding='same')  # depthwise conv
        self.norm = layers.LayerNormalization(epsilon=1e-6)
        # pointwise/1x1 convs, implemented with linear layers
        self.pwconv1 = layers.Dense(4 * dim)
        self.act = layers.Activation('gelu')
        self.pwconv2 = layers.Dense(dim)
        self.drop_path = DropPath(drop_path)
        self.dim = dim
        self.layer_scale_init_value = layer_scale_init_value

    def build(self, input_shape):
        self.gamma = tf.Variable(
            initial_value=self.layer_scale_init_value * tf.ones((self.dim)),
            trainable=True,
            name='_gamma')
        self.built = True

    def call(self, x):
        input = x
        x = self.dwconv(x)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)
        if self.gamma is not None:
            x = self.gamma * x

        x = input + self.drop_path(x)
        return x

In [6]:
def drop_path(inputs, drop_prob, is_training):
    # borrowed from https://github.com/rishigami/Swin-Transformer-TF/blob/main/swintransformer/model.py
    if (not is_training) or (drop_prob == 0.):
        return inputs

    # Compute keep_prob
    keep_prob = 1.0 - drop_prob

    # Compute drop_connect tensor
    random_tensor = keep_prob
    shape = (tf.shape(inputs)[0],) + (1,) * \
        (len(tf.shape(inputs)) - 1)
    random_tensor += tf.random.uniform(shape, dtype=inputs.dtype)
    binary_tensor = tf.floor(random_tensor)
    output = tf.math.divide(inputs, keep_prob) * binary_tensor
    return output

def downsample_block(input, dim):
    y = layers.LayerNormalization(axis=-1, epsilon=1e-6)(input)
    y = layers.Conv2D(dim, kernel_size=2, strides=2)(y)
    return y


def create_convnext_model(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], num_classes=1000):

    input = layers.Input(shape=(224,224,3))

    # Stem + res2
    y = layers.Conv2D(dims[0], kernel_size=4, strides=4)(input)
    y = layers.LayerNormalization(epsilon=1e-6)(y)
    for i in range(depths[0]):
        y = ConvNeXt_Block(dims[0])(y)

    # downsample + res3
    y = downsample_block(y, dims[1])
    for i in range(depths[1]):
        y = ConvNeXt_Block(dims[1])(y)

    # downsample + res4
    y = downsample_block(y, dims[2])
    for i in range(depths[2]):
        y = ConvNeXt_Block(dims[2])(y)
    
    # downsample + res5
    y = downsample_block(y, dims[3])
    for i in range(depths[3]):
        y = ConvNeXt_Block(dims[3])(y)

    y = layers.GlobalAveragePooling2D()(y)
    y = layers.LayerNormalization(epsilon=1e-6)(y) # final norm layer
    y = layers.Dense(num_classes)(y)

    return tf.keras.Model(inputs=input, outputs=y)




In [10]:
model = create_convnext_model(num_classes=1000)

In [11]:
print(model.summary())

Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 conv2d_9 (Conv2D)           (None, 56, 56, 96)        4704      
                                                                 
 layer_normalization_41 (Lay  (None, 56, 56, 96)       192       
 erNormalization)                                                
                                                                 
 conv_ne_xt__block_30 (ConvN  (None, 56, 56, 96)       79296     
 eXt_Block)                                                      
                                                                 
 conv_ne_xt__block_31 (ConvN  (None, 56, 56, 96)       79296     
 eXt_Block)                                                      
                                                           