In [None]:
import os
import time
import tempfile
from typing import Dict, Any, Tuple, Optional, List

import torch
import gradio as gr
from PIL import Image

# ---------------------------------------------------------------------
# IMPORTS / PATHS
# ---------------------------------------------------------------------
import sys
# sys.path.append('/content/drive/MyDrive/University of London/Year 3/FINAL PROJECT/Final Project Code/UI')

# PSPM (Run C)  function:
from NST_models.starry_pspm import stylize_image as pspm_stylize

# IOB function:
from NST_models.iob_model import run_style_transfer as iob_run_style_transfer

# MSPM functions (single model w/ style_id switch + interpolation):
from NST_models.mspm_model import (
    stylize_image as mspm_stylize,
    stylize_image_interpolated as mspm_stylize_interpolated,
    load_model as mspm_load_model,   # returns a cached model instance
)

# ASPM (AdaIN) function
from NST_models.aspm_model import stylize_adain_image as aspm_stylize

# PSPM model paths
STARRY_MODEL_PATH = "./Model_Paths/starry_pspm_model_hyp_c1.0_s200000.0_lr0.001.pth"
TSUNAMI_MODEL_PATH = "./Model_Paths/tsunami_pspm_model_hyp_c1.0_s200000.0_lr0.001.pth"
WATERLILIES_MODEL_PATH = "./Model_Paths/waterlilies_pspm_model_hyp_c1.0_s200000.0_lr0.001.pth"
COMPOSITION_MODEL_PATH = "./Model_Paths/composition_pspm_model_hyp_c1.0_s200000.0_lr0.001.pth"
VISITOR_MODEL_PATH = "./Model_Paths/visitor_pspm_model_hyp_c1.0_s200000.0_lr0.001.pth"

# MSPM model path
MSPM_MODEL_PATH = "./Model_Paths/mspm_model_hyp_c10000.0_s1000000000.0_adam.pth"

# ASPM model path
ASPM_MODEL_PATH = "./Model_Paths/adain_model_hyp_c5.0_s10.0_adam.pth"

# Preview images shown in the UI
STYLE_PREVIEW_PATHS: Dict[str, Optional[str]] = {
    "The Starry Night":             "./style/Starry_Night.jpg",
    "The Great Wave off Kanagawa":  "./style/Tsunami_by_hokusai.jpg",
    "Water Lilies":                 "./style/Water_Lillies.jpg",
    "Composition VIII":             "./style/Vassily_Kandinsky_Composition_VIII.jpg",
    "The Visitor":                  "./style/The_Visitor_by_Hennie_Niemann.jpg",
}

# PSPM styles
PSPM_STYLE_MODELS: Dict[str, Optional[str]] = {
    "The Starry Night":             STARRY_MODEL_PATH,
    "The Great Wave off Kanagawa":  TSUNAMI_MODEL_PATH,
    "Water Lilies":                 WATERLILIES_MODEL_PATH,
    "Composition VIII":             COMPOSITION_MODEL_PATH,
    "The Visitor":                  VISITOR_MODEL_PATH,
}

# MSPM style name -> style_id mapping
MSPM_STYLE_ID: Dict[str, int] = {
    "The Starry Night": 0,
    "The Visitor": 1,
    "The Great Wave off Kanagawa": 2,
    "Composition VIII": 3,
    "Water Lilies": 4,
}

# ---------------------------------------------------------------------
# TOP CONFIG (display)
# ---------------------------------------------------------------------
TOP_CFG: Dict[str, Dict[str, Any]] = {
    "IOB":   {"Content Weight": 100, "Style Weight": 1000, "TV Weight": 1e-6, "Steps": 300, "lr": 1e-1},
    "PSPM":  {"Content Weight": 1.0, "Style Weight": 2e5, "TV Weight": 1e-6, "Learning Rate": 1e-3, "Optimizer": "adam"},
    "MSPM":  {"Content Weight": 1e4, "Style Weight": 1e9, "Learning Rate": 1e-3, "Optimizer": "adam"},
    "ASPM":  {"content_w": 5.0, "style_w": 10.0, "tv_w": 1e-6, "alpha": 0.8},
}

