# Extracting dataset and loading it in the environment

In [None]:
import os
import glob
import tarfile
import numpy as np
import cv2
import tensorflow as tf
from sklearn.model_selection import train_test_split
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# Paths to tar.gz files in Google Drive
mountain_tar_path = '/content/drive/MyDrive/seminar_proj_dataset/mountain.tar.gz'
forest_tar_path = '/content/drive/MyDrive/seminar_proj_dataset/forest_path.tar.gz'

# Extract tar.gz files
mountain_extract_dir = '/content/mountain'
forest_extract_dir = '/content/forest_path'

for tar_path, extract_dir in [(mountain_tar_path, mountain_extract_dir), (forest_tar_path, forest_extract_dir)]:
    if not os.path.exists(extract_dir):  # Avoid re-extracting
        with tarfile.open(tar_path, 'r:gz') as tar:
            tar.extractall(path=extract_dir)

# Assuming tar files contain folders 'mountain' and 'forest_path' respectively
mountain_base_dir = os.path.join(mountain_extract_dir, 'mountain')
forest_base_dir = os.path.join(forest_extract_dir, 'forest_path')

# Collect image paths from both folders
mountain_image_paths = glob.glob(os.path.join(mountain_base_dir, '**', '*.jpg'), recursive=True)
forest_image_paths = glob.glob(os.path.join(forest_base_dir, '**', '*.jpg'), recursive=True)

# Combine and shuffle all image paths
all_image_paths = mountain_image_paths + forest_image_paths
print(f"Total mountain images: {len(mountain_image_paths)}")
print(f"Total forest images: {len(forest_image_paths)}")
print(f"Total combined images: {len(all_image_paths)}")

# Loading and preprocessing the data for input for training and visualizaing it as well

In [None]:
import os
import glob
import numpy as np
import cv2
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# Load all mountain and forest images (combined logic from earlier loader)
mountain_base_dir = '/content/dataset_combined/mountain'
forest_base_dir = '/content/dataset_combined/forest_path'

mountain_image_paths = glob.glob(os.path.join(mountain_base_dir, '**', '*.jpg'), recursive=True)
forest_image_paths = glob.glob(os.path.join(forest_base_dir, '**', '*.jpg'), recursive=True)

# Combine both sets
image_paths = sorted(mountain_image_paths + forest_image_paths)
print(f"Total images found: {len(image_paths)}")

def preprocess_image(image_path):
    img_bgr = cv2.imread(image_path)
    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
    img_rgb = cv2.resize(img_rgb, (128, 128))

    img_lab = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2Lab)

    L_channel = img_lab[:, :, 0:1].astype(np.float32) / 255.0
    ab_channels = (img_lab[:, :, 1:].astype(np.float32) - 128.0) / 127.0

    return L_channel, ab_channels

def view_preprocessed_image(image_path):
    L_channel, ab_channels = preprocess_image(image_path)

    # Denormalize
    L_denorm = L_channel * 100.0
    ab_denorm = ab_channels * 128.0

    # Ensure ab_channels are in the proper range [-128, 128]
    ab_denorm = np.clip(ab_denorm, -128, 128)

    # Concatenate the L channel and ab channels to form the Lab image
    lab_image = np.concatenate((L_denorm, ab_denorm), axis=-1)

    # Convert Lab image to uint8 (range 0-255)
    lab_image_uint8 = np.clip(lab_image, 0, 255).astype(np.uint8)

    # Convert the Lab image to RGB
    rgb_image = cv2.cvtColor(lab_image, cv2.COLOR_Lab2RGB)

    print("L_channel min/max:", L_channel.min(), L_channel.max())
    print("ab_channel min/max:", ab_channels.min(), ab_channels.max())
    print("L_denorm min/max:", L_denorm.min(), L_denorm.max())
    print("ab_denorm min/max:", ab_denorm.min(), ab_denorm.max())

    original_img = cv2.imread(image_path)
    original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
    original_img = cv2.resize(original_img, (128, 128))

    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(original_img)
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(rgb_image)
    plt.title('Preprocessed Reconstructed Image')
    plt.axis('off')
    plt.tight_layout()
    plt.show()

    plt.figure(figsize=(5, 5))
    plt.imshow(L_channel.squeeze(), cmap='gray')
    plt.title('Preprocessed L Channel (Grayscale)')
    plt.axis('off')
    plt.show()


