In [207]:
# Standard library imports
import os
import sys
import json
import threading
import time
import sqlite3
import re
import argparse
import copy
import logging

# Third-party imports
import numpy as np
import torch
from tqdm import trange, tqdm
from sentence_transformers import SentenceTransformer
import faiss
import requests
from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification

# Local imports (if any)


# Configure logging
logging.basicConfig(
    level=logging.INFO,  # Set to DEBUG for more detailed output
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout)  # Logs will be output to the console
    ]
)

logger = logging.getLogger(__name__)


In [208]:
# Constants and Prompts

# Outline Generation Prompts
ROUGH_OUTLINE_PROMPT = '''
You want to write an overall and comprehensive academic survey about "[TOPIC]".
You are provided with a list of papers related to the topic below:
---
[PAPER LIST]
---
You need to draft an outline based on the given papers.
The outline should contain a title and several sections.
Each section follows with a brief sentence to describe what to write in this section.
The outline is supposed to be comprehensive and contains [SECTION NUM] sections.

Return in the format:
<format>
Title: [TITLE OF THE SURVEY]
Section 1: [NAME OF SECTION 1]
Description 1: [DESCRIPTION OF SECTION 1]

Section 2: [NAME OF SECTION 2]
Description 2: [DESCRIPTION OF SECTION 2]

...

Section K: [NAME OF SECTION K]
Description K: [DESCRIPTION OF SECTION K]
</format>
The outline:
'''

MERGING_OUTLINE_PROMPT = '''
You are an expert in artificial intelligence who wants to write an overall survey about [TOPIC].
You are provided with a list of outlines as candidates below:
---
[OUTLINE LIST]
---
Each outline contains a title and several sections.
Each section follows with a brief sentence to describe what to write in this section.

You need to generate a final outline based on these provided outlines.
Return in the format:
<format>
Title: [TITLE OF THE SURVEY]
Section 1: [NAME OF SECTION 1]
Description 1: [DESCRIPTION OF SECTION 1]

Section 2: [NAME OF SECTION 2]
Description 2: [DESCRIPTION OF SECTION 2]

...

Section K: [NAME OF SECTION K]
Description K: [DESCRIPTION OF SECTION K]
</format>
Only return the final outline without any other information:
'''

SUBSECTION_OUTLINE_PROMPT = '''
You are an expert in artificial intelligence who wants to write an overall survey about [TOPIC].
You have created an overall outline below:
---
[OVERALL OUTLINE]
---
The outline contains a title and several sections.
Each section follows with a brief sentence to describe what to write in this section.

<instruction>
You need to enrich the section [SECTION NAME].
The description of [SECTION NAME]: [SECTION DESCRIPTION]
You need to generate the framework containing several subsections based on the overall outline.
Each subsection follows with a brief sentence to describe what to write in this subsection.
These papers are provided for reference:
---
[PAPER LIST]
---
Return the outline in the format:
<format>
Subsection 1: [NAME OF SUBSECTION 1]
Description 1: [DESCRIPTION OF SUBSECTION 1]

Subsection 2: [NAME OF SUBSECTION 2]
Description 2: [DESCRIPTION OF SUBSECTION 2]

...

Subsection K: [NAME OF SUBSECTION K]
Description K: [DESCRIPTION OF SUBSECTION K]
</format>
</instruction>
Only return the outline without any other information:
'''

EDIT_FINAL_OUTLINE_PROMPT = '''
You are an expert in artificial intelligence who wants to write an overall survey about [TOPIC].
You have created a draft outline below:
---
[OVERALL OUTLINE]
---
The outline contains a title and several sections.
Each section follows with a brief sentence to describe what to write in this section.

Under each section, there are several subsections.
Each subsection also follows with a brief sentence of description.
You need to modify the outline to make it both comprehensive and coherent with no repeated subsections.
Return the final outline in the format:
<format>
# [TITLE OF SURVEY]

## [NAME OF SECTION 1]

### [NAME OF SUBSECTION 1]

### [NAME OF SUBSECTION 2]

...

### [NAME OF SUBSECTION L]

## [NAME OF SECTION 2]

...

## [NAME OF SECTION K]

...
</format>
Only return the final outline without any other information:
'''

# Subsection Writing Prompts
SUBSECTION_WRITING_PROMPT = '''
You are an expert in the field of "[TOPIC]". You are writing a survey and need to generate the content for a subsection.

Subsection Title: [SUBSECTION NAME]
Subsection Description: [DESCRIPTION]

Please write a detailed and informative subsection of at least [WORD NUM] words. Ensure that the content is coherent, well-structured, and provides valuable insights into the topic.

When mentioning specific papers, cite them in the format [Paper Title]. Use the provided papers as references:
---
[PAPER LIST]
---

Return the content without any additional explanations.
'''

LCE_PROMPT = '''
You are an expert in artificial intelligence who wants to write an overall and comprehensive survey about [TOPIC].
You have created an overall outline below:
---
[OVERALL OUTLINE]
---
<instruction>

Now you need to help to refine one of the subsections to improve the coherence of your survey.

You are provided with the content of the subsection "[SUBSECTION NAME]" along with the previous subsections and following subsections.

Previous Subsection:
--- 
[PREVIOUS]
---

Subsection to Refine: 
---
[SUBSECTION]
---

Following Subsection:
---
[FOLLOWING]
---

If the content of Previous Subsection is empty, it means that the subsection to refine is the first subsection.
If the content of Following Subsection is empty, it means that the subsection to refine is the last subsection.

Now edit the middle subsection to enhance coherence, remove redundancies, and ensure that it connects more fluidly with the previous and following subsections. 
Please keep the essence and core information of the subsection intact. 
</instruction>

Directly return the refined subsection without any other information:
'''

