In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Layer, Input, UpSampling2D, Concatenate, Reshape, Lambda
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Model

class SpatialRNNCell(Layer):
    def __init__(self, direction, **kwargs):
        super(SpatialRNNCell, self).__init__(**kwargs)
        self.direction = direction

    def roll_output(self, inputs, direction):
        shifts = {"right": (0, 1), "left": (0, -1), "up": (-1, 0), "down": (1, 0)}
        return tf.roll(inputs, shift=shifts[direction], axis=(1, 2))

    def build(self, input_shape):
        self.alpha = self.add_weight(shape=(1,), initializer="ones", trainable=True)
        super(SpatialRNNCell, self).build(input_shape)

    def call(self, inputs):
        num_steps = tf.shape(inputs)[2] if self.direction in ['left', 'right'] else tf.shape(inputs)[1]
        outputs = inputs
        for _ in tf.range(num_steps):
            outputs = Lambda(lambda x: self.roll_output(x, self.direction))(outputs)
            outputs *= self.alpha
            outputs = tf.nn.relu(outputs)
        return outputs

    def compute_output_shape(self, input_shape):
        return input_shape

class AttentionEstimatorNetwork(Layer):
    def __init__(self, **kwargs):
        super(AttentionEstimatorNetwork, self).__init__(**kwargs)
        self.conv1 = Conv2D(32, (3, 3), activation='relu', padding='same')
        self.conv2 = Conv2D(32, (3, 3), activation='relu', padding='same')
        self.conv3 = Conv2D(4, (1, 1), activation=None, padding='same')

    def build(self, input_shape):
        self.conv1.build(input_shape)
        self.conv2.build(self.conv1.compute_output_shape(input_shape))
        self.conv3.build(self.conv2.compute_output_shape(self.conv1.compute_output_shape(input_shape)))
        super(AttentionEstimatorNetwork, self).build(input_shape)

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        return self.conv3(x)
    
    def compute_output_shape(self, input_shape):
        output_shape = self.conv1.compute_output_shape(input_shape)
        output_shape = self.conv2.compute_output_shape(output_shape)
        return self.conv3.compute_output_shape(output_shape)
    
