In [1]:
import tkinter as tk
from tkinter import filedialog
from PIL import ImageTk, Image
import customtkinter as ctk
from customtkinter import CTkImage
import tensorflow as tf
import numpy as np

# Function to convert a tensor to an image
def tensor_to_image(tensor):
    """Converts a tensor to a proper PIL image (scaling it to uint8 and clipping)."""
    # Replace NaNs with 0 (or another value you prefer)
    tensor = tf.where(tf.math.is_nan(tensor), tf.zeros_like(tensor), tensor)
    
    # Clip the values to ensure they fall within the valid range for uint8
    tensor = tf.clip_by_value(tensor, 0.0, 1.0)  # Clip to the [0, 1] range
    
    # Scale to [0, 255]
    tensor = tensor * 255.0
    
    # Convert to uint8 after scaling
    tensor = np.array(tensor, dtype=np.uint8)
    
    # Remove the batch dimension if present
    if np.ndim(tensor) > 3:
        tensor = tensor[0]
    
    return Image.fromarray(tensor)




# Function to load an image and scale it to 512px max dimension
def load_img(path_to_img):
    '''Loads an image as a tensor and scales it to 512 pixels (as float32 tensor).'''
    max_dim = 512
    image = tf.io.read_file(path_to_img)
    image = tf.image.decode_jpeg(image)
    image = tf.image.convert_image_dtype(image, tf.float32)  # Keep as float32 in [0, 1] range

    shape = tf.shape(image)[:-1]
    shape = tf.cast(shape, tf.float32)  # Cast to float32 for multiplication
    long_dim = max(shape)
    scale = max_dim / long_dim

    new_shape = tf.cast(shape * scale, tf.int32)  # Scale and cast back to int32

    image = tf.image.resize(image, new_shape)
    image = image[tf.newaxis, :]  # Ensure batch dimension is added (1, height, width, channels)
    return image



def clip_image_values(image, min_value=0.0, max_value=255.0):
    """Clips the pixel values of the image to stay within the valid range."""
    return tf.clip_by_value(image, clip_value_min=min_value, clip_value_max=max_value)


# Preprocessing functions for the VGG model
def preprocess_image(image):
    '''Preprocesses the image to match the input requirements of the VGG19 model'''
    image = tf.cast(image, tf.float32)  # Cast image to float32
    image = tf.keras.applications.vgg19.preprocess_input(image * 255.0)  # Multiply and preprocess
    return image


# download the vgg19 model and inspect the layers
tmp_vgg = tf.keras.applications.vgg19.VGG19()


# delete temporary variable
del tmp_vgg

style_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1']
content_layers = ['block5_conv2']
output_layers = style_layers + content_layers
NUM_STYLE_LAYERS = len(style_layers)
NUM_CONTENT_LAYERS = len(content_layers)


def vgg_model(layer_names):
    '''Creates a VGG model that returns the outputs of specific layers'''
    vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
    vgg.trainable = False
    outputs = [vgg.get_layer(name).output for name in layer_names]
    model = tf.keras.Model([vgg.input], outputs)
    return model

vgg = vgg_model(output_layers)

def get_style_loss(features, targets, weight=1.0):
     

    # Mean squared error between the Gram matrices
    style_loss = tf.reduce_mean(tf.square(features - targets))
    
    return weight * style_loss


def get_content_loss(features, targets, weight=1.0):
   
    # Compute the squared difference between the features and targets
    content_loss = tf.reduce_sum(tf.square(features - targets))
    
    # Multiply by the weighting factor
    return weight * content_loss

# Gram matrix calculation to extract style
def gram_matrix(input_tensor):

    # calculate the gram matrix of the input tensor
    gram = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor) 

    # get the height and width of the input tensor
    input_shape = tf.shape(input_tensor) 
    height = input_shape[1] 
    width = input_shape[2] 

    # get the number of locations (height times width), and cast it as a tf.float32
    num_locations = tf.cast(height * width, tf.float32)

    # scale the gram matrix by dividing by the number of locations
    scaled_gram = gram / num_locations
        
    return scaled_gram

# Function to extract style image features
def get_style_image_features(image, vgg_model):
    '''Extracts style features from the given image'''
    preprocessed_style_image = preprocess_image(image)
    outputs = vgg(preprocessed_style_image)
    style_outputs = outputs[:NUM_STYLE_LAYERS]
    gram_style_features = [gram_matrix(style_output) for style_output in style_outputs]
    return gram_style_features

