In [31]:
import tensorflow as tf

class UNet:
    def __init__(self, input_shape=(256, 256, 3), num_classes=1, encoder_config=None):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.encoder_config = encoder_config if encoder_config else [(64, (3, 3), (2, 2)), (128, (3, 3), (2, 2))]
        self.model = self.build_model()
        
    def encoder(self, x, filters, kernel_size, strides):
        x = tf.keras.layers.Conv2D(filters, kernel_size, strides=strides, padding='same')(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        return x
    
    def decoder(self, x, skip, filters, kernel_size, strides):
        x = tf.keras.layers.UpSampling2D()(x)
        x = tf.keras.layers.concatenate([x, skip])
        x = self.encoder(x, filters, kernel_size, strides)
        return x
    
    def build_model(self):
        inputs = tf.keras.Input(shape=self.input_shape)
        x = inputs
        skip_connections = []
        
        # Encoder
        for enc in self.encoder_config:
            x = self.encoder(x, *enc)
            skip_connections.append(x)
        
        # Decoder (you can also make this configurable)
        for skip, enc in reversed(list(zip(skip_connections[:-1], self.encoder_config[:-1]))):
            x = self.decoder(x, skip, *enc)
        
        # Output layer
        outputs = tf.keras.layers.Conv2D(self.num_classes, (1, 1), activation='softmax')(x)
        
        model = tf.keras.Model(inputs=inputs, outputs=outputs)
        
        return model
    
    def compile(self, optimizer, loss, metrics):
        self.model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    
    def summary(self):
        return self.model.summary()


2023-10-04 07:36:03.776587: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-04 07:36:07.051744: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-04 07:36:07.051774: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-04 07:36:07.068863: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-10-04 07:36:08.886693: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-10-04 07:36:08.899751: I tensorflow/core/platform/cpu_feature_guard.cc:182] This Tens

In [32]:
# Example usage
encoder_config = [(64, (3, 3), (2, 2)), (128, (3, 3), (2, 2))]
unet = UNet(input_shape=(256, 256, 3), num_classes=1, encoder_config=encoder_config)
unet.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
unet.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 256, 256, 3)]        0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, 128, 128, 64)         1792      ['input_1[0][0]']             
                                                                                                  
 batch_normalization (Batch  (None, 128, 128, 64)         256       ['conv2d[0][0]']              
 Normalization)                                                                                   
                                                                                                  
 re_lu (ReLU)                (None, 128, 128, 64)         0         ['batch_normalization[0][0