# Interactive Story Generator - Complete Workflow

This notebook creates interactive stories for SubQuest using a combination of local and external LLMs for optimal results.

## Complete Workflow:
1. **Setup Ollama** with Llama 3.1:8B for story concept refinement
2. **Connect Google Drive** for automatic file storage (optional)
3. **Configure Story Parameters** (nodes, choices, themes, atmosphere)
4. **Refine Story Concept** using local Llama 3.1:8B
5. **Generate External LLM Prompt** optimized for ChatGPT/Claude/Gemini
6. **Create Story Structure** via external LLM (recommended) or local generation
7. **Visualize Story Tree** for review and approval
8. **Generate Context-Aware Image Prompts** using story analysis
9. **Create Images** with memory-optimized Stable Diffusion
10. **Export Final JSON** in SubQuest-compatible format

**💡 Recommended Approach:** Use local Llama for concept refinement and prompt generation, then external LLMs (ChatGPT-4, Claude, Gemini) for actual story creation due to their superior context handling and consistency.

**Run each cell in order for the complete experience!**


## 📦 Step 1: Import Required Packages

In [None]:
# Core libraries
import json
import requests
import os
import time
import subprocess
import uuid
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field
from datetime import datetime

# Visualization
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import FancyBboxPatch
import networkx as nx

# Google Colab integration
try:
    from google.colab import drive, files
    IN_COLAB = True
    print("🔗 Running in Google Colab")
except ImportError:
    IN_COLAB = False
    print("💻 Running locally")

# Image generation (will install if needed)
try:
    from PIL import Image
    import torch
    from diffusers import StableDiffusionPipeline, DiffusionPipeline
    IMAGES_AVAILABLE = True
except ImportError:
    IMAGES_AVAILABLE = False
    print("⚠️ Image generation libraries not available - will install when needed")

print("✅ Core packages imported successfully!")
print("🚀 Ready to create interactive stories!")

🔗 Running in Google Colab


## 🔧 Step 2: Install and Setup Ollama

In [None]:
class OllamaSetup:
    def __init__(self):
        self.base_url = "http://localhost:11434"
        self.server_process = None

    def install_ollama(self):
        """Install Ollama with GPU support."""
        print("🔧 Installing Ollama...")

        try:
            # Download and install Ollama
            result = subprocess.run(
                ["curl", "-fsSL", "https://ollama.com/install.sh"],
                capture_output=True, text=True, check=True
            )

            install_result = subprocess.run(
                ["sh"], input=result.stdout,
                capture_output=True, text=True, check=True
            )

            print("✅ Ollama installed successfully!")

            # Enable GPU support
            os.environ['OLLAMA_USE_CUDA'] = '1'
            print("🚀 GPU acceleration enabled")

            return True

        except Exception as e:
            print(f"❌ Installation failed: {e}")
            return False

    def start_server(self):
        """Start Ollama server."""
        print("🔄 Starting Ollama server...")

        try:
            # Start server in background
            self.server_process = subprocess.Popen(
                ["ollama", "serve"],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                env=dict(os.environ, OLLAMA_USE_CUDA='1')
            )

            # Wait for server to start
            time.sleep(5)

            # Test connection
            response = requests.get(f"{self.base_url}/api/tags", timeout=10)
            if response.status_code == 200:
                print("✅ Ollama server is running!")
                return True
            else:
                print(f"❌ Server not responding: {response.status_code}")
                return False

        except Exception as e:
            print(f"❌ Failed to start server: {e}")
            return False

    def check_gpu(self):
        """Check GPU availability."""
        try:
            result = subprocess.run(
                ["nvidia-smi", "--query-gpu=name,memory.total", "--format=csv,noheader"],
                capture_output=True, text=True, check=True
            )

            if result.stdout.strip():
                gpu_info = result.stdout.strip().split(', ')
                print(f"🎮 GPU: {gpu_info[0]}")
                print(f"💾 Memory: {gpu_info[1]}")
                return True
            else:
                print("⚠️ No GPU detected - will use CPU")
                return False

        except:
            print("⚠️ GPU check failed - will use CPU")
            return False

# Initialize Ollama setup
ollama_setup = OllamaSetup()

# Check GPU first
has_gpu = ollama_setup.check_gpu()

# Install and start Ollama
if ollama_setup.install_ollama():
    if ollama_setup.start_server():
        print("\n🎉 Ollama is ready for model installation!")
        if has_gpu:
            print("💡 GPU detected - you can use larger models for better quality")
        else:
            print("💡 No GPU - recommend using smaller models (1B-3B)")
    else:
        print("\n❌ Server failed to start")
else:
    print("\n❌ Installation failed")

## 🤖 Step 3: Install Llama Model

In [None]:
class ModelManager:
    def __init__(self):
        self.base_url = "http://localhost:11434"
        self.current_model = "llama3.1:8b"
        self.required_model = "llama3.1:8b"

    def setup_required_model(self):
        """Setup the required Llama 3.1:8B model."""
        print("🤖 Setting up Llama 3.1:8B Model")
        print("=" * 50)
        print("📋 This notebook requires Llama 3.1:8B for optimal story concept refinement.")
        print("💡 For actual story generation, we'll create a prompt for external LLMs.")
        print()

        # Check if model is already available
        downloaded_models = self.list_downloaded_models()

        if self.required_model in downloaded_models:
            print(f"✅ {self.required_model} is already available!")
            if self.test_model(self.required_model):
                print("🎉 Model is ready for use!")
                return True

        print(f"📥 Installing {self.required_model}...")
        print("📊 Size: ~5.0GB - This may take several minutes")
        print("💡 This is a one-time download")

        if self.pull_model(self.required_model):
            if self.test_model(self.required_model):
                print(f"🎉 {self.required_model} is ready!")
                return True
            else:
                print(f"❌ Model installed but not working properly")
                return False
        else:
            print(f"❌ Failed to install {self.required_model}")
            return False

    def pull_model(self, model_name: str) -> bool:
        """Download a model from Ollama registry."""
        try:
            response = requests.post(
                f"{self.base_url}/api/pull",
                json={"name": model_name},
                stream=True,
                timeout=600  # 10 minute timeout for 8B model
            )

            if response.status_code == 200:
                print("📦 Downloading...")
                for line in response.iter_lines():
                    if line:
                        try:
                            data = json.loads(line.decode('utf-8'))
                            if 'status' in data:
                                print(f"\r{data['status']}", end='', flush=True)
                            if data.get('status') == 'success':
                                print(f"\n✅ {model_name} downloaded successfully!")
                                return True
                        except json.JSONDecodeError:
                            continue

                print(f"\n✅ {model_name} download completed!")
                return True
            else:
                print(f"\n❌ Failed to download {model_name}: HTTP {response.status_code}")
                return False

        except Exception as e:
            print(f"\n❌ Download error: {e}")
            return False

    def test_model(self, model_name: str) -> bool:
        """Test if model is working."""
        try:
            response = requests.post(
                f"{self.base_url}/api/generate",
                json={
                    "model": model_name,
                    "prompt": "Write one sentence about storytelling.",
                    "stream": False
                },
                timeout=30
            )

            if response.status_code == 200:
                result = response.json()
                text = result.get('response', '').strip()
                if text:
                    print(f"✅ Model test successful!")
                    return True
            return False
        except:
            return False

    def list_downloaded_models(self) -> list:
        """Get list of downloaded models."""
        try:
            response = requests.get(f"{self.base_url}/api/tags", timeout=10)
            if response.status_code == 200:
                data = response.json()
                return [model['name'] for model in data.get('models', [])]
            return []
        except:
            return []

# Initialize and setup the required model
model_manager = ModelManager()

if model_manager.setup_required_model():
    print(f"\n🎉 Ready to proceed with {model_manager.current_model}!")
else:
    print(f"\n❌ Could not setup required model. Please check your Ollama installation.")


## 💾 Step 4: Setup Google Drive Connection

In [None]:
class DriveManager:
    def __init__(self):
        self.drive_mounted = False
        self.drive_path = "/content/drive" if IN_COLAB else "./local_drive"
        self.project_folder = None
        self.use_drive = False

    def setup_storage(self):
        """Interactive setup for storage options."""
        print("💾 Storage Setup")
        print("=" * 40)

        if IN_COLAB:
            choice = input("Connect Google Drive for automatic storage? (y/n): ").strip().lower()

            if choice == 'y':
                return self.connect_drive()
            else:
                print("📁 Using local storage (files will be downloadable)")
                self.use_drive = False
                return self.setup_local_storage()
        else:
            print("📁 Running locally - using local storage")
            self.use_drive = False
            return self.setup_local_storage()

    def connect_drive(self):
        """Connect to Google Drive."""
        print("🔗 Connecting to Google Drive...")

        try:
            drive.mount('/content/drive')
            self.drive_mounted = True
            self.use_drive = True

            # Create project folder
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            self.project_folder = f"/content/drive/MyDrive/SubQuest_Stories/Story_{timestamp}"

            os.makedirs(self.project_folder, exist_ok=True)
            os.makedirs(f"{self.project_folder}/images", exist_ok=True)

            print(f"✅ Google Drive connected!")
            print(f"📁 Project folder: {self.project_folder}")
            return True

        except Exception as e:
            print(f"❌ Drive connection failed: {e}")
            print("📁 Falling back to local storage")
            return self.setup_local_storage()

    def setup_local_storage(self):
        """Setup local storage."""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.project_folder = f"./SubQuest_Story_{timestamp}"

        os.makedirs(self.project_folder, exist_ok=True)
        os.makedirs(f"{self.project_folder}/images", exist_ok=True)

        print(f"✅ Local storage ready!")
        print(f"📁 Project folder: {self.project_folder}")
        return True

    def get_image_base_url(self):
        """Get base URL for images."""
        if self.use_drive:
            # For Drive, we'll use direct file IDs later
            return "drive://"
        else:
            # Ask user for base URL
            print("\n🌐 Image URL Configuration")
            print("For local storage, you need to specify where images will be hosted.")

            base_url = input("Enter base URL for images (e.g., 'https://mysite.com/images/'): ").strip()

            if not base_url:
                base_url = "./images/"  # Relative path default
                print(f"Using default: {base_url}")

            return base_url

