# First PART

# WORKING MULTICLASS SEGMENTATION
## START


In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model
import tensorflow.keras.backend as K
import matplotlib.pyplot as plt
import numpy as np

# Constants
IMG_SIZE = (256, 256)
BATCH_SIZE = 32
EPOCHS = 5
#First, update the number of classes to include background
NUM_CLASSES = 6  # 5 injury types + 1 background class

#Update INJURY_CLASSES dictionary to include background
INJURY_CLASSES = {
    0: {"name": "background", "color": (0, 0, 0)},      # Black for background
    1: {"name": "abrasion", "color": (255, 0, 0)},      # Red
    2: {"name": "bruise", "color": (0, 0, 255)},        # Blue
    3: {"name": "cut", "color": (0, 255, 0)},           # Green
    4: {"name": "laceration", "color": (255, 255, 0)},  # Yellow
    5: {"name": "stab_wound", "color": (255, 0, 255)}   # Magenta
}

# Define paths
train_files = ["/kaggle/input/compound-injury-dataset/train.tfrecord"]
val_files = ["/kaggle/input/compound-injury-dataset/val.tfrecord"]
test_files = ["/kaggle/input/compound-injury-dataset/test.tfrecord"]

def weighted_categorical_accuracy(y_true, y_pred):
    """
    Custom accuracy metric that handles class imbalance
    by weighting each class equally regardless of pixel count
    """
    threshold = 0.5
    accuracies = []

    for i in range(NUM_CLASSES):
        y_true_class = y_true[..., i]
        y_pred_class = tf.cast(y_pred[..., i] > threshold, tf.float32)

        true_positives = K.sum(y_true_class * y_pred_class)
        total_pixels = K.sum(y_true_class)

        # Avoid division by zero
        class_accuracy = tf.where(
            total_pixels > 0,
            true_positives / (total_pixels + K.epsilon()),
            0.0
        )
        accuracies.append(class_accuracy)

    return K.mean(tf.stack(accuracies))


def parse_tfrecord(example_proto, img_size=IMG_SIZE, training=False):
    """
    Parse TFRecord format data into images and masks with proper background handling
    """
    feature_description = {
        'image/encoded': tf.io.FixedLenFeature([], tf.string),
        'image/height': tf.io.FixedLenFeature([], tf.int64),
        'image/width': tf.io.FixedLenFeature([], tf.int64),
        'image/object/bbox/xmin': tf.io.VarLenFeature(tf.float32),
        'image/object/bbox/xmax': tf.io.VarLenFeature(tf.float32),
        'image/object/bbox/ymin': tf.io.VarLenFeature(tf.float32),
        'image/object/bbox/ymax': tf.io.VarLenFeature(tf.float32),
        'image/object/class/label': tf.io.VarLenFeature(tf.int64),
    }

    example = tf.io.parse_single_example(example_proto, feature_description)

    # Decode and preprocess image
    image = tf.io.decode_jpeg(example['image/encoded'], channels=3)
    image = tf.image.resize(image, img_size)
    image = tf.cast(image, tf.float32) / 255.0

    # if training:
    #     image = tf.image.random_brightness(image, 0.1)
    #     image = tf.image.random_contrast(image, 0.9, 1.1)
    #     image = tf.clip_by_value(image, 0, 1)

    # Get bounding box coordinates and class labels
    xmin = tf.sparse.to_dense(example['image/object/bbox/xmin'])
    xmax = tf.sparse.to_dense(example['image/object/bbox/xmax'])
    ymin = tf.sparse.to_dense(example['image/object/bbox/ymin'])
    ymax = tf.sparse.to_dense(example['image/object/bbox/ymax'])
    class_labels = tf.cast(tf.sparse.to_dense(example['image/object/class/label']), tf.int32)

    # Initialize masks tensor with -1 to distinguish unclassified pixels
    masks = tf.fill(img_size, -1)

    def process_box(i, masks):
        y_start = tf.cast(ymin[i] * tf.cast(img_size[0], tf.float32), tf.int32)
        y_end = tf.cast(ymax[i] * tf.cast(img_size[0], tf.float32), tf.int32)
        x_start = tf.cast(xmin[i] * tf.cast(img_size[1], tf.float32), tf.int32)
        x_end = tf.cast(xmax[i] * tf.cast(img_size[1], tf.float32), tf.int32)

        y_coords = tf.range(y_start, y_end)
        x_coords = tf.range(x_start, x_end)
        y_coords, x_coords = tf.meshgrid(y_coords, x_coords, indexing='ij')

        # Shift class indices by 1 to account for background class
        class_idx = class_labels[i]
        indices = tf.stack([
            tf.reshape(y_coords, [-1]),
            tf.reshape(x_coords, [-1])
        ], axis=1)
        updates = tf.fill([tf.size(y_coords)], class_idx)
        return tf.tensor_scatter_nd_update(masks, indices, updates)

    # Process each bounding box
    num_boxes = tf.shape(xmin)[0]
    i = tf.constant(0)

    def condition(i, masks):
        return tf.less(i, num_boxes)

    def body(i, masks):
        masks = process_box(i, masks)
        return i + 1, masks

    _, final_masks = tf.while_loop(
        condition,
        body,
        [i, masks],
        shape_invariants=[i.get_shape(), tf.TensorShape(img_size)]
    )

    # Set unclassified pixels (still -1) to background class (0)
    final_masks = tf.where(final_masks < 0, 0, final_masks)
    
    # Convert to one-hot encoding
    final_masks = tf.one_hot(final_masks, NUM_CLASSES)
    
    return image, final_masks

