In [None]:
import os
from PIL import Image
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import img_to_array
import numpy as np
import matplotlib.pyplot as plt

# Check GPU availability
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

# Step 1: Load and Preprocess Image from a Path
def load_and_resize_image(file_path, desired_size=(500, 500)):
    try:
        img = Image.open(file_path)
        img = img.convert("RGB")
        img_resized = img.resize(desired_size, Image.Resampling.LANCZOS)
        return img_resized
    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        return None

# Step 2: Convert Selected Image to Tensor
def preprocess_image(image):
    image = img_to_array(image)
    image = tf.expand_dims(image, axis=0)
    image = tf.keras.applications.vgg19.preprocess_input(image)
    return image

# Step 3: Define Gram Matrix
def gram_matrix(features):
    b, h, w, c = tf.shape(features)
    features = tf.reshape(features, (b, h * w, c))
    gram = tf.matmul(features, features, transpose_a=True)
    return gram / tf.cast(h * w * c, tf.float32)

# Step 4: Define Loss Functions
def content_loss(content_features, generated_features):
    return tf.reduce_mean(tf.square(content_features - generated_features))

def style_loss(style_features, generated_features):
    gram_style = gram_matrix(style_features)
    gram_generated = gram_matrix(generated_features)
    return tf.reduce_mean(tf.square(gram_style - gram_generated))

# Step 5: Build Residual Blocks
def residual_block(x, filters):
    skip = x
    x = layers.Conv2D(filters, (3, 3), activation='relu', padding='same')(x)
    x = layers.Conv2D(filters, (3, 3), activation=None, padding='same')(x)
    return layers.add([x, skip])

# Step 6: Build the CNN Model
def CNN_model():
    inputs = tf.keras.Input(shape=(500, 500, 3))

    # Downsampling
    x = layers.Conv2D(32, (3, 3), strides=2, activation='relu', padding='same')(inputs)
    x1 = layers.Conv2D(64, (3, 3), strides=2, activation='relu', padding='same')(x)
    x2 = layers.Conv2D(128, (3, 3), strides=2, activation='relu', padding='same')(x1)
    x3 = layers.Conv2D(256, (3, 3), strides=2, activation='relu', padding='same')(x2)

    # Residual Blocks
    x4 = residual_block(x3, 256)
    x5 = residual_block(x4, 256)
    x6 = residual_block(x5, 256)

    # Upsampling
    x7 = layers.Conv2DTranspose(256, (3, 3), strides=2, activation='relu', padding='same')(x6)
    x8 = layers.Conv2DTranspose(128, (3, 3), strides=2, activation='relu', padding='same')(x7)
    x9 = layers.Conv2DTranspose(64, (3, 3), strides=2, activation='relu', padding='same')(x8)
    x10 = layers.Conv2DTranspose(32, (3, 3), strides=2, activation='relu', padding='same')(x9)

    # Final output layer
    outputs = layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x10)

    # Create the model with selected outputs
    model = tf.keras.Model(inputs, [x1, x2, x3, x5, x6, outputs])
    return model

# Step 7: Compute Total Loss with Selected Layers
def compute_total_loss(model, style_image, content_image, generated_image, content_weight=1.0, style_weight=100.0):
    content_features = [model(content_image)[2], model(content_image)[4]]  # Using x3 and x6
    style_features = [model(style_image)[0], model(style_image)[1], model(style_image)[3]]  # Using x1, x2, and x5
    generated_features = [model(generated_image)[2], model(generated_image)[4]]

    # Compute content loss
    content_loss_value = tf.reduce_mean(tf.square(content_features[0] - generated_features[0])) + \
                         tf.reduce_mean(tf.square(content_features[1] - generated_features[1]))

    # Compute style loss
    style_loss_value = 0
    for sf, gf in zip(style_features, [model(generated_image)[0], model(generated_image)[1], model(generated_image)[3]]):
        style_gram = gram_matrix(sf)
        generated_gram = gram_matrix(gf)
        style_loss_value += tf.reduce_mean(tf.square(style_gram - generated_gram))

    # Calculate total loss
    total_loss_value = content_weight * content_loss_value + style_weight * style_loss_value
    total_loss_value = total_loss_value/500
    return total_loss_value

# Step 8: Training Function
def train_style_transfer(CNN_model, content_tensor, style_tensor, num_iterations=1000, learning_rate=0.001):
    generated_image = tf.Variable(content_tensor, trainable=True)
    optimizer = tf.optimizers.Adam(learning_rate)

    model = CNN_model()

    try:
        for i in range(num_iterations):
            with tf.GradientTape() as tape:
                total_loss = compute_total_loss(model, style_tensor, content_tensor, generated_image)

            gradients = tape.gradient(total_loss, [generated_image])
            optimizer.apply_gradients(zip(gradients, [generated_image]))

            if i % 100 == 0:
                print(f"Iteration {i}/{num_iterations}, Total Loss: {total_loss.numpy()}")

        return generated_image

    except Exception as e:
        print(f"An error occurred during training: {e}")
        return None

# Main Code Execution
if _name_ == "_main_":
    desired_size = (500, 500)

    # Direct file paths for content and style images
    content_image_path = "/content/Data/Romanticism_style.jpg" 
    style_image_path = "/content/Data/Renaissance_style.jpg"

    
    # Load and preprocess images
    content_image = load_and_resize_image(content_image_path, desired_size)
    style_image = load_and_resize_image(style_image_path, desired_size)

    if content_image is None or style_image is None:
        raise ValueError("Content or Style image not found!")

    content_tensor = preprocess_image(content_image)
    style_tensor = preprocess_image(style_image)

    # Train the model and display the output image
    final_image = train_style_transfer(CNN_model, content_tensor, style_tensor)

    if final_image is not None:
        final_image = tf.clip_by_value(final_image, 0, 255)
        final_image = tf.cast(final_image[0], tf.uint8).numpy()

        # Display the image using matplotlib
        plt.imshow(final_image)
        plt.axis('off')  # Hide axis
        plt.show()
    else:
        print("Failed to generate the stylized image.")