In [1]:
import gradio as gr
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.tools import BaseTool
from langchain.chat_models import ChatOpenAI
from PIL import Image
import os
import requests
from io import BytesIO

In [2]:
# This will setup teh OPENAI API key in os environment variables 
#os.environ["OPENAI_API_KEY"] = "xxxxx"
def setup_environment():
    import sys
    sys.path.append('C:\\gitworkspace\\aimldemo\\jupyterworkapce')
    import stratup_env_setup
    stratup_env_setup.set_env()

In [3]:
setup_environment()

In [4]:
import openai

In [5]:
if not openai.api_key:
    print("Error: OpenAI API key is missing.")
else:
    print("OpenAI API key is set.")

OpenAI API key is set.


In [6]:
# Custom LangChain Tool to interact with DALL-E
class DalleTool(BaseTool):
    name: str = "DALL-E Image Generator"
    description: str = "Generates an image from a text prompt using OpenAI's DALL-E."

    def _run(self, prompt: str) -> str:
        """
        Generate an image URL from the given prompt using OpenAI's updated API.

        Args:
            prompt (str): Text description of the desired image.

        Returns:
            str: URL of the generated image.
        """
        try:
            response = openai.Image.create(
                prompt=prompt,
                n=1,  # Number of images
                size="512x512"  # Image dimensions
            )
            return response["data"][0]["url"]
        except Exception as e:
            raise ValueError(f"Error generating image: {e}")

    async def _arun(self, *args, **kwargs):
        """Asynchronous method required by BaseTool but not used here."""
        raise NotImplementedError("DalleTool does not support async operations.")


In [7]:
# Initialize the LangChain LLM
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)

  llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.7)


In [8]:
# Define the Prompt Template
prompt_template = PromptTemplate(
    input_variables=["description"],
    template="Generate an image of: {description}"
)

In [9]:
# Combine Prompt Template and LLM into a LangChain
image_chain = LLMChain(llm=llm, prompt=prompt_template)

  image_chain = LLMChain(llm=llm, prompt=prompt_template)


In [10]:
def generate_image_with_langchain(description: str):
    """
    Generate an image using LangChain for processing the prompt and OpenAI's DALL-E.

    Args:
        description (str): Text prompt describing the image.

    Returns:
        PIL.Image: The generated image.
    """
    # Generate the raw prompt using LangChain
    raw_prompt = image_chain.invoke({"description": description})
    
    # Ensure the raw_prompt is a plain string
    if isinstance(raw_prompt, dict) and "description" in raw_prompt:
        prompt = raw_prompt["description"]
    elif isinstance(raw_prompt, str):
        prompt = raw_prompt
    else:
        raise ValueError(f"Unexpected prompt format: {raw_prompt}")

    # Debug: Print the prompt
    print(f"Using prompt for DALL-E: {prompt}")

    # Use DALL-E to generate the image
    try:
        dalle_tool = DalleTool()
        image_url = dalle_tool._run(prompt)  # Pass the plain string prompt
    except Exception as e:
        raise ValueError(f"Error generating image: {e}")

    # Fetch the image from the URL
    image_response = requests.get(image_url)
    image = Image.open(BytesIO(image_response.content))
    return image

In [11]:
# Gradio UI
def gradio_ui():
    """
    Creates and launches the Gradio interface for DALL-E with LangChain integration.
    """
    with gr.Blocks() as demo:
        gr.Markdown("<h1 style='text-align: center;'>Image Generator</h1>")
        with gr.Row():
            with gr.Column():
                prompt_input = gr.Textbox(
                    label="Enter a text prompt",
                    placeholder="Describe the image you want to generate...",
                    lines=3
                )
                generate_button = gr.Button("Generate")
            with gr.Column():
                image_output = gr.Image(label="Generated Image")
        
        # Connect the button to the LangChain-based image generation function
        generate_button.click(
            fn=generate_image_with_langchain,
            inputs=[prompt_input],
            outputs=[image_output]
        )

        # Launch the Gradio UI
        demo.launch()

In [12]:
gradio_ui()

* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.


Using prompt for DALL-E: A serene lake surrounded by snow-capped mountains during sunset
Using prompt for DALL-E: A serene lake surrounded by snow-capped mountains during sunset
Using prompt for DALL-E: A beautiful areal view of Interlaken,  Switzerland with the two lakes
Using prompt for DALL-E: A beautiful areal view of Interlaken,  Switzerland with the two lakes
Using prompt for DALL-E: A tiger walking on Brooklyn bridge in teh night
Using prompt for DALL-E: A tiger walking on Brooklyn bridge in the night
Using prompt for DALL-E: New York city skyline in the night
Using prompt for DALL-E: New York city Time square in the night
Using prompt for DALL-E: New York city Time square in the night