def multiclass_dice_loss(y_true, y_pred):
    """
    Simplified Dice loss for multi-class segmentation with proper gradient computation
    """
    smooth = 1.0
    loss = 0
    
    for i in range(NUM_CLASSES):
        # Extract the current class
        y_true_f = y_true[..., i]
        y_pred_f = y_pred[..., i]
        
        # Calculate intersection and union directly
        intersection = K.sum(y_true_f * y_pred_f, axis=[1, 2])
        denominator = K.sum(y_true_f, axis=[1, 2]) + K.sum(y_pred_f, axis=[1, 2])
        
        # Calculate dice coefficient
        dice = (2. * intersection + smooth) / (denominator + smooth)
        loss += (1 - K.mean(dice))
    
    return loss / NUM_CLASSES


def multiclass_iou_metric(y_true, y_pred):
    """
    Simplified IoU metric with proper gradient computation
    """
    smooth = 1.0
    ious = []
    
    for i in range(NUM_CLASSES):
        y_true_f = y_true[..., i]
        y_pred_f = y_pred[..., i]
        
        intersection = K.sum(y_true_f * y_pred_f, axis=[1, 2])
        union = K.sum(y_true_f, axis=[1, 2]) + K.sum(y_pred_f, axis=[1, 2]) - intersection
        iou = K.mean((intersection + smooth) / (union + smooth))
        ious.append(iou)
    
    return K.mean(tf.stack(ious))

def load_dataset(tfrecord_files, img_size=IMG_SIZE, batch_size=BATCH_SIZE, training=False):
    """
    Load and prepare dataset from TFRecord files
    """
    dataset = tf.data.TFRecordDataset(tfrecord_files)
    dataset = dataset.map(
        lambda x: parse_tfrecord(x, img_size, training),
        num_parallel_calls=tf.data.AUTOTUNE
    )
    if training:
        dataset = dataset.shuffle(1000)
    dataset = dataset.repeat()
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

