In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import Model
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.activations import *
from tensorflow.keras.regularizers import *
from tensorflow.keras import Input

from typing import Union , Tuple , Dict , List

In [None]:
def conv(input_tensor: tf.float32,
         filters: int) -> tf.float32:
    kernel_size = (3, 3)
    padding = 'same'
    activation_fn = 'relu'
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               padding=padding)(input_tensor)
    return Activation(activation_fn)(x)

 

def multihead_attention(input_tensor: tf.float32,
                        num_heads: int,
                        dims: int) -> tf.float32:
    dropout_rate = 0.1
    x = MultiHeadAttention(num_heads=num_heads,
                           key_dim=dims,
                           dropout=dropout_rate)(input_tensor, input_tensor)
    x = Add()([x, input_tensor])
    return LayerNormalization()(x)



def TransBlock(input_tensor: tf.float32,
               num_heads: int,
               dims: int,
               filters: int) -> tf.float32:
    x = multihead_attention(input_tensor=input_tensor,
                            num_heads=num_heads,
                            dims=dims)
    x = conv(input_tensor=x, filters=filters)
    x = Add()([input_tensor, x])
    return LayerNormalization()(x)
 

def TransUnet(height: int,
              width: int,
              color_channels: int,
              num_classes: int,
              n_transblock: int) -> tf.float32:

    '''
    @params :
    1. n_transblock : numbers of transformer blocks in model

    note : upsample is replaced by conv2dTranspose
    '''
    pool_size = (2, 2)
    kernel_size = (3, 3)
    strides = (2, 2)
    padding = 'same'

    input_shape = (height, width, color_channels)
    inputs = Input(shape=input_shape)

    x = conv(input_tensor=inputs, filters=128)
    out1 = MaxPooling2D(pool_size=pool_size)(x)
    x = conv(input_tensor=out1, filters=256)
    out2 = MaxPooling2D(pool_size=pool_size)(x)
    x = conv(input_tensor=out2, filters=512)
    out3 = MaxPooling2D(pool_size=pool_size)(x)
    out3 = BatchNormalization()(out3)
    
    
    x = out3
    for _ in range(n_transblock):
        x = TransBlock(input_tensor=x, num_heads=12, dims=512, filters=512)
    x = MaxPooling2D(pool_size=pool_size)(x)
    x = conv(input_tensor=x, filters=512)
    x = Conv2DTranspose(filters=512,
                        kernel_size=kernel_size,
                        strides=strides,
                        padding=padding)(x)
    x = Concatenate()([x, out3])
    x = conv(input_tensor=x, filters=256)
    x = Conv2DTranspose(filters=256,
                        kernel_size=kernel_size,
                        strides=strides,
                        padding=padding)(x)
    x = Concatenate()([x, out2])
    x = conv(input_tensor=x, filters=128)
    x = Conv2DTranspose(filters=128,
                        kernel_size=kernel_size,
                        strides=strides,
                        padding=padding)(x)
    x = Concatenate()([x, out1])
    x = conv(input_tensor=x, filters=16)
    x = Conv2DTranspose(filters=16,
                        kernel_size=kernel_size,
                        strides=strides,
                        padding=padding)(x)
    outputs = conv(input_tensor=x, filters=num_classes)
    return Model(inputs=inputs, outputs=outputs, name='TransUnet')