In [None]:
import base64
import io
import time

import IPython.display
import ipywidgets
import PIL.Image
import requests

generated_videos = []

prompt = ipywidgets.widgets.Textarea(
    value="A serene sunset over a calm ocean with gentle waves",
    description="Prompt: ",
    layout=ipywidgets.Layout(width="auto"),
)

model = ipywidgets.widgets.Dropdown(
    options=["sora-2", "sora-2-pro"],
    value="sora-2",
    description="Model: ",
)

size = ipywidgets.widgets.Dropdown(
    options=["720x1280", "1280x720", "1024x1792", "1792x1024"],
    value="720x1280",
    description="Size: ",
)

seconds = ipywidgets.widgets.Dropdown(
    options=["4", "8", "12"],
    value="4",
    description="Duration (s): ",
)

input_reference_upload = ipywidgets.widgets.FileUpload(
    accept="image/*",
    multiple=False,
    description="Reference Image: ",
)

input_reference_preview = ipywidgets.widgets.Output()


def display_uploaded_image(change):
    with input_reference_preview:
        input_reference_preview.clear_output()
        if input_reference_upload.value:
            uploaded_file = input_reference_upload.value[0]
            image_data = uploaded_file["content"]
            IPython.display.display(IPython.display.Image(data=image_data))


input_reference_upload.observe(display_uploaded_image, names="value")

output = ipywidgets.widgets.Output()

button = ipywidgets.widgets.Button(
    description="Generate",
)

status_label = ipywidgets.widgets.Label(value="")


def resize_image_to_size(image_bytes, target_size):
    width, height = map(int, target_size.split("x"))
    image = PIL.Image.open(io.BytesIO(image_bytes))
    
    if image.mode == "RGBA":
        background = PIL.Image.new("RGB", image.size, (255, 255, 255))
        background.paste(image, mask=image.split()[3])
        image = background
    elif image.mode != "RGB":
        image = image.convert("RGB")
    
    resized_image = image.resize((width, height), PIL.Image.Resampling.LANCZOS)
    
    buffer = io.BytesIO()
    resized_image.save(buffer, format="PNG")
    buffer.seek(0)
    
    return buffer.getvalue()


def poll_video_status(video_id):
    import sys
    
    max_attempts = 120
    poll_count = 0

    while poll_count < max_attempts:
        status_response = requests.get(
            f"http://cortex-api.cortex-api.svc.cluster.local:8080/v1/videos/{video_id}",
            headers={"Content-Type": "application/json"},
        )

        video_data = status_response.json()

        if video_data.get("error") is not None:
            status_label.value = video_data["error"]["message"]
            print(video_data["error"]["message"], file=sys.stderr)
            return None

        status = video_data.get("status")
        progress = video_data.get("progress", 0)

        status_label.value = f"Status: {status}, Progress: {progress}%"

        if status == "completed":
            return True
        elif status in ["failed", "cancelled"]:
            error_message = video_data.get("error", f"Video generation {status}")
            status_label.value = error_message
            print(error_message, file=sys.stderr)
            return None

        time.sleep(5)
        poll_count += 1

    status_label.value = "Polling timeout after 10 minutes"
    print("Polling timeout after 10 minutes", file=sys.stderr)
    return None


def fetch_video_content(video_id):
    import sys
    
    video_response = requests.get(
        f"http://cortex-api.cortex-api.svc.cluster.local:8080/v1/videos/{video_id}/content",
    )

    if video_response.status_code != 200:
        status_label.value = f"Failed to retrieve video: {video_response.status_code}"
        print(f"Failed to retrieve video: {video_response.status_code}", file=sys.stderr)
        return None

    return base64.b64encode(video_response.content).decode("utf-8")