# Initialize drive manager
drive_manager = DriveManager()

# Interactive storage setup
if drive_manager.setup_storage():
    image_base_url = drive_manager.get_image_base_url()
    print(f"\n🎉 Storage configured successfully!")
    print(f"📁 Project: {drive_manager.project_folder}")
    print(f"🖼️ Image base URL: {image_base_url}")
else:
    print("\n❌ Storage setup failed")

## ⚙️ Step 5: Configure Story Parameters

In [None]:
@dataclass
class StoryConfig:
    """Complete story configuration."""
    title: str
    description: str
    theme: str
    target_nodes: int
    max_choices_per_node: int
    min_choices_per_node: int
    allow_fan_in: bool  # Multiple paths to same node
    ending_nodes: int
    story_concept: str
    tone: str
    target_audience: str

class StoryConfigurator:
    def __init__(self):
        self.config = None

    def interactive_setup(self):
        """Interactive story configuration."""
        print("⚙️ Story Configuration")
        print("=" * 50)
        print("Let's configure your interactive story parameters!")
        print()

        # Basic story info
        title = input("📖 Story Title: ").strip()
        if not title:
            title = "My Interactive Adventure"

        description = input("📝 Brief Description: ").strip()
        if not description:
            description = "An exciting interactive story"

        theme = input("🎭 Theme/Genre (fantasy, sci-fi, mystery, adventure, etc.): ").strip()
        if not theme:
            theme = "adventure"

        # Story structure
        print("\n📊 Story Structure Configuration:")

        target_nodes = self._get_int_input(
            "🎯 Target number of story nodes (5-15 recommended): ",
            default=8, min_val=3, max_val=20
        )

        max_choices = self._get_int_input(
            "🔀 Maximum choices per node (2-4 recommended): ",
            default=3, min_val=2, max_val=5
        )

        min_choices = self._get_int_input(
            "🔀 Minimum choices per node (2-3 recommended): ",
            default=2, min_val=2, max_val=max_choices
        )

        # Advanced options
        print("\n🔧 Advanced Options:")

        fan_in_choice = input("🔄 Allow fan-in (multiple paths to same node)? (y/n): ").strip().lower()
        allow_fan_in = fan_in_choice == 'y'

        ending_nodes = self._get_int_input(
            "🏁 Number of different endings (2-5 recommended): ",
            default=3, min_val=1, max_val=8
        )

        # Story details
        print("\n✍️ Story Details:")

        story_concept = input("💡 Story Concept (describe your story idea): ").strip()
        if not story_concept:
            story_concept = f"An interactive {theme} story with meaningful choices"

        tone = input("🎨 Tone (serious, humorous, dark, light, etc.): ").strip()
        if not tone:
            tone = "engaging"

        audience = input("👥 Target Audience (all ages, teen, adult, etc.): ").strip()
        if not audience:
            audience = "all ages"

        # Create configuration
        self.config = StoryConfig(
            title=title,
            description=description,
            theme=theme,
            target_nodes=target_nodes,
            max_choices_per_node=max_choices,
            min_choices_per_node=min_choices,
            allow_fan_in=allow_fan_in,
            ending_nodes=ending_nodes,
            story_concept=story_concept,
            tone=tone,
            target_audience=audience
        )

        self.display_config()
        return self.config

    def _get_int_input(self, prompt: str, default: int, min_val: int, max_val: int) -> int:
        """Get integer input with validation."""
        while True:
            try:
                value = input(f"{prompt}(default: {default}): ").strip()
                if not value:
                    return default

                num = int(value)
                if min_val <= num <= max_val:
                    return num
                else:
                    print(f"❌ Please enter a number between {min_val} and {max_val}")

            except ValueError:
                print("❌ Please enter a valid number")

    def display_config(self):
        """Display current configuration."""
        if not self.config:
            print("❌ No configuration available")
            return

        print("\n✅ Story Configuration Summary:")
        print("=" * 50)
        print(f"📖 Title: {self.config.title}")
        print(f"📝 Description: {self.config.description}")
        print(f"🎭 Theme: {self.config.theme}")
        print(f"🎯 Target Nodes: {self.config.target_nodes}")
        print(f"🔀 Choices per Node: {self.config.min_choices_per_node}-{self.config.max_choices_per_node}")
        print(f"🔄 Fan-in Allowed: {'Yes' if self.config.allow_fan_in else 'No'}")
        print(f"🏁 Ending Nodes: {self.config.ending_nodes}")
        print(f"💡 Concept: {self.config.story_concept}")
        print(f"🎨 Tone: {self.config.tone}")
        print(f"👥 Audience: {self.config.target_audience}")
        print("=" * 50)

# Interactive configuration
configurator = StoryConfigurator()
story_config = configurator.interactive_setup()

print("\n🎉 Configuration complete! Ready for prompt generation.")

## 🧠 Step 6: Generate Optimized Prompt with LLM

This step uses your local Llama 3.1:8B model to refine your story concept and create an optimized prompt for external LLMs. The local model analyzes your basic concept and expands it into a rich, detailed story description perfect for interactive storytelling.

The generated prompt is specifically designed for ChatGPT, Claude, or Gemini, which have larger context windows and better consistency for complex story generation.


