In [1]:
# ====> ENVIRONMENT SETUP
import os
import sys
import yaml

def read_yaml(fpath: str) -> dict:
    with open(fpath, "r", encoding="utf-8") as f:
        data = yaml.safe_load(f)
    return data

def abs_join(par: str, path: str) -> str:
    return os.path.abspath(os.path.join(par, path))

CREDENTIALS = read_yaml("../credentials.yaml")
os.environ["GOOGLE_API_KEY"] = CREDENTIALS["google_key"]
os.environ["OPENAI_API_KEY"] = CREDENTIALS["openai_key"]
os.environ["NBLM_EMAIL"] = CREDENTIALS["nblm_email"]
os.environ["NBLM_PASSWORD"] = CREDENTIALS["nblm_password"]

project_root = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
sys.path.insert(0, project_root)

In [2]:
# ===========> COMMON VARIABLES
BROWSER_PATH = "C:/Users/jcmcs/AppData/Local/BraveSoftware/Brave-Browser/Application/brave.exe"
DRIVER_PATH = "../drivers/chromedriver.exe"

TEX_TEMPLATE_PATH = "../templates/paper_template.tex"
PROMPT_CONFIG_PATH = "../templates/prompt_config.yaml"
REVISION_CONFIG_PATH = "../templates/review_config.yaml"

CWD = os.getcwd()
OUT_DIR = abs_join(CWD, "out")
OUT_GEN_STRUCTURE_PATH = abs_join(OUT_DIR, "genstruct.yaml")
OUT_TEX_PATH = abs_join(OUT_DIR, "lastgenerated.tex")
OUT_DUMP_PATH = abs_join(OUT_DIR, "lastgenerated.dump")
OUT_REVIEWED_TEX_PATH = abs_join(OUT_DIR, "reviewed-"+os.path.basename(OUT_TEX_PATH))
OUT_REVIEWED_DUMP_PATH = abs_join(OUT_DIR, "reviewed-"+os.path.basename(OUT_DUMP_PATH))

# Protoyping

In [None]:
from typing import List,Union
import undetected_chromedriver as uc
from fake_useragent import UserAgent
from langchain.document_loaders import PyPDFLoader
from langchain_core.messages import SystemMessage, AIMessage
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.vectorstores import FAISS
from chatbots import NotebookLMBot

import re
import os
from time import sleep

def countdown_print(msg: str, sec: int):
    print(msg, end=' ')
    print("0 s", end='\r')
    for t in range(sec):
        sleep(1)
        print(f"{msg} {t} s", end='\r')

def init_driver(browser_path: Union[str,None] = None, driver_path: Union[str,None] = None) -> uc.Chrome:
    op = uc.ChromeOptions()
    op.add_argument(f"user-agent={UserAgent.random}")
    op.add_argument("user-data-dir=./")
    op.add_experimental_option("detach", True)
    op.add_experimental_option("excludeSwitches", ["enable-logging"])
    driver = uc.Chrome(
        chrome_options=op,
        browser_executable_path=browser_path,
        driver_executable_path=driver_path
    )
    return driver

def get_pdf_contents(pdf_paths: List[str]):
    doc_data = []
    for file in pdf_paths:
        loader = PyPDFLoader(file)
        doc_data.extend(loader.load())
    return doc_data

def generate_paper_structure(nblm: NotebookLMBot, prompt: str, subject: str, pdf_paths: List[str], outfile: str, driver_path: str = "../drivers/chromedriver", browser_path: Union[str,None] = None):
    """ Generate paper structure using NotebookLM """
    if prompt.find("{subject}") != -1:
        prompt = prompt.replace("{subject}", subject)

    # Use NotebookLM bot to send it
    nblm.send_prompt(prompt, sleep_for=40)
    response = nblm.get_last_response()

    # format response and save to yaml
    result = "sections:\n"
    for line in response.split("\n"):
        result += f"  {line}\n"
    
    with open(outfile, "w", encoding="utf-8") as f:
        f.write(result)

    result = read_yaml(outfile)
    print(f"Finished generating paper structure. Got a structure with {len(result["sections"])} sections.\n")

    return read_yaml(outfile)

def setup_context_msg(
        header_prompt: str, 
        pdf_paths: List[str], 
        summarize = False,
        summary_llm = None,
    ):
    """ Setup context SystemMessage with writing instructions + PDFs contents """
    
    context = header_prompt
    pdf_content = get_pdf_contents(pdf_paths)

    ctx_content = ""
    if summarize and (summary_llm is not None):
        # Summarize pdf contents
        text_splitter = CharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
        chunks = text_splitter.split_documents(pdf_content)
        summaries = []
        for chunk in chunks:
            summary = summary_llm(f"Summarize this text: {chunk}")
            sleep(1*60)
            summaries.append(summary)
        ctx_content = "\n".join(summaries)
    else:
        # Include entire PDF content
        contents = [doc.page_content for doc in pdf_content]
        ctx_content = "\n".join(contents)

    context += "\n\nThe PDF content of the given references are:\n" + ctx_content

    return SystemMessage(content=context)

