In [50]:
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Optional


class StoryFlavor(str, Enum):
    FAIRY_TALE = "fairy_tale"
    THRILLER = "thriller"
    ROMANCE = "romance"
    SCIENCE_FICTION = "science_fiction"


class StoryStatus(str, Enum):
    GENERATING_STORY = "generating_story"
    COMPLETED = "completed"
    GENERATING_AUDIO = "generating_audio"
    FAILED = "failed"
    RESTRICTED_CONTENT_DETECTED = "restricted_content_detected"
    JUST_CREATED = "just_created"


@dataclass
class Story:
    id: str
    flavor: StoryFlavor
    title: str
    story_text: str
    created_at: datetime
    status: StoryStatus = StoryStatus.GENERATING_STORY
    image_url: Optional[str] = None
    audio_url: Optional[str] = None
    audio_duration_seconds: Optional[float] = None
    error_message: Optional[str] = None

In [57]:
from datetime import datetime
from typing import Optional, List

from pydantic import BaseModel, Field


class StoryGenerationRequest(BaseModel):
    flavor: StoryFlavor
    additional_context: Optional[str] = Field(
        default=None,
        alias="additionalContext",
        description="Additional instructions or context for the story",
    )
    eighting_plus_enabled: bool = Field(
        default=False,
        alias="eightingPlusEnabled",
        description="Whether to allow 18+ content",
    )


class ImageInsights(BaseModel):
    title: str = Field(
        ..., 
        description="A short title for the future story.",
    )
    caption: str = Field(
        ..., 
        description="A detailed description of the scene.",
    )
    subjects: list[str] = Field(
        default_factory=list,
        description="Main visible entities (short noun phrases).",
    )
    setting: str | None = Field(
        default=None,
        description="Concise environment description (e.g., 'snowy forest at dusk').",
    )
    actions: list[str] = Field(
        default_factory=list,
        description="Observable actions (e.g., 'walking', 'reaching for door').",
    )
    mood: list[str] = Field(
        default_factory=list,
        description="1-3 word mood descriptors (e.g., 'mysterious', 'cozy').",
    )
    hooks: list[str] = Field(
        default_factory=list,
        description="2-5 short, grounded narrative hooks to start a story.",
    )


class RestrictedContentResponse(BaseModel):
    reasoning: str = Field(
        ...,
        description="A detailed reasoning for the decision.",
    )
    summary: str | None = Field(
        default=None,
        description="One-sentence rationale grounded in visible/explicit cues.",
    )
    is_restricted: bool = Field(
        default=False,
        description="True if content should be blocked for under-18.",
    )


class StoryGenerationResponse(BaseModel):
    title: str = Field(..., description="The title of the story.")
    text: str = Field(..., description="The text of the story.")


class RestrictedContentDetected(Exception):
    pass

In [52]:
BASE_WPM = 150

flavour_to_wpm: dict[StoryFlavor, float] = {
    StoryFlavor.FAIRY_TALE: 0.95 * BASE_WPM,
    StoryFlavor.THRILLER: 0.90 * BASE_WPM,
    StoryFlavor.ROMANCE: 0.95 * BASE_WPM,
    StoryFlavor.SCIENCE_FICTION: 0.98 * BASE_WPM,
}

In [61]:
import os
import json
import base64
from typing import cast
from io import BytesIO
from pathlib import Path
from PIL import Image

from langchain_ollama import ChatOllama