# ---------------------------------------------------------------------
# HELPERS
# ---------------------------------------------------------------------
def _open_image_or_none(path: Optional[str]):
    if not path or not os.path.isfile(path):
        return None
    try:
        return Image.open(path).convert("RGB")
    except Exception:
        return None

def preprocess_image(img: Image.Image, target_size: Optional[int] = None):
    if target_size:
        img = img.copy()
        img.thumbnail((target_size, target_size))
    return img

def postprocess_image(img: Image.Image):
    # Upscale final outputs to 512×512 for display/saving
    return img.resize((512, 512), Image.BICUBIC)

def _pil_to_temp_png(pil_img: Image.Image, prefix: str = "content") -> str:
    fd, path = tempfile.mkstemp(prefix=f"{prefix}_", suffix=".png")
    os.close(fd)
    pil_img.save(path, "PNG")
    return path

# ---------------------------------------------------------------------
# MODEL ADAPTERS
# ---------------------------------------------------------------------
def run_pspm(content: Image.Image,
             style_name: str,
             model_path: Optional[str],
             cfg: Dict[str, Any],
             progress: gr.Progress):

    for _ in progress.tqdm(range(10), desc=f"Loading PSPM — {style_name}"):
        time.sleep(0.01)
    if not os.path.isfile(model_path):
        raise gr.Error(f"PSPM model not found at:\n{model_path}")

    # Save content PIL to a temp file patho.
    content_path = _pil_to_temp_png(content)

    # Call PSPM function
    out = pspm_stylize(model_path, content_path, None)

    # Ensure output is resized to 512×512
    out = postprocess_image(out)

    meta = {"model": "PSPM", "style_name": style_name, "style_model": os.path.basename(model_path), "cfg": cfg}
    return out, meta

def run_iob(content: Image.Image, style: Image.Image, cfg: Dict[str, Any], progress: gr.Progress):

    content = preprocess_image(content)
    style   = preprocess_image(style)

    # Save inputs to disk (expects file paths)
    content_path = _pil_to_temp_png(content, "iob_content")
    style_path   = _pil_to_temp_png(style,   "iob_style")

    # Choose an output path
    fd, output_path = tempfile.mkstemp(prefix="iob_out_", suffix=".png")
    os.close(fd)

    # Map UI config to model hyperparams
    content_w = cfg.get("content_w", 1.0)
    style_w   = cfg.get("style_w",   10.0)
    tv_w      = cfg.get("tv_w",      1e-6)
    steps     = cfg.get("steps",     300)
    lr        = cfg.get("lr",        10.0)

    # Use LBFGS and content init to align with your notebook’s typical settings
    optimizer = "lbfgs"
    init_type = "content"

    # Run your exported IOB function (writes output image to output_path)
    c_loss, s_loss, runtime_sec = iob_run_style_transfer(
        content_path, style_path, output_path,
        content_weight=content_w,
        style_weight=style_w,
        tv_weight=tv_w,
        optimizer_type=optimizer,
        init_type=init_type,
        iterations=steps,
        learning_rate=lr,
    )

    # Load produced image for Gradio
    out_pil = Image.open(output_path).convert("RGB")

    # Ensure output is resized to 512×512
    out_pil = postprocess_image(out_pil)

    meta = {
        "model": "IOB",
        "cfg": {
            "content_w": content_w,
            "style_w": style_w,
            "tv_w": tv_w,
            "optimizer": optimizer,
            "init": init_type,
            "iterations": steps,
            "lr": lr
        },
        "metrics": {
            "final_content_loss": float(c_loss),
            "final_style_loss": float(s_loss),
            "runtime_sec": float(runtime_sec)
        }
    }
    return out_pil, meta

