## Imports and Setups

In [1]:
!pip install -qq wandb

In [2]:
import wandb
from wandb.keras import WandbMetricsLogger
from wandb.keras import WandbEvalCallback

wandb.login()

2024-04-03 17:50:41.766550: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-04-03 17:50:41.766658: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-04-03 17:50:41.895440: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [3]:
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models

import tensorflow_datasets as tfds

import os
import numpy as np
from argparse import Namespace
import matplotlib.pyplot as plt

In [4]:
configs = Namespace(
    epochs = 10,
    img_size = 128,
    batch_size = 32,
    num_classes = 3,
)
# configs

# Dataloader

We will be using Oxford Pets Dataset which we can directly get from TensorFlow Datasets.

In [5]:
train_ds, valid_ds = tfds.load('oxford_iiit_pet', split=["train", "test"])

[1mDownloading and preparing dataset 773.52 MiB (download: 773.52 MiB, generated: 774.69 MiB, total: 1.51 GiB) to /root/tensorflow_datasets/oxford_iiit_pet/3.2.0...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/3680 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/oxford_iiit_pet/3.2.0.incompleteZ7KEQ2/oxford_iiit_pet-train.tfrecord*...:…

Generating test examples...:   0%|          | 0/3669 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/oxford_iiit_pet/3.2.0.incompleteZ7KEQ2/oxford_iiit_pet-test.tfrecord*...: …

[1mDataset oxford_iiit_pet downloaded and prepared to /root/tensorflow_datasets/oxford_iiit_pet/3.2.0. Subsequent calls will reuse this data.[0m


In [6]:
AUTOTUNE = tf.data.experimental.AUTOTUNE


def parse_data(example):
    # Parse image
    image = example["image"]
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, size=(configs.img_size, configs.img_size))

    # Parse mask
    mask = example["segmentation_mask"] - 1 # ground truth labels are [1,2,3].
    mask = tf.image.resize(mask, size=(configs.img_size, configs.img_size), method='nearest')
    mask = tf.one_hot(tf.squeeze(mask, axis=-1), depth=configs.num_classes)

    return image, mask

trainloader = (
    train_ds
    .shuffle(1024)
    .map(parse_data, num_parallel_calls=AUTOTUNE)
    .batch(configs.batch_size)
    .prefetch(AUTOTUNE)
)

validloader = (
    valid_ds
    .map(parse_data, num_parallel_calls=AUTOTUNE)
    .batch(configs.batch_size)
    .prefetch(AUTOTUNE)
)

## Model

In [7]:
# ref: https://github.com/ayulockin/deepimageinpainting/blob/master/Image_Inpainting_Autoencoder_Decoder_v2_0.ipynb
class SegmentationModel:
    '''
    Build UNET based model for segmentation task.
    '''
    def prepare_model(self, OUTPUT_CHANNEL, input_size=(configs.img_size, configs.img_size, 3)):
        inputs = layers.Input(input_size)

        conv1, pool1 = self.__ConvBlock(32, (3,3), (2,2), 'relu', 'same', inputs)
        conv2, pool2 = self.__ConvBlock(64, (3,3), (2,2), 'relu', 'same', pool1)
        conv3, pool3 = self.__ConvBlock(128, (3,3), (2,2), 'relu', 'same', pool2)
        conv4, pool4 = self.__ConvBlock(256, (3,3), (2,2), 'relu', 'same', pool3)

        conv5, up6 = self.__UpConvBlock(512, 256, (3,3), (2,2), (2,2), 'relu', 'same', pool4, conv4)
        conv6, up7 = self.__UpConvBlock(256, 128, (3,3), (2,2), (2,2), 'relu', 'same', up6, conv3)
        conv7, up8 = self.__UpConvBlock(128, 64, (3,3), (2,2), (2,2), 'relu', 'same', up7, conv2)
        conv8, up9 = self.__UpConvBlock(64, 32, (3,3), (2,2), (2,2), 'relu', 'same', up8, conv1)

        conv9 = self.__ConvBlock(32, (3,3), (2,2), 'relu', 'same', up9, False)

        outputs = layers.Conv2D(OUTPUT_CHANNEL, (3, 3), activation='softmax', padding='same')(conv9)

        return models.Model(inputs=[inputs], outputs=[outputs])

    def __ConvBlock(self, filters, kernel_size, pool_size, activation, padding, connecting_layer, pool_layer=True):
        conv = layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(connecting_layer)
        conv = layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(conv)
        if pool_layer:
          pool = layers.MaxPooling2D(pool_size)(conv)
          return conv, pool
        else:
          return conv

    def __UpConvBlock(self, filters, up_filters, kernel_size, up_kernel, up_stride, activation, padding, connecting_layer, shared_layer):
        conv = layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(connecting_layer)
        conv = layers.Conv2D(filters=filters, kernel_size=kernel_size, activation=activation, padding=padding)(conv)
        up = layers.Conv2DTranspose(filters=up_filters, kernel_size=up_kernel, strides=up_stride, padding=padding)(conv)
        up = layers.concatenate([up, shared_layer], axis=3)

        return conv, up

#### Initialize Model and Compile

In [8]:
# output channel is 3 because we have three classes in our mask
tf.keras.backend.clear_session()
model = SegmentationModel().prepare_model(configs.num_classes)

model.compile(
    optimizer="adam",
    loss="categorical_crossentropy",
)

model.summary()

## Callback

In [9]:
segmentation_classes = ['pet', 'pet_outline', 'background']

# returns a dictionary of labels
def labels():
    l = {}
    for i, label in enumerate(segmentation_classes):
        l[i] = label
    return l

In [10]:
class WandbSemanticLogger(WandbEvalCallback):
    def __init__(
        self,
        validloader,
        data_table_columns=["index", "image"],
        pred_table_columns=["epoch", "index", "image", "prediction"],
        num_samples=100,
    ):
        super().__init__(
            data_table_columns,
            pred_table_columns,
        )

        self.val_data = validloader.unbatch().take(num_samples)

    def add_ground_truth(self, logs):
        for idx, (image, mask) in enumerate(self.val_data):
            self.data_table.add_data(
                idx,
                self._prepare_wandb_mask(
                    image.numpy(),
                    np.argmax(mask.numpy(), axis=-1),
                    "ground_truth"
                )
            )

    def add_model_predictions(self, epoch, logs):
        data_table_ref = self.data_table_ref
        table_idxs = data_table_ref.get_index()

        for idx, (image, mask) in enumerate(self.val_data):
            prediction = self.model.predict(tf.expand_dims(image, axis=0), verbose=0)
            prediction = np.argmax(tf.squeeze(prediction, axis=0).numpy(), axis=-1)

            self.pred_table.add_data(
                epoch,
                data_table_ref.data[idx][0],
                self._prepare_wandb_mask(
                    data_table_ref.data[idx][1],
                    np.argmax(mask.numpy(), axis=-1),
                    "ground_truth"
                ),
                self._prepare_wandb_mask(
                    data_table_ref.data[idx][1],
                    prediction,
                    "prediction"
                )
            )

    def _prepare_wandb_mask(self, image, mask, mask_type):
        return wandb.Image(
            image,
            masks = {
                "ground_truth": {
                    "mask_data": mask,
                    "class_labels": labels()
            }})

## Train

In [11]:
run = wandb.init(project='seg-keras-AyushThakur-0404', config=configs)

_ = model.fit(
    trainloader,
    epochs=configs.epochs,  # 10
    validation_data=validloader,
    callbacks=[
        WandbMetricsLogger(log_freq=2),
        WandbSemanticLogger(validloader)
      ]
    )

run.finish()

[34m[1mwandb[0m: Currently logged in as: [33mmykcs[0m ([33mteam-mykcs[0m). Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m:   202 of 202 files downloaded.  


Epoch 1/10


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment
2024-04-03 17:53:08.792270: E external/local_xla/xla/service/gpu/buffer_comparator.cc:1137] Difference at 2175: 3.31527, expected 2.86333
2024-04-03 17:53:08.792361: E external/local_xla/xla/service/gpu/buffer_comparator.cc:1137] Difference at 8832: 3.04387, expected 2.59193
2024-04-03 17:53:08.792371: E external/local_xla/xla/service/gpu/buffer_comparator.cc:1137] Difference at 8960: 3.285, expected 2.83305
2024-04-03 17:53:08.792382: E external/local_xla/xla/service/gpu/buffer_comparator.cc:1137] Difference at 9600: 3.16471, expected 2.71277
2024-04-03 17:53:08.792400: E external/local_xla/xla/service/gpu/buffer_comparator.cc:1137] Difference at 12032: 3.08921, expected 2.63726
2024-04-03 17:53:08.792421: E external/local_xla/xla/service/gpu/buffer_comparator.cc:1137] Difference at 15360: 3.30205, expected 2.8501
2024-04-03 17:53:08.792518: E external/local_xla/xla/service/gpu/

[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 72ms/step - loss: 78101.1719

2024-04-03 17:53:41.506967: E external/local_xla/xla/service/gpu/buffer_comparator.cc:1137] Difference at 32768: 5.29405, expected 4.31245
2024-04-03 17:53:41.507021: E external/local_xla/xla/service/gpu/buffer_comparator.cc:1137] Difference at 32769: 7.69962, expected 6.71802
2024-04-03 17:53:41.507030: E external/local_xla/xla/service/gpu/buffer_comparator.cc:1137] Difference at 32770: 7.66582, expected 6.68423
2024-04-03 17:53:41.507038: E external/local_xla/xla/service/gpu/buffer_comparator.cc:1137] Difference at 32771: 7.36342, expected 6.38183
2024-04-03 17:53:41.507046: E external/local_xla/xla/service/gpu/buffer_comparator.cc:1137] Difference at 32772: 6.87316, expected 5.89156
2024-04-03 17:53:41.507054: E external/local_xla/xla/service/gpu/buffer_comparator.cc:1137] Difference at 32773: 6.8059, expected 5.82431
2024-04-03 17:53:41.507062: E external/local_xla/xla/service/gpu/buffer_comparator.cc:1137] Difference at 32774: 5.81217, expected 4.83057
2024-04-03 17:53:41.507069: 

