# Co-STORM quick test (OpenAI + Tavily)

This notebook exercises the Co-STORM pipeline using **OpenAI** for the LLMs and **Tavily** for retrieval.

**Prerequisites**
- `OPENAI_API_KEY` set in the environment
- `TAVILY_API_KEY` set in the environment
- `OPENAI_API_TYPE` set to `openai` (optional; defaults to `openai` below)

You can also place keys in `secrets.toml` and call `load_api_key()` if preferred.


In [None]:
import os
import json

from knowledge_storm.collaborative_storm.engine import (
    CollaborativeStormLMConfigs,
    RunnerArgument,
    CoStormRunner,
)
from knowledge_storm.lm import OpenAIModel
from knowledge_storm.logging_wrapper import LoggingWrapper
from knowledge_storm.rm import TavilySearchRM
from knowledge_storm.utils import load_api_key


In [None]:
# Optional: load from secrets.toml if present
# load_api_key(toml_file_path="secrets.toml")

os.environ.setdefault("OPENAI_API_TYPE", "openai")

if os.getenv("OPENAI_API_TYPE") != "openai":
    raise ValueError("This notebook expects OPENAI_API_TYPE=openai.")

if not os.getenv("OPENAI_API_KEY"):
    raise ValueError("Missing OPENAI_API_KEY in the environment.")
if not os.getenv("TAVILY_API_KEY"):
    raise ValueError("Missing TAVILY_API_KEY in the environment.")

topic = "Recent advances in battery recycling"
user_utterance = "Focus on policy changes in the last two years."
output_dir = "./results/co-storm-notebook"


In [None]:
# Configure Co-STORM LMs (all using OpenAI in this setup)
lm_config = CollaborativeStormLMConfigs()
openai_kwargs = {
    "api_key": os.getenv("OPENAI_API_KEY"),
    "api_provider": "openai",
    "temperature": 1.0,
    "top_p": 0.9,
    "api_base": None,
}

gpt_4o_model_name = "gpt-4o"

question_answering_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs)
discourse_manage_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs)
utterance_polishing_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=2000, **openai_kwargs)
warmstart_outline_gen_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=500, **openai_kwargs)
question_asking_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=300, **openai_kwargs)
knowledge_base_lm = OpenAIModel(model=gpt_4o_model_name, max_tokens=1000, **openai_kwargs)

lm_config.set_question_answering_lm(question_answering_lm)
lm_config.set_discourse_manage_lm(discourse_manage_lm)
lm_config.set_utterance_polishing_lm(utterance_polishing_lm)
lm_config.set_warmstart_outline_gen_lm(warmstart_outline_gen_lm)
lm_config.set_question_asking_lm(question_asking_lm)
lm_config.set_knowledge_base_lm(knowledge_base_lm)


In [None]:
# Configure runner + retriever
runner_argument = RunnerArgument(
    topic=topic,
    retrieve_top_k=5,
    max_search_queries=2,
    total_conv_turn=10,
    max_search_thread=3,
    max_search_queries_per_turn=2,
    warmstart_max_num_experts=2,
    warmstart_max_turn_per_experts=2,
    warmstart_max_thread=2,
    max_thread_num=5,
    max_num_round_table_experts=2,
    moderator_override_N_consecutive_answering_turn=2,
    node_expansion_trigger_count=10,
)

rm = TavilySearchRM(
    tavily_search_api_key=os.getenv("TAVILY_API_KEY"),
    k=runner_argument.retrieve_top_k,
    include_raw_content=True,
)

logging_wrapper = LoggingWrapper(lm_config)
costorm_runner = CoStormRunner(
    lm_config=lm_config,
    runner_argument=runner_argument,
    logging_wrapper=logging_wrapper,
    rm=rm,
)

# Warm start the system to build shared conceptual space
costorm_runner.warm_start()


In [None]:
# Observe one turn from the system
conv_turn = costorm_runner.step()
print(f"**{conv_turn.role}**: {conv_turn.utterance}\n")

# Inject a user utterance
costorm_runner.step(user_utterance=user_utterance)

# Observe another turn
conv_turn = costorm_runner.step()
print(f"**{conv_turn.role}**: {conv_turn.utterance}\n")


In [None]:
# Generate a report and save artifacts
costorm_runner.knowledge_base.reorganize()
article = costorm_runner.generate_report()

os.makedirs(output_dir, exist_ok=True)
with open(os.path.join(output_dir, "report.md"), "w") as f:
    f.write(article)

instance_copy = costorm_runner.to_dict()
with open(os.path.join(output_dir, "instance_dump.json"), "w") as f:
    json.dump(instance_copy, f, indent=2)

log_dump = costorm_runner.dump_logging_and_reset()
with open(os.path.join(output_dir, "log.json"), "w") as f:
    json.dump(log_dump, f, indent=2)

print(f"Artifacts saved to: {output_dir}")