class DirectionAwareSpatialContextModule(Layer):
    def __init__(self, **kwargs):
        super(DirectionAwareSpatialContextModule, self).__init__(**kwargs)
        self.attention_network = AttentionEstimatorNetwork()
        self.spatial_rnn_cells = {
            'right': SpatialRNNCell('right'),
            'left': SpatialRNNCell('left'),
            'up': SpatialRNNCell('up'),
            'down': SpatialRNNCell('down')
        }
        self.hidden_to_hidden_conv = None

    def build(self, input_shape):
        self.attention_network.build(input_shape)
        for cell in self.spatial_rnn_cells.values():
            cell.build(input_shape)
        num_channels = self.attention_network.compute_output_shape(input_shape)[-1]
        self.hidden_to_hidden_conv = Conv2D(num_channels // 4, (1, 1), activation='relu')
        self.hidden_to_hidden_conv.build(input_shape)
        super(DirectionAwareSpatialContextModule, self).build(input_shape)

    def call(self, inputs):
        attention_weights = self.attention_network(inputs)
        attention_weights_split = tf.split(attention_weights, num_or_size_splits=4, axis=-1)
        
        context_features = []
        for direction, rnn_cell in self.spatial_rnn_cells.items():
            rnn_output = rnn_cell(inputs)
            weighted_feature = rnn_output * attention_weights_split.pop(0)
            context_features.append(weighted_feature)

        concatenated_features = tf.concat(context_features, axis=-1)
        output_dsc_features = self.hidden_to_hidden_conv(concatenated_features)
        return output_dsc_features

    def compute_output_shape(self, input_shape):
        output_shape = (input_shape[0], input_shape[1], input_shape[2], input_shape[3] // 4)
        return output_shape



# Instantiate the VGG16 model
vgg = VGG16(include_top=False, weights='imagenet', input_tensor=Input(shape=(224, 224, 3)))

# Initialize the DirectionAwareSpatialContextModule
dascm = DirectionAwareSpatialContextModule()

# Define the input for the functional model
input_tensor = Input(shape=(224, 224, 3))

# Get the feature maps from all layers of VGG16, except the input layer
vgg_outputs = [layer.output for layer in vgg.layers[1:]]  # Skip the input layer

# Upsample and apply the DirectionAwareSpatialContextModule to each of the feature maps
dascm_outputs = []
for output in vgg_outputs:
    # The output needs to be upsampled to the same size
    # Calculate the upsample size for current feature map
    upsample_size = (224 // output.shape[1], 224 // output.shape[2])
    upsampled_output = UpSampling2D(size=upsample_size)(output)
    # Process each upsampled output with the dascm
    processed_output = dascm(upsampled_output)
    dascm_outputs.append(processed_output)

# Concatenate all the DASC module feature maps
# We can concatenate along the channel axis as the spatial dimensions are now equal
concatenated_outputs = Concatenate(axis=-1)(dascm_outputs)

# Define the complete model
model = Model(inputs=input_tensor, outputs=concatenated_outputs)

# Compile the model
# model.compile(optimizer='adam', loss='categorical_crossentropy')

# Print the model summary to check if everything is connected properly
# model.summary()

In [2]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os

# Define paths to the datasets
image_dir = 'E:/CS 512 Project/Proj/shadow/SBU-shadow/SBUTrain4KRecoveredSmall/ShadowImages'
mask_dir = 'E:/CS 512 Project/Proj/shadow/SBU-shadow/SBUTrain4KRecoveredSmall/ShadowMasks'

# Initialize the data generators
image_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)
mask_datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

# Connect the generators to your dataset
# Since we are not classifying images, we set class_mode to None
image_generator = image_datagen.flow_from_directory(
    image_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode=None,  # No class labels
    subset='training',
    seed=1)

mask_generator = mask_datagen.flow_from_directory(
    mask_dir,
    target_size=(224, 224),
    batch_size=32,
    class_mode=None,  # No class labels
    color_mode='grayscale',  # Masks are grayscale
    subset='training',
    seed=1)



Found 3268 images belonging to 1 classes.
Found 3268 images belonging to 1 classes.


In [3]:
def weighted_cross_entropy(y_true, y_pred):
    # Calculate the weight for each class
    Np = tf.reduce_sum(y_true)  # Number of positive samples
    Nn = tf.reduce_sum(1 - y_true)  # Number of negative samples
    TN = tf.math.count_nonzero(y_pred * (1 - y_true))
    TP = tf.math.count_nonzero((1 - y_pred) * y_true)
    
    # Weights for balancing the classes
    weight_for_0 = (Nn / (Nn + Np)) * (1 / (1 - TN / Nn))
    weight_for_1 = (Np / (Nn + Np)) * (1 / (1 - TP / Np))
    
    # Calculate the actual weighted cross-entropy
    # The '+ epsilon()' part is to avoid log(0)
    loss = -(weight_for_1 * y_true * tf.math.log(y_pred + tf.keras.backend.epsilon()) + 
             weight_for_0 * (1 - y_true) * tf.math.log(1 - y_pred + tf.keras.backend.epsilon()))
    
    return tf.reduce_mean(loss)

model.compile(optimizer='adam', loss=weighted_cross_entropy)


In [4]:
import math
def generate_train_batches(image_generator, mask_generator):
    while True:
        # Get next image and mask batch
        image_batch = next(image_generator)
        mask_batch = next(mask_generator)
        yield (image_batch, mask_batch)

# Create a generator
train_generator = generate_train_batches(image_generator, mask_generator)

# Ensure that image_dir points directly to the folder containing the images
steps_per_epoch = math.ceil(len(os.listdir(image_dir)) / 32)

# Train the model
history = model.fit(
    train_generator,
    steps_per_epoch=steps_per_epoch,
    epochs=50  # or however many you choose
)



Epoch 1/50


KeyError: 'Exception encountered when calling Functional.call().\n\n\x1b[1m2189501697936\x1b[0m\n\nArguments received by Functional.call():\n  • inputs=tf.Tensor(shape=(None, 224, 224, 3), dtype=float32)\n  • training=True\n  • mask=None'

In [16]:
# Test the output of the generator
image_batch, mask_batch = next(train_generator)
print('Image batch shape:', image_batch.shape)
print('Mask batch shape:', mask_batch.shape)


Image batch shape: (32, 224, 224, 3)
Mask batch shape: (32, 224, 224, 1)


In [19]:
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