[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m56s[0m 263ms/step - loss: 87447.4297 - val_loss: 2026468.3750
Epoch 2/10


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment


[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 194ms/step - loss: 1525764736.0000 - val_loss: 3395709.5000
Epoch 3/10


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment


[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 192ms/step - loss: 393510848.0000 - val_loss: 707528753152.0000
Epoch 4/10


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment


[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m46s[0m 234ms/step - loss: 618652416.0000 - val_loss: 9581776896.0000
Epoch 5/10


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment


[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 190ms/step - loss: 6883639.5000 - val_loss: 2351592704.0000
Epoch 6/10


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment


[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 192ms/step - loss: 3340133.5000 - val_loss: 15792557056.0000
Epoch 7/10


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment


[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 193ms/step - loss: 12052047.0000 - val_loss: 53262995456.0000
Epoch 8/10


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment


[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 191ms/step - loss: 18570660.0000 - val_loss: 297293611008.0000
Epoch 9/10


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment


[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 192ms/step - loss: 115496144.0000 - val_loss: 5369804881920.0000
Epoch 10/10


Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment


[1m115/115[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 191ms/step - loss: 33356281856.0000 - val_loss: 59445618933760.0000


VBox(children=(Label(value='2.249 MB of 2.249 MB uploaded (0.613 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
batch/batch_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
batch/loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█
epoch/epoch,▁▂▃▃▄▅▆▆▇█
epoch/loss,▁▁▁▁▁▁▁▁▁█
epoch/val_loss,▁▁▁▁▁▁▁▁▂█

0,1
batch/batch_step,1158.0
batch/loss,303274852352.0
epoch/epoch,9.0
epoch/loss,303274885120.0
epoch/val_loss,59445618933760.0