def prepare_data(image_paths):
    # 90% train, 5% val, 5% test
    train_val_paths, test_paths = train_test_split(image_paths, test_size=0.05, random_state=42)
    train_paths, val_paths = train_test_split(train_val_paths, test_size=0.05263, random_state=42)  # ~5% of total

    print(f"Train: {len(train_paths)}, Val: {len(val_paths)}, Test: {len(test_paths)}")

    def batch_preprocess(paths):
        L_all, ab_all = [], []
        for path in paths:
            L, ab = preprocess_image(path)
            L_all.append(L)
            ab_all.append(ab)
        return np.array(L_all), np.array(ab_all)

    L_train, ab_train = batch_preprocess(train_paths)
    L_val, ab_val = batch_preprocess(val_paths)
    L_test, ab_test = batch_preprocess(test_paths)

    print("Train L shape:", L_train.shape, "| ab shape:", ab_train.shape)
    print("Val   L shape:", L_val.shape, "| ab shape:", ab_val.shape)
    print("Test  L shape:", L_test.shape, "| ab shape:", ab_test.shape)

    return (L_train, ab_train), (L_val, ab_val), (L_test, ab_test), train_paths, val_paths, test_paths

# 👇 Prepare and visualize
(L_train, ab_train), (L_val, ab_val), (L_test, ab_test), train_paths, val_paths, test_paths = prepare_data(image_paths)

view_preprocessed_image(train_paths[700])


# Building model architecture

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model

def build_unet_model(input_shape=(128, 128, 1)):
    """
    Build a lightweight U-Net model with dropout for image colorization.
    Args:
        input_shape: Shape of input L channel (default: [128, 128, 1]).
    Returns:
        Keras Model predicting ab channels [128, 128, 2].
    """
    inputs = layers.Input(shape=input_shape)

    # Encoder
    # Block 1
    conv1 = layers.Conv2D(32, 3, padding='same', activation='relu')(inputs)
    conv1 = layers.Conv2D(32, 3, padding='same')(conv1)
    conv1 = layers.BatchNormalization()(conv1)
    conv1 = layers.Activation('relu')(conv1)
    pool1 = layers.AveragePooling2D(pool_size=(2, 2))(conv1)

    # Block 2
    conv2 = layers.Conv2D(64, 3, padding='same', activation='relu')(pool1)
    conv2 = layers.Conv2D(64, 3, padding='same')(conv2)
    conv2 = layers.BatchNormalization()(conv2)
    conv2 = layers.Activation('relu')(conv2)
    pool2 = layers.AveragePooling2D(pool_size=(2, 2))(conv2)

    # Block 3
    conv3 = layers.Conv2D(128, 3, padding='same', activation='relu')(pool2)
    conv3 = layers.Conv2D(128, 3, padding='same')(conv3)
    conv3 = layers.BatchNormalization()(conv3)
    conv3 = layers.Activation('relu')(conv3)
    pool3 = layers.AveragePooling2D(pool_size=(2, 2))(conv3)

    # Block 4
    conv4 = layers.Conv2D(256, 3, padding='same', activation='relu')(pool3)
    conv4 = layers.Conv2D(256, 3, padding='same')(conv4)
    conv4 = layers.BatchNormalization()(conv4)
    conv4 = layers.Activation('relu')(conv4)
    pool4 = layers.AveragePooling2D(pool_size=(2, 2))(conv4)

    # Bottleneck
    conv5 = layers.Conv2D(256, 3, padding='same', activation='relu')(pool4)
    conv5 = layers.Conv2D(256, 3, padding='same')(conv5)
    conv5 = layers.BatchNormalization()(conv5)
    conv5 = layers.Activation('relu')(conv5)
    conv5 = layers.Dropout(0.3)(conv5)

    # Decoder
    # Block 4
    up6 = layers.UpSampling2D(size=(2, 2))(conv5)
    up6 = layers.Concatenate()([up6, conv4])
    conv6 = layers.Conv2D(128, 3, padding='same', activation='relu')(up6)
    conv6 = layers.Conv2D(128, 3, padding='same')(conv6)
    conv6 = layers.BatchNormalization()(conv6)
    conv6 = layers.Activation('relu')(conv6)
    conv6 = layers.Dropout(0.3)(conv6)

    # Block 3
    up7 = layers.UpSampling2D(size=(2, 2))(conv6)
    up7 = layers.Concatenate()([up7, conv3])
    conv7 = layers.Conv2D(64, 3, padding='same', activation='relu')(up7)
    conv7 = layers.Conv2D(64, 3, padding='same')(conv7)
    conv7 = layers.BatchNormalization()(conv7)
    conv7 = layers.Activation('relu')(conv7)
    conv7 = layers.Dropout(0.3)(conv7)

    # Block 2
    up8 = layers.UpSampling2D(size=(2, 2))(conv7)
    up8 = layers.Concatenate()([up8, conv2])
    conv8 = layers.Conv2D(32, 3, padding='same', activation='relu')(up8)
    conv8 = layers.Conv2D(32, 3, padding='same')(conv8)
    conv8 = layers.BatchNormalization()(conv8)
    conv8 = layers.Activation('relu')(conv8)
    conv8 = layers.Dropout(0.3)(conv8)

    # Block 1
    up9 = layers.UpSampling2D(size=(2, 2))(conv8)
    up9 = layers.Concatenate()([up9, conv1])
    conv9 = layers.Conv2D(32, 3, padding='same', activation='relu')(up9)
    conv9 = layers.Conv2D(32, 3, padding='same')(conv9)
    conv9 = layers.BatchNormalization()(conv9)
    conv9 = layers.Activation('relu')(conv9)

    # Output layer
    outputs = layers.Conv2D(2, 3, padding='same', activation='tanh')(conv9)

    return Model(inputs, outputs, name='Lightweight_U-Net_Colorization')

