<a href="https://colab.research.google.com/github/guscldns/TestProject/blob/main/0728/10_2_(solution)UNET_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# implementing  custom U-net

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

### standard U-Net 구현

- 주요 함수 :  
1. layers.Cropping2D(cropping=((top_crop, bottom_crop), (left_crop, right_crop)))  
2. upconv_block(inputs, filters, kernel_size=2):  
    x = layers.UpSampling2D(size=(2, 2))(inputs)  
    x = layers.Conv2D(filters, kernel_size, activation='relu', padding='same')(x)


In [None]:
img_size = (572, 572)
num_classes = 2

In [None]:
def conv_block(inputs, filters, kernel_size=3):
    x = layers.Conv2D(filters, kernel_size, activation='relu')(inputs)
    x = layers.Conv2D(filters, kernel_size, activation='relu')(x)
    return x

def upconv_block(inputs, filters, kernel_size=2):
    x = layers.UpSampling2D(size=(2, 2))(inputs)
    x = layers.Conv2D(filters, kernel_size, activation='relu', padding='same')(x)
    return x

def unet(img_size, num_classes):
    inputs = keras.Input(shape=img_size + (3,))

    # Contracting Path
    c1 = conv_block(inputs, 64, kernel_size=3)
    p1 = layers.MaxPooling2D(pool_size=(2, 2))(c1)

    c2 = conv_block(p1, 128, kernel_size=3)
    p2 = layers.MaxPooling2D(pool_size=(2, 2))(c2)

    c3 = conv_block(p2, 256, kernel_size=3)
    p3 = layers.MaxPooling2D(pool_size=(2, 2))(c3)

    c4 = conv_block(p3, 512, kernel_size=3)
    p4 = layers.MaxPooling2D(pool_size=(2, 2))(c4)

    # Bottom
    b = conv_block(p4, 1024, kernel_size=3)

    # Expanding Path
    u1 = upconv_block(b, 512, kernel_size=2)
    c4_crop = layers.Cropping2D(cropping=((4, 4), (4, 4)))(c4)
    u1_concat = layers.Concatenate()([u1, c4_crop])
    c5 = conv_block(u1_concat, 512, kernel_size=3)

    u2 = upconv_block(c5, 256, kernel_size=2)
    c3_crop = layers.Cropping2D(cropping=((16, 16), (16, 16)))(c3)
    u2_concat = layers.Concatenate()([u2, c3_crop])
    c6 = conv_block(u2_concat, 256, kernel_size=3)

    u3 = upconv_block(c6, 128, kernel_size=2)
    c2_crop = layers.Cropping2D(cropping=((40, 40), (40, 40)))(c2)
    u3_concat = layers.Concatenate()([u3, c2_crop])
    c7 = conv_block(u3_concat, 128, kernel_size=3)

    u4 = upconv_block(c7, 64, kernel_size=2)
    c1_crop = layers.Cropping2D(cropping=((88, 88), (88, 88)))(c1)
    u4_concat = layers.Concatenate()([u4, c1_crop])
    c8 = conv_block(u4_concat, 64, kernel_size=3)

    outputs = layers.Conv2D(num_classes, 1, activation='softmax')(c8)

    return tf.keras.Model(inputs, outputs)



keras.backend.clear_session()

model = unet(img_size, num_classes)
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 572, 572, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 570, 570, 64  1792        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 conv2d_1 (Conv2D)              (None, 568, 568, 64  36928       ['conv2d[0][0]']                 
                                )                                                             

## small image size를 위한 custom unet 구현
1. no padding -> padding
2. crop -> no crop

In [None]:
img_size = (128, 128)
num_classes = 3

In [None]:
def conv_block(inputs, filters, kernel_size=3):
    # no padding -> padding
    x = layers.Conv2D(filters, kernel_size, activation='relu', padding='same')(inputs)
    x = layers.Conv2D(filters, kernel_size, activation='relu', padding='same')(x)
    return x

def upconv_block(inputs, filters, kernel_size=2):
    x = layers.UpSampling2D(size=(2, 2))(inputs)
    x = layers.Conv2D(filters, kernel_size, activation='relu', padding='same')(x)
    return x

def unet(img_size, num_classes):
    inputs = keras.Input(shape=img_size + (3,))

    # Contracting Path
    c1 = conv_block(inputs, 64, kernel_size=3)
    p1 = layers.MaxPooling2D(pool_size=(2, 2))(c1)

    c2 = conv_block(p1, 128, kernel_size=3)
    p2 = layers.MaxPooling2D(pool_size=(2, 2))(c2)

    c3 = conv_block(p2, 256, kernel_size=3)
    p3 = layers.MaxPooling2D(pool_size=(2, 2))(c3)

    c4 = conv_block(p3, 512, kernel_size=3)
    p4 = layers.MaxPooling2D(pool_size=(2, 2))(c4)

    # Bottom
    b = conv_block(p4, 1024, kernel_size=3)

    # Expanding Path
    u1 = upconv_block(b, 512, kernel_size=2)
    u1_concat = layers.Concatenate()([u1, c4]) # c4 <- c4_crop (no crop <- crop)
    c5 = conv_block(u1_concat, 512, kernel_size=3)

    u2 = upconv_block(c5, 256, kernel_size=2)
    u2_concat = layers.Concatenate()([u2, c3]) # c3 <- c3_crop (no crop <- crop)
    c6 = conv_block(u2_concat, 256, kernel_size=3)

    u3 = upconv_block(c6, 128, kernel_size=2)
    u3_concat = layers.Concatenate()([u3, c2]) # c2 <- c2_crop (no crop <- crop)
    c7 = conv_block(u3_concat, 128, kernel_size=3)

    u4 = upconv_block(c7, 64, kernel_size=2)
    u4_concat = layers.Concatenate()([u4, c1]) # c1 <- c1_crop (no crop <- crop)
    c8 = conv_block(u4_concat, 64, kernel_size=3)

    outputs = layers.Conv2D(num_classes, 1, activation='softmax')(c8)

    return tf.keras.Model(inputs, outputs)


keras.backend.clear_session()

model = unet(img_size, num_classes)
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 128, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 128, 128, 64  1792        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 conv2d_1 (Conv2D)              (None, 128, 128, 64  36928       ['conv2d[0][0]']                 
                                )                                                             