class StoryGenerator:
    def __init__(self) -> None:
        self._vision_model_name = os.getenv("OLLAMA_VLM_MODEL", "qwen2.5vl:7b")
        self._txt_model_name = os.getenv("OLLAMA_TXT_MODEL", "qwen2.5:7b")
        self._ollama_url = os.getenv("OLLAMA_URL", "http://localhost:11434")
        self._debug_cache_name: str | None = None

    async def generate(
        self,
        request: StoryGenerationRequest,
        image_bytes: bytes,
        cache_name: str,
    ) -> StoryGenerationResponse:
        image_bytes = self._convert_image_to_jpeg(image_bytes)
        self._debug_cache_name = cache_name

        if not request.eighting_plus_enabled:
            await self._perform_elder_content_check(request, image_bytes)

        insights = await self._get_image_insights(request, image_bytes)

        return await self._generate_story(request, insights)
    
    # ---------------------
    # Tiny debug-cache helpers
    # ---------------------
    def _cache_path(self, step: str) -> Path:
        base_dir = Path.cwd().parent / "tests" / "data" / "cache"
        base_dir.mkdir(parents=True, exist_ok=True)
        return base_dir / f"{self._debug_cache_name}.json"

    def _cache_get(self, step: str):
        path = self._cache_path(step)
        if path.exists():
            try:
                with path.open("r", encoding="utf-8") as f:
                    data = json.load(f)
                    return data.get(step)
            except Exception:
                return None
        return None

    def _cache_set(self, step: str, value) -> None:
        path = self._cache_path(step)
        blob = {}
        if path.exists():
            try:
                with path.open("r", encoding="utf-8") as f:
                    blob = json.load(f)
            except Exception:
                blob = {}
        blob[step] = value
        with path.open("w", encoding="utf-8") as f:
            json.dump(blob, f, ensure_ascii=False, indent=2)

    async def _perform_elder_content_check(
        self,
        request: StoryGenerationRequest,
        image_bytes: bytes,
    ) -> None:
        cached = self._cache_get("elder_check")
        if cached:
            if cached.get("is_restricted"):
                raise RestrictedContentDetected(cached.get("summary"))
            return
        print("Performing elder content check...")

        img_bytes_url: str = self._image_to_data_url(image_bytes)

        system = (
            "You are a concise content safety classifier."
            " Block explicit sexual content (nudity/acts/exploitation), graphic violence/gore,"
            " sexualization of minors, or hateful/terrorist propaganda."
            " Allow 16+ content: mild romance/affection, non-graphic injuries, sports, everyday scenes."
            " Decisions must be grounded strictly in the visible image and the user's extra text."
            " Output a compact JSON object only."
        )
        user = (
            "Classify if the content should be restricted for under-18 viewers."
            " Consider the image and this extra text (may be empty):\n\n"
            f"EXTRA_TEXT: {request.additional_context}\n\n"
            " Keep summary one sentence, grounded in visible cues."
        )
        messages = [
            {"role": "system", "content": [{"type": "text", "text": system}]},
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": user},
                    {"type": "image_url", "image_url": img_bytes_url},
                ],
            },
        ]

        structured = (
            ChatOllama(
                model=self._vision_model_name,
                base_url=self._ollama_url,
                temperature=0,
                num_ctx=8192
            )
            .with_structured_output(RestrictedContentResponse)
        )

        result = await structured.ainvoke(messages)
        self._cache_set("elder_check", result.model_dump())

        if result.is_restricted:
            raise RestrictedContentDetected(result.summary)

    async def _get_image_insights(self, request: StoryGenerationRequest, image_bytes: bytes) -> ImageInsights:
        cached = self._cache_get("image_insights")
        if cached:
            return ImageInsights(**cached)
        print("Getting image insights...")

        img_bytes_url: str = self._image_to_data_url(image_bytes)

        system = (
            "You are a vision assistant extracting grounded story-building cues."
            " Be literal and faithful to the image; do not invent entities."
            " Output only JSON that matches the schema precisely."
        )
        user = (
            "Extract grounded insights for later story writing."
            " Keep items concise, no punctuation beyond commas where natural."
            " Take in consideration the user instructions for later story writing: "
            f"```{request.additional_context}```"
        )
        messages = [
            {"role": "system", "content": [{"type": "text", "text": system}]},
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": user},
                    {"type": "image_url", "image_url": img_bytes_url},
                ],
            },
        ]
        structured = (
            ChatOllama(
                model=self._vision_model_name,
                base_url=self._ollama_url,
                temperature=0.3,
                num_ctx=8192
            )
            .with_structured_output(ImageInsights)
        )

        result = cast(ImageInsights, await structured.ainvoke(messages))
        self._cache_set("image_insights", result.model_dump())
        return result

    async def _generate_story(
        self,
        request: StoryGenerationRequest,
        insights: ImageInsights,
    ) -> StoryGenerationResponse:
        cached = self._cache_get("story")
        if cached:
            return StoryGenerationResponse(**cached)
        wpm = flavour_to_wpm[request.flavor]
        minutes = 4.0
        speech_margin = 0.92  # need this for pauses and extra effects
        max_words = int(wpm * minutes * speech_margin)

        tokens_to_predict = max(256, min(1024, int(max_words * 1.3)))

        if request.eighting_plus_enabled:
            content_guideline = (
                "18+ is enabled: mature themes are permitted. Do NOT depict minors, "
                "illegal or non-consensual acts. Avoid pornographic detail; be tasteful."
            )
        else:
            content_guideline = (
                "18+ is NOT enabled: content must be suitable for under-18. 16+ content is allowed. "
                "No explicit sexual content; no graphic violence/gore."
            )

        system = (
            "You are a seasoned storyteller. Write vivid, coherent prose tailored to the requested flavor."
            " Keep language accessible and engaging. "
            " Use natural rhythm for spoken delivery: short to medium sentences, varied cadence."
            " Format the story with clear paragraph breaks (one blank line between paragraphs)."
            " For major shifts in scene or time, insert a line with '---' as a break."
            f"{content_guideline}"
            " Output only the final story text; do not include a title or any commentary."
        )
        flavor_line = f"Flavor: {request.flavor.value}."
        context_line = f"Additional instructions for the story: {request.additional_context}"
        insight_brief = (
            "Use the following details as inspiration for your story, but feel free to creatively expand, add new elements, or imagine additional context to make the story more engaging and vivid:\n"
            f"- Story Title: {insights.title}\n"
            f"- Caption: {insights.caption}\n"
            f"- Subjects: {', '.join(insights.subjects)}\n"
            f"- Setting: {insights.setting}\n"
            f"- Actions: {', '.join(insights.actions)}\n"
            f"- Mood: {', '.join(insights.mood)}\n"
            f"- Possible hooks (may be ignored): {', '.join(insights.hooks)}\n"
        )
        user = (
            f"Write a story. The story text should be medium-large: at least 300 words, but less than the {max_words} words.\n"
            f"{flavor_line}\n\n"
            f"{insight_brief}\n"
            f"{context_line}"
        )

        messages = [
            {"role": "system", "content": [{"type": "text", "text": system}]},
            {"role": "user", "content": [{"type": "text", "text": user}]},
        ]

        structured = (
            ChatOllama(
                model=self._txt_model_name,  # lightweight local model; VLM can do text-only
                base_url=self._ollama_url,
                temperature=1.2,
                num_ctx=8192,
                num_predict=tokens_to_predict,
            )
        )

        story_text = (await structured.ainvoke(messages)).content
        # Use the title from image insights, not from the model output
        result = StoryGenerationResponse(
            title=insights.title, 
            text=story_text,
        )
        self._cache_set("story", result.model_dump())
        return result

    @staticmethod
    def _image_to_data_url(image_bytes: bytes, mime: str = "image/jpeg") -> str:
        b64 = base64.b64encode(image_bytes).decode("ascii")
        return f"data:{mime};base64,{b64}"
    
    @staticmethod
    def _convert_image_to_jpeg(image_bytes: bytes) -> bytes:
        with BytesIO(image_bytes) as input_buffer:
            image = Image.open(input_buffer).convert("RGB")

            with BytesIO() as output_buffer:
                image.save(output_buffer, format="JPEG", quality=92, subsampling=2, optimize=True, progressive=False)
                return output_buffer.getvalue()