# Build and summarize the model
model = build_unet_model(input_shape=(128, 128, 1))
model.summary()

# Defining custom metrics (SSIM & PSNR) and training the model


In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model, callbacks
import tensorflow.keras.backend as K

# Define custom metrics (SSIM and PSNR)
def ssim_metric(y_true, y_pred):
    y_true_scaled = (y_true + 1) / 2
    y_pred_scaled = (y_pred + 1) / 2
    return tf.reduce_mean(tf.image.ssim(y_true_scaled, y_pred_scaled, max_val=1.0))

def psnr_metric(y_true, y_pred):
    return tf.reduce_mean(tf.image.psnr(y_true, y_pred, max_val=2.0))

# Compile the model (assumes 'model' is already defined)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss='mse',
    metrics=['mse', ssim_metric, psnr_metric]
)

# Define callbacks
early_stopping = callbacks.EarlyStopping(
    monitor='val_loss',
    patience=8,
    restore_best_weights=True,
    verbose=1
)

checkpoint = callbacks.ModelCheckpoint(
    filepath='/content/drive/MyDrive/mountains_forest_u_net_best.h5',
    monitor='val_loss',
    save_best_only=True,
    verbose=1
)

# Train the model using in-memory NumPy arrays
history = model.fit(
    x=L_train,
    y=ab_train,
    validation_data=(L_val, ab_val),
    batch_size=64,
    epochs=30,
    callbacks=[early_stopping, checkpoint],
    verbose=1
)

# Analyzing the results and visualizing the output produced


In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import cv2

# Convert normalized Lab to RGB
def lab_to_rgb(L, ab):
    L = L * 100.0
    ab = ab * 128.0
    lab = np.concatenate((L, ab), axis=-1).astype(np.float32)
    # lab = np.clip(lab, 0, 255).astype(np.uint8)
    rgb = cv2.cvtColor(lab, cv2.COLOR_Lab2RGB)
    return np.clip(rgb, 0, 1)  # Scale to [0,1] for visualization

# Pick 10 samples from test set
num_samples = 10
L_sample = L_test[num_samples:]
ab_true_sample = ab_test[num_samples:]

# Predict ab channels
ab_pred_sample = model.predict(L_sample, verbose=1)

# Convert Lab to RGB for display
rgb_original = [lab_to_rgb(L, ab) for L, ab in zip(L_sample, ab_true_sample)]
rgb_colorized = [lab_to_rgb(L, ab) for L, ab in zip(L_sample, ab_pred_sample)]
grayscale = [L.squeeze() for L in L_sample]

# Calculate metrics
mse_scores = [np.mean((true - pred) ** 2) for true, pred in zip(ab_true_sample, ab_pred_sample)]
ssim_scores = [tf.image.ssim((true + 1) / 2, (pred + 1) / 2, max_val=1.0).numpy() for true, pred in zip(ab_true_sample, ab_pred_sample)]
psnr_scores = [tf.image.psnr(true, pred, max_val=2.0).numpy() for true, pred in zip(ab_true_sample, ab_pred_sample)]

print(f"\nAverage MSE: {np.mean(mse_scores):.4f}")
print(f"Average SSIM: {np.mean(ssim_scores):.4f}")
print(f"Average PSNR: {np.mean(psnr_scores):.2f} dB")

# Plot results
plt.figure(figsize=(15, 30))
for i in range(num_samples):
    plt.subplot(num_samples, 3, i * 3 + 1)
    plt.imshow(rgb_original[i])
    plt.title('Original')
    plt.axis('off')

    plt.subplot(num_samples, 3, i * 3 + 2)
    plt.imshow(grayscale[i], cmap='gray')
    plt.title('Grayscale')
    plt.axis('off')

    plt.subplot(num_samples, 3, i * 3 + 3)
    plt.imshow(rgb_colorized[i])
    plt.title('Colorized')
    plt.axis('off')

plt.tight_layout()
plt.show()