# Function to extract content image features
def get_content_image_features(image, vgg_model):
    '''Extracts content features from the given image'''
    preprocessed_image = preprocess_image(image)
    outputs = vgg(preprocessed_image)
    content_outputs = outputs[NUM_STYLE_LAYERS:]
    return content_outputs



def get_style_content_loss(style_targets, style_outputs, content_targets, content_outputs, style_weight, content_weight):
    """Combines the style and content loss to compute total loss."""
    
    # Compute style loss
    style_loss = tf.add_n([get_style_loss(style_output, style_target, style_weight) 
                           for style_output, style_target in zip(style_outputs, style_targets)])
    
    # Compute content loss
    content_loss = tf.add_n([get_content_loss(content_output, content_target, content_weight) 
                             for content_output, content_target in zip(content_outputs, content_targets)])
    
    # scale the style loss by multiplying by the style weight and dividing by the number of style layers
    style_loss = style_loss * style_weight / NUM_STYLE_LAYERS 

    # scale the content loss by multiplying by the content weight and dividing by the number of content layers
    content_loss = content_loss * content_weight / NUM_CONTENT_LAYERS 
    
    # Total loss is the sum of style and content loss
    total_loss = style_loss + content_loss
    
    return total_loss

def calculate_gradients(image, style_targets, content_targets, style_weight, content_weight, var_weight, vgg):
    """Calculates the gradients of the total loss with respect to the image."""
    
    with tf.GradientTape() as tape:
       
        # Calculate style loss
        style_features = get_style_image_features(image, vgg)  # Pass vgg model
        
        # Calculate content loss
        content_features = get_content_image_features(image, vgg)  # Pass vgg model
        
        # Get the style and content loss
        loss = get_style_content_loss(style_targets, style_features, content_targets, content_features, style_weight, content_weight) 
        
        # Calculate gradients of loss with respect to the image
        gradients = tape.gradient(loss, image) 

    return gradients


def update_image_with_style(image, style_targets, content_targets, style_weight,
                            var_weight, content_weight, optimizer, vgg):
    # Calculate the gradients of the total loss with respect to the image
    gradients = calculate_gradients(image, style_targets, content_targets,
                                    style_weight, content_weight, var_weight, vgg)  # Pass vgg model

    # Apply the gradients to the image
    optimizer.apply_gradients([(gradients, image)])

    # Clip the image values to stay within the valid range [0, 255]
    image.assign(clip_image_values(image, min_value=0.0, max_value=255.0))



# Function to perform neural style transfer
# Function to perform neural style transfer with style transfer logic
def fit_style_transfer(content_image, style_image, style_weight=1e-2, content_weight=1e-4, 
                       var_weight=0, optimizer='adam', epochs=1, steps_per_epoch=1):
    # Load VGG model (without the fully connected layers)
    vgg = vgg_model(output_layers)

    # Extract the style and content features from the images
    style_targets = get_style_image_features(style_image, vgg)
    content_targets = get_content_image_features(content_image, vgg)

    # Initialize the generated image as a copy of the content image in float32
    generated_image = tf.cast(content_image, dtype=tf.float32)
    generated_image = tf.Variable(generated_image) 

    # Training loop
    for epoch in range(epochs):
        print(f"Epoch {epoch + 1}/{epochs} running...")
        for step in range(steps_per_epoch):
            print(f"Step {step + 1}/{steps_per_epoch}")
            # Update the image
            update_image_with_style(generated_image, style_targets, content_targets, 
                                    style_weight, var_weight, content_weight, optimizer, vgg)  # Pass vgg here

        # Optional: Print or save intermediate results
        print(f"Epoch {epoch + 1} completed")

    return generated_image






# Function to open file dialog and select image
# Function to open file dialog and select image
def select_image(image_type):
    filepath = filedialog.askopenfilename()
    if image_type == 'content':
        content_image_path.set(filepath)
        img = Image.open(filepath)
        img = img.resize((150, 150), Image.Resampling.LANCZOS)
        content_img_label.imgtk = CTkImage(light_image=img, size=(150, 150))  # Use CTkImage
        content_img_label.configure(image=content_img_label.imgtk)
    elif image_type == 'style':
        style_image_path.set(filepath)
        img = Image.open(filepath)
        img = img.resize((150, 150), Image.Resampling.LANCZOS)
        style_img_label.imgtk = CTkImage(light_image=img, size=(150, 150))  # Use CTkImage
        style_img_label.configure(image=style_img_label.imgtk)