In [62]:
generator = StoryGenerator()

In [55]:
with open("./data/some-basic.jpg", "rb") as f:
    image_bytes = f.read()

result = await generator.generate(
    StoryGenerationRequest(
        flavor=StoryFlavor.FAIRY_TALE,
        additionalContext="A fairy tale about a princess and a prince.",
        eightingPlusEnabled=False,
    ),
    image_bytes=image_bytes,
    cache_name="fairy_tale_basic",
)

Performing elder content check...
Getting image insights...
Got image insights: %s title='The Enchanted Ballroom' caption='A magical ballroom with glowing neon lights and a reflective dance floor' subjects=[] setting='A grand ballroom with a dance floor that reflects the vibrant neon lights creating a mesmerizing effect' actions=[] mood=['mysterious', 'magical', 'romantic'] hooks=["The ballroom's glowing dance floor", 'The neon lights reflecting off the mirrored walls', 'The fairy tale atmosphere of the ballroom']
You are a seasoned storyteller. Write vivid, coherent prose tailored to the requested flavor. Keep language accessible and engaging.  Use natural rhythm for spoken delivery: short to medium sentences, varied cadence. Format the story with clear paragraph breaks (one blank line between paragraphs). For major shifts in scene or time, insert a line with '---' as a break.18+ is NOT enabled: content must be suitable for under-18. 16+ content is allowed. No explicit sexual content;

In [60]:
print(len(result.text))
print(result.title)
print(result.text)

2245
The Enchanted Ballroom
In the heart of a grand old castle stood the Enchanted Ballroom, a place where magic danced with every step. Its walls shimmered like polished silver, and its floors glistened under the soft glow of neon lights that painted the air in hues of blue and green. The ballroom was more than just a room; it was a realm where dreams came true.

Princess Elara, with her long golden hair cascading down to her waist, stood before a full-length mirror, admiring her reflection as she adjusted the lace of her dress. Her emerald eyes sparkled with anticipation for the evening’s grand ball. Tonight, the enchanted ballroom would host a dance that promised to be magical and memorable.

Prince Liam, dressed in a suit that shone like polished silver, strode into the ballroom, his gaze sweeping over the room with an air of confidence and grace. His deep blue eyes met Elara’s as he approached, and she felt her heart flutter. He wore a smile that made him seem both regal and appro

In [None]:
wpm = flavour_to_wpm[StoryFlavor.FAIRY_TALE]
minutes = 4.0
speech_margin = 0.95  # need this for pauses and extra effects
max_words = int(wpm * minutes * speech_margin)
max_words