In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models

In [2]:
def conv_block(x, filters, name=None):
    x = layers.Conv2D(
        filters, 3, padding="same", activation="relu", name=f"{name}_conv1"
    )(x)
    x = layers.BatchNormalization(name=f"{name}_bn1")(x)

    x = layers.Conv2D(
        filters, 3, padding="same", activation="relu", name=f"{name}_conv2"
    )(x)
    x = layers.BatchNormalization(name=f"{name}_bn2")(x)

    return x

In [3]:
def encoder_block(x, filters, name):
    x = conv_block(x, filters, name=name)
    skip = x                      # lưu cho skip connection
    x = layers.MaxPooling2D((2, 2), name=f"{name}_pool")(x)
    return x, skip

In [4]:
def decoder_block(x, skip, filters, name):
    x = layers.Conv2DTranspose(
        filters, (2, 2), strides=2, padding="same", name=f"{name}_up"
    )(x)

    x = layers.Concatenate(name=f"{name}_concat")([x, skip])
    x = conv_block(x, filters, name=name)

    return x

In [8]:
def build_unet_denoise(input_shape=(256, 256, 1)):
    inputs = layers.Input(shape=input_shape, name="noisy_spectrogram")

    # -------- Encoder --------
    x1, s1 = encoder_block(inputs, 64, "enc1")
    x2, s2 = encoder_block(x1, 128, "enc2")
    x3, s3 = encoder_block(x2, 256, "enc3")
    x4, s4 = encoder_block(x3, 512, "enc4")

    # -------- Bottleneck --------
    b = conv_block(x4, 1024, "bottleneck")

    # -------- Decoder --------
    d4 = decoder_block(b, s4, 512, "dec4")
    d3 = decoder_block(d4, s3, 256, "dec3")
    d2 = decoder_block(d3, s2, 128, "dec2")
    d1 = decoder_block(d2, s1, 64, "dec1")

    # -------- Output (mask) --------
    mask = layers.Conv2D(
        1, (1, 1), activation="sigmoid", name="mask"
    )(d1)

    model = models.Model(inputs, mask, name="UNet_Denoise")
    return model

In [9]:
model = build_unet_denoise()
model.summary()