CHECK_CITATION_PROMPT = '''
You are an expert in the field of "[TOPIC]". You have written the following subsection:
---
[SUBSECTION]
---
Please ensure that all statements that require citations are properly cited using the papers provided below:
---
[PAPER LIST]
---
If any statements are missing citations, add appropriate citations from the provided papers in the format [Paper Title].

Return the updated subsection without any additional explanations.
'''

# Include any other prompts as needed


In [209]:
class tokenCounter:
    def __init__(self, encoding_name='gpt2'):
        from transformers import GPT2Tokenizer
        self.tokenizer = GPT2Tokenizer.from_pretrained(encoding_name)
        logger.info(f"TokenCounter initialized with encoding: {encoding_name}")

    def num_tokens_from_string(self, string):
        """Returns the number of tokens in a text string."""
        num = len(self.tokenizer.encode(string))
        logger.debug(f"Counted {num} tokens in a string.")
        return num

    def num_tokens_from_list_string(self, strings):
        """Returns the total number of tokens in a list of text strings."""
        total = sum([self.num_tokens_from_string(s) for s in strings])
        logger.debug(f"Counted total {total} tokens in a list of strings.")
        return total

    def compute_price(self, input_tokens, output_tokens, model='gpt-3.5-turbo'):
        """
        Computes the cost based on the number of input and output tokens.
        Prices are based on OpenAI's pricing as of September 2023.
        """
        logger.info(f"Computing price for model: {model}")
        # Define token prices per 1000 tokens (These are example prices; please check OpenAI's pricing for up-to-date values)
        if model == 'gpt-3.5-turbo':
            input_price_per_1k = 0.0015  # $0.0015 per 1K tokens (input)
            output_price_per_1k = 0.002   # $0.002 per 1K tokens (output)
        elif model == 'gpt-4':
            input_price_per_1k = 0.03    # $0.03 per 1K tokens (input)
            output_price_per_1k = 0.06    # $0.06 per 1K tokens (output)
        else:
            input_price_per_1k = 0.01    # default values
            output_price_per_1k = 0.02

        input_cost = (input_tokens / 1000) * input_price_per_1k
        output_cost = (output_tokens / 1000) * output_price_per_1k
        total_cost = input_cost + output_cost
        logger.info(f"Computed price: Input Cost = ${input_cost:.4f}, Output Cost = ${output_cost:.4f}, Total Cost = ${total_cost:.4f}")
        return total_cost


In [210]:
def num_tokens_from_string(string, encoding_name='gpt2'):
    """Returns the number of tokens in a text string."""
    from transformers import GPT2Tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained(encoding_name)
    return len(tokenizer.encode(string))

def remove_descriptions(text):
    lines = text.split('\n')
    filtered_lines = [line for line in lines if not line.strip().startswith("Description")]
    result = '\n'.join(filtered_lines)
    return result

def extract_num(string):
    numbers = re.findall(r'\d+', string)
    if len(numbers) == 0:
        return ''
    return int(numbers[0])

def generate_prompt(template, paras):
    prompt = template
    for k, v in paras.items():
        prompt = prompt.replace(f'[{k}]', v)
    return prompt

def chunk_texts(texts, max_tokens):
    chunks = []
    current_chunk = []
    current_tokens = 0
    for idx, text in enumerate(texts):
        text_tokens = num_tokens_from_string(text)
        if current_tokens + text_tokens > max_tokens:
            logger.debug(f"Chunk {len(chunks)+1}: Adding {len(current_chunk)} texts with total tokens {current_tokens}")
            chunks.append(current_chunk)
            current_chunk = [text]
            current_tokens = text_tokens
            logger.debug(f"Starting new chunk with text index {idx}")
        else:
            current_chunk.append(text)
            current_tokens += text_tokens
    if current_chunk:
        logger.debug(f"Final Chunk {len(chunks)+1}: Adding {len(current_chunk)} texts with total tokens {current_tokens}")
        chunks.append(current_chunk)
    logger.info(f"Total chunks created: {len(chunks)}")
    return chunks



In [211]:
class APIModel:
    def __init__(self, model, api_key, api_url):
        self.__api_key = api_key
        self.__api_url = api_url
        self.model = model

    def __req(self, text, temperature, max_try=5):
        url = f"{self.__api_url}"
        pay_load_dict = {
            "model": f"{self.model}",
            "messages": [{
                "role": "user",
                "content": f"{text}"
            }],
            "temperature": temperature
        }
        payload = json.dumps(pay_load_dict)
        headers = {
            'Accept': 'application/json',
            'Authorization': f'Bearer {self.__api_key}',
            'Content-Type': 'application/json'
        }
        for attempt in range(max_try):
            try:
                response = requests.post(url, headers=headers, data=payload)
                response.raise_for_status()
                return response.json()['choices'][0]['message']['content']
            except requests.exceptions.RequestException as e:
                time.sleep(0.2)
        return None

    def chat(self, text, temperature=1):
        return self.__req(text, temperature=temperature)

    def batch_chat(self, text_batch, temperature=0):
        max_threads = 5  # Adjust as needed
        res_l = ['No response'] * len(text_batch)
        thread_l = []

        def worker(i, text):
            res_l[i] = self.chat(text, temperature)

        for i, text in enumerate(text_batch):
            thread = threading.Thread(target=worker, args=(i, text))
            thread_l.append(thread)
            thread.start()
            while len(thread_l) >= max_threads:
                thread_l = [t for t in thread_l if t.is_alive()]
                time.sleep(0.3)

        for thread in thread_l:
            thread.join()
        return res_l


