# Run this code block to install dependencies

In [1]:
# !git clone https://github.com/KenzaB27/TransUnet.git
# %cd TransUnet
# !pip install -r requirements.txt 
# %cd ..

In [2]:
%cd TransUnet
import models.transunet as transunet
import utils.visualize as visualize
import experiments.config as conf
import importlib
%cd ..

/Users/srinathramalingam/Desktop/codebase/TransUnet/TransUnet
/Users/srinathramalingam/Desktop/codebase/TransUnet


In [3]:
import os
import cv2
import pickle
import imageio
import numpy as np
from tqdm import tqdm
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from bp import Environment, String
from focal_loss import BinaryFocalLoss
from dTurk.generators import SemsegData
from dTurk.builders import model_builder
from tensorflow.keras import backend as K
from dTurk.utils.clr_callback import CyclicLR
from dTurk.metrics import MeanIoU, WeightedMeanIoU
from tensorflow.keras.callbacks import ModelCheckpoint
from dTurk.loaders.dataset_loader import SemsegDatasetLoader
from tensorflow.keras.callbacks import TensorBoard, EarlyStopping
from dTurk.augmentation.transforms import get_train_transform_policy, get_validation_transform_policy
from dTurk.models.sm_models.losses import CategoricalCELoss, CategoricalFocalLoss, DiceLoss, JaccardLoss

Segmentation Models: using `tf.keras` framework.


In [4]:
env = Environment()

In [5]:
config = conf.get_transunet()
config['image_size'] = 256
config["filters"] = 3
config['n_skip'] = 3
config['decoder_channels'] = [128, 64, 32, 16]
config['resnet']['n_layers'] = (3,4,9,12)
config['dropout'] = 0.1
config['grid'] = (28,28)

In [6]:
loss = 'iou'
lr = 0.005
batch_size = 12
patience = 12
monitor="val_loss"
log="transunet1"
step_size = int(2.0 * 6400 / batch_size)
save_path = 'weights'
checkpoint_filepath = save_path + '/checkpoint/'

In [7]:
train_input_names = ["/Users/srinathramalingam/Desktop/codebase/Jupyter_predictions_Nearspace/MACH-77-it3/train_labels/" + i for i in os.listdir("/Users/srinathramalingam/Desktop/codebase/Jupyter_predictions_Nearspace/MACH-77-it3/train") if i.endswith(".png")]
val_input_names = ["/Users/srinathramalingam/Desktop/codebase/Jupyter_predictions_Nearspace/MACH-77-it3/val_labels/" + i for i in os.listdir("/Users/srinathramalingam/Desktop/codebase/Jupyter_predictions_Nearspace/MACH-77-it3/val") if i.endswith(".png")]

In [8]:
from dTurk.models.SM_UNet import SM_UNet_Builder
import dTurk.models.sm_models as sm

In [9]:
builder = SM_UNet_Builder(
    encoder_name='efficientnetv2-l',
    input_shape=(256, 256, 3),
    num_classes=3,
    activation="softmax",
    train_encoder=False,
    encoder_weights="imagenet",
    decoder_block_type="upsampling",
    head_dropout=0,  # dropout at head
    dropout=0,  # dropout at feature extraction
)

In [10]:
model = builder.build_model()

2022-06-29 16:02:33.193031: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Instructions for updating:
Restoring a name-based tf.train.Saver checkpoint using the object-based restore API. This mode uses global names to match variables, and so is somewhat fragile. It also adds new restore ops to the graph each time it is called when graph building. Prefer re-encoding training checkpoints in the object-based format: run save() on the object-based saver (the same one this message is coming from) and use that checkpoint in the future.


In [11]:
network = transunet.TransUnet(config, trainable=False)

ListWrapper([128, 64, 32, 16])


In [12]:
def data_dir(machine='akami', dataset='MACH-77-it1'):
    if machine == 'local':
        dataset_directory='/Users/srinathramalingam/mnt' + f'/datasets/semseg_base/{dataset}'
    else:
        dataset_directory='/home/bv' +  f'/datasets/semseg_base/{dataset}'
    return dataset_directory

