In [1]:
# Install necessary libraries
!pip install -q transformers datasets sentence-transformers faiss-gpu chromadb tqdm openai gradio

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m19.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.0/18.0 MB[0m [31m80.8 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m52.3 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m94.9/94.9 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.5/46.5 MB[0m [31m34.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
!pip install --upgrade openai

Collecting openai
  Downloading openai-1.70.0-py3-none-any.whl.metadata (25 kB)
Downloading openai-1.70.0-py3-none-any.whl (599 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m599.1/599.1 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hInstalling collected packages: openai
  Attempting uninstall: openai
    Found existing installation: openai 1.57.4
    Uninstalling openai-1.57.4:
      Successfully uninstalled openai-1.57.4
Successfully installed openai-1.70.0


In [3]:
# Import libraries
import torch
import numpy as np
import pandas as pd
import os
import json
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
import chromadb
from tqdm.notebook import tqdm
import openai
import gradio as gr

In [4]:
def safe_process_chroma_results(results, process_func, default_message="No data available"):
    """Safely process results from ChromaDB query
    
    Args:
        results: The results from ChromaDB query
        process_func: Function to process each metadata item
        default_message: Default message if processing fails
        
    Returns:
        Processed string representation
    """
    try:
        if not results or 'metadatas' not in results:
            return default_message
            
        metadatas = results['metadatas']
        if not metadatas:
            return default_message
            
        
        if isinstance(metadatas, list) and metadatas and isinstance(metadatas[0], dict):
           
            return "\n".join([process_func(item) for item in metadatas])
        elif isinstance(metadatas, list):
            
            if metadatas:
                return f"Data available but in unexpected format ({len(metadatas)} items)"
            else:
                return default_message
        else:
            return default_message
    except Exception as e:
        print(f"Error in safe_process_chroma_results: {e}")
        return default_message

In [5]:

class StoryMemory:
    def __init__(self, embedding_model_name="all-MiniLM-L6-v2"):
        
        self.embedding_model = SentenceTransformer(embedding_model_name)
        
        
        self.client = chromadb.Client()
        
        
        try:
            self.char_collection = self.client.get_or_create_collection("characters")
            self.plot_collection = self.client.get_or_create_collection("plots")
            self.setting_collection = self.client.get_or_create_collection("settings")
            
            
            self.subplot_collection = self.client.get_or_create_collection("subplots")
            self.theme_collection = self.client.get_or_create_collection("themes")
            
            self.character_arc_collection = self.client.get_or_create_collection("character_arcs")
            
        except Exception as e:
            print(f"Error initializing collections: {e}")
            
            try:
                self.client.delete_collection("characters")
                self.client.delete_collection("plots")
                self.client.delete_collection("settings")
                self.client.delete_collection("subplots")
                self.client.delete_collection("themes")
                self.client.delete_collection("character_arcs")
                
                self.char_collection = self.client.create_collection("characters")
                self.plot_collection = self.client.create_collection("plots")
                self.setting_collection = self.client.create_collection("settings")
                self.subplot_collection = self.client.create_collection("subplots")
                self.theme_collection = self.client.create_collection("themes")
                self.character_arc_collection = self.client.create_collection("character_arcs")
            except Exception as e2:
                print(f"Error recreating collections: {e2}")
                raise
        
    def add_character(self, char_id, char_data):
        """Add or update character information in the memory"""
        char_text = f"Character {char_data['name']}: {json.dumps(char_data)}"
        
        sanitized_metadata = {}
        for key, value in char_data.items():
            if isinstance(value, (str, int, float, bool)):
                sanitized_metadata[key] = value
            elif isinstance(value, list):
                sanitized_metadata[key] = ", ".join(map(str, value))
            else:
                sanitized_metadata[key] = str(value)
        
        self.char_collection.add(
            documents=[char_text],
            metadatas=[sanitized_metadata],
            ids=[f"char_{char_id}"]
        )
    
    
    def update_character_arc(self, char_name, episode_num, development):
        """Track character development/arc across episodes"""
        arc_id = f"arc_{char_name}_{episode_num}"
        arc_text = f"Episode {episode_num} - {char_name}: {development}"
        
        self.character_arc_collection.add(
            documents=[arc_text],
            metadatas=[{
                "character": char_name,
                "episode": episode_num,
                "development": development
            }],
            ids=[arc_id]
        )
        
    def add_plot_point(self, plot_id, episode, description, related_chars=None):
        """Add a plot point to memory"""
        plot_text = f"Episode {episode} - Plot point: {description}"
        
        
        if related_chars and isinstance(related_chars, list):
            related_chars_str = ", ".join(related_chars)
        else:
            related_chars_str = str(related_chars) if related_chars else ""
        
        metadata = {
            "episode": episode,
            "description": description,
            "related_characters": related_chars_str
        }
        
        self.plot_collection.add(
            documents=[plot_text],
            metadatas=[metadata],
            ids=[f"plot_{plot_id}"]
        )

            
    def add_subplot(self, subplot_id, title, description, start_episode, characters_involved):
        """Add a subplot to memory"""
        subplot_text = f"Subplot: {title} - {description}"
        
        metadata = {
            "title": title,
            "description": description,
            "start_episode": start_episode,
            "status": "ongoing",
            "characters_involved": ", ".join(characters_involved) if isinstance(characters_involved, list) else characters_involved
        }
        
        self.subplot_collection.add(
            documents=[subplot_text],
            metadatas=[metadata],
            ids=[f"subplot_{subplot_id}"]
        )
    
    
    def update_subplot(self, subplot_id, status, resolution=None):
        """Update the status of a subplot (ongoing, resolved, etc.)"""

        pass
        
    def add_setting(self, setting_id, setting_data):
        """Add setting information to memory"""
        setting_text = f"Setting: {json.dumps(setting_data)}"
        
        
        sanitized_metadata = {}
        for key, value in setting_data.items():
            if isinstance(value, (str, int, float, bool)):
                sanitized_metadata[key] = value
            elif isinstance(value, list):
                sanitized_metadata[key] = ", ".join(map(str, value))
            else:
                sanitized_metadata[key] = str(value)
        
        self.setting_collection.add(
            documents=[setting_text],
            metadatas=[sanitized_metadata],
            ids=[f"setting_{setting_id}"]
        )
    
    def query_characters(self, query, n_results=5):
        """Query character information"""
        results = self.char_collection.query(
            query_texts=[query],
            n_results=n_results
        )
        return results
    
    def query_plot_points(self, query, n_results=5):
        """Query plot information"""
        results = self.plot_collection.query(
            query_texts=[query],
            n_results=n_results
        )
        return results
    
    
    def get_episode_context(self, episode_num, max_plot_points=10):
        """Get context for generating a specific episode"""
        
        try:
            all_chars = self.char_collection.get()
        except Exception as e:
            print(f"Error getting characters: {e}")
            all_chars = {"documents": [], "metadatas": [], "ids": []}
        
        
        query = f"Important plot points for episode {episode_num}"
        
        
        if episode_num > 1:
            try:
                prev_plots = self.plot_collection.query(
                    query_texts=[query],
                    n_results=max_plot_points,
                    where={"episode": {"$lt": episode_num}}
                )
            except Exception as e:
                print(f"Error querying plot points: {e}")
                prev_plots = {"documents": [], "metadatas": [], "ids": []}
        else:
            prev_plots = {"documents": [], "metadatas": [], "ids": []}
            
        
        try:
            active_subplots = self.subplot_collection.query(
                query_texts=[f"Relevant subplots for episode {episode_num}"],
                n_results=5
            )
        except Exception as e:
            print(f"Error getting subplots: {e}")
            active_subplots = {"documents": [], "metadatas": [], "ids": []}
            
        
        try:
            character_arcs = self.character_arc_collection.query(
                query_texts=[f"Character development before episode {episode_num}"],
                n_results=10,
                where={"episode": {"$lt": episode_num}}
            )
        except Exception as e:
            print(f"Error getting character arcs: {e}")
            character_arcs = {"documents": [], "metadatas": [], "ids": []}
        
        
        print(f"Characters data structure: {list(all_chars.keys())}")
        print(f"Plot points data structure: {list(prev_plots.keys())}")
        
        
        context = {
            "characters": all_chars,
            "previous_plots": prev_plots,
            "current_episode": episode_num,
            "active_subplots": active_subplots,
            "character_arcs": character_arcs
        }
        
        return context

In [6]:
class StoryInputProcessor:
    def __init__(self):
        pass
    
    def process_initial_prompt(self, prompt):
        """Process the initial story prompt/concept"""
        
        story_elements = {
            "genre": self._extract_genre(prompt),
            "setting": self._extract_setting(prompt),
            "main_characters": self._extract_characters(prompt),
            "theme": self._extract_theme(prompt),
            "original_prompt": prompt
        }
        return story_elements
    
    def _extract_genre(self, prompt):

        genres = ["fantasy", "sci-fi", "romance", "thriller", "comedy", "drama", "horror", "mystery", "adventure"]
        for genre in genres:
            if genre.lower() in prompt.lower():
                return genre
        return "general fiction"  # Default
    
    def _extract_setting(self, prompt):
        
        if "future" in prompt.lower() or "space" in prompt.lower():
            return "futuristic"
        elif "medieval" in prompt.lower() or "ancient" in prompt.lower():
            return "historical"
        else:
            return "contemporary"
    
    def _extract_characters(self, prompt):
        
        return []
    
    def _extract_theme(self, prompt):
        themes = ["love", "betrayal", "redemption", "survival", "growth", "friendship"]
        for theme in themes:
            if theme.lower() in prompt.lower():
                return theme
        return "journey"  

In [7]:
class OpenAIStoryGenerator:
    def __init__(self, api_key, model="gpt-3.5-turbo", max_tokens=1024):
        openai.api_key = api_key
        self.model = model
        self.max_tokens = max_tokens

    def generate_story_outline(self, story_elements):
        """Generate an overall story outline"""
        prompt = f"""
        Create a compelling story outline with the following elements:
        Genre: {story_elements['genre']}
        Setting: {story_elements['setting']}
        Theme: {story_elements['theme']}
        Number of episodes: 5

        The outline should include:
        1. Main characters with brief descriptions
        2. Overall story arc
        3. Brief summary of each episode
        4. Key plot points and how they develop across episodes
        5. At least 2 subplots that span multiple episodes
        6. Character development arcs for main characters

        Story Outline:
        """

        response = openai.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            max_tokens=self.max_tokens,
            temperature=0.7
        )
        return response.choices[0].message.content

    def generate_episode(self, episode_num, story_outline, memory_context):
        """Generate a specific episode"""
        

        
        characters = ""
        if 'characters' in memory_context and 'metadatas' in memory_context['characters']:
            try:
                characters = "\n".join([f"- {char.get('name', 'Unknown')}: {char.get('description', 'No description')}"
                                       for char in memory_context['characters']['metadatas']])
            except Exception as e:
                print(f"Error processing characters: {e}")
                characters = "Characters available but couldn't be processed."
                
            if not characters:
                characters = "No character information available."

       
        previous_plots = ""
        if episode_num > 1 and 'previous_plots' in memory_context:
            
            previous_plots = safe_process_chroma_results(
                memory_context['previous_plots'],
                lambda plot: f"- Episode {plot.get('episode', '?')}: {plot.get('description', 'No description')}",
                "No previous plot information available."
            )

        
        active_subplots = ""
        if 'active_subplots' in memory_context:
            active_subplots = safe_process_chroma_results(
                memory_context['active_subplots'],
                lambda subplot: f"- Subplot: {subplot.get('title', '?')}: {subplot.get('description', 'No description')} (Status: {subplot.get('status', 'ongoing')})",
                "No active subplots."
            )

        
        character_arcs = ""
        if 'character_arcs' in memory_context:
            character_arcs = safe_process_chroma_results(
                memory_context['character_arcs'],
                lambda arc: f"- {arc.get('character', '?')} (Episode {arc.get('episode', '?')}): {arc.get('development', 'No development info')}",
                "No character development information."
            )

        prompt = f"""
        You are writing episode {episode_num} of a multi-episode story.

        Story Outline: {story_outline}

        Characters:
        {characters}

        Previous Important Events:
        {previous_plots}
        
        Active Subplots:
        {active_subplots}
        
        Character Development So Far:
        {character_arcs}

        Write Episode {episode_num} in script format with character dialogue and actions. 
        Make sure to maintain consistency with previous episodes while advancing the plot.
        Ensure character personalities remain consistent with their established traits.
        Continue developing subplots and character arcs naturally.
        End the episode with a hook that leads into the next episode.

        EPISODE {episode_num}:
        """

        response = openai.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            max_tokens=self.max_tokens,
            temperature=0.7
        )
        return response.choices[0].message.content

In [8]:
class LongFormStoryGenerator(OpenAIStoryGenerator):
    def __init__(self, api_key, model="gpt-3.5-turbo", max_tokens=1024):
        super().__init__(api_key, model, max_tokens)
        
    def generate_episode_with_chunking(self, episode_num, story_outline, memory_context, max_chunk_length=10000):
        """Generate a long-form episode by creating it in manageable chunks"""
        
        episode_outline = self._generate_episode_outline(episode_num, story_outline, memory_context)
        
        
        scenes = self._break_into_scenes(episode_outline)
        
        
        complete_episode = f"EPISODE {episode_num}:\n\n"
        
        for i, scene in enumerate(scenes):
            print(f"Generating scene {i+1} of {len(scenes)}...")
            scene_prompt = f"""
            You are writing scene {i+1} of episode {episode_num}.
            
            Scene outline: {scene}
            
            Characters:
            {self._extract_characters_for_scene(scene, memory_context)}
            
            Write this scene in detail with rich dialogue and descriptions. 
            This is part of a longer episode, so focus only on this scene.
            
            SCENE {i+1}:
            """
            
            scene_content = self._call_openai_api(scene_prompt)
            complete_episode += f"\n\n[SCENE {i+1}]\n{scene_content}\n\n"
            
        return complete_episode
    
    def _generate_episode_outline(self, episode_num, story_outline, memory_context):
        """Generate just the outline/plan for the episode"""
        
        previous_episodes_summary = safe_process_chroma_results(
            memory_context.get('previous_plots', {}),
            lambda plot: f"- Episode {plot.get('episode', '?')}: {plot.get('description', 'No description')}",
            "No previous episode information."
        )
        
        prompt = f"""
        Create a detailed scene-by-scene outline for episode {episode_num}.
        
        Story Outline: {story_outline}
        
        Previous Episodes Summary:
        {previous_episodes_summary}
        
        For each scene, include:
        1. Setting
        2. Characters present
        3. Brief description of what happens
        4. How it advances the plot or character development
        
        Format each scene like this:
        SCENE X: [brief title]
        Setting: [location]
        Characters: [list of characters]
        Action: [what happens]
        Purpose: [how it advances story]
        
        Create 5-8 scenes for this episode.
        """
        
        response = self._call_openai_api(prompt)
        return response
        
    def _break_into_scenes(self, episode_outline):
        """Analyze the outline and break it into logical scene chunks"""
        
        import re
        scene_pattern = r'SCENE \d+:'
        scenes = re.split(scene_pattern, episode_outline)
        
        
        if scenes and scenes[0].strip() == '':
            scenes = scenes[1:]
            
        
        for i in range(len(scenes)):
            scenes[i] = f"SCENE {i+1}:{scenes[i]}"
            
        return scenes
    
    def _extract_characters_for_scene(self, scene, memory_context):
        """Extract which characters are likely in this scene"""
        character_info = ""
        
        if 'characters' in memory_context and 'metadatas' in memory_context['characters']:
            try:
                
                chars_in_scene = []
                for char in memory_context['characters']['metadatas']:
                    char_name = char.get('name', '')
                    if char_name and char_name in scene:
                        chars_in_scene.append(f"- {char_name}: {char.get('description', 'No description')}")
                
                character_info = "\n".join(chars_in_scene)
            except Exception as e:
                print(f"Error extracting characters for scene: {e}")
                character_info = "Character information available but couldn't be processed."
        
        return character_info or "No specific character information available."
    
    def _call_openai_api(self, prompt):
        """Helper method to call OpenAI API"""
        response = openai.chat.completions.create(
            model=self.model,
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            max_tokens=self.max_tokens,
            temperature=0.7
        )
        return response.choices[0].message.content

In [9]:

class StoryPostProcessor:
    def __init__(self, memory):
        self.memory = memory
        
        try:
            self.ner_pipeline = pipeline("ner", model="dslim/bert-base-NER")
            self.use_ner = True
            print("NER model loaded successfully!")
        except Exception as e:
            print(f"Error loading NER model: {e}")
            self.use_ner = False
        
    def extract_and_store_elements(self, episode_num, episode_content):
        """Extract story elements from generated content and store in memory"""
        
        characters = self._extract_characters(episode_content)
        for i, char in enumerate(characters):
            
            sanitized_char = self._sanitize_metadata(char)
            self.memory.add_character(f"{episode_num}_{i}", sanitized_char)
        
        
        plot_points = self._extract_plot_points(episode_content)
        for i, plot in enumerate(plot_points):
            
            if isinstance(plot, str):
                related_char_names = [char["name"] for char in characters]
                self.memory.add_plot_point(
                    f"{episode_num}_{i}",
                    episode_num,
                    plot,
                    related_char_names
                )
        
        
        subplots = self._extract_subplots(episode_content, episode_num)
        for i, subplot in enumerate(subplots):
            self.memory.add_subplot(
                f"{episode_num}_{i}",
                subplot["title"],
                subplot["description"],
                episode_num,
                subplot["characters"]
            )
            
        
        character_developments = self._extract_character_developments(episode_content, characters)
        for char_name, development in character_developments.items():
            self.memory.update_character_arc(char_name, episode_num, development)
        
        return {
            "characters": characters,
            "plot_points": plot_points,
            "subplots": subplots,
            "character_developments": character_developments
        }
    
    def _sanitize_metadata(self, metadata_dict):
        """Convert all values to simple types that ChromaDB accepts"""
        sanitized = {}
        for key, value in metadata_dict.items():
            if isinstance(value, (str, int, float, bool)):
                sanitized[key] = value
            elif isinstance(value, list):
                sanitized[key] = ", ".join(map(str, value))
            elif isinstance(value, dict):
                sanitized[key] = json.dumps(value)
            else:
                sanitized[key] = str(value)
        return sanitized
    
    def _extract_characters(self, text):
        """Extract character information from text"""
        if self.use_ner:
            return self._extract_characters_with_ner(text)
        else:
            return self._extract_characters_with_regex(text)
            
    def _extract_characters_with_ner(self, text):
        """Extract characters using NER pipeline"""
        
        try:
            results = self.ner_pipeline(text[:10000])  
            
            
            characters = []
            current_entity = {"name": "", "type": ""}
            
            for entity in results:
                if entity["entity"].startswith("B-PER"):
                    
                    if current_entity["name"] and current_entity["type"] == "PER":
                        characters.append({"name": current_entity["name"].strip()})
                    current_entity = {"name": entity["word"], "type": "PER"}
                elif entity["entity"].startswith("I-PER") and current_entity["type"] == "PER":
                    
                    current_entity["name"] += " " + entity["word"]
            
            
            if current_entity["name"] and current_entity["type"] == "PER":
                characters.append({"name": current_entity["name"].strip()})
                
            
            unique_characters = []
            seen_names = set()
            for char in characters:
                if char["name"] not in seen_names:
                    seen_names.add(char["name"])
                    char["description"] = f"Character appearing in the story"
                    unique_characters.append(char)
                    
            return unique_characters
        except Exception as e:
            print(f"Error in NER character extraction: {e}")
            
            return self._extract_characters_with_regex(text)
    
    def _extract_characters_with_regex(self, text):
        """Extract character information using regex (fallback method)"""
        
        import re
        dialogue_pattern = r'([A-Z][A-Za-z\s]+):'
        character_names = list(set(re.findall(dialogue_pattern, text)))
        
        characters = []
        for name in character_names:
            name = name.strip()
            if name and len(name) > 1:  
                characters.append({
                    "name": name,
                    "description": f"Character appearing in the story",
                    "appearances": "yes"  
                })
        
        return characters
    
    def _extract_plot_points(self, text):
        """Extract key plot points from an episode"""
        
        
        
        
        lines = [line.strip() for line in text.split('\n') if line.strip()]
        
        
        
        plot_points = []
        for i, line in enumerate(lines):
            
            if ':' not in line and len(line) > 50:
                plot_points.append(line[:100] + "...")  
                
            if len(plot_points) >= 3:  
                break
                
        return plot_points
    
    
    def _extract_subplots(self, text, episode_num):
        """Extract subplot information from the episode"""
        
        
        
        
        subplot_keywords = [
            "meanwhile", "on the other hand", "elsewhere", 
            "at the same time", "subplot", "side story"
        ]
        
        lines = [line.strip() for line in text.split('\n') if line.strip()]
        
        subplots = []
        for keyword in subplot_keywords:
            for i, line in enumerate(lines):
                if keyword.lower() in line.lower() and len(line) > 30:
                    
                    if i < len(lines) - 1:
                        context = lines[i] + " " + lines[i+1]
                    else:
                        context = lines[i]
                    
                    
                    chars = self._extract_characters_with_regex(context)
                    char_names = [c["name"] for c in chars]
                    
                    subplots.append({
                        "title": f"Subplot from Episode {episode_num}",
                        "description": context[:100] + "..." if len(context) > 100 else context,
                        "characters": char_names
                    })
        
        
        unique_subplots = []
        seen_descriptions = set()
        for subplot in subplots:
            if subplot["description"] not in seen_descriptions:
                seen_descriptions.add(subplot["description"])
                unique_subplots.append(subplot)
                
            if len(unique_subplots) >= 2:  
                break
                
        return unique_subplots
    
    
    def _extract_character_developments(self, text, characters):
        """Extract character development information"""
        developments = {}
        
        for character in characters:
            char_name = character["name"]
            
            char_mentions = []
            
            
            paragraphs = [p.strip() for p in text.split('\n\n') if p.strip()]
            
            for paragraph in paragraphs:
                if char_name in paragraph and len(paragraph) > 50 and ":" not in paragraph[:len(char_name)+1]:
                    char_mentions.append(paragraph)
            
            
            if char_mentions:
                longest_mention = max(char_mentions, key=len)
                developments[char_name] = longest_mention[:150] + "..." if len(longest_mention) > 150 else longest_mention
        
        return developments

In [10]:

class AIStorytellingPipeline:
    def __init__(self, use_openai=False, openai_api_key=None, use_long_form=False):
        
        self.memory = StoryMemory()
        self.input_processor = StoryInputProcessor()
        
        
        if use_openai:
            if use_long_form:
                self.story_generator = LongFormStoryGenerator(api_key=openai_api_key, model="gpt-3.5-turbo")
            else:
                self.story_generator = OpenAIStoryGenerator(api_key=openai_api_key, model="gpt-3.5-turbo")
        else:
            
            self.story_generator = OpenAIStoryGenerator(api_key=openai_api_key, model="gpt-3.5-turbo")
            
        self.post_processor = StoryPostProcessor(self.memory)
        self.story_data = {}
        self.use_long_form = use_long_form
        
    def create_story(self, initial_prompt, num_episodes=5):
        """Create a complete story based on the initial prompt"""
        
        story_elements = self.input_processor.process_initial_prompt(initial_prompt)
        
        
        print("Generating story outline...")
        story_outline = self.story_generator.generate_story_outline(story_elements)
        self.story_data["outline"] = story_outline
        
        
        initial_elements = self.post_processor.extract_and_store_elements(0, story_outline)
        
        
        self.story_data["episodes"] = []
        for ep_num in range(1, num_episodes + 1):
            print(f"Generating episode {ep_num}...")
            
           
            memory_context = self.memory.get_episode_context(ep_num)
            
            
            if self.use_long_form and hasattr(self.story_generator, 'generate_episode_with_chunking'):
                episode_content = self.story_generator.generate_episode_with_chunking(ep_num, story_outline, memory_context)
            else:
                episode_content = self.story_generator.generate_episode(ep_num, story_outline, memory_context)
            
            
            extracted_elements = self.post_processor.extract_and_store_elements(ep_num, episode_content)
            
            
            self.story_data["episodes"].append({
                "episode_number": ep_num,
                "content": episode_content,
                "extracted_elements": extracted_elements
            })
            
            print(f"Episode {ep_num} completed.")
        
        return self.story_data
    
    def save_story(self, filename="generated_story.json"):
        """Save the generated story to a file"""
        with open(filename, 'w') as f:
            json.dump(self.story_data, f, indent=2)
        print(f"Story saved to {filename}")
    
    def load_story(self, filename="generated_story.json"):
        """Load a previously generated story"""
        with open(filename, 'r') as f:
            self.story_data = json.load(f)
        print(f"Story loaded from {filename}")
        
    def get_story_analytics(self):
        """Generate analytics about the story"""
        if not self.story_data or "episodes" not in self.story_data:
            return {"error": "No story data available"}
            
        analytics = {
            "total_episodes": len(self.story_data["episodes"]),
            "character_consistency": self._analyze_character_consistency(),
            "subplot_progression": self._analyze_subplot_progression(),
            "word_count": self._calculate_word_count(),
            "character_count": self._count_characters()
        }
        
        return analytics
    
    def _analyze_character_consistency(self):
        """Analyze how consistent characters are across episodes"""
        if not self.story_data or "episodes" not in self.story_data:
            return {}
            
        
        character_appearances = {}
        
        for ep in self.story_data["episodes"]:
            ep_num = ep["episode_number"]
            if "extracted_elements" in ep and "characters" in ep["extracted_elements"]:
                for char in ep["extracted_elements"]["characters"]:
                    char_name = char["name"]
                    if char_name not in character_appearances:
                        character_appearances[char_name] = []
                    character_appearances[char_name].append(ep_num)
        
        
        results = {
            "total_characters": len(character_appearances),
            "recurring_characters": sum(1 for chars in character_appearances.values() if len(chars) > 1),
            "character_details": character_appearances
        }
        
        return results
    
    def _analyze_subplot_progression(self):
        """Analyze how subplots progress across episodes"""
        
        return {"subplot_analysis": "Placeholder for subplot analysis"}
    
    def _calculate_word_count(self):
        """Calculate total word count for the story"""
        if not self.story_data:
            return 0
            
        word_count = 0
        
        
        if "outline" in self.story_data:
            word_count += len(self.story_data["outline"].split())
            
        
        if "episodes" in self.story_data:
            for ep in self.story_data["episodes"]:
                if "content" in ep:
                    word_count += len(ep["content"].split())
                    
        return word_count
    
    def _count_characters(self):
        """Count unique characters in the story"""
        if not self.story_data or "episodes" not in self.story_data:
            return 0
            
        unique_chars = set()
        
        for ep in self.story_data["episodes"]:
            if "extracted_elements" in ep and "characters" in ep["extracted_elements"]:
                for char in ep["extracted_elements"]["characters"]:
                    unique_chars.add(char["name"])
                    
        return len(unique_chars)

In [11]:

class StoryControlPanel:
    def __init__(self, pipeline):
        self.pipeline = pipeline
        
    def create_ui(self):
        with gr.Blocks() as app:
            gr.Markdown("# AI Story Generator")
            
            with gr.Tab("Generate Story"):
                prompt_input = gr.Textbox(label="Story Prompt", lines=5)
                num_episodes = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of Episodes")
                
                with gr.Row():
                    generate_btn = gr.Button("Generate Story")
                    
                with gr.Accordion("Advanced Settings", open=False):
                    model_choice = gr.Dropdown(["gpt-3.5-turbo", "gpt-4"], label="OpenAI Model")
                    use_long_form = gr.Checkbox(label="Use Long-Form Generation", value=False)
                    
                story_output = gr.Textbox(label="Generated Story", lines=20)
                
            with gr.Tab("Story Memory"):
                with gr.Row():
                    character_list = gr.Dataframe(label="Characters")
                    plot_points = gr.Dataframe(label="Plot Points")
                
            with gr.Tab("Analytics"):
                analytics_output = gr.JSON(label="Story Analytics")
                analyze_btn = gr.Button("Analyze Story")
                
            generate_btn.click(
                fn=self.generate_story,
                inputs=[prompt_input, num_episodes, model_choice, use_long_form],
                outputs=[story_output, character_list, plot_points]
            )
            
            analyze_btn.click(
                fn=self.analyze_story,
                inputs=[],
                outputs=[analytics_output]
            )
            
        return app
    
    def generate_story(self, prompt, num_episodes, model, use_long_form):
        
        self.pipeline.story_generator.model = model
        self.pipeline.use_long_form = use_long_form
        
        
        story_data = self.pipeline.create_story(prompt, num_episodes=num_episodes)
        
       
        story_text = f"## STORY OUTLINE\n\n{story_data['outline']}\n\n"
        for i, episode in enumerate(story_data["episodes"]):
            story_text += f"\n\n## EPISODE {i+1}\n\n{episode['content']}\n\n"
            
        
        characters = []
        for ep in story_data["episodes"]:
            if "extracted_elements" in ep and "characters" in ep["extracted_elements"]:
                for char in ep["extracted_elements"]["characters"]:
                    characters.append({
                        "Name": char["name"],
                        "Episode": ep["episode_number"],
                        "Description": char.get("description", "")
                    })
        
        plots = []
        for ep in story_data["episodes"]:
            if "extracted_elements" in ep and "plot_points" in ep["extracted_elements"]:
                for plot in ep["extracted_elements"]["plot_points"]:
                    plots.append({
                        "Episode": ep["episode_number"],
                        "Description": plot
                    })
        
        return story_text, characters, plots
    
    def analyze_story(self):
        analytics = self.pipeline.get_story_analytics()
        return analytics

In [12]:

def run_pipeline(api_key, initial_prompt, num_episodes=3, use_long_form=False, show_ui=False):
    
    pipeline = AIStorytellingPipeline(use_openai=True, openai_api_key=api_key, use_long_form=use_long_form)
    
    if show_ui:
        
        control_panel = StoryControlPanel(pipeline)
        ui = control_panel.create_ui()
        ui.launch(share=True)
    else:
        
        try:
            story_data = pipeline.create_story(initial_prompt, num_episodes=num_episodes)
            
            
            pipeline.save_story()
            
            
            for i, episode in enumerate(story_data["episodes"]):
                print(f"\n\n=== EPISODE {i+1} ===")
                print(episode["content"])
                
            
            analytics = pipeline.get_story_analytics()
            print("\n\n=== STORY ANALYTICS ===")
            print(json.dumps(analytics, indent=2))
            
            return story_data
            
        except Exception as e:
            print(f"Error during story generation: {e}")
            
            
            try:
                if hasattr(pipeline, 'story_data') and pipeline.story_data:
                    pipeline.save_story("partial_story.json")
                    print(f"Partial story saved to partial_story.json")
                    
                    
                    for i, episode in enumerate(pipeline.story_data.get("episodes", [])):
                        print(f"\n\n=== EPISODE {i+1} ===")
                        print(episode["content"])
            except Exception as save_error:
                print(f"Error saving partial story: {save_error}")
            
            return None

In [13]:
openai_api_key = "YOUR_API_KEY"  # Your OpenAI API key

initial_prompt = """
Create a sci-fi adventure about a team of explorers who discover a hidden civilization 
on a distant planet. The story should have elements of mystery, friendship, and betrayal.
"""

# For regular generation (comment this out):
# run_pipeline(openai_api_key, initial_prompt, num_episodes=5)

# For long-form generation with UI (uncomment this):
run_pipeline(openai_api_key, initial_prompt, num_episodes=5, use_long_form=True, show_ui=True)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/829 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/433M [00:00<?, ?B/s]

Some weights of the model checkpoint at dslim/bert-base-NER were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


tokenizer_config.json:   0%|          | 0.00/59.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/2.00 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Device set to use cuda:0


NER model loaded successfully!
* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://337ef97a2f8ef9ca15.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Generating story outline...


/root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx.tar.gz: 100%|██████████| 79.3M/79.3M [00:01<00:00, 74.0MiB/s]


Generating episode 1...
Characters data structure: ['ids', 'embeddings', 'documents', 'uris', 'included', 'data', 'metadatas']
Plot points data structure: ['documents', 'metadatas', 'ids']
Episode 1 completed.
Generating episode 2...
Characters data structure: ['ids', 'embeddings', 'documents', 'uris', 'included', 'data', 'metadatas']
Plot points data structure: ['ids', 'embeddings', 'documents', 'uris', 'included', 'data', 'metadatas', 'distances']
Episode 2 completed.
Generating episode 3...
Characters data structure: ['ids', 'embeddings', 'documents', 'uris', 'included', 'data', 'metadatas']
Plot points data structure: ['ids', 'embeddings', 'documents', 'uris', 'included', 'data', 'metadatas', 'distances']
Episode 3 completed.
Generating episode 4...
Characters data structure: ['ids', 'embeddings', 'documents', 'uris', 'included', 'data', 'metadatas']
Plot points data structure: ['ids', 'embeddings', 'documents', 'uris', 'included', 'data', 'metadatas', 'distances']
Episode 4 comple

# Github- https://github.com/dasdebanna/storytelling

In [None]:
# Read the saved story
import json

with open('generated_story.json', 'r') as f:
    saved_story = json.load(f)

# Print all episodes
for i, episode in enumerate(saved_story["episodes"]):
    print(f"\n\n=== EPISODE {i+1} ===")
    print(episode["content"])