In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, GlobalAveragePooling2D, Dense, Add, UpSampling2D, concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import load_img, img_to_array

# Constants (make sure to change the paths)
input_dir = r"C:\Users\slowi\Deep Learning\Tensorflow test\leftImg8bit_trainvaltest\leftImg8bit\train"
target_dir = r"C:\Users\slowi\Deep Learning\Tensorflow test\gtFine_trainvaltest\gtFine\train"
img_size = (512, 512)  # Changed to match the model's expected input size
num_classes = 19  # Adjusted to match your model's output
batch_size = 2
epochs = 2  # Adjust as needed

# Data Generator
class Cityscapes(tf.keras.utils.Sequence):
    def __init__(self, batch_size, img_size, input_img_paths, target_img_paths):
        self.batch_size = batch_size
        self.img_size = img_size
        self.input_img_paths = input_img_paths
        self.target_img_paths = target_img_paths
        self.void_classes = [0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1]
        self.valid_classes = [7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33]
        self.class_map = dict(zip(self.valid_classes, range(num_classes)))
        self.ignore_index = 255
        self.on_epoch_end()

    def __len__(self):
        return len(self.input_img_paths) // self.batch_size

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.input_img_paths))
        np.random.shuffle(self.indexes)

    def __getitem__(self, idx):
        i = idx * self.batch_size
        batch_input_img_paths = [self.input_img_paths[k] for k in self.indexes[i:i + self.batch_size]]
        batch_target_img_paths = [self.target_img_paths[k] for k in self.indexes[i:i + self.batch_size]]
        
        x = np.zeros((self.batch_size,) + self.img_size + (3,), dtype="float32")
        y = np.zeros((self.batch_size,) + self.img_size + (num_classes,), dtype="uint8")
        
        for j, (input_path, target_path) in enumerate(zip(batch_input_img_paths, batch_target_img_paths)):
            img = load_img(input_path, target_size=self.img_size)
            x[j] = img_to_array(img) / 255.0
            
            label = load_img(target_path, target_size=self.img_size, color_mode="grayscale")
            label = img_to_array(label).astype(np.int32).squeeze()
            label = self.fix_indxs(label)
            y[j] = self.one_hot_encode(label)
            
        return x, y

    def fix_indxs(self, mask):
        mask = np.where(np.isin(mask, self.void_classes), self.ignore_index, mask)
        mask = np.where(mask != self.ignore_index, np.vectorize(self.class_map.get)(mask, self.ignore_index), mask)
        mask = np.where(mask == self.ignore_index, num_classes, mask)  # Set ignore index to num_classes
        return mask
    
    def one_hot_encode(self, lbl):
        one_hot = np.eye(num_classes + 1)[lbl]  # Include an additional class for ignore_index
        return one_hot[..., :num_classes]  # Drop the additional class to keep num_classes channels

# Helper function to get image paths
def get_image_paths(directory, extension):
    paths = []
    for city in os.listdir(directory):
        city_path = os.path.join(directory, city)
        if os.path.isdir(city_path):
            for fname in os.listdir(city_path):
                if fname.endswith(extension):
                    paths.append(os.path.join(city_path, fname))
    return sorted(paths)

input_img_paths = get_image_paths(input_dir, ".png")
target_img_paths = get_image_paths(target_dir, "gtFine_labelIds.png")

# Create data generators
train_gen = Cityscapes(batch_size, img_size, input_img_paths, target_img_paths)

# Define the model architecture
def SpatialPath(input_tensor):
    x = Conv2D(64, (7, 7), strides=2, padding='same')(input_tensor)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(128, (3, 3), strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(256, (3, 3), strides=2, padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    return x

def AttentionRefinementModule(input_tensor):
    x = GlobalAveragePooling2D()(input_tensor)
    x = Dense(input_tensor.shape[-1], activation='sigmoid')(x)
    x = tf.keras.layers.Multiply()([input_tensor, x[:, tf.newaxis, tf.newaxis, :]])
    return x

def ContextPath(input_tensor):
    base_model = tf.keras.applications.Xception(include_top=False, weights='imagenet', input_tensor=input_tensor)
    feature_13 = base_model.get_layer('block13_sepconv2_bn').output
    feature_14 = base_model.get_layer('block14_sepconv2_act').output
    feature_13_arm = AttentionRefinementModule(feature_13)
    feature_14_arm = AttentionRefinementModule(feature_14)
    global_context = GlobalAveragePooling2D()(feature_14_arm)
    global_context = tf.keras.layers.Reshape((1, 1, -1))(global_context)
    global_context = tf.keras.layers.UpSampling2D(size=(tf.keras.backend.int_shape(feature_14_arm)[1], tf.keras.backend.int_shape(feature_14_arm)[2]), interpolation='nearest')(global_context)
    feature_14_arm = tf.keras.layers.Add()([feature_14_arm, global_context])
    feature_13_arm = UpSampling2D(size=(2, 2), interpolation='bilinear')(feature_13_arm)
    feature_14_arm = UpSampling2D(size=(4, 4), interpolation='bilinear')(feature_14_arm)
    return feature_13_arm, feature_14_arm

def FeatureFusionModule(spatial_out, context_out):
    concatenated = concatenate([spatial_out, context_out])
    x = Conv2D(256, (3, 3), padding='same')(concatenated)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    pooled = GlobalAveragePooling2D()(x)
    pooled = Dense(256, activation='relu')(pooled)
    pooled = Dense(256, activation='sigmoid')(pooled)
    pooled = tf.keras.layers.Reshape((1, 1, 256))(pooled)
    x = tf.keras.layers.Multiply()([x, pooled])
    return x

# Input layer
input_tensor = Input(shape=(512, 512, 3))

# Spatial Path
spatial_out = SpatialPath(input_tensor)

# Context Path
context_out_13, context_out_14 = ContextPath(input_tensor)

# Feature Fusion Module
fused_out = FeatureFusionModule(spatial_out, context_out_14)

# Upsample the final output
fused_out = UpSampling2D(size=(8, 8), interpolation='bilinear')(fused_out)

# Final classifier
output_tensor = Conv2D(num_classes, (1, 1), activation='softmax')(fused_out)

# Create the model
model = Model(input_tensor, output_tensor)

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Model summary
model.summary()

# Train the model
model.fit(train_gen, epochs=epochs, verbose=1)


Epoch 1/2
