# Run this code block to install dependencies

In [1]:
# !git clone https://github.com/KenzaB27/TransUnet.git
# %cd TransUnet
# !pip install -r requirements.txt 
# %cd ..

In [2]:
%cd TransUnet
import models.transunet as transunet
import utils.visualize as visualize
import experiments.config as conf
import importlib
%cd ..

/Users/srinathramalingam/Desktop/codebase/TransUnet/TransUnet
/Users/srinathramalingam/Desktop/codebase/TransUnet


In [3]:
import os
import cv2
import pickle
import imageio
import numpy as np
from tqdm import tqdm
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from bp import Environment, String
from focal_loss import BinaryFocalLoss
from dTurk.generators import SemsegData
from dTurk.builders import model_builder
from tensorflow.keras import backend as K
from dTurk.utils.clr_callback import CyclicLR
from dTurk.metrics import MeanIoU, WeightedMeanIoU
from tensorflow.keras.callbacks import ModelCheckpoint
from dTurk.loaders.dataset_loader import SemsegDatasetLoader
from tensorflow.keras.callbacks import TensorBoard, EarlyStopping
from dTurk.augmentation.transforms import get_train_transform_policy, get_validation_transform_policy
from dTurk.models.sm_models.losses import CategoricalCELoss, CategoricalFocalLoss, DiceLoss, JaccardLoss

Segmentation Models: using `tf.keras` framework.


In [4]:
import os
import argparse
import pandas as pd
import tensorflow as tf
from bp import Environment
import TransUnet.experiments.config as conf
from dTurk.utils.clr_callback import CyclicLR
import TransUnet.models.transunet as transunet
from tensorflow.keras.callbacks import TensorBoard, EarlyStopping
from train_helpers import dice_loss, mean_iou, oversampling, create_dataset

In [5]:
env = Environment()

config = conf.get_transunet()
config['image_size'] = 256
config["filters"] = 3
config['n_skip'] = 3
config['decoder_channels'] = [128, 64, 32, 16]
config['resnet']['n_layers'] = (3,4,9,12)
config['dropout'] = 0.1
config['grid'] = (28,28)
config["n_layers"] = 12

In [6]:
dataset = "MACH-77-it3"
machine = "local"
monitor = "val_loss"
epochs = 75
patience = 12
batch_size = 32
lr = 0.005
train_augmentation_file = "/Users/srinathramalingam/Desktop/codebase/dTurk/dTurk/augmentation/configs/light.yaml"
save_path = "weights/TransUnet"
checkpoint_filepath = save_path + "/checkpoint/"

In [7]:
dataset_directory = os.environ.get("BP_PATH_REMOTE") + "/datasets/semseg_base" + "/" + dataset

In [8]:
try:
    gpus = tf.config.list_physical_devices("GPU")
    tf.config.set_visible_devices(gpus[args_dict["gpu"]], "GPU")
except:
    print("Gpus not found")

Gpus not found


In [9]:
train_input_names = [
    dataset_directory + "/train_labels/" + i
    for i in os.listdir(dataset_directory + "/train_labels/")
    if i.endswith(".png")
]
val_input_names = [
    dataset_directory + "/val/" + i for i in os.listdir(dataset_directory + "/val/") if i.endswith(".png")
]

In [10]:
train_input_names = oversampling(train_input_names, machine, dataset, -1)
train_ds_batched, val_ds_batched = create_dataset(train_input_names, val_input_names, train_augmentation=train_augmentation_file)

100%|████████████████████████████████████████████| 90/90 [00:00<00:00, 236.28it/s]
2022-06-29 18:30:51.409412: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [11]:
step_size = int(2.0 * len(train_input_names) / batch_size)

In [12]:
network = transunet.TransUnet(config, trainable=False)

ListWrapper([128, 64, 32, 16])


In [13]:
network.model.compile(optimizer="adam", loss=dice_loss, metrics=mean_iou)

In [14]:
callbacks = []
cyclic_lr = CyclicLR(
    base_lr=lr / 10.0,
    max_lr=lr,
    step_size=step_size,
    mode="triangular2",
    cyclic_momentum=False,
    max_momentum=False,
    base_momentum=0.8,
)
callbacks.append(cyclic_lr)

early_stopping = EarlyStopping(
    monitor=monitor,
    mode="min" if "loss" in monitor else "max",
    patience=patience,
    verbose=1,
    restore_best_weights=True,
)
callbacks.append(early_stopping)

In [None]:
history = network.model.fit(
    train_ds_batched, epochs=epochs, validation_data=val_ds_batched, callbacks=[callbacks]
)

Epoch 1/75

In [None]:
iou = history.history["primary_mean_iou"]
val_iou = history.history["primary_mean_iou"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]

df = pd.DataFrame(iou)
df.columns = ["mean_iou"]
df["val_mean_iou"] = val_iou
df["loss"] = loss
df["val_loss"] = val_loss

df.to_csv("TransUnet-logs.csv")

In [None]:
network.model.load_weights(checkpoint_filepath)
saved_model_path = args_dict["save_path"] + "/model"
network.model.save(saved_model_path)