# Function to start the style transfer process
def start_style_transfer():
    content_image = load_img(content_image_path.get())
    style_image = load_img(style_image_path.get())
    # define style and content weight
    style_weight =  2e-2
    content_weight = 1e-2 
    
    # Define the optimizer with a lower learning rate
    adam = tf.optimizers.Adam(
        tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=5.0,  # Reduced learning rate for stability
            decay_steps=100,
            decay_rate=0.50
        )
    )
    
    output_image = fit_style_transfer(style_image=style_image, content_image=content_image, 
                                                    style_weight=style_weight, content_weight=content_weight,
                                                    var_weight=0, optimizer=adam, epochs=10, steps_per_epoch=10)

    # Display the final stylized image
    stylized_image_pil = tensor_to_image(output_image)
    stylized_image_pil = stylized_image_pil.resize((150, 150), Image.Resampling.LANCZOS)

    result_image_ctk = CTkImage(light_image=stylized_image_pil, size=(150, 150))
    result_img_label.configure(image=result_image_ctk)
    
    

# Initialize the main window
root = ctk.CTk()
root.geometry("600x400")
root.title("Neural Style Transfer")

# Content and Style Image Paths
content_image_path = ctk.StringVar()
style_image_path = ctk.StringVar()

# Title Label
title_label = ctk.CTkLabel(root, text="Neural Style Transfer", font=('Arial', 20))
title_label.pack(pady=20)

# Select Content Image
content_btn = ctk.CTkButton(root, text="Select Content Image", command=lambda: select_image('content'))
content_btn.pack(pady=10)
content_img_label = ctk.CTkLabel(root, text="No Image Selected")
content_img_label.pack()

# Select Style Image
style_btn = ctk.CTkButton(root, text="Select Style Image", command=lambda: select_image('style'))
style_btn.pack(pady=10)
style_img_label = ctk.CTkLabel(root, text="No Image Selected")
style_img_label.pack()

# Start Style Transfer Button
start_btn = ctk.CTkButton(root, text="Start Style Transfer", command=start_style_transfer)
start_btn.pack(pady=20)

# Result Image
result_img_label = ctk.CTkLabel(root, text="Result will be displayed here")
result_img_label.pack()

# Run the application
root.mainloop()



Epoch 1/10 running...
Step 1/10
Step 2/10
Step 3/10
Step 4/10
Step 5/10
Step 6/10
Step 7/10
Step 8/10
Step 9/10
Step 10/10
Epoch 1 completed
Epoch 2/10 running...
Step 1/10
Step 2/10
Step 3/10
Step 4/10
Step 5/10
Step 6/10
Step 7/10
Step 8/10
Step 9/10
Step 10/10
Epoch 2 completed
Epoch 3/10 running...
Step 1/10
Step 2/10
Step 3/10
Step 4/10
Step 5/10
Step 6/10
Step 7/10
Step 8/10
Step 9/10
Step 10/10
Epoch 3 completed
Epoch 4/10 running...
Step 1/10
Step 2/10
Step 3/10
Step 4/10
Step 5/10
Step 6/10
Step 7/10
Step 8/10
Step 9/10
Step 10/10
Epoch 4 completed
Epoch 5/10 running...
Step 1/10
Step 2/10
Step 3/10
Step 4/10
Step 5/10
Step 6/10
Step 7/10
Step 8/10
Step 9/10
Step 10/10
Epoch 5 completed
Epoch 6/10 running...
Step 1/10
Step 2/10
Step 3/10
Step 4/10
Step 5/10
Step 6/10
Step 7/10
Step 8/10
Step 9/10
Step 10/10
Epoch 6 completed
Epoch 7/10 running...
Step 1/10
Step 2/10
Step 3/10
Step 4/10
Step 5/10
Step 6/10
Step 7/10
Step 8/10
Step 9/10
Step 10/10
Epoch 7 completed
Epoch 8/10 ru