def enhanced_unet_model(img_size=(256,256),filters_base=64, dropout_rate=0.3):
    """
    Improved U-Net architecture with:
    1. Batch Normalization for better training stability
    2. Spatial Dropout for regularization
    3. Residual connections in each block
    4. LeakyReLU activation for better gradient flow
    5. Optional SE (Squeeze-and-Excitation) blocks for channel attention
    """
    inputs = layers.Input(shape=(img_size[0], img_size[1], 3))

    def conv_block(x, filters, dropout=True):
        # First conv layer in block
        conv = layers.Conv2D(filters, 3, padding='same')(x)
        conv = layers.BatchNormalization()(conv)
        conv = layers.LeakyReLU(0.2)(conv)

        # Second conv layer in block
        conv = layers.Conv2D(filters, 3, padding='same')(conv)
        conv = layers.BatchNormalization()(conv)
        conv = layers.LeakyReLU(0.2)(conv)

        # Optional dropout
        if dropout:
            conv = layers.SpatialDropout2D(dropout_rate)(conv)

        # Residual connection
        if x.shape[-1] == filters:
            conv = layers.Add()([conv, x])

        return conv

    def se_block(x, filters):
        """Squeeze-and-Excitation block for channel attention"""
        se = layers.GlobalAveragePooling2D()(x)
        se = layers.Dense(filters // 16, activation='relu')(se)
        se = layers.Dense(filters, activation='sigmoid')(se)
        se = layers.Reshape((1, 1, filters))(se)
        return layers.multiply([x, se])

    # Encoder
    conv1 = conv_block(inputs, filters_base, dropout=False)
    pool1 = layers.MaxPooling2D(2)(conv1)

    conv2 = conv_block(pool1, filters_base*2)
    pool2 = layers.MaxPooling2D(2)(conv2)

    conv3 = conv_block(pool2, filters_base*4)
    conv3 = se_block(conv3, filters_base*4)  # Add SE block
    pool3 = layers.MaxPooling2D(2)(conv3)

    # Bridge
    conv4 = conv_block(pool3, filters_base*8)
    conv4 = se_block(conv4, filters_base*8)  # Add SE block

    # Decoder
    up3 = layers.UpSampling2D(2)(conv4)
    # Add 1x1 conv for channel matching if needed
    if up3.shape[-1] != conv3.shape[-1]:
        up3 = layers.Conv2D(conv3.shape[-1], 1, padding='same')(up3)
    merge3 = layers.Concatenate()([up3, conv3])
    conv5 = conv_block(merge3, filters_base*4)

    up2 = layers.UpSampling2D(2)(conv5)
    if up2.shape[-1] != conv2.shape[-1]:
        up2 = layers.Conv2D(conv2.shape[-1], 1, padding='same')(up2)
    merge2 = layers.Concatenate()([up2, conv2])
    conv6 = conv_block(merge2, filters_base*2)

    up1 = layers.UpSampling2D(2)(conv6)
    if up1.shape[-1] != conv1.shape[-1]:
        up1 = layers.Conv2D(conv1.shape[-1], 1, padding='same')(up1)
    merge1 = layers.Concatenate()([up1, conv1])
    conv7 = conv_block(merge1, filters_base, dropout=False)

    # Output
    outputs = layers.Conv2D(NUM_CLASSES, 1, activation='softmax')(conv7)



    return Model(inputs, outputs)

def plot_training_history(history):
    """
    Plot training metrics
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Plot loss
    ax1.plot(history.history['loss'], label='Training Loss')
    ax1.plot(history.history['val_loss'], label='Validation Loss')
    ax1.set_title('Model Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()

    # Plot IoU
    ax2.plot(history.history['multiclass_iou_metric'], label='Training IoU')
    ax2.plot(history.history['val_multiclass_iou_metric'], label='Validation IoU')
    ax2.set_title('Model IoU')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('IoU')
    ax2.legend()

    plt.tight_layout()
    plt.show()






In [None]:
# Load datasets
train_dataset = load_dataset(train_files, training=True)
val_dataset = load_dataset(val_files)
# Create and compile model
multiClassWorkingModel = enhanced_unet_model()
multiClassWorkingModel.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss='categorical_crossentropy',  # Changed to standard categorical crossentropy
    metrics=['categorical_accuracy', multiclass_iou_metric]
)

In [None]:
#multiClassWorkingModel.summary()

In [None]:
# Calculate steps per epoch
steps_per_epoch = 2560 // BATCH_SIZE
validation_steps = 190 // BATCH_SIZE
test_steps = 120 // BATCH_SIZE

# Train model
history = multiClassWorkingModel.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=EPOCHS,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
)

In [None]:
# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['multiclass_iou_metric'], label='Training IoU')
plt.plot(history.history['val_multiclass_iou_metric'], label='Validation IoU')
plt.title('Model IoU')
plt.xlabel('Epoch')
plt.ylabel('IoU')
plt.legend()

plt.tight_layout()
plt.show()
test_dataset=load_dataset(test_files)

test_results = multiClassWorkingModel.evaluate(test_dataset, steps=test_steps)
print(f"\nTest Loss: {test_results[0]:.4f}")
print(f"Test IoU: {test_results[1]:.4f}")

In [None]:
INJURY_CLASSES = {
    0: {"name": "background", "color": (0, 0, 0)},      # Black for background
    1: {"name": "abrasion", "color": (255, 0, 0)},      # Red
    2: {"name": "bruise", "color": (0, 0, 255)},        # Blue
    3: {"name": "cut", "color": (0, 255, 0)},           # Green
    4: {"name": "laceration", "color": (255, 255, 0)},  # Yellow
    5: {"name": "stab_wound", "color": (255, 0, 255)}   # Magenta
}

# [Previous code remains the same up until the visualization functions]

def create_overlay_mask(mask, alpha=0.4):
    """
    Create a colored overlay visualization with mutually exclusive classes
    """
    # Remove batch dimension if present
    if len(mask.shape) > 3:
        mask = mask[0]

    height, width = mask.shape[:2]
    overlay = np.zeros((height, width, 4), dtype=np.float32)

    # Get the class with highest probability for each pixel
    class_indices = np.argmax(mask, axis=-1)
    
    # Create colored overlay using winning class only
    for class_idx in range(NUM_CLASSES):
        class_mask = (class_indices == class_idx)
        if np.any(class_mask):
            color = INJURY_CLASSES[class_idx]["color"]
            overlay[class_mask, 0] = color[0] / 255.0
            overlay[class_mask, 1] = color[1] / 255.0
            overlay[class_mask, 2] = color[2] / 255.0
            overlay[class_mask, 3] = alpha

    return overlay

def visualize_prediction(model, dataset):
    """
    Visualize predictions with mutually exclusive class assignment
    """
    for image, true_mask in dataset.take(10):
        if isinstance(image, tf.Tensor):
            image = image.numpy()
        if isinstance(true_mask, tf.Tensor):
            true_mask = true_mask.numpy()

        # Make prediction
        pred_mask = model.predict(image[0:1])
        display_image = image[0]

        # Create overlays
        true_overlay = create_overlay_mask(true_mask)
        pred_overlay = create_overlay_mask(pred_mask)

        # Create figure
        plt.figure(figsize=(20, 7))

        # Plot original image
        plt.subplot(1, 3, 1)
        plt.imshow(display_image)
        plt.title("Original Image")
        plt.axis('off')

        # Plot true segmentation
        plt.subplot(1, 3, 2)
        plt.imshow(display_image)
        plt.imshow(true_overlay)
        plt.title("True Segmentation")
        plt.axis('off')

        # Plot predicted segmentation
        plt.subplot(1, 3, 3)
        plt.imshow(display_image)
        plt.imshow(pred_overlay)
        plt.title("Predicted Segmentation")
        plt.axis('off')

        # Add legend
        legend_elements = [
            plt.Rectangle((0,0), 1, 1,
                         fc=tuple(c/255 for c in INJURY_CLASSES[i]["color"]) + (0.6,),
                         label=INJURY_CLASSES[i]["name"])
            for i in range(NUM_CLASSES)
        ]
        plt.figlegend(handles=legend_elements,
                     loc='center right',
                     bbox_to_anchor=(0.98, 0.5),
                     title="Injury Types")

        plt.tight_layout()
        plt.show()

        # Print statistics with mutually exclusive classification
        print("\nDetected injuries:")
        pred_classes = np.argmax(pred_mask[0], axis=-1)
        true_classes = np.argmax(true_mask[0], axis=-1)
        
        for i in range(NUM_CLASSES):
            true_pixels = np.sum(true_classes == i)
            pred_pixels = np.sum(pred_classes == i)
            if true_pixels > 0 or pred_pixels > 0:
                print(f"{INJURY_CLASSES[i]['name']}:")
                print(f"  True pixels: {true_pixels}")
                print(f"  Predicted pixels: {pred_pixels}")


In [None]:
visualize_prediction(multiClassWorkingModel,train_dataset)

## END


# WORKING MULTICLASS SEGMENTATION END


In [None]:
# Train model
history = multiClassWorkingModel.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=EPOCHS,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
)

In [None]:
multiClassWorkingModel.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=5,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
)

In [None]:
multiClassWorkingModel.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=5,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
)

In [None]:
visualize_prediction(multiClassWorkingModel,test_dataset)

In [None]:
test_results = multiClassWorkingModel.evaluate(test_dataset, steps=test_steps)
print(f"\nTest Loss: {test_results[0]:.4f}")
print(f"Test IoU: {test_results[1]:.4f}")

In [None]:
multiClassWorkingModel.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=5,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
)
test_results = multiClassWorkingModel.evaluate(test_dataset, steps=test_steps)
print(f"\nTest Loss: {test_results[0]:.4f}")
print(f"Test IoU: {test_results[1]:.4f}")

In [None]:
multiClassWorkingModel.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=5,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
)
test_results = multiClassWorkingModel.evaluate(test_dataset, steps=test_steps)
print(f"\nTest Loss: {test_results[0]:.4f}")
print(f"Test IoU: {test_results[1]:.4f}")

In [None]:
multiClassWorkingModel.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=5,
    steps_per_epoch=steps_per_epoch,
    validation_steps=validation_steps,
)
test_results = multiClassWorkingModel.evaluate(test_dataset, steps=test_steps)
print(f"\nTest Loss: {test_results[0]:.4f}")
print(f"Test IoU: {test_results[1]:.4f}")