In [212]:
class Database:
    def __init__(self, db_path, embedding_model_name):
        logger.info("Initializing Database...")
        self.embedding_model = SentenceTransformer(embedding_model_name, trust_remote_code=True)
        self.embedding_model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
        logger.info("SentenceTransformer model loaded and moved to appropriate device.")

        self.db_path = db_path
        self.conn = sqlite3.connect(os.path.join(self.db_path, 'arxiv_paper_db.sqlite'))
        self.cursor = self.conn.cursor()
        logger.info("Connected to SQLite database.")

        # Load FAISS indexes
        logger.info("Loading FAISS indexes...")
        self.title_index = faiss.read_index(os.path.join(db_path, 'faiss_paper_title_embeddings.bin'))
        with open(os.path.join(db_path, 'index_to_arxivid_title.json'), 'r') as f:
            self.index_to_id_title = json.load(f)
        logger.info("Title FAISS index and mapping loaded.")

        self.abs_index = faiss.read_index(os.path.join(db_path, 'faiss_paper_abs_embeddings.bin'))
        with open(os.path.join(db_path, 'index_to_arxivid_abs.json'), 'r') as f:
            self.index_to_id_abs = json.load(f)
        logger.info("Abstract FAISS index and mapping loaded.")

    def close(self):
        self.conn.close()
        logger.info("Database connection closed.")

    def get_embeddings(self, texts):
        logger.debug(f"Generating embeddings for {len(texts)} texts.")
        embeddings = self.embedding_model.encode(texts, show_progress_bar=False)
        # Ensure embeddings are in the shape (n, d)
        if embeddings.ndim == 1:
            embeddings = embeddings.reshape(1, -1)
        return embeddings

    def get_ids_from_query(self, query, num, title=True, shuffle=False):
        logger.info(f"Fetching top {num} IDs for query: {query}")
        q = self.get_embeddings([query])  # Keep as 2D array
        ids = self.search(q, top_k=num, title=title)
        if shuffle:
            np.random.shuffle(ids)
            logger.debug("Shuffled the IDs.")
        return ids

    def search(self, query_vector, top_k=1, title=True):
        # Ensure query_vector is a 2D array
        if query_vector.ndim == 1:
            query_vector = query_vector.reshape(1, -1).astype('float32')
        else:
            query_vector = query_vector.astype('float32')
        
        logger.debug(f"Query vector shape: {query_vector.shape}")
        
        index = self.title_index if title else self.abs_index
        index_to_id = self.index_to_id_title if title else self.index_to_id_abs
        distances, indices = index.search(query_vector, top_k)
        results = [index_to_id[str(idx)] for idx in indices[0] if idx != -1]
        logger.debug(f"Search returned {len(results)} results.")
        return results

    def get_paper_info_from_ids(self, ids):
        placeholders = ','.join('?' for _ in ids)
        query = f"SELECT * FROM cs_paper_info WHERE id IN ({placeholders})"
        self.cursor.execute(query, ids)
        rows = self.cursor.fetchall()
        columns = [description[0] for description in self.cursor.description]
        result = [dict(zip(columns, row)) for row in rows]
        logger.info(f"Retrieved information for {len(result)} papers.")
        return result

    def get_titles_from_citations(self, citations):
        logger.info(f"Fetching IDs for {len(citations)} citations.")
        placeholders = ','.join('?' for _ in citations)
        query = f"SELECT id FROM cs_paper_info WHERE title IN ({placeholders})"
        self.cursor.execute(query, citations)
        ids = [row[0] for row in self.cursor.fetchall()]
        logger.debug(f"Found {len(ids)} matching IDs for citations.")
        return ids