In [None]:
class ExternalPromptGenerator:
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.base_url = "http://localhost:11434"

    def refine_story_concept(self, config) -> str:
        """Use local LLM to refine the story concept."""
        print("🧠 Refining story concept with Llama 3.1:8B...")

        refinement_prompt = f"""You are a creative writing expert. Take this basic story concept and expand it into a rich, detailed story description perfect for interactive storytelling.

Original Story Details:
- Title: {config.title}
- Theme: {config.theme}
- Basic Concept: {config.story_concept}
- Tone: {config.tone}
- Target Audience: {config.target_audience}

Technical Requirements:
- Will have {config.target_nodes} story nodes
- Each node will have {config.min_choices_per_node}-{config.max_choices_per_node} choices
- Will have {config.ending_nodes} different endings
- Fan-in allowed: {'Yes' if config.allow_fan_in else 'No'}

Create an enhanced story concept that includes:
1. A compelling opening scenario with specific setting details
2. Key characters with clear motivations and backgrounds
3. The central conflict or challenge driving the story
4. Potential plot developments, twists, and complications
5. How different choices lead to meaningful consequences
6. Varied ending possibilities with different outcomes
7. Rich atmospheric and world-building details
8. Specific scenarios where player agency matters

Make it detailed, engaging, and perfect for creating {config.target_nodes} interconnected story nodes. Focus on creating scenarios where choices have real impact on the story's direction and outcome.

Return only the enhanced story concept, nothing else."""

        try:
            response = requests.post(
                f"{self.base_url}/api/generate",
                json={
                    "model": self.model_name,
                    "prompt": refinement_prompt,
                    "stream": False,
                    "options": {
                        "temperature": 0.8,
                        "top_p": 0.9,
                        "num_ctx": 4096
                    }
                },
                timeout=120
            )

            if response.status_code == 200:
                result = response.json()
                refined_concept = result.get('response', '').strip()

                if refined_concept:
                    print("✅ Story concept refined!")
                    return refined_concept
                else:
                    print("❌ Empty response, using original concept")
                    return config.story_concept
            else:
                print(f"❌ Refinement failed: HTTP {response.status_code}")
                return config.story_concept

        except Exception as e:
            print(f"❌ Error refining concept: {e}")
            return config.story_concept

    def generate_external_prompt(self, config, refined_concept: str) -> str:
        """Generate a complete prompt for external LLMs like ChatGPT or Gemini."""

        prompt_text = f"""# Interactive Story Generation Task

You are tasked with creating a complete interactive story in JSON format. Follow these specifications exactly:

## Story Requirements

**Basic Information:**
- Title: {config.title}
- Description: {config.description}
- Theme: {config.theme}
- Tone: {config.tone}
- Target Audience: {config.target_audience}

**Enhanced Story Concept:**
{refined_concept}

**Technical Specifications:**
- Total nodes: {config.target_nodes}
- Choices per non-ending node: {config.min_choices_per_node} to {config.max_choices_per_node}
- Number of different endings: {config.ending_nodes}
- Fan-in allowed: {'Yes' if config.allow_fan_in else 'No'}
- All ending nodes must be reachable from the start node

## Critical Requirements

1. RETURN ONLY THE JSON - No explanations, no markdown formatting, no additional text
2. Every ending node must be reachable through valid choice paths from the start
3. Use descriptive node IDs (not just numbers like "node1", "node2")
4. Each story node should have 2-3 engaging paragraphs of content
5. Choices must have meaningful consequences that affect the story outcome

## Exact JSON Format Required

Return ONLY the JSON in this exact format:

{{
  "title": "{config.title}",
  "description": "{config.description}",
  "startNodeId": "opening_scene",
  "nodes": {{
    "opening_scene": {{
      "id": "opening_scene",
      "title": "Descriptive Scene Title",
      "content": "Engaging story content here. Write 2-3 paragraphs that immerse the reader in the scene. Include sensory details, character thoughts, and atmospheric elements that bring the story to life.",
      "imageUrl": "placeholder.jpg",
      "choices": [
        {{
          "id": "choice_1",
          "text": "Descriptive choice that shows consequences",
          "nextNodeId": "consequence_node_1"
        }},
        {{
          "id": "choice_2",
          "text": "Alternative choice with different outcome",
          "nextNodeId": "consequence_node_2"
        }}
      ]
    }},
    "consequence_node_1": {{
      "id": "consequence_node_1",
      "title": "Result of First Choice",
      "content": "Content showing the consequences of the first choice. Continue the narrative based on the player's decision.",
      "imageUrl": "placeholder.jpg",
      "choices": [
        {{
          "id": "choice_3",
          "text": "Next decision point",
          "nextNodeId": "ending_1"
        }}
      ]
    }},
    "ending_1": {{
      "id": "ending_1",
      "title": "One Possible Ending",
      "content": "A complete ending that resolves the story based on the choices made. Make it satisfying and conclusive.",
      "imageUrl": "placeholder.jpg",
      "isEnd": true
    }}
  }}
}}

## Validation Rules

- All nextNodeId values must reference actual node IDs in the nodes object
- Ending nodes must have "isEnd": true and NO choices array
- Non-ending nodes must NOT have the isEnd field
- Every node must have: id, title, content, imageUrl
- Every choice must have: id, text, nextNodeId
- Ensure exactly {config.target_nodes} total nodes
- Ensure exactly {config.ending_nodes} ending nodes
- Make sure all endings are reachable through different choice paths

## Content Guidelines

- Write engaging, immersive content for each node
- Make choices meaningful - they should lead to genuinely different outcomes
- Include rich descriptions that help visualize the scenes
- Maintain consistency with the {config.theme} theme and {config.tone} tone
- Keep the {config.target_audience} audience in mind
- Each node should advance the story or develop character/plot

Generate the complete interactive story now. Remember: RETURN ONLY THE JSON, nothing else."""

        return prompt_text

    def display_external_prompt(self, prompt: str):
        """Display the prompt for copying to external LLMs."""
        print("\n" + "="*80)
        print("🚀 EXTERNAL LLM PROMPT READY")
        print("="*80)
        print("📋 Copy the prompt below and paste it into ChatGPT, Claude, or Gemini:")
        print("💡 These models have larger context windows and can handle complex stories better.")
        print("\n" + "-"*80)
        print(prompt)
        print("-"*80)
        print("\n📝 Instructions:")
        print("1. Copy the entire prompt above")
        print("2. Paste it into your preferred LLM (ChatGPT-4, Claude, Gemini Pro)")
        print("3. Wait for the JSON response")
        print("4. Copy the JSON output")
        print("5. Return to this notebook and paste it in the next step")
        print("\n⚠️ Important: Make sure you get ONLY the JSON in the response!")
        print("="*80)

    def review_and_approve_concept(self, original_concept: str, refined_concept: str) -> str:
        """Allow user to review and modify the refined concept."""
        print("\n📋 Story Concept Refinement Review")
        print("=" * 60)
        print("ORIGINAL CONCEPT:")
        print("-" * 30)
        print(original_concept)
        print("\n" + "="*60)
        print("REFINED CONCEPT:")
        print("-" * 30)
        print(refined_concept)
        print("=" * 60)

        while True:
            choice = input("\nApprove refined concept? (y)es / (e)dit / (o)riginal / (r)egenerate: ").strip().lower()

            if choice in ['y', 'yes']:
                print("✅ Refined concept approved!")
                return refined_concept
            elif choice in ['e', 'edit']:
                print("\n✏️ Enter your custom concept:")
                custom_concept = input("> ")
                if custom_concept.strip():
                    print("✅ Custom concept set!")
                    return custom_concept
                else:
                    print("❌ Empty concept, keeping refined version")
                    return refined_concept
            elif choice in ['o', 'original']:
                print("✅ Using original concept")
                return original_concept
            elif choice in ['r', 'regenerate']:
                print("🔄 Regenerating concept...")
                return self.refine_story_concept(story_config)
            else:
                print("❌ Invalid choice. Please enter 'y', 'e', 'o', or 'r'")

# Generate refined concept and external prompt
prompt_generator = ExternalPromptGenerator(model_manager.current_model)

# Step 1: Refine the story concept
print("🎯 Step 1: Refining your story concept...")
refined_concept = prompt_generator.refine_story_concept(story_config)

# Step 2: Review and approve the concept
final_story_concept = prompt_generator.review_and_approve_concept(
    story_config.story_concept,
    refined_concept
)

# Step 3: Generate external prompt
print("\n🎯 Step 2: Generating external LLM prompt...")
external_prompt = prompt_generator.generate_external_prompt(story_config, final_story_concept)

# Step 4: Display the prompt
prompt_generator.display_external_prompt(external_prompt)

print("\n✅ External prompt generated! Use it with ChatGPT, Claude, or Gemini for best results.")
print("🚀 Ready to proceed to Step 7 for story generation!")


## 📚 Step 7: Generate Story Structure

Choose between two generation methods:

1. **Local Generation (Llama 3.1:8B)**: Fully automated but limited by context window
   - Best for: Simple stories (5-8 nodes)
   - Pros: No manual steps, immediate results
   - Cons: May struggle with complex narratives or consistency

2. **External LLM Generation (RECOMMENDED)**: Manual copy-paste but superior quality
   - Best for: Complex stories (8+ nodes), better narrative consistency
   - Pros: Larger context windows, better story coherence, handles complex branching
   - Cons: Requires manual copy-paste step

**💡 Recommendation:** Use external LLM generation for best results, especially for
stories with multiple endings or complex choice consequences.



