In [None]:
import tkinter as tk
from tkinter import filedialog
from PIL import ImageTk, Image
import customtkinter as ctk
import tensorflow as tf
import numpy as np

# def tensor_to_image(tensor):
#     """Converts a tensor to a proper PIL image (scaling it to uint8 and clipping)."""
#     # Ensure the tensor is a float32 tensor
#     if tensor.dtype != tf.float32:
#         tensor = tf.cast(tensor, tf.float32)
    
#     # Replace NaNs with 0 (or another value you prefer)
#     tensor = tf.where(tf.math.is_nan(tensor), tf.zeros_like(tensor), tensor)
    
#     # Clip the values to ensure they fall within the valid range for uint8
#     tensor = tf.clip_by_value(tensor, 0.0, 1.0)  # Clip to the [0, 1] range
    
#     # Scale to [0, 255]
#     tensor = tensor * 255.0
    
#     # Convert to uint8 after scaling
#     tensor = np.array(tensor, dtype=np.uint8)
    
#     # Remove the batch dimension if present
#     if np.ndim(tensor) > 3:
#         tensor = tensor[0]
    
#     return Image.fromarray(tensor)
def tensor_to_image(tensor):
  
    tensor_shape = tf.shape(tensor)
    number_elem_shape = tf.shape(tensor_shape)
    if number_elem_shape > 3:
        assert tensor_shape[0] == 1
        tensor = tensor[0]
    return tf.keras.preprocessing.image.array_to_img(tensor) 

# Function to load an image and scale it to 512px max dimension
def load_img(path_to_img):
    
    max_dim = 512
    image = tf.io.read_file(path_to_img)
    image = tf.image.decode_jpeg(image)
    image = tf.image.convert_image_dtype(image, tf.float32)  # Keep as float32 in [0, 1] range

    shape = tf.shape(image)[:-1]
    shape = tf.cast(shape, tf.float32)  # Cast to float32 for multiplication
    long_dim = max(shape)
    scale = max_dim / long_dim

    new_shape = tf.cast(shape * scale, tf.int32)  # Scale and cast back to int32

    image = tf.image.resize(image, new_shape)
    image = image[tf.newaxis, :]  # Ensure batch dimension is added (1, height, width, channels)
    return image

def clip_image_values(image, min_value=0.0, max_value=255.0):
    """Clips the pixel values of the image to stay within the valid range."""
    return tf.clip_by_value(image, clip_value_min=min_value, clip_value_max=max_value)

def preprocess_image(image):
    '''Preprocesses the image to match the input requirements of the VGG19 model'''
    image = tf.cast(image, tf.float32)  # Cast image to float32
    image = tf.keras.applications.vgg19.preprocess_input(image)  
    return image

style_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1']
content_layers = ['block5_conv2']
output_layers = style_layers + content_layers
NUM_STYLE_LAYERS = len(style_layers)
NUM_CONTENT_LAYERS = len(content_layers)

def vgg_model(layer_names):
    '''Creates a VGG model that returns the outputs of specific layers'''
    vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
    vgg.trainable = False
    outputs = [vgg.get_layer(name).output for name in layer_names]
    model = tf.keras.Model([vgg.input], outputs)
    return model

vgg = vgg_model(output_layers)

def get_style_loss(features, targets, weight=1.0):
    style_loss = tf.reduce_mean(tf.square(features - targets))
    return weight * style_loss

def get_content_loss(features, targets, weight=1.0):
    content_loss = tf.reduce_sum(tf.square(features - targets))
    return weight * content_loss

def gram_matrix(input_tensor):
    gram = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
    input_shape = tf.shape(input_tensor)
    height = input_shape[1]
    width = input_shape[2]
    num_locations = tf.cast(height * width, tf.float32)
    scaled_gram = gram / num_locations
    return scaled_gram

def get_style_image_features(image, vgg_model):
    preprocessed_style_image = preprocess_image(image)
    outputs = vgg_model(preprocessed_style_image)
    style_outputs = outputs[:NUM_STYLE_LAYERS]
    gram_style_features = [gram_matrix(style_output) for style_output in style_outputs]
    return gram_style_features

def get_content_image_features(image, vgg_model):
    preprocessed_content_image = preprocess_image(image)
    outputs = vgg_model(preprocessed_content_image)
    content_outputs = outputs[NUM_STYLE_LAYERS:]
    return content_outputs

def get_style_content_loss(style_targets, style_outputs, content_targets, content_outputs, style_weight, content_weight):
    style_loss = tf.add_n([get_style_loss(style_output, style_target, style_weight) 
                           for style_output, style_target in zip(style_outputs, style_targets)])
    
    content_loss = tf.add_n([get_content_loss(content_output, content_target, content_weight) 
                             for content_output, content_target in zip(content_outputs, content_targets)])
    
    style_loss = style_loss * style_weight / NUM_STYLE_LAYERS
    content_loss = content_loss * content_weight / NUM_CONTENT_LAYERS
    
    total_loss = style_loss + content_loss
    return total_loss

def calculate_gradients(image, style_targets, content_targets, style_weight, content_weight, var_weight, vgg_model):
    with tf.GradientTape() as tape:
        style_features = get_style_image_features(image, vgg_model)
        content_features = get_content_image_features(image, vgg_model)
        loss = get_style_content_loss(style_targets, style_features, content_targets, content_features, style_weight, content_weight)
        gradients = tape.gradient(loss, image)
    return gradients