In [213]:
class outlineWriter:
    def __init__(self, model: str, api_key: str, api_url: str, database) -> None:
        self.model, self.api_key, self.api_url = model, api_key, api_url 
        self.api_model = APIModel(self.model, self.api_key, self.api_url)
        self.db = database
        self.token_counter = tokenCounter()
        self.input_token_usage, self.output_token_usage = 0, 0
        logger.info("OutlineWriter initialized.")

    def _generate_prompt(self, template, paras):
        prompt = template
        for k in paras.keys():
            prompt = prompt.replace(f'[{k}]', paras[k])
        return prompt

    def draft_outline(self, topic, reference_num=600, chunk_size=30000, section_num=6):
        logger.info(f"Drafting outline for topic: '{topic}' with {reference_num} references and {section_num} sections.")
        references_ids = self.db.get_ids_from_query(topic, num=reference_num, shuffle=True)
        logger.debug(f"Retrieved {len(references_ids)} reference IDs.")
        references_infos = self.db.get_paper_info_from_ids(references_ids)
        logger.debug("Fetched paper information from database.")

        references_titles = [r['title'] for r in references_infos]
        references_abs = [r['abs'] for r in references_infos]
        logger.info("Starting chunking of references.")
        abs_chunks, titles_chunks = self.chunking(references_abs, references_titles, chunk_size=chunk_size)
        logger.info(f"Completed chunking into {len(abs_chunks)} chunks.")

        # Generate rough section-level outline
        outlines = self.generate_rough_outlines(topic=topic, papers_chunks=abs_chunks, titles_chunks=titles_chunks, section_num=section_num)
        logger.info("Generated rough section-level outlines.")

        # Merge outline
        section_outline = self.merge_outlines(topic=topic, outlines=outlines)
        logger.info("Merged outlines into a final section outline.")

        # Generate subsection-level outline
        subsection_outlines = self.generate_subsection_outlines(topic=topic, section_outline=section_outline, rag_num=50)
        logger.info("Generated subsection-level outlines.")

        # Process outlines
        merged_outline = self.process_outlines(section_outline, subsection_outlines)
        logger.info("Processed and merged section and subsection outlines.")

        # Edit final outline
        final_outline = self.edit_final_outline(merged_outline)
        logger.info("Edited final outline for coherence and comprehensiveness.")

        return final_outline

    def generate_rough_outlines(self, topic, papers_chunks, titles_chunks, section_num=8):
        logger.info("Generating rough outlines for each chunk.")
        prompts = []
        for i in trange(len(papers_chunks), desc="Generating Rough Outlines"):
            titles = titles_chunks[i]
            papers = papers_chunks[i]
            paper_texts = '' 
            for idx, (t, p) in enumerate(zip(titles, papers)):
                paper_texts += f'---\npaper_title: {t}\n\npaper_content:\n\n{p}\n'
            paper_texts += '---\n'
            prompt = self._generate_prompt(ROUGH_OUTLINE_PROMPT, paras={'PAPER LIST': paper_texts, 'TOPIC': topic, 'SECTION NUM': str(section_num)})
            prompts.append(prompt)
            logger.debug(f"Generated prompt for chunk {i+1}.")
        self.input_token_usage += self.token_counter.num_tokens_from_list_string(prompts)
        logger.debug(f"Total tokens for rough outline prompts: {self.token_counter.num_tokens_from_list_string(prompts)}")
        outlines = self.api_model.batch_chat(text_batch=prompts, temperature=1)
        self.output_token_usage += self.token_counter.num_tokens_from_list_string(outlines)
        logger.info("Rough outlines generated via API.")
        return outlines

    def merge_outlines(self, topic, outlines):
        logger.info("Merging multiple outlines into a final outline.")
        outline_texts = '' 
        for i, o in zip(range(len(outlines)), outlines):
            outline_texts += f'---\noutline_id: {i}\n\noutline_content:\n\n{o}\n'
        outline_texts += '---\n'
        prompt = self._generate_prompt(MERGING_OUTLINE_PROMPT, paras={'OUTLINE LIST': outline_texts, 'TOPIC': topic})
        self.input_token_usage += self.token_counter.num_tokens_from_string(prompt)
        logger.debug(f"Token usage after merging prompt: {self.token_counter.num_tokens_from_string(prompt)}")
        outline = self.api_model.chat(prompt, temperature=1)
        self.output_token_usage += self.token_counter.num_tokens_from_string(outline)
        logger.info("Final outline merged via API.")
        return outline

    def generate_subsection_outlines(self, topic, section_outline, rag_num):
        logger.info("Generating subsection outlines for each section.")
        survey_title, survey_sections, survey_section_descriptions = self.extract_title_sections_descriptions(section_outline)
        logger.debug(f"Survey Title: {survey_title}")
        logger.debug(f"Number of Sections: {len(survey_sections)}")
        prompts = []
        for section_name, section_description in zip(survey_sections, survey_section_descriptions):
            logger.debug(f"Generating subsections for section: {section_name}")
            references_ids = self.db.get_ids_from_query(section_description, num=rag_num, shuffle=True)
            references_infos = self.db.get_paper_info_from_ids(references_ids)
            references_titles = [r['title'] for r in references_infos]
            references_papers = [r['abs'] for r in references_infos]
            paper_texts = '' 
            for t, p in zip(references_titles, references_papers):
                paper_texts += f'---\npaper_title: {t}\n\npaper_content:\n\n{p}\n'
            paper_texts += '---\n'
            prompt = self._generate_prompt(SUBSECTION_OUTLINE_PROMPT, paras={
                'OVERALL OUTLINE': section_outline,
                'SECTION NAME': section_name,
                'SECTION DESCRIPTION': section_description,
                'TOPIC': topic,
                'PAPER LIST': paper_texts
            })
            prompts.append(prompt)
            logger.debug(f"Generated prompt for subsection in section: {section_name}")
        self.input_token_usage += self.token_counter.num_tokens_from_list_string(prompts)
        logger.debug(f"Total tokens for subsection outline prompts: {self.token_counter.num_tokens_from_list_string(prompts)}")
        sub_outlines = self.api_model.batch_chat(prompts, temperature=1)
        self.output_token_usage += self.token_counter.num_tokens_from_list_string(sub_outlines)
        logger.info("Subsection outlines generated via API.")
        return sub_outlines

    def edit_final_outline(self, outline):
        logger.info("Editing final outline for coherence and comprehensiveness.")
        prompt = self._generate_prompt(EDIT_FINAL_OUTLINE_PROMPT, paras={'OVERALL OUTLINE': outline})
        self.input_token_usage += self.token_counter.num_tokens_from_string(prompt)
        refined_outline = self.api_model.chat(prompt, temperature=1).replace('<format>\n','').replace('</format>','')
        self.output_token_usage += self.token_counter.num_tokens_from_string(refined_outline)
        logger.info("Final outline edited via API.")
        return refined_outline

    def extract_title_sections_descriptions(self, outline):
        logger.debug("Extracting title, sections, and descriptions from outline.")
        try:
            title = outline.split('Title: ')[1].split('\n')[0]
        except IndexError:
            logger.error("Failed to extract title from outline.")
            title = "Untitled Survey"
        sections, descriptions = [], []
        for i in range(1, 101):
            section_key = f'Section {i}: '
            desc_key = f'Description {i}: '
            if section_key in outline:
                section = outline.split(section_key)[1].split('\n')[0].strip()
                sections.append(section)
                if desc_key in outline:
                    description = outline.split(desc_key)[1].split('\n')[0].strip()
                    descriptions.append(description)
                    logger.debug(f"Section {i}: {section} - Description: {description}")
                else:
                    descriptions.append('')
                    logger.debug(f"Section {i}: {section} - No Description Found")
        logger.debug(f"Extracted {len(sections)} sections.")
        return title, sections, descriptions

    def extract_subsections_subdescriptions(self, outline):
        logger.debug("Extracting subsections and their descriptions from outline.")
        subsections, subdescriptions = [], []
        for i in range(1, 101):
            subsection_key = f'Subsection {i}: '
            subdesc_key = f'Description {i}: '
            if subsection_key in outline:
                subsection = outline.split(subsection_key)[1].split('\n')[0].strip()
                subsections.append(subsection)
                if subdesc_key in outline:
                    subdescription = outline.split(subdesc_key)[1].split('\n')[0].strip()
                    subdescriptions.append(subdescription)
                    logger.debug(f"Subsection {i}: {subsection} - Description: {subdescription}")
                else:
                    subdescriptions.append('')
                    logger.debug(f"Subsection {i}: {subsection} - No Description Found")
        logger.debug(f"Extracted {len(subsections)} subsections.")
        return subsections, subdescriptions

    def chunking(self, papers, titles, chunk_size=14000):
        logger.info("Starting chunking of papers and titles.")
        logger.debug(f"Total papers: {len(papers)}")
        paper_chunks, title_chunks = [], []
        total_length = self.token_counter.num_tokens_from_list_string(papers)
        num_of_chunks = int(total_length / chunk_size) + 1
        avg_len = int(total_length / num_of_chunks) + 1
        logger.debug(f"Total tokens: {total_length}, Number of chunks: {num_of_chunks}, Average tokens per chunk: {avg_len}")
        split_points = []
        l = 0
        for j in range(len(papers)):
            l += self.token_counter.num_tokens_from_string(papers[j])
            if l > avg_len:
                l = 0
                split_points.append(j)
                logger.debug(f"Split at paper index {j}")
                continue
        start = 0
        for point in split_points:
            paper_chunks.append(papers[start:point])
            title_chunks.append(titles[start:point])
            logger.debug(f"Chunk {len(paper_chunks)}: Papers {start} to {point-1}")
            start = point
        paper_chunks.append(papers[start:])
        title_chunks.append(titles[start:])
        logger.debug(f"Final chunk {len(paper_chunks)}: Papers {start} to end")
        logger.info("Completed chunking of papers and titles.")
        return paper_chunks, title_chunks

    def process_outlines(self, section_outline, sub_outlines):
        logger.info("Processing and merging section and subsection outlines into final survey outline.")
        survey_title, survey_sections, survey_section_descriptions = self.extract_title_sections_descriptions(outline=section_outline)
        logger.debug(f"Survey Title: {survey_title}")
        logger.debug(f"Number of Sections: {len(survey_sections)}")
        res = f'# {survey_title}\n\n'
        for i in range(len(survey_sections)):
            section = survey_sections[i]
            res += f'## {i+1} {section}\nDescription: {survey_section_descriptions[i]}\n\n'
            subsections, subsection_descriptions = self.extract_subsections_subdescriptions(sub_outlines[i])
            for j in range(len(subsections)):
                subsection = subsections[j]
                res += f'### {i+1}.{j+1} {subsection}\nDescription: {subsection_descriptions[j]}\n\n'
                logger.debug(f"Added subsection {j+1} to section {i+1}: {subsection}")
        logger.info("Final survey outline constructed.")
        return res