def vector_store_from_pdf_content(pdf_content, txtembed_model: str = "google"):
    text_splitter = CharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
    chunks = text_splitter.split_documents(pdf_content)
    chunks = [chunk.page_content for chunk in chunks]

    match txtembed_model.strip().lower():
        case "openai":
            embeddings = OpenAIEmbeddings()
        case "google":
            embeddings = GoogleGenerativeAIEmbeddings(model="models/text-embedding-004")
        case _:
            raise ValueError(f"Valid txtembed_model values are 'openai' or 'google'")
    vector_store = FAISS.from_texts(chunks, embeddings)
    return vector_store

def setup_ctx_msg_faiss(header_prompt: str, vector_store: FAISS, section_query: str, k: int = 5):
    context = header_prompt
    
    relevant_docs = vector_store.similarity_search(section_query, k=k)
    context += "\n"
    context += "\n".join([doc.page_content for doc in relevant_docs])
    return context

def init_chain(llm, ctx_msg: SystemMessage, write_prompt: str):
    """ Setup LLMChain with proper prompts and context """
    req_prompt = HumanMessagePromptTemplate.from_template(write_prompt)
    
    full_prompt = ChatPromptTemplate.from_messages([ctx_msg, req_prompt])
    
    chain = full_prompt | llm
    return full_prompt, chain

def write_section(chain, subject: str, title: str, description: str) -> AIMessage:
    """ Write the given section """
    return chain.invoke({
        "subject": subject,
        "title": title,
        "description": description,
    })

def dump_generated_sections(sections: dict, outpath: str):
    with open(outpath, "w", encoding="utf-8") as f:
        yaml.safe_dump(sections, f)


def save_latex_sections(tex_template_path: str, sections: List[dict], outpath: str):
    """ 
    Join the contents of every section to the output LaTeX file 
    'sections' must be a list of dictionaries with two keys: 'title' and 'content'
    """
    with open(tex_template_path, "r", encoding="utf-8") as f:
        tex_template = f.read()

    paper_content = ""
    
    bib_content = ""
    bib_pattern = r"\\begin{filecontents\*}(.*?)\\end{filecontents\*}"

    for section in sections:
        # Extract biblatex file content
        match = re.search(bib_pattern, section["content"], re.DOTALL)
        sec_bib_content = match.group(1).strip() if match else None
        if sec_bib_content is None:
            print("FAILED TO MATCH BIBLATEX CONTENT IN SECTION:", section["title"])
            continue

        section_text = re.sub(bib_pattern, "", section["content"], flags=re.DOTALL)
        
        paper_content += section_text
        bib_content += sec_bib_content
    bib_content = bib_content.replace("{mybib.bib}", "")
    bib_file = outpath+"bib.bib"
    
    # Replace paper content in latex template and save it
    tex_content = tex_template.replace("{content}", paper_content).replace("{bibresourcefile}", os.path.basename(bib_file))
    tex_content = tex_content.replace("```latex", "").replace("```","")
    with open(outpath, "w", encoding="utf-8") as f:
        f.write(tex_content)

    # also save the biblatex file
    with open(bib_file, "w", encoding="utf-8") as f:
        f.write(bib_content)

##########################################################################################

paper_cfg = read_yaml(PROMPT_CONFIG_PATH)
pdf_paths = [
    "../refexamples/ArigaK2023_ChemOfMat.pdf",
    "../refexamples/FangC2022_LangmuirBattery.pdf",
    "../refexamples/OliveiraO2022_PastAndFuture.pdf",
    # "../refexamples/LuC2024_AIScientist.pdf"
]
paper_subject = paper_cfg["subject"]

driver = init_driver(BROWSER_PATH, DRIVER_PATH)
nblm = NotebookLMBot(
    user=os.environ["NBLM_EMAIL"],
    password=os.environ["NBLM_PASSWORD"],
    driver=driver,
    src_paths=pdf_paths,
)
if not nblm.login():
    print("Unable to login to NotebookLM")
    exit()

os.makedirs(OUT_DIR, exist_ok=True)

paper_structure = generate_paper_structure(
    nblm=nblm,
    prompt=paper_cfg["gen_struct_prompt"],
    subject=paper_subject,
    pdf_paths=pdf_paths,
    outfile=OUT_GEN_STRUCTURE_PATH,
    driver_path=DRIVER_PATH,
    browser_path=BROWSER_PATH,
)
# paper_structure = read_yaml(OUT_GEN_STRUCTURE_PATH)