class VideoTabContent:
    def __init__(self, video_data, video_index):
        self.video_data = video_data
        self.video_index = video_index
        
        self._video_output = ipywidgets.widgets.Output()
        self._video_info = ipywidgets.widgets.HTML(
            value=f"<b>Video {video_index + 1}</b><br>ID: {video_data['id']}<br>Size: {video_data['size']}, Duration: {video_data['seconds']}s",
        )
        
        self._original_prompt_display = ipywidgets.widgets.HTML(
            value=f"<b>Original Prompt:</b><br>{video_data['prompt']}",
        )
        
        self._remix_prompt = ipywidgets.widgets.Textarea(
            value="",
            placeholder="Describe the changes you want to make (e.g., 'Change colors to teal and rust')",
            description="Changes: ",
            layout=ipywidgets.Layout(width="auto"),
        )
        
        self._remix_button = ipywidgets.widgets.Button(
            description="Remix This Video",
            button_style="info",
        )
        
        self._remix_button.on_click(self.remix_handler)
        
        self.widgets = ipywidgets.widgets.VBox([
            self._video_info,
            self._video_output,
            ipywidgets.widgets.HTML(value="<hr>"),
            ipywidgets.widgets.HTML(value="<h3>Remix This Video</h3>"),
            self._original_prompt_display,
            self._remix_prompt,
            self._remix_button,
        ])
        
        self.display_video()
    
    def display_video(self):
        with self._video_output:
            self._video_output.clear_output()
            video_bytes = base64.b64decode(self.video_data["video_base64"])
            IPython.display.display(IPython.display.Video(data=video_bytes, embed=True, mimetype="video/mp4"))
    
    @output.capture()
    def remix_handler(self, b):
        import sys
        
        if not self._remix_prompt.value.strip():
            status_label.value = "Please enter changes to make in the Remix prompt"
            print("Please enter changes to make in the Remix prompt", file=sys.stderr)
            return
        
        b.disabled = True
        button.disabled = True
        status_label.value = "Creating remix..."
        
        r = requests.post(
            f"http://cortex-api.cortex-api.svc.cluster.local:8080/v1/videos/{self.video_data['id']}/remix",
            json={"prompt": self._remix_prompt.value},
            headers={"Content-Type": "application/json"}
        )
        
        response = r.json()
        
        if response.get("error") is not None:
            status_label.value = response["error"]["message"]
            b.disabled = False
            button.disabled = False
            print(response["error"]["message"], file=sys.stderr)
            return
        
        video_id = response.get("id")
        if not video_id:
            status_label.value = "Failed to get video ID"
            b.disabled = False
            button.disabled = False
            print("Failed to get video ID", file=sys.stderr)
            return
        
        status_label.value = f"Remix video created: {video_id}"
        
        if not poll_video_status(video_id):
            b.disabled = False
            button.disabled = False
            return
        
        status_label.value = "Remix completed"
        
        video_base64 = fetch_video_content(video_id)
        if not video_base64:
            b.disabled = False
            button.disabled = False
            return
        
        status_label.value = f"Remixed video loaded: {video_id}"
        
        new_video_data = {
            "id": video_id,
            "video_base64": video_base64,
            "prompt": self._remix_prompt.value,
            "size": size.value,
            "seconds": seconds.value,
        }
        
        generated_videos.append(new_video_data)
        
        new_tab = VideoTabContent(new_video_data, len(generated_videos) - 1)
        video_tabs[f"Video {len(generated_videos)}"] = new_tab
        tabs.children = tabs.children + (new_tab.widgets,)
        tabs.set_title(len(tabs.children) - 1, f"Video {len(generated_videos)}")
        tabs.selected_index = len(tabs.children) - 1
        
        b.disabled = False
        button.disabled = False


@output.capture()
def generate(b: ipywidgets.widgets.Button):
    import sys
    
    b.disabled = True
    status_label.value = "Creating video..."

    files = {}
    data = {
        "prompt": prompt.value,
        "model": model.value,
        "size": size.value,
        "seconds": seconds.value,
    }

    if input_reference_upload.value:
        uploaded_file = input_reference_upload.value[0]
        image_bytes = uploaded_file["content"]
        
        resized_image_bytes = resize_image_to_size(image_bytes, size.value)
        
        files["input_reference"] = (uploaded_file["name"], resized_image_bytes, uploaded_file["type"])

    if files:
        r = requests.post(
            "http://cortex-api.cortex-api.svc.cluster.local:8080/v1/videos",
            data=data,
            files=files,
        )
    else:
        r = requests.post(
            "http://cortex-api.cortex-api.svc.cluster.local:8080/v1/videos",
            json=data,
            headers={"Content-Type": "application/json"},
        )

    response = r.json()

    if response.get("error") is not None:
        status_label.value = response["error"]["message"]
        b.disabled = False
        print(response["error"]["message"], file=sys.stderr)
        return

    video_id = response.get("id")
    if not video_id:
        status_label.value = "Failed to get video ID"
        b.disabled = False
        print("Failed to get video ID", file=sys.stderr)
        return

    status_label.value = f"Video created: {video_id}"

    if not poll_video_status(video_id):
        b.disabled = False
        return

    status_label.value = "Video generation completed"

    video_base64 = fetch_video_content(video_id)
    if not video_base64:
        b.disabled = False
        return

    status_label.value = f"Video loaded: {video_id}"

    video_data = {
        "id": video_id,
        "video_base64": video_base64,
        "prompt": prompt.value,
        "size": size.value,
        "seconds": seconds.value,
    }
    
    generated_videos.append(video_data)
    
    new_tab = VideoTabContent(video_data, len(generated_videos) - 1)
    video_tabs[f"Video {len(generated_videos)}"] = new_tab
    tabs.children = tabs.children + (new_tab.widgets,)
    tabs.set_title(len(tabs.children) - 1, f"Video {len(generated_videos)}")
    tabs.selected_index = len(tabs.children) - 1

    b.disabled = False


button.on_click(generate)

video_tabs = {}
tabs = ipywidgets.widgets.Tab()
tabs.children = []

IPython.display.display(
    prompt,
    model,
    size,
    seconds,
    input_reference_upload,
    input_reference_preview,
    button,
    output,
    status_label,
    ipywidgets.widgets.HTML(value="<hr>"),
    ipywidgets.widgets.HTML(value="<h3>Generated Videos</h3>"),
    tabs,
)