In [214]:
class subsectionWriter:
    def __init__(self, model: str, api_key: str, api_url: str, database) -> None:
        self.model, self.api_key, self.api_url = model, api_key, api_url
        self.api_model = APIModel(self.model, self.api_key, self.api_url)
        self.db = database
        self.token_counter = tokenCounter()
        self.input_token_usage, self.output_token_usage = 0, 0
        logger.info("subsectionWriter initialized.")

    def _generate_prompt(self, template, paras):
        prompt = template
        for k in paras.keys():
            prompt = prompt.replace(f'[{k}]', paras[k])
        return prompt

    def write(self, topic, outline, rag_num=30, subsection_len=500, refining=True, reflection=True):
        logger.info(f"Starting to write subsections for topic: '{topic}' with rag_num={rag_num} and subsection_len={subsection_len}.")
        # Parse the outline to get sections and subsections
        parsed_outline = self.parse_outline(outline=outline)
        logger.debug("Parsed Outline:")
        logger.debug(json.dumps(parsed_outline, indent=2))

        # Initialize structures
        section_content = [[] for _ in range(len(parsed_outline['sections']))]
        section_paper_texts = [[] for _ in range(len(parsed_outline['sections']))]

        total_ids = []
        section_references_ids = [[] for _ in range(len(parsed_outline['sections']))]
        for i, sub_desc in enumerate(parsed_outline['subsection_descriptions']):
            for d in sub_desc:
                references_ids = self.db.get_ids_from_query(d, num=rag_num, shuffle=False)
                total_ids += references_ids
                section_references_ids[i].append(references_ids)
                logger.debug(f"Section {i+1}: Retrieved {len(references_ids)} references for description.")

        # Fetch unique references
        unique_ids = list(set(total_ids))
        logger.info(f"Total unique references fetched: {len(unique_ids)}")
        total_references_infos = self.db.get_paper_info_from_ids(unique_ids)
        logger.debug("Fetched all unique paper information from database.")
        temp_title_dic = {p['id']: p['title'] for p in total_references_infos}
        temp_abs_dic = {p['id']: p['abs'] for p in total_references_infos}

        # Prepare paper texts for each section
        for i in range(len(parsed_outline['sections'])):
            for references_ids in section_references_ids[i]:
                references_titles = [temp_title_dic[_] for _ in references_ids]
                references_papers = [temp_abs_dic[_] for _ in references_ids]
                paper_texts = '' 
                for t, p in zip(references_titles, references_papers):
                    paper_texts += f'---\n\npaper_title: {t}\n\npaper_content:\n\n{p}\n'
                paper_texts += '---\n'
                section_paper_texts[i].append(paper_texts)
                logger.debug(f"Section {i+1}: Prepared paper texts for {len(references_ids)} references.")

        # Start threads to write subsections
        thread_l = []
        for i in range(len(parsed_outline['sections'])):
            thread = threading.Thread(target=self.write_subsection_with_reflection, args=(
                section_paper_texts[i], topic, outline, 
                parsed_outline['sections'][i], 
                parsed_outline['subsections'][i], 
                parsed_outline['subsection_descriptions'][i], 
                section_content, i, rag_num, str(subsection_len)
            ))
            thread_l.append(thread)
            thread.start()
            logger.debug(f"Started thread for Section {i+1}.")
            time.sleep(0.1)  # Slight delay to prevent overwhelming the API

        # Wait for all threads to complete
        for thread in thread_l:
            thread.join()
        logger.info("All subsection threads have completed.")

        # Generate the raw survey document
        raw_survey = self.generate_document(parsed_outline, section_content)
        logger.debug("Generated raw survey document.")
        
        # Process references
        raw_survey_with_references, raw_references = self.process_references(raw_survey)
        logger.info("Processed references for raw survey.")

        if refining:
            logger.info("Starting refinement of subsections.")
            final_section_content = self.refine_subsections(topic, outline, section_content)
            refined_survey = self.generate_document(parsed_outline, final_section_content)
            logger.debug("Generated refined survey document.")
            refined_survey_with_references, refined_references = self.process_references(refined_survey)
            logger.info("Processed references for refined survey.")
            return raw_survey + '\n', raw_survey_with_references + '\n', raw_references, refined_survey + '\n', refined_survey_with_references + '\n', refined_references
        else:
            return raw_survey + '\n', raw_survey_with_references + '\n', raw_references

    def refine_subsections(self, topic, outline, section_content):
        logger.info("Refining subsections for coherence and consistency.")
        section_content_even = copy.deepcopy(section_content)
        logger.debug("Created a deep copy of section_content for even indexing.")

        thread_l = []
        for i in range(len(section_content)):
            for j in range(len(section_content[i])):
                if j % 2 == 0:
                    if j == 0:
                        contents = ['', section_content[i][j], section_content[i][j+1] if j+1 < len(section_content[i]) else '']
                    elif j == (len(section_content[i]) - 1):
                        contents = [section_content[i][j-1], section_content[i][j], '']
                    else:
                        contents = [section_content[i][j-1], section_content[i][j], section_content[i][j+1]]
                    thread = threading.Thread(target=self.lce, args=(topic, outline, contents, section_content_even[i], j))
                    thread_l.append(thread)
                    thread.start()
                    logger.debug(f"Started LCE thread for Section {i+1}, Subsection {j+1} (even).")

        for thread in thread_l:
            thread.join()
        logger.info("Completed first pass of subsection refinement (even indices).")

        final_section_content = copy.deepcopy(section_content_even)
        thread_l = []
        for i in range(len(section_content_even)):
            for j in range(len(section_content_even[i])):
                if j % 2 == 1:
                    if j == (len(section_content_even[i]) - 1):
                        contents = [section_content_even[i][j-1], section_content_even[i][j], '']
                    else:
                        contents = [section_content_even[i][j-1], section_content_even[i][j], section_content_even[i][j+1]]
                    thread = threading.Thread(target=self.lce, args=(topic, outline, contents, final_section_content[i], j))
                    thread_l.append(thread)
                    thread.start()
                    logger.debug(f"Started LCE thread for Section {i+1}, Subsection {j+1} (odd).")

        for thread in thread_l:
            thread.join()
        logger.info("Completed second pass of subsection refinement (odd indices).")
        
        return final_section_content

    def write_subsection_with_reflection(self, paper_texts_l, topic, outline, section, subsections, subdescriptions, res_l, idx, rag_num=20, subsection_len=1000, citation_num=8):
        logger.info(f"Writing subsections for Section '{section}'.")
        prompts = []
        for j in range(len(subsections)):
            subsection = subsections[j]
            description = subdescriptions[j]

            prompt = self._generate_prompt(SUBSECTION_WRITING_PROMPT, paras={
                'OVERALL OUTLINE': outline,
                'SUBSECTION NAME': subsection,
                'DESCRIPTION': description,
                'TOPIC': topic,
                'PAPER LIST': paper_texts_l[j],
                'SECTION NAME': section,
                'WORD NUM': str(subsection_len),
                'CITATION NUM': str(citation_num)
            })
            prompts.append(prompt)
            logger.debug(f"Generated prompt for Subsection {j+1}: '{subsection}'.")

        self.input_token_usage += self.token_counter.num_tokens_from_list_string(prompts)
        logger.debug(f"Total tokens for subsection prompts: {self.token_counter.num_tokens_from_list_string(prompts)}")
        contents = self.api_model.batch_chat(prompts, temperature=1)
        self.output_token_usage += self.token_counter.num_tokens_from_list_string(contents)
        logger.info(f"Generated content for Section '{section}' subsections via API.")

        # Clean up content
        contents = [c.replace('<format>', '').replace('</format>', '') for c in contents]

        # Check and add citations
        prompts = []
        for content, paper_texts in zip(contents, paper_texts_l):
            prompts.append(self._generate_prompt(CHECK_CITATION_PROMPT, paras={'SUBSECTION': content, 'TOPIC': topic, 'PAPER LIST': paper_texts}))
            logger.debug("Generated citation checking prompt.")

        self.input_token_usage += self.token_counter.num_tokens_from_list_string(prompts)
        contents = self.api_model.batch_chat(prompts, temperature=1)
        self.output_token_usage += self.token_counter.num_tokens_from_list_string(contents)
        logger.info(f"Checked and updated citations for Section '{section}' subsections via API.")

        # Clean up citation-checked content
        contents = [c.replace('<format>', '').replace('</format>', '') for c in contents]

        # Assign the processed content to the result list
        res_l[idx] = contents
        logger.debug(f"Assigned processed content to section {idx+1}.")
        return contents

    def lce(self, topic, outline, contents, res_l, idx):
        logger.info(f"Refining subsection at index {idx}.")
        prompt = self._generate_prompt(LCE_PROMPT, paras={
            'OVERALL OUTLINE': outline,
            'PREVIOUS': contents[0],
            'FOLLOWING': contents[2],
            'TOPIC': topic,
            'SUBSECTION': contents[1]
        })
        self.input_token_usage += self.token_counter.num_tokens_from_string(prompt)
        refined_content = self.api_model.chat(prompt, temperature=1).replace('<format>', '').replace('</format>', '')
        self.output_token_usage += self.token_counter.num_tokens_from_string(refined_content)
        logger.debug(f"Refined subsection {idx+1}: {refined_content[:100]}...")  # Log first 100 characters
        res_l[idx] = refined_content
        return refined_content

    def parse_outline(self, outline):
        logger.debug("Parsing the outline.")
        result = {
            "title": "",
            "sections": [],
            "section_descriptions": [],
            "subsections": [],
            "subsection_descriptions": []
        }

        lines = outline.strip().split('\n')
        current_section_index = -1
        current_subsection_index = -1

        for i, line in enumerate(lines):
            line = line.strip()
            if line.startswith('# '):
                result["title"] = line[2:].strip()
                logger.debug(f"Found title: {result['title']}")
            elif line.startswith('## '):
                section_title = line[3:].strip()
                result["sections"].append(section_title)
                result["section_descriptions"].append("")  # Placeholder for description
                result["subsections"].append([])
                result["subsection_descriptions"].append([])
                current_section_index += 1
                current_subsection_index = -1
                logger.debug(f"Found section {current_section_index +1}: {section_title}")
                # Check if next line is 'Description:'
                if i + 1 < len(lines) and lines[i + 1].startswith('Description:'):
                    description = lines[i + 1].split('Description:', 1)[1].strip()
                    result["section_descriptions"][current_section_index] = description
                    logger.debug(f"Found description for section {current_section_index +1}: {description}")
            elif line.startswith('### '):
                subsection_title = line[4:].strip()
                result["subsections"][current_section_index].append(subsection_title)
                result["subsection_descriptions"][current_section_index].append("")
                current_subsection_index += 1
                logger.debug(f"Found subsection {current_subsection_index +1} in section {current_section_index +1}: {subsection_title}")
                # Check if next line is 'Description:'
                if i + 1 < len(lines) and lines[i + 1].startswith('Description:'):
                    description = lines[i + 1].split('Description:', 1)[1].strip()
                    result["subsection_descriptions"][current_section_index][current_subsection_index] = description
                    logger.debug(f"Found description for subsection {current_subsection_index +1} in section {current_section_index +1}: {description}")

        logger.info("Completed parsing the outline.")
        return result

    def process_references(self, survey):
        logger.info("Processing references in the survey.")
        citations = self.extract_citations(survey)
        logger.debug(f"Extracted {len(citations)} citations.")
        updated_survey, references = self.replace_citations_with_numbers(citations, survey)
        logger.debug("Replaced citations with numbered references.")
        return updated_survey, references

    def extract_citations(self, markdown_text):
        logger.debug("Extracting citations from markdown text.")
        # Regular expression to match citations within square brackets
        pattern = re.compile(r'\[(.*?)\]')
        matches = pattern.findall(markdown_text)
        citations = []
        for match in matches:
            # Exclude matches that are likely not citations (e.g., text with asterisks)
            if '*' in match:  # Simple heuristic: skip if asterisks are present
                logger.debug(f"Skipping non-citation match: {match}")
                continue
            parts = match.split(';')
            for part in parts:
                cit = part.strip()
                if cit and cit not in citations:
                    citations.append(cit)
        logger.debug(f"Filtered citations: {citations}")
        return citations

    def replace_citations_with_numbers(self, citations, markdown_text):
        logger.info("Replacing citations with numbered references.")
        ids = self.db.get_titles_from_citations(citations)
        citation_to_ids = {citation: idx for citation, idx in zip(citations, ids)}
        logger.debug(f"Mapped citations to IDs: {citation_to_ids}")

        paper_infos = self.db.get_paper_info_from_ids(ids)
        temp_dic = {p['id']: p['title'] for p in paper_infos}

        titles = [temp_dic.get(_, '') for _ in ids]
        logger.debug(f"Retrieved titles for IDs: {titles}")

        ids_to_titles = {idx: title for idx, title in zip(ids, titles)}
        titles_to_ids = {title: idx for idx, title in ids_to_titles.items()}
        title_to_number = {title: num+1 for num, title in enumerate(titles)}

        number_to_title = {num: title for title, num in title_to_number.items()}
        number_to_title_sorted = {key: number_to_title[key] for key in sorted(number_to_title)}

        def replace_match(match):
            citation_text = match.group(1)
            individual_citations = citation_text.split(';')
            numbered_citations = []
            for citation in individual_citations:
                citation = citation.strip()
                try:
                    paper_id = citation_to_ids[citation]
                    title = ids_to_titles[paper_id]
                    number = title_to_number[title]
                    numbered_citations.append(str(number))
                except KeyError:
                    logger.warning(f"Citation not found in database: '{citation}'")
                    # Optionally, keep the original citation or skip
                    # Here, we'll skip it
                    continue
            if numbered_citations:
                return '[' + '; '.join(numbered_citations) + ']'
            else:
                return match.group(0)

        updated_text = re.sub(r'\[(.*?)\]', replace_match, markdown_text)
        logger.debug("Citations replaced with numbers in survey text.")

        # Generate the references section
        references_section = "\n\n## References\n\n"
        references = {num: titles_to_ids[title] for num, title in number_to_title_sorted.items()}
        for idx, title in number_to_title_sorted.items():
            t = title.replace('\n', '')
            references_section += f"[{idx}] {t}\n\n"
        logger.info("References section created.")

        return updated_text + references_section, references

    def generate_document(self, parsed_outline, subsection_contents):
        logger.info("Generating the final survey document.")
        document = []
        
        # Append title
        title = parsed_outline['title']
        document.append(f"# {title}\n")
        logger.debug(f"Added title: {title}")
        
        # Iterate over sections and their content
        for i, section in enumerate(parsed_outline['sections']):
            document.append(f"## {section}\n")
            logger.debug(f"Added section {i+1}: {section}")
            
            # Append subsections and their contents
            for j, subsection in enumerate(parsed_outline['subsections'][i]):
                document.append(f"### {subsection}\n")
                logger.debug(f"Added subsection {i+1}.{j+1}: {subsection}")
                
                # Append detailed content for each subsection
                if i < len(subsection_contents) and j < len(subsection_contents[i]):
                    content = subsection_contents[i][j]
                    document.append(content + "\n")
                    logger.debug(f"Added content for subsection {i+1}.{j+1}")
        
        final_document = "\n".join(document)
        logger.info("Final survey document generated.")
        return final_document

