In [1]:
import os
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import SGD
from sklearn.model_selection import train_test_split
from sspfpn import SSPFPN, customized_loss
import matrix as mt
import prepareVOC12 as voc

In [2]:
class_dict = voc.class_dict
data_dir = "./VOC12/train"
labels_dir = "./VOC12/train_label"

x, y = voc.load_data(data_dir, labels_dir)
x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2, random_state=42)

print(x_train.shape)
print(x_val.shape)
print(y_train.shape)
print(y_val.shape)

(1171, 224, 224, 3)
(293, 224, 224, 3)
(1171, 224, 224, 3)
(293, 224, 224, 3)


In [None]:
def color_contained(label_map):
    print("Label map shape:", label_map.shape)

    # Reshape the label_map to a 2D array
    label_map_2d = label_map.reshape(-1, 3)

    # Find the unique color vectors in the label map
    unique_colors = np.unique(label_map_2d, axis=0)
    plt.imshow(label_map)
    plt.show()

    print("Unique color vectors in label map:", unique_colors)

In [None]:
# convert labels to one_hot_maps

y_train_onehot = []
y_val_onehot = []

for label_map in y_train:
    one_hot_map = voc.label_to_onehot(label_map, class_dict)
    y_train_onehot.append(one_hot_map)

for label_map in y_val:
    one_hot_map = voc.label_to_onehot(label_map, class_dict)
    y_val_onehot.append(one_hot_map)

y_train_onehot = np.array(y_train_onehot)
y_val_onehot = np.array(y_val_onehot)

In [None]:
index = 52
original_label_map = y_val[index]
one_hot_map = y_val_onehot[index]
converted_label_map = voc.onehot_to_label(one_hot_map, class_dict)

fig, axs = plt.subplots(1, 2, figsize=(10, 5))
axs[0].imshow(original_label_map)
axs[0].set_title('Original Label Map')
axs[1].imshow(converted_label_map)
axs[1].set_title('Converted Label Map')
plt.show()
# print(one_hot_map[90][200])
# for i in range(len(one_hot_map)):
#     for j in range(len(one_hot_map)):
#         if sum(one_hot_map[i][j]) != 0:
#             print(one_hot_map[i][j]) 

color_contained(original_label_map)
color_contained(converted_label_map)

In [None]:
# Create and train the model
input_shape = (224, 224, 3)

model = SSPFPN(input_shape)

In [4]:
# epochs = 50
# batch_size = 10
# steps_per_epoch = len(x_train) // batch_size
# optimizer = tf.keras.optimizers.SGD(learning_rate=2.5e-4, momentum=0.9, decay=5e-4)
print(x_train[0].shape)


print(np.array(mt.split_image_into_blocks(x_train[0]))[7][7].shape)

(224, 224, 3)
(28, 28, 3)


In [None]:
# Prepare dataset
from tensorflow.keras.optimizers.schedules import ExponentialDecay
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train_onehot)).batch(5)
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val_onehot)).batch(5)

# Set the optimizer
initial_learning_rate = 2.5e-4
decay_rate = 0.9
decay_steps = 300  # training size / batch size

lr_schedule = ExponentialDecay(initial_learning_rate, decay_steps, decay_rate, staircase=True)
opt = SGD(learning_rate=lr_schedule, momentum=0.9)

In [None]:
from tensorflow.keras.mixed_precision import global_policy, set_global_policy, Policy

policy = Policy('mixed_float16')
set_global_policy(policy)

In [None]:
# Train for 50 epochs
for epoch in range(50):
    print(f"Epoch {epoch + 1}/50")

    # Train
    for batch_x, batch_y in train_dataset:
        with tf.GradientTape() as tape:
            batch_x = tf.cast(batch_x, tf.float32)
            y_pred = model(batch_x, training=True)  
            loss = customized_loss(batch_x, batch_y, y_pred)
            gradients = tape.gradient(loss, model.trainable_variables)
            opt.apply_gradients(zip(gradients, model.trainable_variables))

    # Validate
    val_losses = []
    for batch_x, batch_y in val_dataset:
        y_pred = model(batch_x, training=False)
        loss = customized_loss(batch_x, batch_y, y_pred)
        val_losses.append(loss.numpy())
    val_loss = np.mean(val_losses)
    print(f"Validation loss: {val_loss:.4f}")