In [None]:
# Check if Ollama server is running and restart if needed
def check_and_restart_ollama():
    """Check if Ollama is running and restart if needed."""
    print("🔍 Checking Ollama server status...")

    try:
        # Test connection
        response = requests.get("http://localhost:11434/api/tags", timeout=5)
        if response.status_code == 200:
            print("✅ Ollama server is running!")
            return True
    except:
        pass

    print("❌ Ollama server not running. Starting it...")

    try:
        # Start Ollama server in background
        import subprocess
        import time

        # Kill any existing Ollama processes
        subprocess.run(["pkill", "-f", "ollama"], capture_output=True)
        time.sleep(2)

        # Start new server
        server_process = subprocess.Popen(
            ["ollama", "serve"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            env=dict(os.environ, OLLAMA_USE_CUDA='1')
        )

        # Wait for server to start
        print("⏳ Starting Ollama server...")
        time.sleep(10)

        # Test connection again
        response = requests.get("http://localhost:11434/api/tags", timeout=10)
        if response.status_code == 200:
            print("✅ Ollama server restarted successfully!")
            return True
        else:
            print("❌ Server started but not responding properly")
            return False

    except Exception as e:
        print(f"❌ Failed to start Ollama server: {e}")
        return False

# Check server before generating story
if check_and_restart_ollama():
    print("🚀 Ready to generate story!")
else:
    print("❌ Cannot start Ollama server. Please check your installation.")

@dataclass
class Choice:
    """Represents a choice in the story."""
    id: str
    text: str
    nextNodeId: str

@dataclass
class StoryNode:
    """Represents a node in the story graph."""
    id: str
    title: str
    content: str
    imageUrl: str
    choices: List[Choice] = field(default_factory=list)
    isEnd: bool = False

@dataclass
class Story:
    """Represents the complete story structure."""
    title: str
    description: str
    startNodeId: str
    nodes: Dict[str, StoryNode] = field(default_factory=dict)

class StoryGenerationManager:
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.base_url = "http://localhost:11434"

    def choose_generation_method(self, config, refined_concept: str):
        """Let user choose between internal or external generation."""
        print("🎯 Story Generation Options")
        print("=" * 50)
        print("Choose how you want to generate your story:")
        print()
        print("1. 📱 Generate in Notebook (Llama 3.1:8B)")
        print("   ✅ Fully automated workflow")
        print("   ⚠️ Limited by local model context window")
        print("   💡 Best for smaller stories (5-8 nodes)")
        print()
        print("2. 🌐 Generate with External LLM (ChatGPT/Claude/Gemini)")
        print("   ✅ Larger context windows, better quality")
        print("   ✅ Can handle complex stories (10+ nodes)")
        print("   ⚠️ Requires manual copy-paste step")
        print()

        while True:
            choice = input("Choose generation method (1 or 2): ").strip()

            if choice == "1":
                print("🚀 Using internal generation with Llama 3.1:8B")
                return self.generate_internal_story(config, refined_concept)
            elif choice == "2":
                print("🌐 Preparing external LLM prompt")
                return self.setup_external_generation(config, refined_concept)
            else:
                print("❌ Please enter 1 or 2")

    def generate_internal_story(self, config, refined_concept: str):
        """Generate story using local Llama model."""
        print("\n📚 Generating story with local Llama 3.1:8B...")

        # Create internal prompt (simpler for local model)
        prompt = self.create_internal_prompt(config, refined_concept)

        max_attempts = 3
        for attempt in range(max_attempts):
            if attempt > 0:
                print(f"🔄 Attempt {attempt + 1}/{max_attempts}...")

            try:
                response = requests.post(
                    f"{self.base_url}/api/generate",
                    json={
                        "model": self.model_name,
                        "prompt": prompt,
                        "stream": False,
                        "options": {
                            "temperature": 0.3,
                            "top_p": 0.8,
                            "num_ctx": 8192,
                            "num_predict": 6000
                        }
                    },
                    timeout=300
                )

                if response.status_code == 200:
                    result = response.json()
                    story_text = result.get('response', '').strip()

                    if story_text:
                        print("✅ Story generated! Parsing JSON...")
                        story = self.parse_story_json(story_text)
                        if story:
                            return story
                        else:
                            print(f"❌ Attempt {attempt + 1} failed, trying again...")
                            continue
                    else:
                        print("❌ Empty response from LLM")
                        continue
                else:
                    print(f"❌ Generation failed: HTTP {response.status_code}")
                    continue

            except Exception as e:
                print(f"❌ Error: {e}")
                continue

        print(f"❌ All {max_attempts} attempts failed")
        print("💡 Try using external LLM generation for better results")
        return None

    def create_internal_prompt(self, config, refined_concept: str) -> str:
        """Create a simpler prompt optimized for local Llama model."""
        return f"""Create an interactive {config.theme} story in JSON format.

Story: {config.title}
Theme: {config.theme}
Concept: {refined_concept}

Requirements:
- {config.target_nodes} total nodes
- {config.ending_nodes} ending nodes
- {config.min_choices_per_node}-{config.max_choices_per_node} choices per node
- All endings must be reachable

Return ONLY this JSON structure:
{{
  "title": "{config.title}",
  "description": "{config.description}",
  "startNodeId": "start",
  "nodes": {{
    "start": {{
      "id": "start",
      "title": "Opening Scene",
      "content": "Story content here (2 paragraphs).",
      "imageUrl": "placeholder.jpg",
      "choices": [
        {{"id": "choice1", "text": "First choice", "nextNodeId": "node2"}},
        {{"id": "choice2", "text": "Second choice", "nextNodeId": "node3"}}
      ]
    }},
    "node2": {{
      "id": "node2",
      "title": "Scene Title",
      "content": "Content based on first choice.",
      "imageUrl": "placeholder.jpg",
      "choices": [{{"id": "choice3", "text": "Next choice", "nextNodeId": "ending1"}}]
    }},
    "ending1": {{
      "id": "ending1",
      "title": "Ending Title",
      "content": "Ending content.",
      "imageUrl": "placeholder.jpg",
      "isEnd": true
    }}
  }}
}}

Generate the complete story JSON now:"""

    def setup_external_generation(self, config, refined_concept: str):
        """Setup external LLM generation."""
        print("🌐 Using the external prompt from Step 6...")
        print("💡 Copy the prompt from Step 6 and use it with ChatGPT/Claude/Gemini")

        # Wait for user to come back with JSON
        return self.input_external_json()

    def input_external_json(self):
        """Get JSON input from user."""
        print("\n📥 Paste your JSON from the external LLM:")
        print("(Press Enter twice when done)")

        json_lines = []
        empty_count = 0

        while True:
            try:
                line = input()
                if line.strip() == "":
                    empty_count += 1
                    if empty_count >= 2:
                        break
                else:
                    empty_count = 0
                    json_lines.append(line)
            except (EOFError, KeyboardInterrupt):
                break

        if not json_lines:
            print("❌ No input received")
            return None

        json_text = "\n".join(json_lines)
        return self.parse_story_json(json_text)

    def parse_story_json(self, story_text: str):
        """Parse JSON and convert to Story object."""
        try:
            # Clean the text
            story_text = story_text.strip()

            # Remove markdown if present
            if "```json" in story_text:
                story_text = story_text.split("```json")[1].split("```")[0]
            elif "```" in story_text:
                story_text = story_text.split("```")[1].split("```")[0]

            # Find JSON boundaries
            json_start = story_text.find('{')
            json_end = story_text.rfind('}') + 1

            if json_start == -1 or json_end == 0:
                print("❌ No valid JSON found")
                return None

            json_text = story_text[json_start:json_end]
            story_data = json.loads(json_text)

            # Validate and convert
            if self.validate_story_data(story_data):
                story = self.convert_to_story_object(story_data)
                if story:
                    print("✅ Story loaded successfully!")
                    print(f"📊 {len(story.nodes)} nodes, {len([n for n in story.nodes.values() if n.isEnd])} endings")
                    return story

            return None

        except json.JSONDecodeError as e:
            print(f"❌ JSON error: {e}")
            return None
        except Exception as e:
            print(f"❌ Processing error: {e}")
            return None

    def validate_story_data(self, data: dict) -> bool:
        """Basic validation of story structure."""
        required = ['title', 'description', 'startNodeId', 'nodes']
        for field in required:
            if field not in data:
                print(f"❌ Missing field: {field}")
                return False

        if data['startNodeId'] not in data['nodes']:
            print("❌ Start node not found")
            return False

        return True

    def convert_to_story_object(self, data: dict):
        """Convert JSON to Story object."""
        try:
            story = Story(
                title=data['title'],
                description=data['description'],
                startNodeId=data['startNodeId']
            )

            for node_id, node_data in data['nodes'].items():
                choices = []
                if 'choices' in node_data:
                    for choice_data in node_data['choices']:
                        choices.append(Choice(
                            id=choice_data['id'],
                            text=choice_data['text'],
                            nextNodeId=choice_data['nextNodeId']
                        ))

                node = StoryNode(
                    id=node_data['id'],
                    title=node_data['title'],
                    content=node_data['content'],
                    imageUrl=node_data.get('imageUrl', 'placeholder.jpg'),
                    choices=choices,
                    isEnd=node_data.get('isEnd', False)
                )

                story.nodes[node_id] = node

            return story

        except Exception as e:
            print(f"❌ Conversion error: {e}")
            return None

# Generate story with choice of method
story_manager = StoryGenerationManager(model_manager.current_model)
generated_story = story_manager.choose_generation_method(story_config, final_story_concept)

if generated_story:
    print(f"\n🎉 Story '{generated_story.title}' ready!")
    print("🚀 Proceeding to visualization...")
else:
    print("\n❌ Story generation failed")



## 🌳 Step 8: Visualize Story Structure

This step creates an interactive visualization of your story's branching structure and displays detailed information about each node. Review the story flow to ensure all paths make sense and all endings are reachable before proceeding to image generation.


In [None]:
import networkx as nx
import matplotlib.pyplot as plt

class StoryVisualizer:
    def __init__(self):
        self.fig = None
        self.ax = None

    def visualize_story_tree(self, story: Story):
        """Create an interactive visualization of the story structure."""
        if not story:
            print("❌ No story to visualize")
            return False

        print("🌳 Creating story tree visualization...")

        # Create directed graph
        G = nx.DiGraph()

        # Add nodes
        for node_id, node in story.nodes.items():
            G.add_node(node_id,
                      title=node.title,
                      is_end=node.isEnd,
                      choices=len(node.choices))

        # Add edges (choices)
        for node_id, node in story.nodes.items():
            for choice in node.choices:
                if choice.nextNodeId in story.nodes:
                    G.add_edge(node_id, choice.nextNodeId,
                              choice_text=choice.text[:30] + "..." if len(choice.text) > 30 else choice.text)

        # Create layout
        try:
            pos = nx.spring_layout(G, k=3, iterations=50)
        except:
            pos = nx.random_layout(G)

        # Create figure
        plt.figure(figsize=(15, 10))

        # Draw nodes with different colors
        start_nodes = [story.startNodeId]
        end_nodes = [node_id for node_id, node in story.nodes.items() if node.isEnd]
        regular_nodes = [node_id for node_id in story.nodes.keys()
                        if node_id not in start_nodes and node_id not in end_nodes]

        # Draw different node types
        if start_nodes:
            nx.draw_networkx_nodes(G, pos, nodelist=start_nodes,
                                 node_color='lightgreen', node_size=1500, alpha=0.8)

        if end_nodes:
            nx.draw_networkx_nodes(G, pos, nodelist=end_nodes,
                                 node_color='lightcoral', node_size=1500, alpha=0.8)

        if regular_nodes:
            nx.draw_networkx_nodes(G, pos, nodelist=regular_nodes,
                                 node_color='lightblue', node_size=1200, alpha=0.8)

        # Draw edges
        nx.draw_networkx_edges(G, pos, edge_color='gray', arrows=True,
                              arrowsize=20, arrowstyle='->', alpha=0.6)

        # Add labels
        labels = {node_id: node_id for node_id in story.nodes.keys()}
        nx.draw_networkx_labels(G, pos, labels, font_size=8, font_weight='bold')

        # Add title and legend
        plt.title(f"Story Structure: {story.title}", fontsize=16, fontweight='bold', pad=20)

        # Create legend
        legend_elements = [
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightgreen',
                      markersize=10, label='Start Node'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightblue',
                      markersize=10, label='Story Node'),
            plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='lightcoral',
                      markersize=10, label='End Node')
        ]
        plt.legend(handles=legend_elements, loc='upper right')

        plt.axis('off')
        plt.tight_layout()
        plt.show()

        # Display detailed node information
        self.display_node_details(story)

        # Fixed input handling
        return self.get_approval()

    def get_approval(self) -> bool:
        """Get user approval with better input handling."""
        print("\n" + "="*50)
        print("📋 STORY STRUCTURE REVIEW")
        print("="*50)
        print("Please review the visualization and node details above.")

        max_attempts = 3
        for attempt in range(max_attempts):
            try:
                print(f"\nAttempt {attempt + 1}/{max_attempts}")
                approval = input("✅ Approve this story structure? (y/n): ").strip().lower()

                if approval in ['y', 'yes']:
                    print("🎉 Story structure approved! Ready for image generation.")
                    return True
                elif approval in ['n', 'no']:
                    print("❌ Story not approved. You may want to regenerate with different parameters.")
                    return False
                else:
                    print("❌ Please enter 'y' for yes or 'n' for no")
                    continue

            except (EOFError, KeyboardInterrupt):
                print("\n⚠️ Input interrupted. Defaulting to approved.")
                return True
            except Exception as e:
                print(f"⚠️ Input error: {e}. Trying again...")
                continue

        print("⚠️ Max attempts reached. Defaulting to approved.")
        return True

    def display_node_details(self, story: Story):
        """Display detailed information about each node."""
        print("\n📋 Detailed Node Information")
        print("=" * 60)

        for node_id, node in story.nodes.items():
            status = "🏁 END" if node.isEnd else f"🔀 {len(node.choices)} choices"
            start_marker = "🚀 START" if node_id == story.startNodeId else ""

            print(f"\n{start_marker} [{node_id}] {node.title} {status}")
            print("-" * 40)

            # Show content preview
            content_preview = node.content[:150] + "..." if len(node.content) > 150 else node.content
            print(f"Content: {content_preview}")

            # Show choices
            if node.choices:
                print("Choices:")
                for i, choice in enumerate(node.choices, 1):
                    print(f"  {i}. {choice.text} → {choice.nextNodeId}")

        print("\n" + "=" * 60)