In [215]:
class Judge:
    def __init__(self, model, api_key, api_url, database):
        self.model = APIModel(model, api_key, api_url)
        self.db = database

    def criteria_based_judging(self, survey, topic, criterion):
        criterion_paras = CRITERIA[criterion]
        content_paras = {
            'TOPIC': topic,
            'SURVEY': survey,
            'Criterion Description': criterion_paras['description']
        }
        for score in range(1, 6):
            content_paras[f'Score {score} Description'] = criterion_paras[f'score {score}']
        prompt = generate_prompt(CRITERIA_BASED_JUDGING_PROMPT, content_paras)
        response = self.model.chat(prompt, temperature=0)
        score = extract_num(response)
        return score

    def batch_criteria_based_judging(self, survey, topic, criteria):
        scores = []
        for criterion in criteria:
            score = self.criteria_based_judging(survey, topic, criterion)
            scores.append(score)
        return scores


In [216]:
def generate_outline(topic, model, section_num, outline_reference_num, db, api_key, api_url):
    outline_writer = outlineWriter(model=model, api_key=api_key, api_url=api_url, database=db)
    outline = outline_writer.draft_outline(topic, reference_num=outline_reference_num, section_num=section_num)
    outline_without_descriptions = remove_descriptions(outline)
    return outline, outline_without_descriptions


