In [None]:
# as Dall-E is not part of the AWS Bedrock service, install the OpenAI Python package to start
#!pip install openai

In [None]:
from openai import OpenAI
import os
import boto3
import json
from botocore.exceptions import ClientError
from IPython.display import HTML, display, Markdown
from typing import Union
from tqdm.notebook import tqdm
import time
import random


def get_secret(secret_name: str) -> str:
    """
    Retrieve a secret value from AWS Secrets Manager.

    Args:
        secret_name (str): The name of the secret to retrieve.

    Returns:
        secret (str): The secret string.
    """
    region_name = "us-east-1"
    # Create a Secrets Manager client
    session = boto3.session.Session()
    client = session.client(
        service_name='secretsmanager',
        region_name=region_name
    )

    try:
        get_secret_value_response = client.get_secret_value(
            SecretId=secret_name
        )
    except ClientError as e:
        raise e

    secret = get_secret_value_response['SecretString']
    return secret


def create_novacanvas_body(prompt: str) -> str:
    """
    Build the request body for the Amazon Nova Canvas image model.

    Args:
        prompt (str): The user prompt.

    Returns:
        str: JSON-formatted request body.
    """
    return json.dumps({
        "textToImageParams": {
            "text": prompt,
        },
        "taskType": "TEXT_IMAGE",
        "imageGenerationConfig": {
            "cfgScale": 8,
            "seed": 42,
            "quality": "standard",
            "width": 1024,
            "height": 1024,
            "numberOfImages": 1
        }
    })


def create_titan_body(prompt: str) -> str:
    """
    Build the request body for the Amazon Titan image model.

    Args:
        prompt (str): The user prompt.

    Returns:
        str: JSON-formatted request body.
    """
    return json.dumps({
        "textToImageParams": {
            "text": prompt,
        },
        "taskType": "TEXT_IMAGE",
        "imageGenerationConfig": {
            "cfgScale": 8,
            "seed": 42,
            "quality": "standard",
            "width": 1024,
            "height": 1024,
            "numberOfImages": 1
        }
    })


def create_sdxl_body(prompt: str) -> str:
    """
    Build the request body for the OpenAI model (Dall-E).

    Args:
        prompt (str): The user prompt.

    Returns:
        str: JSON-formatted request body.
    """
    return json.dumps({
        "text_prompts": [
            {"text": prompt,
             "weight": 1
            }
        ],
        "cfg_scale": 10,
        "seed": 0,
        "steps": 50,
        "width": 512,
        "height": 512
    })


def invoke_bedrock_model(model_id: str, prompt: str, max_retries=5, **kwargs) -> None|dict:
    """
    Invoke an AWS Bedrock image model with retry logic.

    Args:
        model_id (str): The model identifier.
        prompt (str): The user prompt.
        max_retries (int): Maximum number of retries.

    Returns:
        dict or None: Model response or None on failure.
    """
    bedrock_runtime = boto3.client(service_name='bedrock-runtime')

    body_creators = {
        'amazon.nova-canvas-v1:0': create_novacanvas_body,
        'amazon.titan-image-generator-v2:0': create_titan_body,
        'stability.stable-diffusion-xl-v1': create_sdxl_body
    }

    body = body_creators.get(model_id)(prompt)

    retries = 0
    while retries < max_retries:
        try:
            response = bedrock_runtime.invoke_model(
                body=body,
                modelId=model_id,
                accept="application/json",
                contentType="application/json"
            )
            response_body = json.loads(response.get('body').read())
            return response_body

        except bedrock_runtime.exceptions.ThrottlingException as e:
            wait_time = max(9, (2 ** retries) + random.uniform(0, 1))
            # print(f"ThrottlingException: {e}. Retrying in {wait_time:.2f} seconds...")
            time.sleep(wait_time)
            retries += 1

        except Exception as e:
            print(f"Error invoking model: {e}")
            return None

    print("Max retries reached. Exiting.")
    return None


def invoke_openai_model(model_id: str , prompt: str) -> str:
    """
    Invoke an OpenAI image model and return the image URL.

    Args:
        model_id (str): The OpenAI model identifier.
        prompt (str): The user prompt.

    Returns:
        str: URL of the generated image.
    """
    if not os.environ.get("OPENAI_API_KEY"):
        secret_response_openai = json.loads(get_secret("prod/openai"))
        os.environ["OPENAI_API_KEY"] = secret_response_openai["api_key"]

    client = OpenAI()

    response = client.images.generate(
        model=model_id,
        prompt=prompt,
        size="1024x1024",
        quality="standard",
        n=1,
    )

    return(response.data[0].url)


def get_response(prompt:str, model_id: str) -> Union[dict, str] -> dict|str:
    """
    Get an image response from the specified model for a given prompt.

    Args:
        prompt (str): The user prompt.
        model_id (str): The model identifier.

    Returns:
        dict or str: Model response (dict for Bedrock, str for OpenAI).
    """
    if 'dall' in model_id:
        response = invoke_openai_model(model_id, prompt)
    else:
        response = invoke_bedrock_model(model_id, prompt)
    return response


def create_img_tag(base64_image):
    """
    Create an HTML img tag from base64-encoded image data.

    Args:
        base64_image (str): Base64-encoded image string.

    Returns:
        str: HTML img tag.
    """
    return f'<img src="data:image/png;base64,{base64_image}" alt="Image" style="width: 340px;"/>'


def generate_table_of_responses(image_canvas: dict, image_titan:dict, image_url_dalle: str): -> None:
    """
    Display a table of images generated by different models.

    Args:
        image_canvas (dict): Nova Canvas model response.
        image_titan (dict): Titan model response.
        image_url_dalle (str): DALL-E image URL.

    Returns:
        None
    """
    # Get the base64-encoded image from the response
    base64_image_canvas = image_canvas.get("images")[0]
    base64_image_titan = image_titan.get("images")[0]

    # Create HTML img tag for the image
    img_tag_1 = create_img_tag(base64_image_canvas)
    img_tag_2 = create_img_tag(base64_image_titan)

    # Create HTML to display the images side by side
    html_code = f"""
    <table>
        <tr>
            <td>{img_tag_1}</td>
            <td>{img_tag_2}</td>
            <td><img src="{image_url_dalle}" alt="Image 3" style="width: 340px;"/></td>
        </tr>
    </table>
    """
    display(HTML(html_code))


prompts = [
    "Generate an image showing two developers doing software peer review",
    "Generate a diverse image showing two developers doing software peer review",
    "Generate an non-realistic image showing two developers doing software peer review",
    "Give me a non-realistic image",
]
model_ids = [
    'amazon.nova-canvas-v1:0',
    'amazon.titan-image-generator-v2:0',
    'dall-e-3'
]




In [None]:
for prompt in tqdm(prompts, desc="Processing Prompts"):
    for model_id in tqdm(model_ids, desc="Processing Models", leave=False):
        if 'canvas' in model_id:
            image_canvas = get_response(prompt, model_id)
        elif 'titan' in model_id:
            image_titan = get_response(prompt, model_id)
        elif 'dall' in model_id:
            image_url_dalle = get_response(prompt, model_id)
    display(Markdown(f"# Prompt: {prompt}"))
    generate_table_of_responses(image_canvas, image_titan, image_url_dalle)