In [1]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow" # or "jax" or "torch"
import tensorflow.keras as keras
print(keras.backend.backend())

tensorflow


# Architektur
Die U Net Modelarchitektur, wie der Name schon impliziert, ist U-formig angeordnet. 

## Bestandteile
- Convolutions (Blau Pfeile) kernel: 3x3, kein padding, stride: 1
- Pooling (rote Pfeile) kernel: 2x2, stride: 2
- Up Convolutions (de-convolutions): kernel 2x2
- Skip connections (ähnlich zum ResNet)

### Dimensionen von den skip connections
Die untere Skizze bezeichnet die grauen Pfeile mit "crop and copy". Warum nicht nur einfan nur kopieren ? 

Hint: Dimensionen...

Kann man es auch anders lösen, wenn ja wie?

## Features
- je tiefer in der Architektur desto größer das "Receptive Field" (ConvNets)
- nur geringer verlust von Information aka low level features.


![alt text](U-Net.png)

# Wie implementiert man sowas ?
Breche die einzelnen Bestandteile in kleinere runter und implementiere zuerst die "Kleinigkeiten", arbeite dich somit bis zum vollständigen UNet. 
-> Bottom up approach.

In [2]:
from keras.layers import Input, Conv2D, MaxPooling2D, Concatenate, Conv2DTranspose, Dropout, Layer
import tensorflow as tf

In [3]:
def conv_block(inputs, num_filters):
    x = Conv2D(num_filters, 3, activation='relu', padding='same')(inputs)
    x = Conv2D(num_filters, 3, activation='relu', padding='same')(x)
    return x

In [4]:
def encoder_block(inputs, num_filters):
    x = conv_block(inputs, num_filters)
    p = MaxPooling2D((2, 2))(x)
    return x, p

In [5]:
class ResizeLayer(Layer):  # you can easily change this to be a cropping layer
    def __init__(self, target_shape, **kwargs):
        super(ResizeLayer, self).__init__(**kwargs)
        self.target_shape = target_shape

    def call(self, x):
        return tf.image.resize(x, (self.target_shape[0], self.target_shape[1]))

def decoder_block(inputs, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding='same')(inputs)  # deconvolution ...
    # Check dimensions and scale skip features if needed
    if x.shape[1] != skip_features.shape[1] or x.shape[2] != skip_features.shape[2]:
        target_shape = x.shape[1:3]
        skip_features = ResizeLayer(target_shape=target_shape)(skip_features)
    
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

In [6]:
def build_unet(input_shape, out_classes, activation="sigmoid"):
    inputs = Input(input_shape)
    
    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)
    
    b1 = conv_block(p4, 1024)
    
    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)
    
    outputs = Conv2D(out_classes, 1, activation=activation)(d4)
    
    model = tf.keras.Model(inputs, outputs, name='U-Net')
    return model

In [7]:
UNET = build_unet((572,572,1), 1)

In [8]:
UNET.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=keras.losses.BinaryCrossentropy(),  # oder dice loss
    metrics=['accuracy']
)

In [9]:
UNET.summary(line_length=96)

## Nun mit Subclassing

In [10]:
class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, num_filters, **kwargs):
        super(ConvBlock, self).__init__(**kwargs)
        self.conv1 = tf.keras.layers.Conv2D(num_filters, 3, activation='relu', padding='same')
        self.conv2 = tf.keras.layers.Conv2D(num_filters, 3, activation='relu', padding='same')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        return x