In [13]:
def get_loss(loss_name: str):
    name = loss_name.lower()
    
    class_weights=[1,1,1]
    class_indexes=[0,1,2]

    if name == "weighted_categorical_cross_entropy":
        loss_function = CategoricalCELoss(class_weights=class_weights, class_indexes=class_indexes)
    elif name == "binary_crossentropy":
        loss_function = BinaryCrossentropy()
    elif name in ["iou", "jaccard"]:
        loss_function = JaccardLoss(class_weights=class_weights, class_indexes=class_indexes, per_image=False)
    elif name in ["dice", "f1"]:
        loss_function = DiceLoss(class_weights=class_weights, class_indexes=class_indexes, per_image=False)
    elif name in ["focal_dice"]:
        dice_loss = DiceLoss(class_weights=class_weights, class_indexes=class_indexes, per_image=False)
        focal_loss = CategoricalFocalLoss(alpha=0.25, gamma=2, class_indexes=class_indexes)
        loss_function = dice_loss + focal_loss
    elif name in ["focal_iou"]:
        iou_loss = JaccardLoss(class_weights=class_weights, class_indexes=class_indexes, per_image=False)
        focal_loss = CategoricalFocalLoss(alpha=0.25, gamma=2, class_indexes=class_indexes)
        loss_function = iou_loss + focal_loss
    else:
        message = f"Loss function '{name}' is not supported"
        if logger:
            logger.error(message)
        raise Exception(message)
    return loss_function

In [14]:
def create_dataset(train_input_names, val_input_names):
    train_data = SemsegData(
        subset="train",
        transform_policy=get_train_transform_policy(augmentation_file='/Users/srinathramalingam/Desktop/codebase/dTurk/dTurk/augmentation/configs/light.yaml'),
        preprocess=model_builder.get_preprocessing('s'),
        layer_colors=[[0, 0, 0], [255, 0, 0], [0, 255, 0]],
        use_mixup=False,
        use_sample_weights=False,
        use_distance_weights=False,
    )
    val_data = SemsegData(
        subset="val",
        transform_policy=None,
        preprocess=model_builder.get_preprocessing('s'),
        layer_colors=[[0, 0, 0], [255, 0, 0], [0, 255, 0]],
        use_sample_weights=False,
        use_distance_weights=False,
    )

    image_shape = (256, 256, 3)
    label_shape = (
        256,
        256,
        None,
    )
    shapes = [image_shape, label_shape]

    train_ds_batched = train_data.get_tf_data(
        batch_size=12, input_names=train_input_names, shapes=shapes
    )
    val_ds_batched = val_data.get_tf_data(
        batch_size=12, input_names=val_input_names, shapes=shapes
    )

    return train_ds_batched, val_ds_batched

In [15]:
def iou():
    class_weights_primary = np.array([0,0,1])
    class_weights_primary = class_weights_primary / class_weights_primary.sum()
    class_weights = np.array([1,1,1])

    metric =  WeightedMeanIoU(
            num_classes=3, class_weights=class_weights_primary, name="primary_mean_iou"
        )
    return metric

In [16]:
def load_image(filename):
    data = np.empty((0,))
    data = imageio.imread(filename)
    if len(data.shape) == 3:
        if data.shape[0] in (3, 4):
            data = np.transpose(data, (1, 2, 0))
        if data.shape[2] == 4:
            data = data[:, :, :3]
    return data

In [17]:
def split_sample(layer_colored, malady_layer_dict):
    split_layer_name = []
    has_malady = False
    for layer_name, layer_color in malady_layer_dict.items():
        if (np.all(np.equal(layer_colored, layer_color), axis=-1).any()):
            split_layer_name.append(layer_name)
            has_malady = True
    if not has_malady:
        split_layer_name.append("base_layer")
    return split_layer_name

In [18]:
def _using_slice_tag(
    malady_layer_dict: dict, label_filename: str, split_files, split_labels
):
    train_label_path = os.path.dirname(label_filename[0])
    with open(os.path.join(data_dir('local'), "semseg_slice_tag.pickle"), "rb") as slice_tag_file:
        while True:
            try:
                slice_label = pickle.load(slice_tag_file)
                for filename, metadata in slice_label.items():
                    for layer_name in split_sample(metadata["color"], malady_layer_dict):
                        split_files[layer_name].append(
                            os.path.join(train_label_path.replace("train_labels", "train"), filename)
                        )
                        split_labels[layer_name].append(os.path.join(train_label_path, filename))
            except EOFError:
                break