# Visualize the generated story with fixed input
story_approved = False
if generated_story:
    visualizer = StoryVisualizer()
    story_approved = visualizer.visualize_story_tree(generated_story)
else:
    print("❌ No story available for visualization")

## 🎨 Step 9: Generate Images for Story Nodes

### Step 9A: Context-Aware Image Prompt Generation

This step analyzes your complete story to extract consistent visual elements (characters, settings, atmosphere) and generates CLIP-optimized prompts for each node.

The prompts maintain visual consistency across your entire story while being short enough for optimal Stable Diffusion performance.


In [None]:
# Step 9A: Fixed Context-Aware Prompt Generation

class ContextAwareImagePromptGenerator:
    """Generate image prompts with full story context for consistency."""

    def __init__(self, model_name: str):
        self.model_name = model_name
        self.base_url = "http://localhost:11434"

    def check_and_restart_ollama(self):
        """Check if Ollama is running and restart if needed."""
        print("🔍 Checking Ollama server status...")
        try:
            response = requests.get("http://localhost:11434/api/tags", timeout=5)
            if response.status_code == 200:
                print("✅ Ollama server is running!")
                return True
        except:
            pass

        print("❌ Ollama server not running. Starting it...")
        try:
            subprocess.run(["pkill", "-f", "ollama"], capture_output=True)
            time.sleep(2)

            server_process = subprocess.Popen(
                ["ollama", "serve"],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                env=dict(os.environ, OLLAMA_USE_CUDA='1')
            )

            print("⏳ Starting Ollama server...")
            time.sleep(10)

            response = requests.get("http://localhost:11434/api/tags", timeout=10)
            if response.status_code == 200:
                print("✅ Ollama server restarted successfully!")
                return True
            else:
                print("❌ Server started but not responding properly")
                return False

        except Exception as e:
            print(f"❌ Failed to start Ollama server: {e}")
            return False

    def analyze_full_story_context(self, story: Story, theme: str, selected_atmosphere: str) -> dict:
        """Analyze the complete story to extract consistent visual elements."""
        print("🔍 Analyzing full story for visual consistency...")
        print("📖 HOW STORY CONTEXT WORKS:")
        print("   1. Reading ALL story nodes to find common elements")
        print("   2. Extracting character types that appear across multiple nodes")
        print("   3. Identifying settings that repeat in the story")
        print("   4. Using YOUR selected atmosphere (not auto-detected)")

        # Collect ALL story content for analysis
        all_content = []
        all_titles = []
        print(f"\n📚 Analyzing {len(story.nodes)} story nodes:")

        for i, (node_id, node) in enumerate(story.nodes.items(), 1):
            all_content.append(node.content.lower())
            all_titles.append(node.title.lower())
            print(f"   {i}. {node.title} - {len(node.content)} chars")

        # Combine all text for pattern detection
        full_text = " ".join(all_content + all_titles)
        print(f"\n🔍 Total text analyzed: {len(full_text)} characters")

        # Extract consistent characters across the story
        characters = {}
        character_patterns = {
            'detective': ['detective', 'investigator', 'agent', 'cop'],
            'businessman': ['businessman', 'executive', 'trader', 'wall street', 'corporate'],
            'bartender': ['bartender', 'server', 'barkeeper', 'bar'],
            'criminal': ['criminal', 'suspect', 'thief', 'gangster']
        }

        print(f"\n👥 Character Analysis:")
        for char_type, keywords in character_patterns.items():
            count = sum(1 for keyword in keywords if keyword in full_text)
            if count > 0:
                characters[char_type] = f"a {char_type}"
                print(f"   ✅ {char_type}: found {count} references")

        # Extract consistent settings across the story
        settings = {}
        setting_patterns = {
            'speakeasy': ['speakeasy', 'bar', 'tavern', 'drinking'],
            'war_room': ['war room', 'briefing', 'command', 'intel'],
            'office': ['office', 'corporate', 'building', 'headquarters'],
            'rooftop': ['rooftop', 'roof', 'building top']
        }

        print(f"\n🏢 Setting Analysis:")
        for setting_type, keywords in setting_patterns.items():
            count = sum(1 for keyword in keywords if keyword in full_text)
            if count > 0:
                settings[setting_type] = f"a {setting_type.replace('_', ' ')}"
                print(f"   ✅ {setting_type}: found {count} references")

        context = {
            'characters': characters,
            'settings': settings,
            'atmosphere': selected_atmosphere,  # Use the user's selected atmosphere!
            'theme': theme,
            'story_title': story.title
        }

        print(f"\n✅ Story context extracted:")
        print(f"   Characters: {list(characters.keys())}")
        print(f"   Settings: {list(settings.keys())}")
        print(f"   Atmosphere: {selected_atmosphere}")

        return context

    def generate_context_aware_prompt(self, node: StoryNode, story_context: dict) -> str:
        """Generate SHORT prompt with story context - CLIP-friendly (under 77 tokens)."""

        # Create a focused prompt that tells LLM to be concise
        context_prompt = f"""Create a SHORT visual prompt (under 40 words) for this story scene.

STORY CONTEXT (use for consistency):
Characters: {', '.join(story_context['characters'].keys())}
Settings: {', '.join(story_context['settings'].keys())}
Atmosphere: {story_context['atmosphere']}

CURRENT SCENE:
Title: {node.title}
Content: {node.content[:300]}...

Create a concise visual description focusing on:
1. Main character (if any from the context list)
2. Setting (if any from the context list)
3. Key visual moment
4. Atmosphere

Keep it under 40 words. No style terms - just the visual scene.

Visual description:"""

        max_attempts = 3
        for attempt in range(max_attempts):
            if attempt > 0:
                print(f"🔄 Attempt {attempt + 1}/{max_attempts}...")

            try:
                response = requests.post(
                    f"{self.base_url}/api/generate",
                    json={
                        "model": self.model_name,
                        "prompt": context_prompt,
                        "stream": False,
                        "options": {
                            "temperature": 0.7,
                            "top_p": 0.9,
                            "num_ctx": 4096,
                            "num_predict": 60  # Limit response length
                        }
                    },
                    timeout=120
                )

                if response.status_code == 200:
                    result = response.json()
                    generated_prompt = result.get('response', '').strip()

                    if generated_prompt and len(generated_prompt) > 10:
                        # Clean up response - remove any prefixes
                        generated_prompt = generated_prompt.replace('"', '').strip()

                        # Remove common prefixes
                        prefixes_to_remove = [
                            'visual description:', 'scene:', 'here is', 'the scene shows',
                            'visual prompt:', 'description:', 'scene title:', '**', '*'
                        ]

                        for prefix in prefixes_to_remove:
                            if generated_prompt.lower().startswith(prefix):
                                generated_prompt = generated_prompt[len(prefix):].strip()

                        # Take first sentence if multiple
                        sentences = generated_prompt.split('.')
                        if sentences:
                            generated_prompt = sentences[0].strip()

                        # Ensure it's not too long (rough token estimate: 1 token ≈ 4 chars)
                        if len(generated_prompt) > 200:  # ~50 tokens
                            words = generated_prompt.split()
                            generated_prompt = ' '.join(words[:35])  # Limit to ~35 words

                        return generated_prompt
                else:
                    print(f"❌ Generation failed: HTTP {response.status_code}")

            except Exception as e:
                print(f"❌ Error: {e}")

        return None

    def create_context_aware_fallback(self, node: StoryNode, story_context: dict) -> str:
        """Create SHORT fallback prompts using story context."""
        title = node.title.lower()
        content = node.content.lower()

        # Find matching character (short description)
        character = "a figure"
        for char_name in story_context['characters'].keys():
            if char_name in content or char_name in title:
                character = f"a {char_name}"
                break

        # Find matching setting (short description)
        setting = "indoors"
        for setting_name in story_context['settings'].keys():
            if setting_name.replace('_', ' ') in content or setting_name.replace('_', ' ') in title:
                setting = f"in a {setting_name.replace('_', ' ')}"
                break

        # Create short scene description
        return f"{character} {setting}"

