In [None]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow" # or "jax" or "torch"
import tensorflow.keras as keras
print(keras.backend.backend())

# Architektur
Die U Net Modelarchitektur, wie der Name schon impliziert, ist U-formig angeordnet. 

## Bestandteile
- Convolutions (Blau Pfeile) kernel: 3x3, kein padding, stride: 1
- Pooling (rote Pfeile) kernel: 2x2, stride: 2
- Up Convolutions (de-convolutions): kernel 2x2
- Skip connections (ähnlich zum ResNet)

### Dimensionen von den skip connections
Die untere Skizze bezeichnet die grauen Pfeile mit "crop and copy". Warum nicht nur einfan nur kopieren ? 

Hint: Dimensionen...

Kann man es auch anders lösen, wenn ja wie?

## Features
- je tiefer in der Architektur desto größer das "Receptive Field" (ConvNets)
- nur geringer verlust von Information aka low level features.


![alt text](U-Net.png)

# Wie implementiert man sowas ?
Breche die einzelnen Bestandteile in kleinere runter und implementiere zuerst die "Kleinigkeiten", arbeite dich somit bis zum vollständigen UNet. 
-> Bottom up approach.

In [None]:
from keras.layers import Input, Conv2D, MaxPooling2D, Concatenate, Conv2DTranspose, Dropout, Layer
import tensorflow as tf

In [None]:
def conv_block(inputs, num_filters):
    x = Conv2D(num_filters, 3, activation='relu', padding='same')(inputs)
    x = Conv2D(num_filters, 3, activation='relu', padding='same')(x)
    return x

In [None]:
def encoder_block(inputs, num_filters):
    x = conv_block(inputs, num_filters)
    p = MaxPooling2D((2, 2))(x)
    return x, p

In [None]:
class ResizeLayer(Layer):  # you can easily change this to be a cropping layer
    def __init__(self, target_shape, **kwargs):
        super(ResizeLayer, self).__init__(**kwargs)
        self.target_shape = target_shape

    def call(self, x):
        return tf.image.resize_with_crop_or_pad(x, self.target_shape[0], self.target_shape[1])

def decoder_block(inputs, skip_features, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding='same')(inputs)  # deconvolution ...
    # Check dimensions and scale skip features if needed
    if x.shape[1] != skip_features.shape[1] or x.shape[2] != skip_features.shape[2]:
        target_shape = x.shape[1:3]
        skip_features = ResizeLayer(target_shape=target_shape)(skip_features)
    
    x = Concatenate()([x, skip_features])
    x = conv_block(x, num_filters)
    return x

In [None]:
def build_unet(input_shape, out_classes):
    inputs = Input(input_shape)
    
    s1, p1 = encoder_block(inputs, 64)
    s2, p2 = encoder_block(p1, 128)
    s3, p3 = encoder_block(p2, 256)
    s4, p4 = encoder_block(p3, 512)
    
    b1 = conv_block(p4, 1024)
    
    d1 = decoder_block(b1, s4, 512)
    d2 = decoder_block(d1, s3, 256)
    d3 = decoder_block(d2, s2, 128)
    d4 = decoder_block(d3, s1, 64)
    
    outputs = Conv2D(out_classes, 1, activation='sigmoid')(d4)
    
    model = tf.keras.Model(inputs, outputs, name='U-Net')
    return model

In [None]:
UNET = build_unet((573,572,3), 2)

In [None]:
UNET.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=keras.losses.CategoricalCrossentropy(),  # oder dice loss
    metrics=['accuracy']
)

In [None]:
UNET.summary(line_length=96)

## Nun mit Subclassing

In [None]:
class ConvBlock(tf.keras.layers.Layer):
    def __init__(self, num_filters, **kwargs):
        super(ConvBlock, self).__init__(**kwargs)
        self.conv1 = tf.keras.layers.Conv2D(num_filters, 3, activation='relu', padding='same')
        self.conv2 = tf.keras.layers.Conv2D(num_filters, 3, activation='relu', padding='same')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        return x

In [None]:
class EncoderBlock(tf.keras.layers.Layer):
    def __init__(self, num_filters, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.conv = ConvBlock(num_filters)
        self.pool = tf.keras.layers.MaxPooling2D((2, 2))

    def call(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        return x, p

In [None]:
class DecoderBlock(tf.keras.layers.Layer):
    def __init__(self, num_filters, **kwargs):
        super(DecoderBlock, self).__init__(**kwargs)
        self.up = tf.keras.layers.Conv2DTranspose(num_filters, (2, 2), strides=2, padding='same')
        self.conv = ConvBlock(num_filters)

    def call(self, inputs, skip):
        x = self.up(inputs)
        x = tf.concat([x, skip], axis=-1)
        x = self.conv(x)
        return x

In [None]:
class UNet(tf.keras.Model):
    def __init__(self, num_classes, **kwargs):
        super(UNet, self).__init__(**kwargs)
        self.enc1 = EncoderBlock(64)
        self.enc2 = EncoderBlock(128)
        self.enc3 = EncoderBlock(256)
        self.enc4 = EncoderBlock(512)
        self.center = ConvBlock(1024)
        self.dec4 = DecoderBlock(512)
        self.dec3 = DecoderBlock(256)
        self.dec2 = DecoderBlock(128)
        self.dec1 = DecoderBlock(64)
        self.out = tf.keras.layers.Conv2D(num_classes, 1, activation='sigmoid', name="outs")

    def call(self, inputs):
        s1, p1 = self.enc1(inputs)
        s2, p2 = self.enc2(p1)
        s3, p3 = self.enc3(p2)
        s4, p4 = self.enc4(p3)

        b1 = self.center(p4)

        d4 = self.dec4(b1, s4)
        d3 = self.dec3(d4, s3)
        d2 = self.dec2(d3, s2)
        d1 = self.dec1(d2, s1)

        outputs = self.out(d1)
        return outputs

# Einblick Formel DiceLoss
![alt text](dl.png)

## Aber was genau bedeutet das

![alt text](dice_vis.png)

In [None]:
def dice_coeff(y_true, y_pred):
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    dice = (2. * intersection) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f))
    return dice

def dice_loss(y_true, y_pred):
    loss = 1 - dice_coeff(y_true, y_pred)
    return loss

# ODER


class DiceLoss(tf.keras.losses.Loss):  # generalisierte versionen 1-DLCoeff
    def __init__(self, name='dice_loss'):
        super().__init__(name=name)

    def call(self, y_true, y_pred):
        y_true = tf.cast(y_true, dtype=tf.float32)
        y_pred = tf.cast(y_pred, dtype=tf.float32)
        y_true = tf.reshape(y_true, [-1])
        y_pred = tf.reshape(y_pred, [-1])
        nominator = 2 * tf.reduce_sum(y_true * y_pred) + self.smooth
        denominator = tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + self.smooth
        dice_loss = 1 - tf.pow((nominator / denominator), 1/self.gama)
        return dice_loss

In [None]:
unet = UNet(2)
unet.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=DiceLoss(),
    metrics=['accuracy']
)
unet.summary()