In [19]:
def _using_local_files(
    malady_layer_dict: dict,
    train_labels,
    split_files,
    split_labels,
):
    for label_path in tqdm(train_labels):
        label_slice = load_image(label_path)
        for layer_name in split_sample(label_slice, malady_layer_dict):
            split_files[layer_name].append(
                label_path.replace("train_labels", "train")
            )
            split_labels[layer_name].append(label_path)

In [20]:
def _do_oversampling(train_labels):
    malady_layer_dict = {'exposed_deck': [0, 255, 0]}

    split_files = {k: [] for k in list(malady_layer_dict.keys()) + ["base_layer"]}
    split_labels = {k: [] for k in list(malady_layer_dict.keys()) + ["base_layer"]}

    if os.path.exists(os.path.join(data_dir('local'), "semseg_slice_tag.pickle")):
        
        _using_slice_tag(malady_layer_dict, train_labels, split_files, split_labels)
    else:
        _using_local_files(malady_layer_dict, train_labels, split_files, split_labels)
    return split_files, split_labels


In [21]:
def oversampling(train_labels, oversample: int):

    split_files, split_labels = _do_oversampling(train_labels)
    major_class_counts = max(map(len, split_labels.values()))
    oversampling_factor = {
        k: round(major_class_counts / len(v))
        if oversample == -1
        else oversample
        if round(major_class_counts / len(v)) > 1
        else 1
        for k, v in split_files.items()
    }

    train_files = [item for k, v in split_files.items() for item in v * oversampling_factor[k]]
    train_files.sort()

    return train_files

In [22]:
model.compile(optimizer='adam', loss=get_loss(loss), metrics=iou())

In [23]:
callbacks = []
cyclic_lr = CyclicLR(
                        base_lr=lr / 10.0,
                        max_lr=lr,
                        step_size=step_size,
                        mode='triangular2',
                        cyclic_momentum=False,
                        max_momentum=False,
                        base_momentum=0.8,
                    )
callbacks.append(cyclic_lr)

early_stopping = EarlyStopping(
                    monitor=monitor,
                    mode="min" if "loss" in monitor else "max",
                    patience=patience,
                    verbose=1,
                    restore_best_weights=True,
                )
callbacks.append(early_stopping)

tensorboard_path = os.path.join(env.paths.remote, "dTurk", "logs", f"{log}", "tensorboard")
tensorboard = TensorBoard(tensorboard_path, histogram_freq=1)
callbacks.append(tensorboard)

In [24]:
train_input_names = oversampling(train_input_names,-1)

In [25]:
train_ds_batched, val_ds_batched = create_dataset(train_input_names, val_input_names)



In [26]:
network.model.summary()

