<a href="https://colab.research.google.com/github/ayyucedemirbas/vit_segmentor/blob/main/vit_segmentor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!wget https://datasets.simula.no/downloads/kvasir-seg.zip

--2024-12-08 19:40:12--  https://datasets.simula.no/downloads/kvasir-seg.zip
Resolving datasets.simula.no (datasets.simula.no)... 128.39.36.14
Connecting to datasets.simula.no (datasets.simula.no)|128.39.36.14|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 46227172 (44M) [application/zip]
Saving to: ‘kvasir-seg.zip’


2024-12-08 19:40:16 (14.0 MB/s) - ‘kvasir-seg.zip’ saved [46227172/46227172]



In [2]:
!unzip -qq kvasir-seg.zip

In [69]:
import tensorflow as tf
from tensorflow.keras import layers, Model, Input

class PatchEmbedding(layers.Layer):
    def __init__(self, patch_size, embed_dim, **kwargs):
        super(PatchEmbedding, self).__init__(**kwargs)
        self.patch_size = patch_size
        self.embed_dim = embed_dim

    def build(self, input_shape):
        _, self.img_height, self.img_width, self.num_channels = input_shape
        self.num_patches = (self.img_height // self.patch_size) * (self.img_width // self.patch_size)
        self.projection = layers.Conv2D(self.embed_dim, kernel_size=self.patch_size, strides=self.patch_size)
        self.flatten = layers.Reshape((self.num_patches, self.embed_dim))

    def call(self, inputs):
        x = self.projection(inputs)
        x = self.flatten(x)
        return x

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout_rate=0.1, **kwargs):
        super(TransformerBlock, self).__init__(**kwargs)
        self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.dropout1 = layers.Dropout(dropout_rate)
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)

        self.mlp = tf.keras.Sequential([
            layers.Dense(mlp_dim, activation='gelu'),
            layers.Dropout(dropout_rate),
            layers.Dense(embed_dim),
            layers.Dropout(dropout_rate)
        ])
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs, training):
        attn_output = self.attn(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.norm1(inputs + attn_output)

        mlp_output = self.mlp(out1, training=training)
        out2 = self.norm2(out1 + mlp_output)
        return out2

class ViTEncoder(layers.Layer):
    def __init__(self, num_patches, embed_dim, num_layers, num_heads, mlp_dim, dropout_rate=0.1, **kwargs):
        super(ViTEncoder, self).__init__(**kwargs)
        self.num_patches = num_patches
        self.embed_dim = embed_dim
        self.pos_embedding = self.add_weight(
            name="pos_embedding",
            shape=(1, num_patches, embed_dim),
            initializer="random_normal",
            trainable=True,
        )
        self.dropout = layers.Dropout(dropout_rate)
        self.transformer_blocks = [
            TransformerBlock(embed_dim, num_heads, mlp_dim, dropout_rate) for _ in range(num_layers)
        ]

    def call(self, inputs, training):
        x = inputs + self.pos_embedding
        x = self.dropout(x, training=training)
        for block in self.transformer_blocks:
            x = block(x, training=training)
        return x

def create_vit_segmentation_model(input_shape, patch_size, num_classes, embed_dim, num_layers, num_heads, mlp_dim):
    inputs = Input(shape=input_shape)
    patch_embed = PatchEmbedding(patch_size, embed_dim)(inputs)

    num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
    encoder = ViTEncoder(num_patches, embed_dim, num_layers, num_heads, mlp_dim)
    encoded_features = encoder(patch_embed, training=False)

    # Decoder (Upsampling for segmentation)

    decoder_input = layers.Reshape((input_shape[0] // patch_size, input_shape[1] // patch_size, embed_dim))(encoded_features)

    # Upsampling layers to match the input resolution (128, 128)
    x = layers.Conv2DTranspose(128, kernel_size=3, strides=2, padding="same", activation="relu")(decoder_input)  # (64, 64)
    x = layers.Conv2DTranspose(64, kernel_size=3, strides=2, padding="same", activation="relu")(x)  # (128, 128)
    x = layers.Conv2DTranspose(32, kernel_size=3, strides=2, padding="same", activation="relu")(x)
    x = layers.Conv2DTranspose(16, kernel_size=3, strides=2, padding="same", activation="relu")(x)

    outputs = layers.Conv2D(num_classes, kernel_size=1, activation="sigmoid")(x)  # Output shape: (128, 128, num_classes)

    return Model(inputs, outputs)


input_shape = (128, 128, 3)
patch_size = 16
num_classes = 1
embed_dim = 64
num_layers = 4
num_heads = 8
mlp_dim = 128


vit_segmentation_model = create_vit_segmentation_model(
    input_shape, patch_size, num_classes, embed_dim, num_layers, num_heads, mlp_dim
)
vit_segmentation_model.summary()


In [7]:
import tensorflow as tf
import os

def load_image_and_mask(image_path, mask_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, (128, 128))
    image = tf.cast(image, tf.float32) / 255.0

    mask = tf.io.read_file(mask_path)
    mask = tf.image.decode_png(mask, channels=1)
    mask = tf.image.resize(mask, (128, 128))
    mask = tf.cast(mask, tf.uint8)

    return image, mask

def create_dataset(images_dir, masks_dir, batch_size):
    image_paths = sorted([os.path.join(images_dir, fname) for fname in os.listdir(images_dir)])
    mask_paths = sorted([os.path.join(masks_dir, fname) for fname in os.listdir(masks_dir)])

    dataset = tf.data.Dataset.from_tensor_slices((image_paths, mask_paths))
    dataset = dataset.map(load_image_and_mask, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

    return dataset

train_dataset = create_dataset('/content/Kvasir-SEG/images', '/content/Kvasir-SEG/masks', batch_size=16)


In [65]:
for images, masks in train_dataset.take(1):
    predictions = vit_segmentation_model(images)
    print("Predictions shape:", predictions.shape)

Predictions shape: (16, 128, 128, 1)


In [66]:
for images, masks in train_dataset.take(1):
    print("Image shape:", images.shape)
    print("Mask shape:", masks.shape)


Image shape: (16, 128, 128, 3)
Mask shape: (16, 128, 128, 1)


In [70]:
vit_segmentation_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),

)

vit_segmentation_model.fit(train_dataset, epochs=20)

Epoch 1/20
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m44s[0m 268ms/step - loss: -3.0752
Epoch 2/20
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 37ms/step - loss: -313.3849
Epoch 3/20
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 49ms/step - loss: -5259.0391
Epoch 4/20
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 34ms/step - loss: -34475.5664
Epoch 5/20
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 34ms/step - loss: -135088.9062
Epoch 6/20
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 34ms/step - loss: -389779.2812
Epoch 7/20
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 58ms/step - loss: -923582.3125
Epoch 8/20
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 34ms/step - loss: -1910146.6250
Epoch 9/20
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 35ms/step - loss: -3577675.7500
Epoch 10/20
[1m63/63[0m [32m━━━━━━━━━━━━━━━━━━━━

<keras.src.callbacks.history.History at 0x7cccd172a500>