In [None]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from sklearn.model_selection import train_test_split
from skimage.metrics import structural_similarity as ssim, peak_signal_noise_ratio as psnr
import shutil

# Paths to the dataset folders
gray_image_path = "/kaggle/input/natural-color-dataset/NCDataset/Gray/"
color_image_path = "/kaggle/input/natural-color-dataset/ColorfulOriginal/ColorfulOriginal/"
IMAGE_SIZE = 256

def load_images(grayscale_path, color_path, image_size=(IMAGE_SIZE, IMAGE_SIZE)):
    grayscale_images = []
    color_images = []
    category_names = []
    valid_extensions = {".jpg", ".jpeg", ".png"}
    
    for category in os.listdir(grayscale_path):
        gray_dir = os.path.join(grayscale_path, category)
        color_dir = os.path.join(color_path, category)
        
        if os.path.isdir(gray_dir) and os.path.isdir(color_dir):
            for gray_file in os.listdir(gray_dir):
                if not os.path.splitext(gray_file)[1].lower() in valid_extensions:
                    continue
                
                gray_image_path_full = os.path.join(gray_dir, gray_file)
                color_image_path_full = os.path.join(color_dir, gray_file)
                
                if os.path.exists(gray_image_path_full) and os.path.exists(color_image_path_full):
                    gray_img = cv2.imread(gray_image_path_full, cv2.IMREAD_GRAYSCALE)
                    color_img = cv2.imread(color_image_path_full)
                    
                    if gray_img is None or color_img is None:
                        continue
                    
                    gray_img = cv2.resize(gray_img, image_size)
                    color_img = cv2.resize(color_img, image_size)
                    
                    # Convert color images from BGR to RGB
                    color_img = cv2.cvtColor(color_img, cv2.COLOR_BGR2RGB)
                    
                    gray_img = gray_img.astype('float32') / 255.0
                    color_img = color_img.astype('float32') / 255.0
                    
                    gray_img = np.expand_dims(gray_img, axis=-1)
                    
                    grayscale_images.append(gray_img)
                    color_images.append(color_img)
                    category_names.append(category)
                    
    return np.array(grayscale_images), np.array(color_images), np.array(category_names)

print("Grayscale Folder Contents:", os.listdir(gray_image_path))
print("Color Folder Contents:", os.listdir(color_image_path))

grayscale_images, color_images, category_names = load_images(gray_image_path, color_image_path)
print("Grayscale Images Shape:", grayscale_images.shape)
print("Color Images Shape:", color_images.shape)
print("Category Names Shape:", category_names.shape)

# Split data into training and validation sets
X_train, X_val, y_train, y_val, category_train, category_val = train_test_split(
    grayscale_images, color_images, category_names, test_size=0.2, random_state=42
)

print(f"Training Set: {X_train.shape}, {y_train.shape}, {category_train.shape}")
print(f"Validation Set: {X_val.shape}, {y_val.shape}, {category_val.shape}")

def build_unet_colorization_model(input_size=(256, 256, 1)):
    inputs = layers.Input(input_size)
    
    # Encoder
    conv1 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(inputs)
    conv1 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(conv1)
    pool1 = layers.MaxPooling2D(pool_size=(2,2))(conv1)
    
    conv2 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(pool1)
    conv2 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(conv2)
    pool2 = layers.MaxPooling2D(pool_size=(2,2))(conv2)
    
    conv3 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(pool2)
    conv3 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(conv3)
    pool3 = layers.MaxPooling2D(pool_size=(2,2))(conv3)
    
    conv4 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(pool3)
    conv4 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(conv4)
    pool4 = layers.MaxPooling2D(pool_size=(2,2))(conv4)
    
    # Bottleneck
    conv5 = layers.Conv2D(1024, (3,3), activation='relu', padding='same')(pool4)
    conv5 = layers.Conv2D(1024, (3,3), activation='relu', padding='same')(conv5)
    
    # Decoder
    up6 = layers.UpSampling2D(size=(2,2))(conv5)
    up6 = layers.Conv2D(512, (2,2), activation='relu', padding='same')(up6)
    merge6 = layers.concatenate([conv4, up6], axis=3)
    conv6 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(merge6)
    conv6 = layers.Conv2D(512, (3,3), activation='relu', padding='same')(conv6)
    
    up7 = layers.UpSampling2D(size=(2,2))(conv6)
    up7 = layers.Conv2D(256, (2,2), activation='relu', padding='same')(up7)
    merge7 = layers.concatenate([conv3, up7], axis=3)
    conv7 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(merge7)
    conv7 = layers.Conv2D(256, (3,3), activation='relu', padding='same')(conv7)
    
    up8 = layers.UpSampling2D(size=(2,2))(conv7)
    up8 = layers.Conv2D(128, (2,2), activation='relu', padding='same')(up8)
    merge8 = layers.concatenate([conv2, up8], axis=3)
    conv8 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(merge8)
    conv8 = layers.Conv2D(128, (3,3), activation='relu', padding='same')(conv8)
    
    up9 = layers.UpSampling2D(size=(2,2))(conv8)
    up9 = layers.Conv2D(64, (2,2), activation='relu', padding='same')(up9)
    merge9 = layers.concatenate([conv1, up9], axis=3)
    conv9 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(merge9)
    conv9 = layers.Conv2D(64, (3,3), activation='relu', padding='same')(conv9)
    
    conv10 = layers.Conv2D(3, (1,1), activation='sigmoid')(conv9)
    
    model = Model(inputs=inputs, outputs=conv10)
    
    return model

