In [1]:
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
# from google import genai
import os
import re
from pathlib import Path
import google.generativeai as genai
from tqdm import tqdm
import json
from wasabi import msg

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import os
import google.generativeai as genai

genai.configure(api_key="API_KEY")


class Config:
  n_ch = 5
  genere = "fantasy"
  n_char = 10
  n_story_token = 1000
  
  def to_json(self, file_path: str) -> None:
    with open(file_path, 'w') as f:
      json.dump(self.__dict__, f, indent=4)
    

In [5]:
def split_plot(plot: str) -> List[str]:
    plot = plot.split("\n")
    plot = map(lambda x: x.replace("*", "").replace("*", ""), plot)
    
    plot_split = []

    for p in plot:
        if re.search(r"Chapter \d", p):
            plot_split.append(p)

    return plot_split

In [6]:
def generate(prompts: Dict[str, str], config, path, generation_config) -> Tuple[Dict[str, str], Config]:
    # initialize things
    story = {}

    model = genai.GenerativeModel(
    model_name="gemini-2.0-flash",
    generation_config=generation_config,
    )

    chat_session = model.start_chat(
        history=[]
    )

    # fetch separators
    instruction_sep = prompts['instr_sep']
    plot_sep = prompts['plot_sep']
    char_sep = prompts['char_sep']
    story_sep = prompts['story_sep']
    
    system_prompt = prompts['system']
    plot_prompt = prompts["plot"]   # something like "You are a professional novelist. You will write a 5-chapter [genre] story with 10 characters. First, use 2-3 sentences to write the plot for each of the 5 chapters."
    prmpt = '\n'.join([instruction_sep, system_prompt, plot_prompt, plot_sep]) +'\n'
    plot = chat_session.send_message(prmpt).text
    title = plot.split("\n")[0].replace("*", "").replace("#", "").strip()
    story["title"] = title
    story["plot"] = plot

    # create directory
    config.title = story["title"]
    print("Title registered: ", config.title)
    os.makedirs(path / config.title, exist_ok=True)

    # save
    with open(path / config.title / "story.json", "w") as f:
        json.dump(story, f, indent=4)
    
    characters_prompt = prompts["char"]   # something like "Next, write 10 characters who appear in your story."
    prmpt = '\n'.join([instruction_sep, characters_prompt, char_sep])+'\n'
    characters = chat_session.send_message(prmpt).text
    story["char"] = characters

    # save
    with open(path / config.title / "story.json", "w") as f:
        json.dump(story, f, indent=4)

    # split plots
    plot_split = split_plot(plot)
    print(plot_split)
    story["plot_split"] = plot_split
    
    ch1_prompt = prompts["ch1"]   # something like "Use 1000 words to write the first chapter."
    prmpt = '\n'.join([instruction_sep, ch1_prompt, story_sep])+'\n'
    chapter = chat_session.send_message(prmpt).text
    story["ch1"] = chapter

    # save
    with open(path / config.title / "story.json", "w") as f:
        json.dump(story, f, indent=4)
    print("Chapter 1" + " generated.")

    chn_prompt = prompts["chn"]   # something like "Use 1000 words to write the next chapter."
    for i in range(2, config.n_ch + 1):
        prmpt = '\n'.join([instruction_sep, chn_prompt, char_sep, characters, plot_sep, plot_split[i-1], story_sep])+'\n'
        print(prmpt)
        characters = chat_session.send_message(prmpt).text
        story["ch" + str(i)] = characters
        
        # save
        with open(path / config.title / "story.json", "w") as f:
            json.dump(story, f, indent=4)
        print("Chapter " + str(i) + " generated.")

    return story, config, chat_session
    

In [7]:
def run(config: Config, generation_config):
    path = Path(os.getcwd()).parent / "data/raw/llm_ss/Gemini 2.0 Flash"

    prompts = {
        "instr_sep": "### Instruction ###",
        "plot_sep": "### Plot ###",
        "char_sep": "### Characters ###",
        "story_sep": "### Story ###",
        # rewrite [genere as neeeded]
        "system": f"You are a professional novelist. You will write a {config.n_ch}-chapter {config.genere} story with {config.n_char} characters.",
        "plot": f"Write the title in the first line. Next, use 2-3 sentence to write the plot for each of the {config.n_char} chapters. The Chapter number and description should start in the same line (i.e. Chapter 1: [description]) Start with Chapter 1: ",
        "char": f"Next, use 1 sentence to write each of {config.n_char} characters who appear in your story.",
        "ch1": f"Use {config.n_story_token} words to write the first chapter.",
        "chn": f"Use {config.n_story_token} words to write the next chapter."
    }

    story, config, chat = generate(prompts, config, path, generation_config)
    
    # config.title = story["title"]
    # print("Title: ", config.title)
    
    with open(path / config.title / "config.json", "w") as f:
        json.dump(config.__dict__, f, indent=4)
    with open(path / config.title / "story.json", "w") as f:
        json.dump(story, f, indent=4)

    with open(path / config.title / "plot.txt", "w") as f:
        f.write(story["plot"])
    
    story_text = ""
    for key, value in story.items():
        if re.search(r"ch\d", key):
            story_text += value + "\n"
    
    with open(path / config.title / "story.txt", "w") as f:
        f.write(story_text)

    with open(path / config.title / "model_config.json", "w") as f:
        json.dump(generation_config, f, indent=4)
        
    with open(path / config.title / "chat_history.txt", "w") as f:
        f.write(str(chat.history))

    print("Done: ", config.title)
    
    return story, config, chat

In [8]:
config = Config()
config.n_ch = 5
config.n_char = 10
config.n_story_token = 1000
config.genere = "sci-fi"

N = 1

generation_config = {
    "temperature": 1,
    "top_p": 0.95,
    "top_k": 40,
    "max_output_tokens": 8192,
    "response_mime_type": "text/plain",
}

for i in tqdm(range(N)):
    story, config, chat = run(config, generation_config)

  0%|          | 0/1 [00:00<?, ?it/s]


InvalidArgument: 400 API key not valid. Please pass a valid API key. [reason: "API_KEY_INVALID"
domain: "googleapis.com"
metadata {
  key: "service"
  value: "generativelanguage.googleapis.com"
}
, locale: "en-US"
message: "API key not valid. Please pass a valid API key."
]