In [217]:
def write_subsection(topic, model, outline, subsection_len, rag_num, db, api_key, api_url, refinement=True):
    subsection_writer = subsectionWriter(model=model, api_key=api_key, api_url=api_url, database=db)
    outputs = subsection_writer.write(
        topic=topic,
        outline=outline,
        subsection_len=subsection_len,
        rag_num=rag_num,
        refining=refinement
    )
    if refinement:
        survey_content, survey_with_references, references, refined_survey, refined_survey_with_references, refined_references = outputs
        return survey_content, survey_with_references, references, refined_survey, refined_survey_with_references, refined_references
    else:
        survey_content, survey_with_references, references = outputs
        return survey_content, survey_with_references, references


In [None]:
# Parameters (Replace these with your own values or collect them via input)
db_path = './database'
embedding_model = 'nomic-ai/nomic-embed-text-v1'
saving_path = './output'
model_name = 'deepseek-chat'
topic = 'RAG Chunking in Large Language Models'  # Example topic
section_num = 2
subsection_len = 300
outline_reference_num = 100
rag_num = 60
api_url = 'https://api.deepseek.com/chat/completions'
api_key = 'sk-129f52682afe4b73b557c9083a618e16'  # Your API key here

# Validate API Key
if not api_key:
    logger.error("API key is required.")
    sys.exit(1)