def run_mspm(content: Image.Image,
             style_name: str,
             cfg: Dict[str, Any],
             progress: gr.Progress,
             interpolate: bool = False,
             style_name_a: Optional[str] = None,
             style_name_b: Optional[str] = None,
             alpha: Optional[float] = None):

    # Single-model inference with style_id (and optional interpolation)
    if not os.path.isfile(MSPM_MODEL_PATH):
        raise gr.Error(f"MSPM weights not found:\n{MSPM_MODEL_PATH}")

    # Save content PIL to a temp file path for the notebook-style function
    for _ in progress.tqdm(range(10), desc="Loading MSPM"):
        time.sleep(0.01)
    content_path = _pil_to_temp_png(content, "mspm_content")

    if not interpolate:
        if style_name not in MSPM_STYLE_ID:
            raise gr.Error("Unknown MSPM style selection.")
        sid = MSPM_STYLE_ID[style_name]
        out = mspm_stylize(
            model_path=MSPM_MODEL_PATH,
            content_image_path=content_path,
            style_id=sid,
            output_path=None
        )
        out = postprocess_image(out)
        meta = {"model": "MSPM", "style_name": style_name, "cfg": cfg}
        return out, meta
    else:
        # Interpolation path — needs model instance + two style ids + alpha
        if not style_name_a or not style_name_b:
            raise gr.Error("Please select Style A and Style B for interpolation.")
        if style_name_a not in MSPM_STYLE_ID or style_name_b not in MSPM_STYLE_ID:
            raise gr.Error("Unknown MSPM style selection(s).")
        if alpha is None:
            raise gr.Error("Please set an interpolation alpha (0.0 → 1.0).")

        sid_a = MSPM_STYLE_ID[style_name_a]
        sid_b = MSPM_STYLE_ID[style_name_b]

        # Load (or re-use cached) model instance as required by your notebook function
        model = mspm_load_model(MSPM_MODEL_PATH)

        out = mspm_stylize_interpolated(
            model,
            content_image_path=content_path,
            style_id1=sid_a,
            style_id2=sid_b,
            alpha=float(alpha),
        )
        out = postprocess_image(out)

        meta = {"model": "MSPM", "style_name": f"Interpolated: {style_name_a} ↔ {style_name_b} (α={float(alpha):.2f})", "cfg": cfg}
        return out, meta

def run_aspm(content: Image.Image, style: Image.Image, alpha: float, cfg: Dict[str, Any], progress: gr.Progress):

    for _ in progress.tqdm(range(10), desc="Preparing ASPM (AdaIN)"):
        time.sleep(0.01)

    if style is None:
        raise gr.Error("ASPM (AdaIN) requires a style image.")

    content_path = _pil_to_temp_png(content, "aspm_content")
    style_path   = _pil_to_temp_png(style,   "aspm_style")

    # Run AdaIN stylization
    out_pil = aspm_stylize(
        content_image_path=content_path,
        style_image_path=style_path,
        decoder_path=ASPM_MODEL_PATH,
        alpha=float(alpha),
        output_path=None
    )

    # Ensure output is resized to 512×512
    out_pil = postprocess_image(out_pil)

    meta = {"model": "ASPM", "cfg": {**cfg, "alpha": float(alpha)}}
    return out_pil, meta