# Configure style guide with SHORT options
print("🎨 Style Guide Configuration")
print("=" * 50)
print("Choose a style for consistent image generation:")
print()
print("1. Film Noir (dramatic shadows, high contrast)")
print("2. Modern Cinematic (professional, sleek)")
print("3. Photorealistic (camera-like, realistic)")
print("4. Digital Art (concept art style)")
print("5. Custom (enter your own)")

style_choice = input("Choose style (1-5, default: 1): ").strip()

# Create SHORT style guides (CLIP-friendly)
if style_choice == "2":
    style_guide = "cinematic, professional lighting, sharp focus"
    atmosphere = "modern cinematic atmosphere with professional aesthetics"
elif style_choice == "3":
    style_guide = "photorealistic, professional photography, dramatic lighting"
    atmosphere = "realistic photographic atmosphere"
elif style_choice == "4":
    style_guide = "digital art, concept art style, detailed"
    atmosphere = "artistic digital atmosphere"
elif style_choice == "5":
    custom_style = input("Enter your SHORT style guide (under 10 words): ").strip()
    style_guide = custom_style if custom_style else "cinematic, dramatic lighting"
    atmosphere = f"{custom_style} atmosphere"
else:
    style_guide = "film noir, dramatic shadows, high contrast"
    atmosphere = "film noir atmosphere with dramatic shadows"

print(f"✅ Selected style: {style_guide}")
print(f"✅ Atmosphere: {atmosphere}")

# Generate context-aware prompts with PROPER style
image_prompts = {}
if generated_story and model_manager.current_model:
    print(f"\n🧠 Context-Aware Prompt Generation for {len(generated_story.nodes)} nodes")

    # Setup prompt generator
    prompt_generator = ContextAwareImagePromptGenerator(model_manager.current_model)

    # Check Ollama server
    if prompt_generator.check_and_restart_ollama():
        # Analyze full story context with CORRECT atmosphere
        story_context = prompt_generator.analyze_full_story_context(
            generated_story,
            story_config.theme,
            atmosphere  # Pass the correct atmosphere!
        )

        print("🚀 Generating context-aware prompts...")

        # Generate prompts for each node
        failed_prompts = []
        total_nodes = len(generated_story.nodes)

        for i, (node_id, node) in enumerate(generated_story.nodes.items(), 1):
            print(f"\n📝 [{i}/{total_nodes}] Processing: {node_id}")
            print(f"📖 Title: {node.title}")

            try:
                # Generate context-aware prompt (WITHOUT style guide)
                scene_description = prompt_generator.generate_context_aware_prompt(node, story_context)

                if scene_description:
                    # Add style guide to create final prompt
                    final_prompt = f"{scene_description}, {style_guide}"
                    image_prompts[node_id] = final_prompt

                    # Show COMPLETE prompt
                    print(f"✅ Scene: {scene_description}")
                    print(f"🎨 Final: {final_prompt}")
                    print(f"📏 Length: {len(final_prompt)} chars (~{len(final_prompt.split())} words)")
                else:
                    # Use context-aware fallback
                    scene_description = prompt_generator.create_context_aware_fallback(node, story_context)
                    final_prompt = f"{scene_description}, {style_guide}"
                    image_prompts[node_id] = final_prompt
                    failed_prompts.append(node_id)
                    print(f"⚠️ Fallback: {final_prompt}")

            except Exception as e:
                print(f"❌ Error: {str(e)[:50]}")
                scene_description = prompt_generator.create_context_aware_fallback(node, story_context)
                final_prompt = f"{scene_description}, {style_guide}"
                image_prompts[node_id] = final_prompt
                failed_prompts.append(node_id)

            progress = (i / total_nodes) * 100
            print(f"📊 Progress: {progress:.1f}%")

        successful = len(image_prompts) - len(failed_prompts)
        print(f"\n✅ Context-aware prompt generation complete!")
        print(f"   Story-aware prompts: {successful}/{total_nodes}")
        print(f"   Fallbacks used: {len(failed_prompts)}/{total_nodes}")

        # Save prompts to file with COMPLETE prompts visible
        prompts_file = os.path.join(drive_manager.project_folder, "context_aware_prompts.json")
        prompts_data = []

        print(f"\n📋 COMPLETE PROMPTS GENERATED:")
        print("=" * 80)

        for node_id, prompt in image_prompts.items():
            prompts_data.append({
                "node_id": node_id,
                "title": generated_story.nodes[node_id].title,
                "prompt": prompt
            })
            print(f"{node_id}: {prompt}")

        print("=" * 80)

        with open(prompts_file, 'w') as f:
            json.dump(prompts_data, f, indent=2)

        print(f"\n📁 Prompts saved to: {prompts_file}")

    else:
        print("❌ Cannot start Ollama server for prompt generation")
        image_prompts = {}
else:
    print("⏭️ No story available for prompt generation")
    image_prompts = {}

print("\n🎉 Prompt generation phase complete!")
print("💡 All prompts are now CLIP-friendly (under 77 tokens)")


### Step 9B: Memory-Optimized Image Generation

Generates high-quality images using Stable Diffusion XL with aggressive memory management. The system automatically handles VRAM optimization and generates images one at a time to prevent out-of-memory errors. Choose between fast generation (SDXL Turbo) or higher quality (SDXL Base) based on your hardware.

**Important:** Depending on your chosen options, you may encounter out-of-memory errors. If generation fails, restart this step and choose more conservative settings.

**Safe Settings:** SDXL Turbo + 512x512 pixels + Maximum memory optimization
These settings work reliably on most systems with 6GB+ VRAM.


In [None]:
# Install image generation libraries if needed
try:
    from PIL import Image
    import torch
    from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DiffusionPipeline
    import gc
    IMAGES_AVAILABLE = True
except ImportError:
    print("📦 Installing image generation libraries...")
    subprocess.run(["pip", "install", "diffusers", "transformers", "accelerate", "torch", "torchvision", "Pillow"], check=True, capture_output=True)
    from PIL import Image
    import torch
    from diffusers import StableDiffusionPipeline, StableDiffusionXLPipeline, DiffusionPipeline
    import gc
    IMAGES_AVAILABLE = True
    print("✅ Image libraries installed!")

@dataclass
class ImageConfig:
    """Configuration for image generation."""
    width: int = 512
    height: int = 512
    model_type: str = "turbo"  # "turbo" or "base"
    inference_steps: int = 4
    guidance_scale: float = 1.0
    memory_optimization: str = "2"