In [11]:
class EncoderBlock(tf.keras.layers.Layer):
    def __init__(self, num_filters, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.conv = ConvBlock(num_filters)
        self.pool = tf.keras.layers.MaxPooling2D((2, 2))

    def call(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        return x, p

In [12]:
class DecoderBlock(tf.keras.layers.Layer):
    def __init__(self, num_filters, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.up = tf.keras.layers.Conv2DTranspose(num_filters, (2, 2), strides=2, padding='same')
        self.conv = ConvBlock(num_filters)

    def call(self, inputs, skip):
        x = self.up(inputs)
        if x.shape[1] != skip.shape[1] or x.shape[2] != skip.shape[2]:
            target_shape = x.shape[1:3]
            print(x.shape, skip.shape)
            skip = ResizeLayer(target_shape=target_shape)(skip)
        print(skip.shape)
        x = tf.concat([x, skip])
        x = self.conv(x)
        return x

In [13]:
class UNet(tf.keras.Model):
    def __init__(self, num_classes, **kwargs):
        super(UNet, self).__init__(**kwargs)
        self.enc1 = EncoderBlock(64)
        self.enc2 = EncoderBlock(128)
        self.enc3 = EncoderBlock(256)
        self.enc4 = EncoderBlock(512)
        self.center = ConvBlock(1024)
        self.dec4 = DecoderBlock(512)
        self.dec3 = DecoderBlock(256)
        self.dec2 = DecoderBlock(128)
        self.dec1 = DecoderBlock(64)
        self.out = tf.keras.layers.Conv2D(num_classes, 1, activation='sigmoid', name="outs")

    def call(self, inputs):
        s1, p1 = self.enc1(inputs)
        s2, p2 = self.enc2(p1)
        s3, p3 = self.enc3(p2)
        s4, p4 = self.enc4(p3)

        b1 = self.center(p4)

        d4 = self.dec4(b1, s4)
        d3 = self.dec3(d4, s3)
        d2 = self.dec2(d3, s2)
        d1 = self.dec1(d2, s1)

        outputs = self.out(d1)
        return outputs

In [14]:
unet = UNet(2)
unet.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=keras.losses.BinaryCrossentropy(),
    metrics=['accuracy']
)
unet.summary()

## Training auf 2 Arten
### Daten

In [15]:
import numpy as np
import tensorflow as tf
from tensorflow.data import Dataset 

In [16]:
N_SAMPLES = 800

In [17]:
def get_sample_gen(n=10_000):
    """
    creates a generator for n pictures of shape (572,572,1)
    """ 
    def gen():
        for _ in range(n):
            image = np.random.randint(0, 256, size=(572, 572, 1))
            label = np.random.randint(0, 1, size=(560, 560, 1))
            yield image.astype(np.float64), label.astype(np.uint8)
    return gen

In [18]:
def create_datasets(n_train, n_val):
    train = Dataset.from_generator(
        get_sample_gen(n_train),
        output_signature=(
         tf.TensorSpec(shape=(572, 572, 1), dtype=tf.float64),
         tf.TensorSpec(shape=(560, 560, 1), dtype=tf.uint8))
    )

    validation= Dataset.from_generator(
        get_sample_gen(n_val),
        output_signature=(
         tf.TensorSpec(shape=(572, 572, 1), dtype=tf.float64),
         tf.TensorSpec(shape=(560, 560, 1), dtype=tf.uint8))
    )

    return train, validation

In [19]:
train_ds, validation_ds = create_datasets(N_SAMPLES, 200)

### Was nun?
Was muss mit den Daten __immer__ gemacht werden?

In [20]:
def get_mean_std(dataset):
    count = np.int64(0) 

    sum_, sum_squared_diff = np.float64(0), np.float64(0)
    for img, _ in train_ds:
        y_dim, x_dim, n_chanels = img.shape
        sum_ += tf.reduce_sum(img).numpy()
        count += x_dim * y_dim * n_chanels
    mean = sum_ / count

    for img, _ in train_ds:
        sum_squared_diff += tf.reduce_sum(tf.square(img-mean)).numpy()
    std = np.sqrt(sum_squared_diff/count)
    
    return mean, std

In [21]:
mean, std = get_mean_std(train_ds)
print(mean, std)

2024-04-07 20:00:59.878644: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


127.49508384425889 73.90147898073124


2024-04-07 20:01:01.449568: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [22]:
def normalize_image(image, mean, std):
    return (image - mean) / std

In [23]:
train_ds = train_ds.map(lambda image, label: (normalize_image(image, mean, std), label))
validation_ds = validation_ds.map(lambda image, label: (normalize_image(image, mean, std), label)) 

In [24]:
mean_train_after, std_train_after = get_mean_std(train_ds)
mean_val_after, std_val_after = get_mean_std(validation_ds)
print(f"{mean_train_after=}\n{std_train_after=}\n{mean_val_after=}\n{std_val_after=}")


2024-04-07 20:01:03.073211: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-07 20:01:04.652030: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-04-07 20:01:06.136019: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


mean_train_after=4.2661192555753586e-05
std_train_after=1.0000138110846999
mean_val_after=3.706246580558664e-05
std_val_after=0.9999864632374262


2024-04-07 20:01:07.755627: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


## Prep for training

In [25]:
train_ds = train_ds.shuffle(800).batch(8).prefetch(tf.data.AUTOTUNE)
validation_ds = validation_ds.shuffle(200).batch(8).prefetch(tf.data.AUTOTUNE)

In [26]:
# tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs", histogram_freq=1)

# UNET.fit(
#     train_ds,
#     validation_data=validation_ds,
#     epochs=5,
#     callbacks=[tensorboard_callback]
# )

In [27]:
num_epochs = 5
loss_fn = keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.Adam()

for epoch in range(num_epochs):
    for batch_idx, (images, labels) in enumerate(train_ds):
        print(f'started batch {batch_idx + 1}')
        # Open a GradientTape to record the gradients
        with tf.GradientTape() as tape:
            predictions = UNET(images, training=True)
            loss = loss_fn(labels, predictions)

        # Compute the gradients
        gradients = tape.gradient(loss, UNET.trainable_variables)

        # Update the weights
        optimizer.apply_gradients(zip(gradients, UNET.trainable_variables))
        print(f"loss in batch {batch_idx + 1}: {loss}")


    # Print the loss for the current epoch
    print(f"Epoch {epoch + 1}/{num_epochs}")


started batch 1
loss in batch 0: 0.6775063872337341
started batch 2
loss in batch 1: 0.5276795625686646
started batch 3
loss in batch 2: 0.05412348359823227
started batch 4
loss in batch 3: 1.9861663247411343e-07
started batch 5