def update_image_with_style(image, style_targets, content_targets, style_weight, var_weight, content_weight, optimizer, vgg_model):
    gradients = calculate_gradients(image, style_targets, content_targets, style_weight, content_weight, var_weight, vgg_model)
    optimizer.apply_gradients([(gradients, image)])
    image.assign(clip_image_values(image, min_value=0.0, max_value=255.0))

def fit_style_transfer(style_image, content_image, style_weight=1e-2, content_weight=1e-4, var_weight=0, optimizer=None, epochs=1, steps_per_epoch=1, vgg=None, display_fn=None):
    images = []
    step = 0

    if optimizer is None:
        optimizer = tf.optimizers.Adam(
            tf.keras.optimizers.schedules.ExponentialDecay(
                initial_learning_rate=20.0,
                decay_steps=100,
                decay_rate=0.5
            )
        )

    if vgg is None:
        raise ValueError("VGG model must be provided.")

    style_targets = get_style_image_features(style_image, vgg)
    content_targets = get_content_image_features(content_image, vgg)

    generated_image = tf.cast(content_image, dtype=tf.float32)
    generated_image = tf.Variable(generated_image)

    images.append(content_image)

    for n in range(epochs):
        for m in range(steps_per_epoch):
            step += 1
            update_image_with_style(generated_image, style_targets, content_targets, style_weight, var_weight, content_weight, optimizer, vgg)
            print(".", end="")
            if (m + 1) % 10 == 0:
                images.append(generated_image.numpy())

        if display_fn:
            display_image = tensor_to_image(generated_image)
            display_fn(display_image)

        images.append(generated_image.numpy())
        print(f"Train step: {step}")

    generated_image = tf.cast(generated_image, dtype=tf.float32)
    return generated_image, images

def select_image(image_type):
    filepath = filedialog.askopenfilename()
    if image_type == 'content':
        content_image_path.set(filepath)
        img = Image.open(filepath)
        img = img.resize((150, 150), Image.Resampling.LANCZOS)
        content_img_label.imgtk = ImageTk.PhotoImage(img)  # Corrected here
        content_img_label.configure(image=content_img_label.imgtk)
        content_img_label.imgtk = content_img_label.imgtk  # Keep reference to avoid garbage collection
    elif image_type == 'style':
        style_image_path.set(filepath)
        img = Image.open(filepath)
        img = img.resize((150, 150), Image.Resampling.LANCZOS)
        style_img_label.imgtk = ImageTk.PhotoImage(img)  # Corrected here
        style_img_label.configure(image=style_img_label.imgtk)
        style_img_label.imgtk = style_img_label.imgtk  # Keep reference to avoid garbage collection

def start_style_transfer():
    content_image = load_img(content_image_path.get())
    style_image = load_img(style_image_path.get())
    
    style_weight = 2e-2
    content_weight = 1e-2 
    
    adam = tf.optimizers.Adam(
        tf.keras.optimizers.schedules.ExponentialDecay(
            initial_learning_rate=20.0,
            decay_steps=100,
            decay_rate=0.50
        )
    )

    vgg = vgg_model(output_layers)

    output_image, _ = fit_style_transfer(
        style_image=style_image, 
        content_image=content_image, 
        style_weight=style_weight, 
        content_weight=content_weight,
        var_weight=0, 
        optimizer=adam, 
        epochs=10, 
        steps_per_epoch=10,
        vgg=vgg
    )

    stylized_image_pil = tensor_to_image(output_image)  
    stylized_image_pil = stylized_image_pil.resize((150, 150), Image.Resampling.LANCZOS)

    result_image_ctk = ImageTk.PhotoImage(stylized_image_pil)  # Corrected here
    result_img_label.configure(image=result_image_ctk)
    result_img_label.imgtk = result_image_ctk  # Keep reference to avoid garbage collection


#front end
root = ctk.CTk()
root.geometry("1920x1080")
root.title("Neural Style Transfer")

content_image_path = ctk.StringVar()
style_image_path = ctk.StringVar()

# Set up a 3-column grid layout
root.grid_columnconfigure(0, weight=1)  # Left side (Content)
root.grid_columnconfigure(1, weight=2)  # Center (Title and Result)
root.grid_columnconfigure(2, weight=1)  # Right side (Style)

# Title
title_label = ctk.CTkLabel(root, text="Neural Style Transfer", font=('Arial', 40))
title_label.grid(row=0, column=1, pady=40)

# Content Image Selection

content_img_label = ctk.CTkLabel(root, text="Content Image", font=('Arial', 20))
content_img_label.grid(row=1, column=0, padx=20, pady=10, )

content_btn = ctk.CTkButton(root, text="Select Content Image", command=lambda: select_image('content'), width=300, height=50)
content_btn.grid(row=2, column=0, padx=100, pady=100)


# Style Image Selection

style_img_label = ctk.CTkLabel(root, text="Style Image", font=('Arial', 20))
style_img_label.grid(row=1, column=2, padx=20, pady=10, )

style_btn = ctk.CTkButton(root, text="Select Style Image", command=lambda: select_image('style'), width=300, height=50)
style_btn.grid(row=2, column=2, padx=20, pady=10, )


# Start Style Transfer Button
start_btn = ctk.CTkButton(root, text="Start Style Transfer", width=400, height=60)  # command=start_style_transfer
start_btn.grid(row=3, column=1, pady=40)

# Result Image Display 
result_img_label = ctk.CTkLabel(root, text="Result will be displayed here", font=('Arial', 24))
result_img_label.grid(row=4, column=1, pady=40)

root.mainloop()