class MemoryOptimizedImageGenerator:
    def __init__(self):
        self.pipeline = None
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.setup_complete = False
        self.config = None

    def aggressive_memory_cleanup(self):
        """Aggressively free ALL memory before image generation."""
        print("\n🧹 AGGRESSIVE MEMORY CLEANUP")
        print("=" * 50)

        try:
            # 1. Stop Ollama server completely
            print("🔄 Stopping Ollama server...")
            subprocess.run(["pkill", "-f", "ollama"], capture_output=True)
            time.sleep(3)

            # 2. Clear any existing pipeline
            if hasattr(self, 'pipeline') and self.pipeline is not None:
                print("🗑️ Clearing existing pipeline...")
                del self.pipeline
                self.pipeline = None

            # 3. Clear all Python variables that might hold references
            print("🗑️ Clearing Python variables...")
            import sys

            # Clear globals that might hold model references
            globals_to_clear = []
            for name, obj in globals().items():
                if hasattr(obj, '__class__') and any(keyword in str(obj.__class__).lower()
                    for keyword in ['model', 'pipeline', 'transformer', 'diffusion']):
                    globals_to_clear.append(name)

            for name in globals_to_clear:
                if name in globals():
                    del globals()[name]
                    print(f"   Cleared: {name}")

            # 4. Aggressive garbage collection (multiple passes)
            print("🗑️ Running garbage collection...")
            for i in range(5):  # Multiple passes for thorough cleanup
                collected = gc.collect()
                print(f"   Pass {i+1}: collected {collected} objects")

            # 5. Clear CUDA cache completely
            if torch.cuda.is_available():
                print("🗑️ Clearing CUDA cache...")
                torch.cuda.empty_cache()
                torch.cuda.synchronize()

                # Force CUDA memory cleanup
                torch.cuda.ipc_collect()

                # Show memory status
                memory_total = torch.cuda.get_device_properties(0).total_memory / 1024**3
                memory_allocated = torch.cuda.memory_allocated() / 1024**3
                memory_free = memory_total - memory_allocated

                print(f"💾 GPU Memory Status:")
                print(f"   Total: {memory_total:.1f}GB")
                print(f"   Allocated: {memory_allocated:.1f}GB")
                print(f"   Free: {memory_free:.1f}GB")

                if memory_free < 8.0:
                    print("⚠️ Warning: Less than 8GB free - may need smaller images")

            # 6. Clear system cache if possible
            print("🗑️ Clearing system cache...")
            try:
                # Try to clear system cache (Linux/Mac)
                subprocess.run(["sync"], capture_output=True)
            except:
                pass

            print("✅ Aggressive memory cleanup complete!")
            print(f"🎨 Maximum RAM now available for image generation")

            return True

        except Exception as e:
            print(f"⚠️ Memory cleanup warning: {e}")
            print("🔄 Continuing with image generation...")
            return False

    def configure_image_settings(self) -> ImageConfig:
        """Interactive configuration for image generation."""
        print("🎨 Memory-Optimized Image Generation Configuration")
        print("=" * 50)

        # Model selection
        print("🤖 Model Selection:")
        print("1. SDXL Turbo (Fast, 1-4 steps, ~6GB VRAM)")
        print("2. SDXL Base (Slower, 20-30 steps, ~8GB VRAM, higher quality)")

        model_choice = input("Choose model (1-2, default: 1): ").strip()
        if model_choice == "2":
            model_type = "base"
            inference_steps = 25
            guidance_scale = 7.5
            print("✅ Selected: SDXL Base (High Quality)")
        else:
            model_type = "turbo"
            inference_steps = 4
            guidance_scale = 1.0
            print("✅ Selected: SDXL Turbo (Fast)")

        # Image size with memory warnings
        print("\n📐 Image Size (affects VRAM usage):")
        print("1. 512x512 (Fast, ~4GB VRAM)")
        print("2. 768x768 (Better quality, ~6GB VRAM)")
        print("3. 1024x1024 (Best quality, ~10GB VRAM)")

        size_choice = input("Choose size (1-3, default: 1): ").strip()
        if size_choice == "2":
            width, height = 768, 768
            print("⚠️ 768x768 requires ~6GB VRAM")
        elif size_choice == "3":
            width, height = 1024, 1024
            print("⚠️ 1024x1024 requires ~10GB VRAM")
        else:
            width, height = 512, 512
            print("✅ 512x512 is memory-safe")

        # Memory optimization
        print("\n💾 Memory Optimization:")
        print("1. Maximum (safest, CPU offloading)")
        print("2. Balanced (recommended)")
        print("3. Minimal (fastest, highest VRAM usage)")
        memory_choice = input("Choose (1-3, default: 1): ").strip()

        config = ImageConfig(
            width=width,
            height=height,
            model_type=model_type,
            inference_steps=inference_steps,
            guidance_scale=guidance_scale,
            memory_optimization=memory_choice if memory_choice in ['1', '2', '3'] else '1'
        )

        print(f"\n📋 Configuration Summary:")
        print(f"  Model: {config.model_type.upper()}")
        print(f"  Size: {config.width}x{config.height}")
        print(f"  Steps: {config.inference_steps}")
        print(f"  Memory: Level {config.memory_optimization} optimization")

        return config

    def setup_image_generation(self, config: ImageConfig):
        """Setup pipeline with maximum memory optimization."""
        print(f"\n🎨 Setting up {config.model_type.upper()} pipeline with memory optimization...")

        try:
            # Pre-setup cleanup
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            gc.collect()

            # Load model with memory optimization
            if config.model_type == "base":
                print("📥 Loading SDXL Base (this may take a few minutes)...")
                self.pipeline = StableDiffusionXLPipeline.from_pretrained(
                    "stabilityai/stable-diffusion-xl-base-1.0",
                    torch_dtype=torch.float16,
                    use_safetensors=True,
                    variant="fp16" if self.device == "cuda" else None
                )
            else:  # turbo
                print("📥 Loading SDXL Turbo...")
                self.pipeline = DiffusionPipeline.from_pretrained(
                    "stabilityai/sdxl-turbo",
                    torch_dtype=torch.float16,
                    use_safetensors=True,
                    variant="fp16" if self.device == "cuda" else None
                )

            # Move to device
            self.pipeline = self.pipeline.to(self.device)

            # Apply memory optimizations based on level
            if self.device == "cuda":
                print("🔧 Applying memory optimizations...")

                # Always enable attention slicing
                self.pipeline.enable_attention_slicing()

                if config.memory_optimization == "1":  # Maximum
                    print("   - CPU offloading (saves ~4GB VRAM)")
                    self.pipeline.enable_model_cpu_offload()
                    print("   - Sequential CPU offloading")
                    self.pipeline.enable_sequential_cpu_offload()
                    print("   - Attention slicing (slice_size=1)")
                    self.pipeline.enable_attention_slicing(slice_size=1)

                elif config.memory_optimization == "2":  # Balanced
                    print("   - Model CPU offloading")
                    self.pipeline.enable_model_cpu_offload()
                    print("   - Attention slicing")
                    self.pipeline.enable_attention_slicing()

            print("✅ Pipeline ready with memory optimizations!")
            self.setup_complete = True
            self.config = config
            return True

        except Exception as e:
            print(f"❌ Setup failed: {e}")
            return False

    def generate_single_image(self, node_id: str, prompt: str) -> str:
        """Generate EXACTLY ONE image for a single node with memory management."""
        if not self.setup_complete:
            return "placeholder.jpg"

        print(f"\n🎨 Generating image for: {node_id}")
        print(f"📝 Prompt: {prompt}")
        print(f"📏 Prompt length: {len(prompt)} chars, ~{len(prompt.split())} words")

        try:
            # Pre-generation cleanup
            if self.device == "cuda":
                torch.cuda.empty_cache()
            gc.collect()

            # Show memory before generation
            if torch.cuda.is_available():
                memory_before = torch.cuda.memory_allocated() / 1024**3
                print(f"💾 Memory before: {memory_before:.1f}GB")

            # Generate EXACTLY ONE image
            print("🔄 Generating...")
            if self.config.model_type == "turbo":
                # SDXL Turbo settings
                result = self.pipeline(
                    prompt=prompt,
                    num_inference_steps=self.config.inference_steps,
                    guidance_scale=0.0,  # Turbo uses 0.0
                    height=self.config.height,
                    width=self.config.width,
                    num_images_per_prompt=1,  # EXACTLY 1 image
                    generator=torch.Generator(device=self.device).manual_seed(42)
                )
            else:  # base
                # SDXL Base settings
                result = self.pipeline(
                    prompt=prompt,
                    num_inference_steps=self.config.inference_steps,
                    guidance_scale=self.config.guidance_scale,
                    height=self.config.height,
                    width=self.config.width,
                    num_images_per_prompt=1,  # EXACTLY 1 image
                    generator=torch.Generator(device=self.device).manual_seed(42)
                )

            # Extract the single image
            image = result.images[0]  # Get the first (and only) image

            # Save image immediately
            filename = f"{node_id}.png"
            filepath = os.path.join(drive_manager.project_folder, "images", filename)
            image.save(filepath, "PNG", quality=95, optimize=True)

            # Immediate cleanup to free memory
            del image
            del result
            if self.device == "cuda":
                torch.cuda.empty_cache()
            gc.collect()

            # Show memory after cleanup
            if torch.cuda.is_available():
                memory_after = torch.cuda.memory_allocated() / 1024**3
                print(f"💾 Memory after: {memory_after:.1f}GB")

            print(f"✅ Single image saved: {filename}")
            return filename

        except torch.cuda.OutOfMemoryError as e:
            print(f"❌ GPU Out of Memory: {e}")
            print("💡 Try smaller image size or maximum memory optimization")
            return "placeholder.jpg"
        except Exception as e:
            print(f"❌ Generation failed: {e}")
            return "placeholder.jpg"