# ---------------------------------------------------------------------
# ROUTER
# ---------------------------------------------------------------------
def run_nst(
    model_choice: str,                 # model
    content_img: Image.Image,          # content
    style_img: Optional[Image.Image],  # style (for IOB/ASPM)
    pspm_style_name: Optional[str],    # pspm style
    mspm_style_name: Optional[str],    # mspm style (single-select)
    aspm_alpha: float,                 # ASPM alpha
    mspm_interp: bool,                 # MSPM interpolate checkbox
    mspm_style_a: Optional[str],       # MSPM Style A
    mspm_style_b: Optional[str],       # MSPM Style B
    mspm_alpha: float                  # MSPM interpolation alpha
):
    progress = gr.Progress(track_tqdm=True)
    if content_img is None:
        raise gr.Error("Please upload a content image.")
    cfg = TOP_CFG[model_choice]

    if model_choice == "PSPM":
        style_name = pspm_style_name or "The Starry Night"
        model_path = PSPM_STYLE_MODELS.get(style_name)
        out_img, meta = run_pspm(content_img, style_name, model_path, cfg, progress)

    elif model_choice == "MSPM":
        if mspm_interp:
            # Interpolate between two styles
            out_img, meta = run_mspm(
                content_img, mspm_style_name or "The Starry Night", cfg, progress,
                interpolate=True,
                style_name_a=mspm_style_a or mspm_style_name or "The Starry Night",
                style_name_b=mspm_style_b or "The Great Wave off Kanagawa",
                alpha=mspm_alpha
            )
        else:
            style_name = mspm_style_name or "The Starry Night"
            out_img, meta = run_mspm(content_img, style_name, cfg, progress)

    elif model_choice == "ASPM":
        if style_img is None:
            raise gr.Error("ASPM requires a style image.")
        out_img, meta = run_aspm(content_img, style_img, aspm_alpha, cfg, progress)

    else:  # IOB
        if style_img is None:
            raise gr.Error("IOB requires a style image.")
        out_img, meta = run_iob(content_img, style_img, cfg, progress)

    # Meta info
    meta_str = f"Model: {meta.get('model')}\nConfig: {meta.get('cfg')}"

    if meta.get("model") == "PSPM":
        meta_str += f"\nStyle: {meta.get('style_name', meta.get('style_model'))}"

    if meta.get("model") == "MSPM":
        meta_str += f"\nStyle: {meta.get('style_name')}"

    if meta.get("model") == "ASPM":
        meta_str += f"\nAlpha: {meta.get('cfg', {}).get('alpha')}"

    return out_img, meta_str