else:
    logger.info("API key provided.")


In [None]:
# Initialize Database
db = Database(db_path=db_path, embedding_model_name=embedding_model)
logger.info("Database initialized.")

# Ensure saving path exists
os.makedirs(saving_path, exist_ok=True)
logger.info(f"Saving path '{saving_path}' is ready.")

# Initialize API Model
api_model = APIModel(model=model_name, api_key=api_key, api_url=api_url)
logger.info("API Model initialized.")

In [None]:
# Generate Outline
logger.info("Generating outline...")
outline_with_description, outline_wo_description = generate_outline(
    topic=topic,
    model=model_name,
    section_num=section_num,
    outline_reference_num=outline_reference_num,
    db=db,
    api_key=api_key,
    api_url=api_url
)
logger.info("Outline generated.")
logger.debug(f"Outline with descriptions:\n{outline_with_description}")

In [None]:
print(outline_with_description)

In [None]:
# Write Subsections
logger.info("Writing subsections...")
outputs = write_subsection(
    topic=topic,
    model=model_name,
    outline=outline_with_description,
    subsection_len=subsection_len,
    rag_num=rag_num,
    db=db,
    api_key=api_key,
    api_url=api_url,
    refinement=True
)
if len(outputs) == 6:
    raw_survey, raw_survey_with_references, raw_references, refined_survey, refined_survey_with_references, refined_references = outputs
    logger.info("Subsections written with refinement.")
else:
    raw_survey, raw_survey_with_references, raw_references = outputs
    logger.info("Subsections written without refinement.")
logger.debug(f"Raw Survey:\n{raw_survey}")

In [None]:
# Save the survey to a file
survey_filename = os.path.join(saving_path, f"{topic}.md")
with open(survey_filename, 'w', encoding='utf-8') as f:
    f.write(refined_survey_with_references)
print(f"Survey saved to {survey_filename}")

# Save references to a JSON file
references_filename = os.path.join(saving_path, f"{topic}.json")
save_dic = {
    'survey': refined_survey_with_references,
    'reference': refined_references
}
with open(references_filename, 'w', encoding='utf-8') as f:
    json.dump(save_dic, f, indent=4)
print(f"References saved to {references_filename}")


In [None]:
# Evaluate the generated survey
print("Evaluating survey...")
judge = Judge(model=model_name, api_key=api_key, api_url=api_url, database=db)
survey_content = refined_survey_with_references
criteria = ['Coverage', 'Structure', 'Relevance']
scores = judge.batch_criteria_based_judging(survey_content, topic, criteria)

# Save Evaluation Results
evaluation_filename = os.path.join(saving_path, f"{topic}_evaluation.txt")
with open(evaluation_filename, 'w', encoding='utf-8') as f:
    result = f"Evaluation Results for '{topic}'\n"
    for c, s in zip(criteria, scores):
        result += f"{c}: {s}\n"
    f.write(result)
print(f"Evaluation results saved to {evaluation_filename}")


In [None]:
# Close Database Connection
db.close()
