<a href="https://colab.research.google.com/github/nnilayy/Unet/blob/main/Unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input, Dropout
from tensorflow.keras.models import Model

In [2]:
def conv_block(inputs,num_filters):
  x=Conv2D(num_filters,3,padding="same")(inputs)
  x=BatchNormalization()(x)
  x=Activation("relu")(x)

  x=Conv2D(num_filters,3,padding="same")(x)
  x=BatchNormalization()(x)
  x=Activation("relu")(x)
  return x

def encoder_block(inputs,num_filters):
  x=conv_block(inputs,num_filters)
  p=MaxPool2D((2,2))(x)
  return x,p

def decoder_block(inputs, skip_features,num_filters):
  x=Conv2DTranspose(num_filters,(2,2),strides=2,padding="same")(inputs)
  x=Concatenate()([x,skip_features])
  x=conv_block(x,num_filters)
  return x

In [18]:
def build_unet(input_shape):
  inputs=Input(input_shape)

  s1,p1=encoder_block(inputs,64)
  s2,p2=encoder_block(p1,128)
  s3,p3=encoder_block(p2,256)
  s4,p4=encoder_block(p3,512)

  b1 = conv_block(p4,1024)

  d1 = decoder_block(b1,s4,512)
  d2 = decoder_block(d1,s3,256)
  d3 = decoder_block(d2,s2,128)
  d4 = decoder_block(d3,s1,64)

  outputs=Conv2D(1,(1,1),padding="same",activation="sigmoid")(d4)
  model=Model(inputs,outputs,name="U-Net")
  return model

In [19]:
input_shape=(96,96,3)
model=build_unet(input_shape)
model.summary()