ctx = setup_context_msg(
    header_prompt=paper_cfg["response_format"],
    pdf_paths=pdf_paths,
    summarize=False,
    #summary_llm=ChatOpenAI(model="chatgpt-4o-latest")
)

#llm = ChatOpenAI(model="o1-preview")
llm = ChatGoogleGenerativeAI(model="gemini-1.5-pro")
_, chain = init_chain(llm, ctx, paper_cfg["write_prompt"])

# vecstore = vector_store_from_pdf_content(
#     pdf_content=get_pdf_contents(pdf_paths),
#     txtembed_model="google"
# )

paper_content = []
for i, section in enumerate(paper_structure["sections"]):
    print(f"====> STARTED WRITING SECTION: {section["title"]} ({i+1}/{len(paper_structure["sections"])})")

    airesponse = write_section(chain, paper_subject, section["title"], section["description"])
    paper_content.append(
        {
            "title": section["title"],
            "content": airesponse.content,
        }
    )
    print("====> FINISHED WRITING SECTION:", section["title"])
    print("====> REPONSE METADATA:", airesponse.usage_metadata)
    dump_generated_sections({"sections": paper_content}, OUT_DUMP_PATH)

    cooldown_sec = int(60*1.5)
    print(f"Cooldown of {cooldown_sec} s because of request limitations...\n")
    # wait because of TPM (request quota)
    countdown_print("Countdown: ", cooldown_sec)

save_latex_sections(
    tex_template_path=TEX_TEMPLATE_PATH,
    sections=paper_content,
    outpath=OUT_TEX_PATH,
)

In [None]:
# LLM Revision process
rev_config = read_yaml(REVISION_CONFIG_PATH)

with open(OUT_TEX_PATH, "r", encoding="utf-8") as f, open(OUT_TEX_PATH+"bib.bib", "r", encoding="utf-8") as bf:
    latex_content = f.read()
    bib_content = bf.read()

# First copy to temporary .txt files (can only send .txt to NotebookLM)
tmp_tex = OUT_TEX_PATH.replace(".tex", ".txt")
tmp_bib = OUT_TEX_PATH+"bib.txt"
with open(tmp_tex, "w", encoding="utf-8") as tmptex_f, open(tmp_bib, "w", encoding="utf-8") as tmpbib_f:
    tmptex_f.write(latex_content)
    tmpbib_f.write(bib_content)

# add files as sources
nblm.append_sources([tmp_tex, tmp_bib], sleep_for=20)
os.remove(tmp_tex)
os.remove(tmp_bib)

# split paper content
pattern = r"(\\section\{.*?\}.*?)(?=(\\section\{|\\printbibliography|\\end\{document\}))"
matches = re.findall(pattern, latex_content, re.DOTALL)

# Process matches into a list of dictionaries
sections = []
for match, _ in matches:
    # Extract title from the section line
    title_match = re.search(r"\\section\{(.*?)\}", match)
    title = title_match.group(1).strip() if title_match else "Unknown"
    sections.append({"title": title, "content": match.strip()})


# Improve each section
improved = []
rev_ctx = setup_context_msg(
    header_prompt="The PDF content of the references is:",
    pdf_paths=pdf_paths,
    summarize=False,
)
_, chain = init_chain(llm, rev_ctx, rev_config["improve_prompt"])
for i, section in enumerate(sections):
    print(f"====> STARTED REVIEW FOR SECTION {section['title']} ({i+1}/{len(sections)})")

    # ask points for notebooklm
    nblm_prompt = rev_config["nblm_point_prompt"].replace("{generatedpaperfile}", tmp_tex).replace("{number}", str(i+1)).replace("{title}", section["title"])
    nblm.send_prompt(nblm_prompt, sleep_for=40)
    improv_points = nblm.get_last_response()
    improv_points = improv_points[improv_points.find("Clarity and Coherence"):]

    print("Sending improvement prompt to LLM...")
    airesponse = chain.invoke({
        "title": section["title"],
        "sectionlatex": section["content"],
        "sectionimprovement": improv_points,
        "biblatex": bib_content,
    })
    improved.append({
        "title": section["title"],
        "content": airesponse.content,
    })
    print("====> FINISHED REVIEWING SECTION:", section["title"])
    print("====> REPONSE METADATA:", airesponse.usage_metadata, "\n")
    dump_generated_sections({"sections": improved}, OUT_REVIEWED_DUMP_PATH)
    cooldown_sec = int(60*1.5)
    print(f"Cooldown of {cooldown_sec} s because of request limitations...")
    # wait because of TPM (request quota)
    countdown_print("Countdown:", cooldown_sec)

save_latex_sections(
    tex_template_path=TEX_TEMPLATE_PATH,
    sections=improved,
    outpath=OUT_REVIEWED_TEX_PATH,
)

driver.quit()

In [None]:
print(improved)