def psnr_metric(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

model = build_unet_colorization_model()
model.compile(optimizer=Adam(learning_rate=0.001), loss='mse', metrics=['mae', psnr_metric])

model.summary()

# Callbacks
checkpoint = ModelCheckpoint('best_colorization_model.keras', monitor='val_loss', save_best_only=True, mode='min')
lr_reduction = ReduceLROnPlateau(monitor='val_loss', patience=5, factor=0.5, min_lr=1e-6)
early_stop = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

# Train the model
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=50,
    batch_size=16,
    callbacks=[checkpoint, lr_reduction, early_stop],
    verbose=1
)

# Plot training history in a 3x3 grid (showing all metrics)
plt.figure(figsize=(15, 15))
metrics = ['loss', 'val_loss', 'mae', 'val_mae', 'psnr_metric', 'val_psnr_metric']
for i, metric in enumerate(metrics):
    plt.subplot(3, 3, i+1)
    plt.plot(history.history.get(metric, []), label=metric)
    plt.xlabel('Epochs')
    plt.ylabel(metric)
    plt.title(f'Training and Validation {metric.upper()}')
    plt.legend()
plt.tight_layout()
plt.savefig('training_history_grid.png')
plt.show()

# Load the best model
model.load_weights('best_colorization_model.keras')

# Select multiple samples for grid visualization
num_samples = 3  # Number of rows
sample_indices = np.random.choice(len(X_val), num_samples, replace=False)

plt.figure(figsize=(18, 18))
for i, idx in enumerate(sample_indices):
    grayscale_image = X_val[idx].squeeze()
    ground_truth = y_val[idx]
    category_name = category_val[idx]
    
    # Predict colorized image
    sample_grayscale = X_val[idx:idx+1]
    predicted_color = model.predict(sample_grayscale)[0]
    predicted_color = np.clip(predicted_color, 0, 1)
    
    # Compute metrics
    psnr_val = psnr(ground_truth, predicted_color, data_range=1.0)
    ssim_val = ssim(ground_truth, predicted_color, channel_axis=-1, data_range=1.0)
    
    print(f"Sample {i+1} - Category: {category_name}, PSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}")
    
    # Plot Grayscale Image
    plt.subplot(num_samples, 3, i*3 + 1)
    plt.title(f"Grayscale\nCategory: {category_name}")
    plt.imshow(grayscale_image, cmap='gray')
    plt.axis('off')
    
    # Plot Colorized Image
    plt.subplot(num_samples, 3, i*3 + 2)
    plt.title(f"Colorized\nPSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}")
    plt.imshow(predicted_color)
    plt.axis('off')
    
    # Plot Ground Truth
    plt.subplot(num_samples, 3, i*3 + 3)
    plt.title("Ground Truth")
    plt.imshow(ground_truth)
    plt.axis('off')

plt.tight_layout()
plt.savefig('colorization_grid.png')
plt.show()

# Save the model
shutil.make_archive('best_colorization_model', 'zip', '.', 'best_colorization_model.keras')


Grayscale Folder Contents: ['ChilliGreen', 'Broccoli', 'Orange', 'Tomato', 'Brinjal', 'Pomegranate', 'Plum', 'Apple', 'Carrot', 'Pear', 'Strawberry', 'CapsicumGreen', 'LadyFinger', 'Lemon', 'Cucumber', 'Peach', 'Corn', 'Banana', 'Cherry', 'Potato']
Color Folder Contents: ['ChilliGreen', 'Broccoli', 'Orange', 'Tomato', 'Brinjal', 'Pomegranate', 'Plum', 'Apple', 'Carrot', 'Pear', 'Strawberry', 'CapsicumGreen', 'LadyFinger', 'Lemon', 'Cucumber', 'Peach', 'Corn', 'Banana', 'Cherry', 'Potato']