# Load prompts and generate images with memory optimization
image_results = {}

# First, check if we have prompts
if 'image_prompts' in locals() and image_prompts:
    print(f"\n🎨 Memory-Optimized Image Generation for {len(image_prompts)} nodes")
    print("🔄 Process: Cleanup → Setup → Generate → Cleanup per image")

    # Step 1: Aggressive memory cleanup
    generator = MemoryOptimizedImageGenerator()
    cleanup_success = generator.aggressive_memory_cleanup()

    if cleanup_success:
        # Step 2: Configure image generation
        image_config = generator.configure_image_settings()

        # Step 3: Setup image generation pipeline
        if generator.setup_image_generation(image_config):
            print(f"\n🚀 Generating {len(image_prompts)} images...")
            print("💡 Each image is generated individually to minimize memory usage")

            total_nodes = len(image_prompts)
            successful_images = 0

            for i, (node_id, prompt) in enumerate(image_prompts.items(), 1):
                print(f"\n📸 [{i}/{total_nodes}] Processing: {node_id}")

                # Generate single image with memory management
                filename = generator.generate_single_image(node_id, prompt)
                image_results[node_id] = filename

                if filename != "placeholder.jpg":
                    successful_images += 1

                # Show progress
                progress = (i / total_nodes) * 100
                print(f"📊 Progress: {progress:.1f}% ({successful_images}/{i} successful)")

                # Brief pause between images to let memory settle
                if i < total_nodes:
                    time.sleep(1)

            print(f"\n🎉 Image generation complete!")
            print(f"✅ Successfully generated: {successful_images}/{total_nodes} images")
            print(f"📁 Images saved to: {drive_manager.project_folder}/images/")

        else:
            print("❌ Failed to setup image generation pipeline")

    else:
        print("❌ Memory cleanup failed - proceeding anyway")

else:
    print("⏭️ No prompts available for image generation")
    print("💡 Run the prompt generation step first")

print("\n✅ Memory-optimized image generation phase complete!")

# Final memory cleanup
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    memory_final = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()) / 1024**3
    print(f"💾 Final free memory: {memory_final:.1f}GB")


## 💾 Step 10: Save Final Story with Proper Format

Creates the final SubQuest-compatible JSON file with proper image URLs and story structure. Handles both Google Drive and local storage, providing instructions for setting up shareable image links when needed.


In [None]:
import json
import os
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass, field

class StoryExporter:
    def __init__(self, drive_manager, image_base_url: str):
        self.drive_manager = drive_manager
        self.image_base_url = image_base_url

    def _get_image_url(self, filename: str) -> str:
        """Get the proper image URL based on storage type."""
        if filename == "placeholder.jpg":
            return "https://images.unsplash.com/photo-1518709268805-4e9042af2176?w=800"  # Default placeholder

        if self.drive_manager.use_drive:
            # Create proper Google Drive shareable link
            return self._create_drive_shareable_link(filename)
        else:
            # For local storage, use the base URL
            return f"{self.image_base_url.rstrip('/')}/{filename}"

    def _create_drive_shareable_link(self, filename: str) -> str:
        """Create a proper Google Drive shareable link."""
        try:
            # For now, we'll create a placeholder that users can replace
            # In a full implementation, you'd use Google Drive API to get actual file IDs

            # Check if file exists
            filepath = os.path.join(self.drive_manager.project_folder, "images", filename)
            if os.path.exists(filepath):
                # Create instructions for the user
                print(f"📝 Note: For image '{filename}', you'll need to:")
                print(f"   1. Go to Google Drive: {self.drive_manager.project_folder}/images/")
                print(f"   2. Right-click on {filename} → Share → Copy link")
                print(f"   3. Replace the placeholder URL in the JSON file")

                # Return a placeholder that's easy to find and replace
                return f"https://drive.google.com/file/d/REPLACE_WITH_ACTUAL_FILE_ID_FOR_{filename.replace('.', '_')}/view?usp=sharing"
            else:
                return "https://images.unsplash.com/photo-1518709268805-4e9042af2176?w=800"

        except Exception as e:
            print(f"⚠️ Could not process Drive link for {filename}: {e}")
            return "https://images.unsplash.com/photo-1518709268805-4e9042af2176?w=800"


    def create_subquest_json(self, story: 'Story', image_filenames: Dict[str, str]) -> Dict:
        """Create JSON in exact SubQuest format matching demo_simple.json."""
        if not story:
            return None

        print("📝 Creating SubQuest-compatible JSON...")

        # Build the JSON structure exactly like demo_simple.json
        subquest_json = {
            "title": story.title,
            "description": story.description,
            "startNodeId": story.startNodeId,
            "nodes": {}
        }

        # Convert each node to SubQuest format
        for node_id, node in story.nodes.items():
            # Get image URL
            image_filename = image_filenames.get(node_id, "placeholder.jpg")
            image_url = self._get_image_url(image_filename)

            # Create node in exact format
            node_json = {
                "id": node.id,
                "title": node.title,
                "content": node.content,
                "imageUrl": image_url
            }

            # Add choices if not an ending node
            if not node.isEnd and node.choices:
                node_json["choices"] = [
                    {
                        "id": choice.id,
                        "text": choice.text,
                        "nextNodeId": choice.nextNodeId
                    }
                    for choice in node.choices
                ]

            # Add isEnd flag for ending nodes
            if node.isEnd:
                node_json["isEnd"] = True

            subquest_json["nodes"][node_id] = node_json

        return subquest_json

    def save_story(self, story: 'Story', image_filenames: Dict[str, str]) -> Tuple[bool, str]:
        """Creates the final JSON structure and saves it."""
        print("\n💾 Saving the final story...")

        # Create the SubQuest JSON structure
        subquest_json_data = self.create_subquest_json(story, image_filenames)

        if not subquest_json_data:
            print("❌ Failed to create SubQuest JSON data.")
            return False, ""

        # Define the output file path
        output_filename = f"{story.title.replace(' ', '_').lower()}_story.json"
        output_filepath = os.path.join(self.drive_manager.project_folder, output_filename)

        try:
            # Save the JSON data to the file
            with open(output_filepath, 'w') as f:
                json.dump(subquest_json_data, f, indent=2)

            print(f"✅ Story JSON saved to: {output_filepath}")
            return True, output_filepath

        except Exception as e:
            print(f"❌ Failed to save story JSON: {e}")
            return False, ""


    def display_final_results(self, success: bool, filepath: str, story: 'Story'):
        """Display final results and file locations."""
        print("\n" + "=" * 60)
        print("🎉 STORY GENERATION COMPLETE!")
        print("=" * 60)

        if success:
            print(f"✅ Story successfully saved in SubQuest format!")
            print(f"📁 Project folder: {self.drive_manager.project_folder}")
            print(f"📄 Story JSON: {os.path.basename(filepath)}")
            print(f"🖼️ Images folder: images/")

            # Count generated images
            images_dir = os.path.join(self.drive_manager.project_folder, "images")
            if os.path.exists(images_dir):
                image_count = len([f for f in os.listdir(images_dir) if f.endswith(('.png', '.jpg', '.jpeg'))])
                print(f"📸 Generated images: {image_count}")

            print(f"\n📊 Story Statistics:")
            print(f"  - Title: {story.title}")
            print(f"  - Total nodes: {len(story.nodes)}")
            print(f"  - Ending nodes: {len([n for n in story.nodes.values() if n.isEnd])}")

            if self.drive_manager.use_drive:
                print(f"\n☁️ Files saved to Google Drive")
                print(f"💡 You can access them in: MyDrive/SubQuest_Stories/")
                print(f"\n🔗 Image URL Setup:")
                print(f"  1. Go to your Google Drive folder")
                print(f"  2. For each image, right-click → Share → Copy link")
                print(f"  3. Replace the placeholder URLs in the JSON file")
                print(f"  4. Make sure images are set to 'Anyone with the link can view'")
            else:
                print(f"\n💻 Files saved locally")
                print(f"💡 Upload the JSON file to your SubQuest app")

        print("\n" + "=" * 60)
        print("🚀 Next Steps:")
        print("  1. Review the generated story JSON file")
        print("  2. Update Google Drive image links if needed")
        print("  3. Test the story in your SubQuest app")
        print("  4. Share your interactive story with users!")
        print("\n🎉 Thank you for using the Interactive Story Generator!")

# Save the final story
if generated_story:
    exporter = StoryExporter(drive_manager, image_base_url)
    success, result = exporter.save_story(generated_story, image_results)
    exporter.display_final_results(success, result, generated_story)
else:
    print("❌ No story available to save")