# ---------------------------------------------------------------------
# UI
# ---------------------------------------------------------------------
with gr.Blocks(title="NST Lab") as demo:
    gr.Markdown(
        "## Neural Style Transfer — Unified UI (IOB / PSPM / MSPM / ASPM)\n"
        "This interactive interface will allow you to choose one of the 4 models and implement your own nerual syle transfer with content images of your choice. \n"
        "You can download your resulting image by clicking the download icon. \n"
        "- **IOB/ASPM Model:** Uplaod the content and style images of your choice.\n"
        "- **PSPM/MSPM Model:** Uplaod the content image of your choice and choose from the 5 fixed style images given.\n"
    )

    with gr.Row():
        model_dd = gr.Dropdown(
            choices=["IOB", "PSPM", "MSPM", "ASPM"],
            value="PSPM",
            label="Model",
            interactive=True
        )

        # PSPM: single-select style
        pspm_style_dd = gr.Dropdown(
            choices=list(PSPM_STYLE_MODELS.keys()),
            value="The Starry Night",
            label="PSPM Style",
            allow_custom_value=False,
            visible=True
        )

        # MSPM: single-select dropdown (same 5 styles)
        mspm_style_dd = gr.Dropdown(
            choices=list(MSPM_STYLE_ID.keys()),
            value="The Starry Night",
            label="MSPM Styles (visualization)",
            allow_custom_value=False,
            visible=False
        )

    # Previews loaded once
    preview_images: List[Tuple[str, Optional[Image.Image]]] = [
        (name, _open_image_or_none(path)) for name, path in STYLE_PREVIEW_PATHS.items()
    ]

    with gr.Row():
        # Make content upload robust
        content_in = gr.Image(type="pil", image_mode="RGB",
                              label="Content image", sources=["upload", "clipboard", "webcam"])
        # IOB/ASPM: user uploads style; PSPM/MSPM ignore this (fixed styles)
        style_in = gr.Image(type="pil", image_mode="RGB",
                            label="Style image (required for IOB/ASPM)",
                            sources=["upload", "clipboard"], visible=False)

    # STYLE PREVIEW:
    # - IOB: mirrors user-uploaded style image
    # - PSPM: shows selected PSPM style preview
    # - ASPM: mirrors user-uploaded style image
    # - MSPM: shows selected MSPM style preview (fixed) OR two side-by-side when interpolating

    style_preview = gr.Image(
        type="pil",
        label="Style Preview",
        value=_open_image_or_none(STYLE_PREVIEW_PATHS["The Starry Night"]),  # initial (PSPM default)
        interactive=False,
        visible=True
    )
    # Side-by-side previews for MSPM interpolation
    with gr.Row(visible=False) as mspm_interp_previews_row:
        mspm_style_a_preview = gr.Image(type="pil", label="Style A Preview", interactive=False)
        mspm_style_b_preview = gr.Image(type="pil", label="Style B Preview", interactive=False)

    # ASPM alpha slider (user-controllable)
    aspm_alpha_slider = gr.Slider(
        minimum=0.0, maximum=1.0, step=0.05, value=float(TOP_CFG["ASPM"]["alpha"]),
        label="ASPM Alpha (degree of stylization)",
        visible=False
    )

    # MSPM interpolation controls
    with gr.Column(visible=False) as mspm_interp_panel:
        mspm_interp_chk = gr.Checkbox(value=False, label="Interpolate two styles")
        with gr.Row():
            mspm_style_a_dd = gr.Dropdown(
                choices=list(MSPM_STYLE_ID.keys()),
                value="The Starry Night",
                label="Style A",
                allow_custom_value=False
            )
            mspm_style_b_dd = gr.Dropdown(
                choices=list(MSPM_STYLE_ID.keys()),
                value="The Great Wave off Kanagawa",
                label="Style B",
                allow_custom_value=False
            )
        mspm_alpha_slider = gr.Slider(
            minimum=0.0, maximum=1.0, step=0.05, value=0.5,
            label="MSPM Blend (α -> 0 = 100% Style A, 1 = 100% Style B)"
        )

    run_btn = gr.Button("Run", variant="primary")
    with gr.Row():
        out_img = gr.Image(type="pil", label="Stylized Output")
        out_meta = gr.Textbox(label="Run Info", interactive=False)

    # ------------------ Reactive logic ------------------

    # When model changes, toggle which style selector & preview are visible.
    def on_model_change(model: str, current_pspm_style: str, current_mspm_style: str, current_style_upload: Optional[Image.Image], interp_on: bool):
        if model == "PSPM":
            preview = _open_image_or_none(STYLE_PREVIEW_PATHS.get(current_pspm_style, None))
            return (
                gr.update(visible=True),                         # pspm_style_dd
                gr.update(visible=False),                        # mspm_style_dd
                gr.update(visible=False),                        # style_in (upload)
                gr.update(visible=True, value=preview),          # style_preview (fixed - PSPM)
                gr.update(visible=False),                        # aspm_alpha_slider
                gr.update(visible=False),                        # mspm_interp_panel
                gr.update(visible=False)                         # mspm_interp_previews_row
            )
        elif model == "MSPM":
            preview = _open_image_or_none(STYLE_PREVIEW_PATHS.get(current_mspm_style, None))
            return (
                gr.update(visible=False),
                gr.update(visible=True),                        # show MSPM dropdown
                gr.update(visible=False),                       # hide manual style upload
                gr.update(visible=not interp_on, value=preview),# show MSPM preview only if not interpolating
                gr.update(visible=False),
                gr.update(visible=True),                        # mspm_interp_panel ON
                gr.update(visible=interp_on)                    # side-by-side previews visible if interpolating
            )
        elif model == "ASPM":
            return (
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=True),                        # style upload ON
                gr.update(visible=False),                       # style_preview OFF for ASPM
                gr.update(visible=True),                        # aspm_alpha_slider ON
                gr.update(visible=False),                       # mspm_interp_panel OFF
                gr.update(visible=False)                        # mspm_interp_previews_row OFF
            )
        else:  # IOB
            return (
                gr.update(visible=False),
                gr.update(visible=False),
                gr.update(visible=True),                        # style upload ON
                gr.update(visible=False),                       # style_preview OFF for IOB
                gr.update(visible=False),                       # aspm_alpha_slider OFF
                gr.update(visible=False),                       # mspm_interp_panel OFF
                gr.update(visible=False)                        # mspm_interp_previews_row OFF
            )

    model_dd.change(
        on_model_change,
        inputs=[model_dd, pspm_style_dd, mspm_style_dd, style_in, mspm_interp_chk],
        outputs=[pspm_style_dd, mspm_style_dd, style_in, style_preview,
                 aspm_alpha_slider, mspm_interp_panel, mspm_interp_previews_row]
    )

    # PSPM: update preview when PSPM style changes
    def on_pspm_style_change(style_name: str):
        return gr.update(value=_open_image_or_none(STYLE_PREVIEW_PATHS.get(style_name)))

    pspm_style_dd.change(
        on_pspm_style_change,
        inputs=[pspm_style_dd],
        outputs=[style_preview]
    )

    # MSPM: update preview when MSPM style changes (only when not interpolating)
    def on_mspm_style_change(style_name: str, interp_on: bool):
        preview = _open_image_or_none(STYLE_PREVIEW_PATHS.get(style_name))
        return gr.update(value=preview, visible=not interp_on)

    mspm_style_dd.change(
        on_mspm_style_change,
        inputs=[mspm_style_dd, mspm_interp_chk],
        outputs=[style_preview]
    )

    # Interpolation toggle
    # When checked: hide the single preview and show the two side-by-side previews
    def on_mspm_interp_toggle(interp_on: bool, style_a: str, style_b: str):
        prev_a = _open_image_or_none(STYLE_PREVIEW_PATHS.get(style_a))
        prev_b = _open_image_or_none(STYLE_PREVIEW_PATHS.get(style_b))
        return (
            gr.update(visible=not interp_on),  # style_preview (single)
            gr.update(visible=interp_on),      # row of two previews
            gr.update(value=prev_a),
            gr.update(value=prev_b)
        )

    mspm_interp_chk.change(
        on_mspm_interp_toggle,
        inputs=[mspm_interp_chk, mspm_style_a_dd, mspm_style_b_dd],
        outputs=[style_preview, mspm_interp_previews_row, mspm_style_a_preview, mspm_style_b_preview]
    )

    # Update individual A/B previews when dropdowns change
    def on_mspm_style_a_change(style_a: str):
        return gr.update(value=_open_image_or_none(STYLE_PREVIEW_PATHS.get(style_a)))

    def on_mspm_style_b_change(style_b: str):
        return gr.update(value=_open_image_or_none(STYLE_PREVIEW_PATHS.get(style_b)))

    mspm_style_a_dd.change(on_mspm_style_a_change, inputs=[mspm_style_a_dd], outputs=[mspm_style_a_preview])
    mspm_style_b_dd.change(on_mspm_style_b_change, inputs=[mspm_style_b_dd], outputs=[mspm_style_b_preview])

    # IOB/ASPM style upload (no preview now)
    def on_style_upload_change(uploaded: Optional[Image.Image], model: str):
        return gr.update()

    style_in.change(
        on_style_upload_change,
        inputs=[style_in, model_dd],
        outputs=[style_preview]
    )

    # RUN
    run_btn.click(
        run_nst,
        inputs=[
            model_dd, content_in, style_in,
            pspm_style_dd,           # PSPM style
            mspm_style_dd,           # MSPM single style
            aspm_alpha_slider,       # ASPM alpha
            mspm_interp_chk,         # MSPM interpolation checkbox
            mspm_style_a_dd,         # MSPM Style A
            mspm_style_b_dd,         # MSPM Style B
            mspm_alpha_slider        # MSPM alpha
        ],
        outputs=[out_img, out_meta]
    )

# Stream progress nicely
demo.queue(max_size=8).launch(debug=True, share=True)

2025-09-22 14:14:10.353034: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://d744c150266c3d3525.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
