In [1]:
pip install tensorflow tensorflow-hub pillow

Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import tkinter as tk
from tkinter import filedialog, Label, Button, Frame
from PIL import Image, ImageTk
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np

class StyleTransferApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Neural Style Transfer")
        self.root.geometry("1200x800")
        
        # Initialize variables
        self.content_image_path = None
        self.style_image_path = None
        self.stylized_image = None
        self.hub_model = None
        
        # Create frames
        self.create_frames()
        
        # Create buttons
        self.create_buttons()
        
        # Create image display areas
        self.create_image_displays()
        
        # Load the TensorFlow Hub model
        self.load_model()
    
    def create_frames(self):
        # Top frame for buttons
        self.top_frame = Frame(self.root)
        self.top_frame.pack(side=tk.TOP, fill=tk.X, padx=10, pady=10)
        
        # Bottom frame for images
        self.bottom_frame = Frame(self.root)
        self.bottom_frame.pack(side=tk.BOTTOM, fill=tk.BOTH, expand=True, padx=10, pady=10)
        
        # Content image frame
        self.content_frame = Frame(self.bottom_frame)
        self.content_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5)
        
        # Style image frame
        self.style_frame = Frame(self.bottom_frame)
        self.style_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5)
        
        # Result image frame
        self.result_frame = Frame(self.bottom_frame)
        self.result_frame.pack(side=tk.LEFT, fill=tk.BOTH, expand=True, padx=5, pady=5)
    
    def create_buttons(self):
        # Content image button
        self.content_btn = Button(self.top_frame, text="Select Content Image", command=self.select_content_image)
        self.content_btn.pack(side=tk.LEFT, padx=5, pady=5)
        
        # Style image button
        self.style_btn = Button(self.top_frame, text="Select Style Image", command=self.select_style_image)
        self.style_btn.pack(side=tk.LEFT, padx=5, pady=5)
        
        # Generate button
        self.generate_btn = Button(self.top_frame, text="Generate Stylized Image", command=self.generate_stylized_image, state=tk.DISABLED)
        self.generate_btn.pack(side=tk.LEFT, padx=5, pady=5)
        
        # Save button
        self.save_btn = Button(self.top_frame, text="Save Stylized Image", command=self.save_stylized_image, state=tk.DISABLED)
        self.save_btn.pack(side=tk.LEFT, padx=5, pady=5)
    
    def create_image_displays(self):
        # Content image display
        self.content_label = Label(self.content_frame, text="Content Image")
        self.content_label.pack(side=tk.TOP, padx=5, pady=5)
        self.content_display = Label(self.content_frame, text="No image selected", bg="lightgray", width=30, height=15)
        self.content_display.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=5, pady=5)
        
        # Style image display
        self.style_label = Label(self.style_frame, text="Style Image")
        self.style_label.pack(side=tk.TOP, padx=5, pady=5)
        self.style_display = Label(self.style_frame, text="No image selected", bg="lightgray", width=30, height=15)
        self.style_display.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=5, pady=5)
        
        # Result image display
        self.result_label = Label(self.result_frame, text="Stylized Image")
        self.result_label.pack(side=tk.TOP, padx=5, pady=5)
        self.result_display = Label(self.result_frame, text="No image generated", bg="lightgray", width=30, height=15)
        self.result_display.pack(side=tk.TOP, fill=tk.BOTH, expand=True, padx=5, pady=5)
    
    def load_model(self):
        # Function to load the TensorFlow Hub model
        # Show a loading message
        loading_label = Label(self.top_frame, text="Loading model...", fg="blue")
        loading_label.pack(side=tk.RIGHT, padx=5, pady=5)
        self.root.update()
        
        # Load the model
        self.hub_model = hub.load('https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2')
        
        # Remove the loading message
        loading_label.destroy()
        
        # Add a success message
        success_label = Label(self.top_frame, text="Model loaded successfully!", fg="green")
        success_label.pack(side=tk.RIGHT, padx=5, pady=5)
    
    def load_image(self, image_path, max_dim=512):
        # Load and preprocess the image for TensorFlow
        img = tf.io.read_file(image_path)
        img = tf.image.decode_image(img, channels=3, expand_animations=False)
        img = tf.image.convert_image_dtype(img, tf.float32)
        
        shape = tf.cast(tf.shape(img)[:-1], tf.float32)
        long_dim = max(shape)
        scale = max_dim / long_dim
        new_shape = tf.cast(shape * scale, tf.int32)
        img = tf.image.resize(img, new_shape)
        return img[tf.newaxis, :]
    
    def select_content_image(self):
        # Open file dialog to select content image
        file_path = filedialog.askopenfilename(title="Select Content Image",
                                              filetypes=[("Image files", "*.jpg;*.jpeg;*.png")])
        if file_path:
            self.content_image_path = file_path
            # Display the selected image
            self.display_image(file_path, self.content_display)
            # Check if both images are selected to enable the generate button
            self.check_generate_button()
    
    def select_style_image(self):
        # Open file dialog to select style image
        file_path = filedialog.askopenfilename(title="Select Style Image",
                                              filetypes=[("Image files", "*.jpg;*.jpeg;*.png")])
        if file_path:
            self.style_image_path = file_path
            # Display the selected image
            self.display_image(file_path, self.style_display)
            # Check if both images are selected to enable the generate button
            self.check_generate_button()
    
    def display_image(self, image_path, display_label):
        # Display an image in the given label
        img = Image.open(image_path)
        # Resize to fit the display area while maintaining aspect ratio
        img.thumbnail((300, 300))
        img_tk = ImageTk.PhotoImage(img)
        display_label.config(image=img_tk, text="")
        display_label.image = img_tk  # Keep a reference to prevent garbage collection
    
    def check_generate_button(self):
        # Enable generate button if both images are selected
        if self.content_image_path and self.style_image_path:
            self.generate_btn.config(state=tk.NORMAL)
    
    def generate_stylized_image(self):
        # Show processing message
        processing_label = Label(self.top_frame, text="Processing...", fg="blue")
        processing_label.pack(side=tk.RIGHT, padx=5, pady=5)
        self.root.update()
        
        try:
            # Load content and style images
            content_image = self.load_image(self.content_image_path)
            style_image = self.load_image(self.style_image_path)
            
            # Generate the stylized image
            outputs = self.hub_model(tf.constant(content_image), tf.constant(style_image))
            self.stylized_image = outputs[0]
            
            # Convert to PIL image and display
            stylized_image = tf.clip_by_value(self.stylized_image[0], 0.0, 1.0)
            stylized_image = tf.image.convert_image_dtype(stylized_image, tf.uint8)
            pil_image = Image.fromarray(stylized_image.numpy())
            
            # Resize to fit display area
            pil_image.thumbnail((300, 300))
            img_tk = ImageTk.PhotoImage(pil_image)
            self.result_display.config(image=img_tk, text="")
            self.result_display.image = img_tk
            
            # Enable save button
            self.save_btn.config(state=tk.NORMAL)
            
            # Remove processing message and show success message
            processing_label.destroy()
            success_label = Label(self.top_frame, text="Stylization complete!", fg="green")
            success_label.pack(side=tk.RIGHT, padx=5, pady=5)
        
        except Exception as e:
            # Handle errors
            processing_label.destroy()
            error_label = Label(self.top_frame, text=f"Error: {str(e)}", fg="red")
            error_label.pack(side=tk.RIGHT, padx=5, pady=5)
    
    def save_stylized_image(self):
        # Open a file dialog to save the stylized image
        if self.stylized_image is not None:
            file_path = filedialog.asksaveasfilename(
                title="Save Stylized Image",
                defaultextension=".png",
                filetypes=[("PNG files", "*.png"), ("JPEG files", "*.jpg;*.jpeg"), ("All files", "*.*")]
            )
            if file_path:
                # Convert to PIL image and save
                stylized_image = tf.clip_by_value(self.stylized_image[0], 0.0, 1.0)
                stylized_image = tf.image.convert_image_dtype(stylized_image, tf.uint8)
                pil_image = Image.fromarray(stylized_image.numpy())
                pil_image.save(file_path)
                
                # Show a success message
                success_label = Label(self.top_frame, text=f"Image saved to: {os.path.basename(file_path)}", fg="green")
                success_label.pack(side=tk.RIGHT, padx=5, pady=5)

def main():
    root = tk.Tk()
    app = StyleTransferApp(root)
    root.mainloop()

if __name__ == "__main__":
    main()