Model: "U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_8 (InputLayer)           [(None, 96, 96, 3)]  0           []                               
                                                                                                  
 conv2d_100 (Conv2D)            (None, 96, 96, 256)  7168        ['input_8[0][0]']                
                                                                                                  
 batch_normalization_94 (BatchN  (None, 96, 96, 256)  1024       ['conv2d_100[0][0]']             
 ormalization)                                                                                    
                                                                                                  
 activation_94 (Activation)     (None, 96, 96, 256)  0           ['batch_normalization_94[0][0

In [None]:
def unet(pretrained_weights=None, input_size=(256,256,3)):
  inputs=Input(input_size)
# ----------------------------------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------------------
# Encoder Block
  conv1=Conv2D(64,3,activation="relu",padding="same",kernel_initializer="he_normal")(inputs)
  conv1=Conv2D(64,3,activation="relu",padding="same",kernel_initializer="he_normal")(conv1)
  pool1=MaxPool2D((2,2))(conv1)

  conv2=Conv2D(128,3,activation="relu",padding="same",kernel_initializer="he_normal")(pool1)
  conv2=Conv2D(128,3,activation="relu",padding="same",kernel_initializer="he_normal")(conv2)
  pool2=MaxPool2D((2,2))(conv2)

  conv3=Conv2D(256,3,activation="relu",padding="same",kernel_initializer="he_normal")(pool2)
  conv3=Conv2D(256,3,activation="relu",padding="same",kernel_initializer="he_normal")(conv3)
  pool3=MaxPool2D((2,2))(conv3)

  conv4=Conv2D(512,3,activation="relu",padding="same",kernel_initializer="he_normal")(pool3)
  conv4=Conv2D(512,3,activation="relu",padding="same",kernel_initializer="he_normal")(conv4)
  pool4=MaxPool2D((2,2))(conv4)
# ----------------------------------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------------------
# Bottleneck
  conv5=Conv2D(1024,3,activation="relu",padding="same",kernel_initializer="he_normal")(pool4)
  conv5=Conv2D(1024,3,activation="relu",padding="same",kernel_initializer="he_normal")(conv5)
# ----------------------------------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------------------
# Decoder Block
  up6=Conv2DTranspose(512,(2,2),strides=2,padding="same")(conv5)
  merge6=Concatenate()([up6,conv5],axis=3)
  conv6=Conv2D(512,3,activation="relu",padding="same",kernel_initializer="he_normal")(merge6)
  conv6=Conv2D(512,3,activation="relu",padding="same",kernel_initializer="he_normal")(conv6)

  up7=Conv2DTranspose(256,(2,2),strides=2,padding="same")(conv6)
  merge7=Concatenate()([up7,conv3],axis=3)
  conv7=Conv2D(256,3,activation="relu",padding="same",kernel_initializer="he_normal")(merge7)
  conv7=Conv2D(256,3,activation="relu",padding="same",kernel_initializer="he_normal")(conv7)

  up8=Conv2DTranspose(128,(2,2),strides=2,padding="same")(conv7)
  merge8=Concatenate()([up8,conv2],axis=3)
  conv8=Conv2D(128,3,activation="relu",padding="same",kernel_initializer="he_normal")(merge8)
  conv8=Conv2D(128,3,activation="relu",padding="same",kernel_initializer="he_normal")(conv8)

  up9=Conv2DTranspose(64,(2,2),strides=2,padding="same")(conv8)
  merge9=Concatenate()([up8,conv2],axis=3)
  conv9=Conv2D(64,3,activation="relu",padding="same",kernel_initializer="he_normal")(merge9)
  conv9=Conv2D(64,3,activation="relu",padding="same",kernel_initializer="he_normal")(conv9)

  conv10=Conv2D(1,1,activation="sigmoid")(conv9)
# ----------------------------------------------------------------------------------------------------
# ----------------------------------------------------------------------------------------------------
  model=Model(input=inputs, output=conv10)
  model.compile(optimizer=Adam(lr=1e-4),loss="binary_crossentropy",metrics=["accuracy"])
  

In [None]:
# You Can Add BatchNormalization Layers
  conv1=Conv2D(64,3,activation="relu",padding="same",kernel_initializer="he_normal")(inputs)
  conv1=BatchNormalization()(conv1)
  conv1=Conv2D(64,3,activation="relu",padding="same",kernel_initializer="he_normal")(conv1)
  conv1=BatchNormalization()(conv1)
  pool1=MaxPool2D((2,2))(conv1)

# You Can Add Dropout layers too
  conv1=Conv2D(64,3,activation="relu",padding="same",kernel_initializer="he_normal")(inputs)
  conv1=Conv2D(64,3,activation="relu",padding="same",kernel_initializer="he_normal")(conv1)
  conv1=Dropout(0.2)(conv1) 

In [None]:
import tensorflow as tf
import numpy as np
from akida_models import akidanet_imagenet
from keras import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import BatchNormalization, Conv2D, Softmax, ReLU
from cnn2snn import check_model_compatibility
from ei_tensorflow.constrained_object_detection import models, dataset, metrics, util

def build_model(input_shape: tuple, alpha: float,num_classes: int, weight_regularizer=None) -> tf.keras.Model:
    """ Construct a constrained object detection model.

    Args:
        input_shape: Passed to AkidaNet construction.
        alpha: AkidaNet alpha value.
        num_classes: Number of classes, i.e. final dimension size, in output.

    Returns:
        Uncompiled keras model.

    Model takes (B, H, W, C) input and
    returns (B, H//8, W//8, num_classes) logits.
    """
    #! Create a quantized base model without top layers
    a_base_model = akidanet_imagenet(input_shape=input_shape,alpha=alpha,include_top=False,input_scaling=None)
    #! Get pretrained quantized weights and load them into the base model
    #! Available base models are:
    #! akidanet_imagenet_224_alpha_50.h5             - float32 model, 224x224x3, alpha=0.5
    #! akidanet_imagenet_160_alpha_50.h5             - float32 model, 160x160x3, alpha=0.5
    pretrained_weights = './transfer-learning-weights/akidanet/akidanet_imagenet_224_alpha_50.h5'
    a_base_model.load_weights(pretrained_weights, by_name=True, skip_mismatch=True)
    a_base_model.trainable = True
    #! Default batch norm is configured for huge networks, let's speed it up
    for layer in a_base_model.layers:
        if type(layer) == BatchNormalization:
            layer.momentum = 0.9
    #! Cut AkidaNet where it hits 1/8th input resolution; i.e. (HW/8, HW/8, C)
    a_cut_point = a_base_model.get_layer('separable_5_relu')
    #! Now attach a small additional head on the AkidaNet
    a_model_part_head = Conv2D(filters=32, kernel_size=1, strides=1, padding='same',kernel_regularizer=weight_regularizer)(a_cut_point.output)
    a_model_part = ReLU()(a_model_part_head)
    a_logits = Conv2D(filters=num_classes, kernel_size=1, strides=1, padding='same',activation=None, kernel_regularizer=weight_regularizer)(a_model_part)
    fomo_akida = Model(inputs=a_base_model.input, outputs=a_logits)
    #! Check if the model is sompatbile with Akida (fail quickly before training)
    compatible = check_model_compatibility(fomo_akida, input_is_image=True)
    if not compatible:
        print("Model is not compatible with Akida!")
        sys.exit(1)

    return fomo_akida

def train(num_classes: int, learning_rate: float, num_epochs: int,alpha: float, object_weight: int,train_dataset: tf.data.Dataset,validation_dataset: tf.data.Dataset,best_model_path: str,input_shape: tuple, callbacks: 'list',quantize_function,lr_finder: bool = False) -> tf.keras.Model:
    """ Construct and train a constrained object detection model.

    Args:
        num_classes: Number of classes in datasets. This does not include
        implied background class introduced by segmentation map dataset
        conversion.
        learning_rate: Learning rate for Adam.
        num_epochs: Number of epochs passed to model.fit
        alpha: Alpha used to construct AkidaNet. Pretrained weights will be
        used if there is a matching set.
        object_weight: The weighting to give the object in the loss function
            where background has an implied weight of 1.0.
        train_dataset: Training dataset of (x, (bbox, one_hot_y))
        validation_dataset: Validation dataset of (x, (bbox, one_hot_y))
        best_model_path: location to save best model path. note: weights
            will be restored from this path based on best val_f1 score.
        input_shape: The shape of the model's input
        lr_finder: TODO
    Returns:
        Trained keras model.

    Constructs a new constrained object detection model with num_classes+1
    outputs (denoting the classes with an implied background class of 0).
    Both training and validation datasets are adapted from
    (x, (bbox, one_hot_y)) to (x, segmentation_map). Model is trained with a
    custom weighted cross entropy function.
    """


    num_classes_with_background = num_classes + 1

    input_width_height = None
    width, height, input_num_channels = input_shape
    if width != height:
        raise Exception(f"Only square inputs are supported; not {input_shape}")
    input_width_height = width

    model = build_model(input_shape=input_shape,alpha=alpha,num_classes=num_classes_with_background,weight_regularizer=tf.keras.regularizers.l2(4e-5))
    #! Derive output size from model
    model_output_shape = model.layers[-1].output.shape
    _batch, width, height, num_classes = model_output_shape
    if width != height:
        raise Exception(f"Only square outputs are supported; not {model_output_shape}")
    output_width_height = width

    #! Build weighted cross entropy loss specific to this model size
    weighted_xent = models.construct_weighted_xent_fn(model.output.shape, object_weight)
    #! Transform bounding box labels into segmentation maps
    train_segmentation_dataset = train_dataset.map(dataset.bbox_to_segmentation(output_width_height, num_classes_with_background)).batch(32, drop_remainder=False).prefetch(1)
    validation_segmentation_dataset = validation_dataset.map(dataset.bbox_to_segmentation(output_width_height, num_classes_with_background, validation=True)).batch(32, drop_remainder=False).prefetch(1)
    #! Initialise bias of final classifier based on training data prior.
    util.set_classifier_biases_from_dataset(model, train_segmentation_dataset)
    if lr_finder:
        learning_rate = ei_tensorflow.lr_finder.find_lr(model, train_segmentation_dataset, weighted_xent)

    opt = Adam(learning_rate=learning_rate)
    model.compile(loss=weighted_xent,optimizer=opt)

    #! Create callback that will do centroid scoring on end of epoch against
    #! validation data. Include a callback to show % progress in slow cases.
    centroid_callback = metrics.CentroidScoring(validation_segmentation_dataset,output_width_height, num_classes_with_background)
    print_callback = metrics.PrintPercentageTrained(num_epochs)

    #! Include a callback for model checkpointing based on the best validation f1.
    checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(best_model_path,monitor='val_f1', save_best_only=True, mode='max',save_weights_only=True, verbose=0)
    model.fit(train_segmentation_dataset,validation_data=validation_segmentation_dataset,epochs=num_epochs,callbacks=callbacks + [centroid_callback, print_callback, checkpoint_callback],verbose=0)
    #! Restore best weights.
    model.load_weights(best_model_path)
    #! Add explicit softmax layer before export.
    softmax_layer = Softmax()(model.layers[-1].output)
    model = Model(model.input, softmax_layer)
    #! Check if model is compatible with Akida
    compatible = check_model_compatibility(model, input_is_image=True)
    if not compatible:
        print("Model is not compatible with Akida!")
        sys.exit(1)

    akida_model = quantize_function(model=model,train_dataset=train_segmentation_dataset,validation_dataset=validation_segmentation_dataset,optimizer=opt,fine_tune_loss=weighted_xent,fine_tune_metrics=None,best_model_path=best_model_path,callbacks=callbacks + [centroid_callback, print_callback],stopping_metric='val_f1',verbose=0)

    return model, akida_model


EPOCHS = args.epochs or 100
LEARNING_RATE = args.learning_rate or 0.001

def quantize_brainchip(model,train_dataset: tf.data.Dataset,validation_dataset: tf.data.Dataset,best_model_path: str, optimizer: str,fine_tune_loss: str,fine_tune_metrics: 'list[str]',callbacks, stopping_metric='val_accuracy',verbose=2):
    import tensorflow as tf
    import cnn2snn

    print('Performing post-training quantization...')
    akida_model = cnn2snn.quantize(model,weight_quantization=4,activ_quantization=4,input_weight_quantization=8)
    print('Performing post-training quantization OK')
    print('')

    early_stopping = tf.keras.callbacks.EarlyStopping(monitor=stopping_metric,mode='max',verbose=1,min_delta=0,patience=10,restore_best_weights=True)
    callbacks.append(early_stopping)

    print('Running quantization-aware training...')
    akida_model.compile(optimizer=optimizer,loss=fine_tune_loss,metrics=fine_tune_metrics)
    akida_model.fit(train_dataset,epochs=30,verbose=verbose,validation_data=validation_dataset,callbacks=callbacks)
    print('Running quantization-aware training OK')
    print('')

    return akida_model


model, akida_model = train(num_classes=classes,learning_rate=LEARNING_RATE,num_epochs=EPOCHS,alpha=0.5,object_weight=100,train_dataset=train_dataset,validation_dataset=validation_dataset,best_model_path=BEST_MODEL_PATH,input_shape=MODEL_INPUT_SHAPE,callbacks=callbacks,quantize_function=quantize_brainchip,lr_finder=False)
override_mode = 'segmentation'
disable_per_channel_quantization = False