Model: "TransUNet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_2 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, 262, 262, 3)  0           ['input_2[0][0]']                
                                                                                                  
 conv1_conv (Conv2D)            (None, 128, 128, 64  9472        ['conv1_pad[0][0]']              
                                )                                                                 
                                                                                          

                                                                                                  
 conv2_block3_preact_relu (Acti  (None, 64, 64, 256)  0          ['conv2_block3_preact_bn[0][0]'] 
 vation)                                                                                          
                                                                                                  
 conv2_block3_1_conv (Conv2D)   (None, 64, 64, 64)   16384       ['conv2_block3_preact_relu[0][0]'
                                                                 ]                                
                                                                                                  
 conv2_block3_1_bn (BatchNormal  (None, 64, 64, 64)  256         ['conv2_block3_1_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv2_blo

 conv3_block2_2_conv (Conv2D)   (None, 32, 32, 128)  147456      ['conv3_block2_2_pad[0][0]']     
                                                                                                  
 conv3_block2_2_bn (BatchNormal  (None, 32, 32, 128)  512        ['conv3_block2_2_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv3_block2_2_relu (Activatio  (None, 32, 32, 128)  0          ['conv3_block2_2_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv3_block2_3_conv (Conv2D)   (None, 32, 32, 512)  66048       ['conv3_block2_2_relu[0][0]']    
                                                                                                  
 conv3_blo

 conv4_block1_1_conv (Conv2D)   (None, 16, 16, 256)  131072      ['conv4_block1_preact_relu[0][0]'
                                                                 ]                                
                                                                                                  
 conv4_block1_1_bn (BatchNormal  (None, 16, 16, 256)  1024       ['conv4_block1_1_conv[0][0]']    
 ization)                                                                                         
                                                                                                  
 conv4_block1_1_relu (Activatio  (None, 16, 16, 256)  0          ['conv4_block1_1_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv4_block1_2_pad (ZeroPaddin  (None, 18, 18, 256)  0          ['conv4_block1_1_relu[0][0]']    
 g2D)     

                                                                                                  
 conv4_block3_2_relu (Activatio  (None, 16, 16, 256)  0          ['conv4_block3_2_bn[0][0]']      
 n)                                                                                               
                                                                                                  
 conv4_block3_3_conv (Conv2D)   (None, 16, 16, 1024  263168      ['conv4_block3_2_relu[0][0]']    
                                )                                                                 
                                                                                                  
 conv4_block3_out (Add)         (None, 16, 16, 1024  0           ['conv4_block2_out[0][0]',       
                                )                                 'conv4_block3_3_conv[0][0]']    
                                                                                                  
 conv4_blo

 reshape (Reshape)              (None, 256, 768)     0           ['embedding[0][0]']              
                                                                                                  
 Transformer/posembed_input (Ad  (None, 256, 768)    196608      ['reshape[0][0]']                
 dPositionEmbs)                                                                                   
                                                                                                  
 dropout (Dropout)              (None, 256, 768)     0           ['Transformer/posembed_input[0][0
                                                                 ]']                              
                                                                                                  
 Transformer/encoderblock_0 (Tr  ((None, 256, 768),  7087872     ['dropout[0][0]']                
 ansformerBlock)                 (None, 12, None, N                                               
          

In [28]:
history = model.fit(train_ds_batched, epochs=25, validation_data=val_ds_batched, callbacks=[callbacks])

Epoch 1/25


2022-06-29 16:03:21.227290: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at whole_file_read_ops.cc:116 : NOT_FOUND: /Users/srinathramalingam/Desktop/codebase/Jupyter_predictions_Nearspace/MACH-77-it3/train/294602_10.png; No such file or directory
2022-06-29 16:03:21.227385: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at whole_file_read_ops.cc:116 : NOT_FOUND: /Users/srinathramalingam/Desktop/codebase/Jupyter_predictions_Nearspace/MACH-77-it3/train/295715_1.png; No such file or directory
2022-06-29 16:03:21.227538: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at whole_file_read_ops.cc:116 : NOT_FOUND: /Users/srinathramalingam/Desktop/codebase/Jupyter_predictions_Nearspace/MACH-77-it3/train/982260_133.png; No such file or directory
2022-06-29 16:03:21.227664: W tensorflow/core/framework/op_kernel.cc:1745] OP_REQUIRES failed at whole_file_read_ops.cc:116 : NOT_FOUND: /Users/srinathramalingam/Desktop/codebase/Jupyter_predictio

UnknownError: Graph execution error:

NotFoundError: /Users/srinathramalingam/Desktop/codebase/Jupyter_predictions_Nearspace/MACH-77-it3/train/295460_0.png; No such file or directory [Op:ReadFile]
Traceback (most recent call last):

  File "/Users/srinathramalingam/opt/anaconda3/envs/bv2/lib/python3.10/site-packages/tensorflow/python/ops/script_ops.py", line 270, in __call__
    ret = func(*args)

  File "/Users/srinathramalingam/opt/anaconda3/envs/bv2/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
    return func(*args, **kwargs)

  File "/Users/srinathramalingam/Desktop/codebase/dTurk/dTurk/generators/tf_data.py", line 259, in parse_data
    image_np, label_np = self.load_data(filename)

  File "/Users/srinathramalingam/Desktop/codebase/dTurk/dTurk/generators/tf_data.py", line 106, in load_data
    image_np = self.load_image(filename)

  File "/Users/srinathramalingam/Desktop/codebase/dTurk/dTurk/generators/tf_data.py", line 86, in load_image
    img = tf.io.read_file(filename)

  File "/Users/srinathramalingam/opt/anaconda3/envs/bv2/lib/python3.10/site-packages/tensorflow/python/ops/io_ops.py", line 133, in read_file
    return gen_io_ops.read_file(filename, name)

  File "/Users/srinathramalingam/opt/anaconda3/envs/bv2/lib/python3.10/site-packages/tensorflow/python/ops/gen_io_ops.py", line 566, in read_file
    return read_file_eager_fallback(

  File "/Users/srinathramalingam/opt/anaconda3/envs/bv2/lib/python3.10/site-packages/tensorflow/python/ops/gen_io_ops.py", line 589, in read_file_eager_fallback
    _result = _execute.execute(b"ReadFile", 1, inputs=_inputs_flat,

  File "/Users/srinathramalingam/opt/anaconda3/envs/bv2/lib/python3.10/site-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,

tensorflow.python.framework.errors_impl.NotFoundError: /Users/srinathramalingam/Desktop/codebase/Jupyter_predictions_Nearspace/MACH-77-it3/train/295460_0.png; No such file or directory [Op:ReadFile]


	 [[{{node PyFunc}}]]
	 [[IteratorGetNext]] [Op:__inference_train_function_81822]

nsorflow/python/ops/script_ops.py", line 270, in __call__
    ret = func(*args)

  File "/Users/srinathramalingam/opt/anaconda3/envs/bv2/lib/python3.10/site-packages/tensorflow/python/autograph/impl/api.py", line 642, in wrapper
    return func(*args, **kwargs)

  File "/Users/srinathramalingam/Desktop/codebase/dTurk/dTurk/generators/tf_data.py", line 259, in parse_data
    image_np, label_np = self.load_data(filename)

  File "/Users/srinathramalingam/Desktop/codebase/dTurk/dTurk/generators/tf_data.py", line 106, in load_data
    image_np = self.load_image(filename)

  File "/Users/srinathramalingam/Desktop/codebase/dTurk/dTurk/generators/tf_data.py", line 86, in load_image
    img = tf.io.read_file(filename)

  File "/Users/srinathramalingam/opt/anaconda3/envs/bv2/lib/python3.10/site-packages/tensorflow/python/ops/io_ops.py", line 133, in read_file
    return gen_io_ops.read_file(filename, name)

  File "/Users/srinathramalingam/opt/anaconda3/envs/bv2/lib/python3.10/site-packages/ten

    return read_file_eager_fallback(

  File "/Users/srinathramalingam/opt/anaconda3/envs/bv2/lib/python3.10/site-packages/tensorflow/python/ops/gen_io_ops.py", line 589, in read_file_eager_fallback
    _result = _execute.execute(b"ReadFile", 1, inputs=_inputs_flat,

  File "/Users/srinathramalingam/opt/anaconda3/envs/bv2/lib/python3.10/site-packages/tensorflow/python/eager/execute.py", line 54, in quick_execute
    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,

tensorflow.python.framework.errors_impl.NotFoundError: /Users/srinathramalingam/Desktop/codebase/Jupyter_predictions_Nearspace/MACH-77-it3/train/299243_0.png; No such file or directory [Op:ReadFile]


2022-06-29 16:03:21.243801: W tensorflow/core/framework/op_kernel.cc:1733] UNKNOWN: NotFoundError: /Users/srinathramalingam/Desktop/codebase/Jupyter_predictions_Nearspace/MACH-77-it3/train/299999_2.png; No such file or directory [Op:ReadFile]
Traceback (most recent call last):

  File "/Users/srinathramali

In [None]:
network.model.load_weights(checkpoint_filepath)
saved_model_path = save_path + "/model"
network.model.save(saved_model_path)

In [None]:
def predict(image, model, file):
    img = cv2.imread(image)
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    img = np.array(img)/255
    img = img.reshape((1,256,256,3))
    prediction = model.predict(img)
    prediction = tf.clip_by_value(prediction, 0.0, 1.0)
    pred = prediction * 255
    pred = np.array(prediction).reshape((256,256,3))
    img = img.reshape((256,256,3))
    return img, pred

In [None]:
val_input_names = ["/Users/srinathramalingam/Desktop/codebase/Jupyter_predictions_Nearspace/MACH-77-it2/val/images/" + i for i in os.listdir("/Users/srinathramalingam/Desktop/codebase/Jupyter_predictions_Nearspace/MACH-77-it2/val/images")]

In [None]:
for file in range(len(val_input_names)):
    img, pred = predict(val_input_names[file], network.model, file)
    
    plt.subplot(1, 2, 1)
    plt.imshow(img)

    plt.subplot(1, 2, 2)
    plt.imshow(pred)

    plt.show()