## function calling as `tool use`

> AI Agent 的核心 & 基础功能：LLMs 不仅能说，function calling 赋予其 tool use 执行（execute）的能力；
> - 自动地选择（决策）用哪个函数；
> - （基于纯 language 的 query，或者说自然语言风格的对话）生成（generation）函数的参数（argument generation）=> 实例化一个具体的函数调用，当然具体的调用，我们的代码程序来执行；


> messages：list
> - 软件开发上代表着数据流（data flow）；
> - 无状态（stateless）的 LLMs 的  working memory；


- 数据化 => 自动化 => 智能化
- huggingGPT 其实就是在做这件事情，通过一种 diy 的方式；
- （coding 代码意义上的）functions 其实是一种formal 的、精确的形式，定义、计算以及处理逻辑，而且是确定性的；
- LLMs with function calling，追求一种更精确的输出和执行，更多地可以耦合到具体的复杂代码逻辑及业务流程中；
- 推荐下[《大模型应用开发 动手做AI Agent GPT大语言模型应用》](https://www.bilibili.com/opus/935785456083140628?spm_id_from=333.999.0.0)
    - 面向开发者
    - 系统而全面

## call functions with chat models

- https://cookbook.openai.com/examples/how_to_call_functions_with_chat_models
    - https://github.com/openai/openai-cookbook/blob/main/examples/How_to_call_functions_with_chat_models.ipynb

In [1]:
import os
os.environ['http_proxy'] = 'http://127.0.0.1:7890'
os.environ['https_proxy'] = 'http://127.0.0.1:7890'

In [4]:
!pip install scipy --quiet
!pip install tenacity --quiet
!pip install tiktoken --quiet
!pip install termcolor --quiet
!pip install openai --quiet

In [2]:
import json
import openai
from openai import OpenAI
from tenacity import retry, wait_random_exponential, stop_after_attempt
from termcolor import colored  

In [3]:
openai.__version__

'1.35.13'

In [4]:
from dotenv import load_dotenv
# .env (OPENAI_API_KEY=sk-proj-xxxx)
load_dotenv()

True

In [5]:
GPT_MODEL = "gpt-4o"
client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])

### utilities

In [6]:
@retry(wait=wait_random_exponential(multiplier=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request(messages, tools=None, tool_choice=None, model=GPT_MODEL):
    try:
        response = client.chat.completions.create(
            model=model,
            messages=messages,
            tools=tools,
            tool_choice=tool_choice,
        )
        return response
    except Exception as e:
        print("Unable to generate ChatCompletion response")
        print(f"Exception: {e}")
        return e

### tools

- `def get_current_weather(location: str, format: Literal["celsius", "fahrenheit"])`
- `def get_n_day_weather_forecast(location: str, format: Literal['celsius', 'fahrenheit'], num_days: int)`

In [7]:
tools = [
    {
        "type": "function",
        "function": {
            "name": "get_current_weather",
            "description": "Get the current weather",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The city and state, e.g. San Francisco, CA",
                    },
                    "format": {
                        "type": "string",
                        "enum": ["celsius", "fahrenheit"],
                        "description": "The temperature unit to use. Infer this from the users location.",
                    },
                },
                "required": ["location", "format"],
            },
        }
    },
    {
        "type": "function",
        "function": {
            "name": "get_n_day_weather_forecast",
            "description": "Get an N-day weather forecast",
            "parameters": {
                "type": "object",
                "properties": {
                    "location": {
                        "type": "string",
                        "description": "The city and state, e.g. San Francisco, CA",
                    },
                    "format": {
                        "type": "string",
                        "enum": ["celsius", "fahrenheit"],
                        "description": "The temperature unit to use. Infer this from the users location.",
                    },
                    "num_days": {
                        "type": "integer",
                        "description": "The number of days to forecast",
                    }
                },
                "required": ["location", "format", "num_days"]
            },
        }
    },
]

### 构造 messages （`get_current_weather`）

- 不只是 list of dict，还可以支持 ChatCompletionMessage 对象（本质上也是字典）
    - `{'role': '', 'content': ''}`
    - role: system, user, assistant, user assitant
        - tool
    - user/assistant: 构成一问一答；

In [8]:
messages = []
messages.append({"role": "system", 
                 "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."})
messages.append({"role": "user", 
                 "content": "What's the weather like today"})
messages

[{'role': 'system',
  'content': "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
 {'role': 'user', 'content': "What's the weather like today"}]

In [9]:
chat_response = chat_completion_request(
    messages, tools=tools
)
chat_response

ChatCompletion(id='chatcmpl-9miuqM8DcABewg8LJTM9p7F6FRGWk', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content='Could you please provide me with the city and state (or country) for which you would like to know the current weather?', role='assistant', function_call=None, tool_calls=None))], created=1721399428, model='gpt-4o-2024-05-13', object='chat.completion', service_tier=None, system_fingerprint='fp_5e997b69d8', usage=CompletionUsage(completion_tokens=26, prompt_tokens=182, total_tokens=208))

In [11]:
chat_response.choices[0].message, type(chat_response.choices[0].message)

(ChatCompletionMessage(content='Could you please provide me with the city and state (or country) for which you would like to know the current weather?', role='assistant', function_call=None, tool_calls=None),
 openai.types.chat.chat_completion_message.ChatCompletionMessage)

In [12]:
assistant_message = chat_response.choices[0].message
messages.append(assistant_message)
assistant_message

ChatCompletionMessage(content='Could you please provide me with the city and state (or country) for which you would like to know the current weather?', role='assistant', function_call=None, tool_calls=None)

In [14]:
messages

[{'role': 'system',
  'content': "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
 {'role': 'user', 'content': "What's the weather like today"},
 ChatCompletionMessage(content='Could you please provide me with the city and state (or country) for which you would like to know the current weather?', role='assistant', function_call=None, tool_calls=None)]

In [15]:
messages.append({"role": "user", "content": "I'm in Glasgow, Scotland."})
chat_response = chat_completion_request(
    messages, tools=tools
)
assistant_message = chat_response.choices[0].message
messages.append(assistant_message)
assistant_message

ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_ZugwC3KTxMeYjWQ9SZj0rSk6', function=Function(arguments='{"location":"Glasgow, Scotland","format":"celsius"}', name='get_current_weather'), type='function')])

In [17]:
messages

[{'role': 'system',
  'content': "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
 {'role': 'user', 'content': "What's the weather like today"},
 ChatCompletionMessage(content='Could you please provide me with the city and state (or country) for which you would like to know the current weather?', role='assistant', function_call=None, tool_calls=None),
 {'role': 'user', 'content': "I'm in Glasgow, Scotland."},
 ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_ZugwC3KTxMeYjWQ9SZj0rSk6', function=Function(arguments='{"location":"Glasgow, Scotland","format":"celsius"}', name='get_current_weather'), type='function')])]

In [18]:
messages = messages[:-2]
messages

[{'role': 'system',
  'content': "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."},
 {'role': 'user', 'content': "What's the weather like today"},
 ChatCompletionMessage(content='Could you please provide me with the city and state (or country) for which you would like to know the current weather?', role='assistant', function_call=None, tool_calls=None)]

In [19]:
messages.append({"role": "user", "content": "I'm in Beijing, China."})
chat_response = chat_completion_request(
    messages, tools=tools
)
assistant_message = chat_response.choices[0].message
messages.append(assistant_message)
assistant_message

ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_9onbR17EcwDBhlC0gUF3wQl0', function=Function(arguments='{"location":"Beijing, China","format":"celsius"}', name='get_current_weather'), type='function')])

### 构造 messages（`get_n_day_weather_forecast`）

In [21]:
messages = []
messages.append({"role": "system", 
                 "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."})
messages.append({"role": "user", 
                 "content": "what is the weather going to be like in Beijing, China over the next x days"})
chat_response = chat_completion_request(
    messages, tools=tools
)
assistant_message = chat_response.choices[0].message
messages.append(assistant_message)
assistant_message

ChatCompletionMessage(content='Could you please specify the number of days you would like the forecast for?', role='assistant', function_call=None, tool_calls=None)

In [22]:
messages.append({"role": "user", "content": "5 days"})
chat_response = chat_completion_request(
    messages, tools=tools
)
chat_response.choices[0]

Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_uVCEvzGM0IVFMkhWoBdBdNoL', function=Function(arguments='{"location":"Beijing, China","format":"celsius","num_days":5}', name='get_n_day_weather_forecast'), type='function')]))

### tool_choice

- `tool_choice=None`
- `tool_choice={"type": "function", "function": {"name": "get_n_day_weather_forecast"}}`
- Newer models such as gpt-4o or gpt-3.5-turbo can call multiple functions in one turn.

In [23]:
messages = []
messages.append({"role": "system", "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."})
messages.append({"role": "user", "content": "Give me the current weather (use Celcius) for Toronto, Canada."})
chat_response = chat_completion_request(
    messages, tools=tools, tool_choice="none"
)
chat_response.choices[0].message

ChatCompletionMessage(content='Sure, I will fetch the current weather for Toronto, Canada in Celsius.', role='assistant', function_call=None, tool_calls=None)

In [24]:
# if we don't force the model to use get_n_day_weather_forecast it may not
messages = []
messages.append({"role": "system", "content": "Don't make assumptions about what values to plug into functions. Ask for clarification if a user request is ambiguous."})
messages.append({"role": "user", "content": "Give me a weather report for Toronto, Canada."})
chat_response = chat_completion_request(
    messages, tools=tools
)
chat_response.choices[0].message

ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_Ija511nv3AbyKF7FdSqfOCjh', function=Function(arguments='{"location": "Toronto, Canada", "format": "celsius"}', name='get_current_weather'), type='function'), ChatCompletionMessageToolCall(id='call_x6XAima2QsE1FghBPTx3UnMU', function=Function(arguments='{"location": "Toronto, Canada", "format": "celsius", "num_days": 3}', name='get_n_day_weather_forecast'), type='function')])

In [25]:
chat_response

ChatCompletion(id='chatcmpl-9mjBGPqxkqPlTRVZe39IHGlMBLRXD', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_Ija511nv3AbyKF7FdSqfOCjh', function=Function(arguments='{"location": "Toronto, Canada", "format": "celsius"}', name='get_current_weather'), type='function'), ChatCompletionMessageToolCall(id='call_x6XAima2QsE1FghBPTx3UnMU', function=Function(arguments='{"location": "Toronto, Canada", "format": "celsius", "num_days": 3}', name='get_n_day_weather_forecast'), type='function')]))], created=1721400446, model='gpt-4o-2024-05-13', object='chat.completion', service_tier=None, system_fingerprint='fp_c4e5b6fa31', usage=CompletionUsage(completion_tokens=68, prompt_tokens=187, total_tokens=255))

## argument generation (sql)

- https://www.sqlitetutorial.net/sqlite-sample-database/
    - https://www.sqlitetutorial.net/wp-content/uploads/2018/03/chinook.zip

Steps to invoke a function call using Chat Completions API:

Step 1: Prompt the model with content that may result in model selecting a tool to use. The description of the tools such as a function names and signature is defined in the 'Tools' list and passed to the model in API call. If selected, the function name and parameters are included in the response.

Step 2: Check programmatically if model wanted to call a function. If true, proceed to step 3.

Step 3: Extract the function name and parameters from response, call the function with parameters. **Append the result to messages.**

Step 4: Invoke the chat completions API with the message list to get the response.

In [31]:
import sqlite3

conn = sqlite3.connect("data/chinook.db")
print("Opened database successfully")

Opened database successfully


In [32]:
def get_table_names(conn):
    """Return a list of table names."""
    table_names = []
    tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
    for table in tables.fetchall():
        table_names.append(table[0])
    return table_names


def get_column_names(conn, table_name):
    """Return a list of column names."""
    column_names = []
    columns = conn.execute(f"PRAGMA table_info('{table_name}');").fetchall()
    for col in columns:
        column_names.append(col[1])
    return column_names


def get_database_info(conn):
    """Return a list of dicts containing the table name and columns for each table in the database."""
    table_dicts = []
    for table_name in get_table_names(conn):
        columns_names = get_column_names(conn, table_name)
        table_dicts.append({"table_name": table_name, "column_names": columns_names})
    return table_dicts

In [33]:
database_schema_dict = get_database_info(conn)
database_schema_string = "\n".join(
    [
        f"Table: {table['table_name']}\nColumns: {', '.join(table['column_names'])}"
        for table in database_schema_dict
    ]
)
database_schema_dict

[{'table_name': 'albums', 'column_names': ['AlbumId', 'Title', 'ArtistId']},
 {'table_name': 'sqlite_sequence', 'column_names': ['name', 'seq']},
 {'table_name': 'artists', 'column_names': ['ArtistId', 'Name']},
 {'table_name': 'customers',
  'column_names': ['CustomerId',
   'FirstName',
   'LastName',
   'Company',
   'Address',
   'City',
   'State',
   'Country',
   'PostalCode',
   'Phone',
   'Fax',
   'Email',
   'SupportRepId']},
 {'table_name': 'employees',
  'column_names': ['EmployeeId',
   'LastName',
   'FirstName',
   'Title',
   'ReportsTo',
   'BirthDate',
   'HireDate',
   'Address',
   'City',
   'State',
   'Country',
   'PostalCode',
   'Phone',
   'Fax',
   'Email']},
 {'table_name': 'genres', 'column_names': ['GenreId', 'Name']},
 {'table_name': 'invoices',
  'column_names': ['InvoiceId',
   'CustomerId',
   'InvoiceDate',
   'BillingAddress',
   'BillingCity',
   'BillingState',
   'BillingCountry',
   'BillingPostalCode',
   'Total']},
 {'table_name': 'invoice_i

In [34]:
tools = [
    {
        "type": "function",
        "function": {
            "name": "ask_database",
            "description": "Use this function to answer user questions about music. Input should be a fully formed SQL query.",
            "parameters": {
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": f"""
                                SQL query extracting info to answer the user's question.
                                SQL should be written using this database schema:
                                {database_schema_string}
                                The query should be returned in plain text, not in JSON.
                                """,
                    }
                },
                "required": ["query"],
            },
        }
    }
]

In [35]:
def ask_database(conn, query):
    """Function to query SQLite database with a provided SQL query."""
    try:
        results = str(conn.execute(query).fetchall())
    except Exception as e:
        results = f"query failed with error: {e}"
    return results

In [36]:
# Step #1: Prompt with content that may result in function call. In this case the model can identify the information requested by the user is potentially available in the database schema passed to the model in Tools description. 
messages = [{
    "role": "user", 
    "content": "What is the name of the album with the most tracks?"
}]

response = client.chat.completions.create(
    model='gpt-4o', 
    messages=messages, 
    tools=tools, 
    tool_choice="auto"
)

# Append the message to messages list
response_message = response.choices[0].message 
messages.append(response_message)
response_message

ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_fr3IUDJgvMxonRTweFWNPlH7', function=Function(arguments='{"query":"SELECT albums.Title, COUNT(tracks.TrackId) AS TrackCount FROM albums JOIN tracks ON albums.AlbumId = tracks.AlbumId GROUP BY albums.AlbumId ORDER BY TrackCount DESC LIMIT 1;"}', name='ask_database'), type='function')])

In [42]:
response_message.tool_calls[0].function.name, json.loads(response_message.tool_calls[0].function.arguments)['query']

('ask_database',
 'SELECT albums.Title, COUNT(tracks.TrackId) AS TrackCount FROM albums JOIN tracks ON albums.AlbumId = tracks.AlbumId GROUP BY albums.AlbumId ORDER BY TrackCount DESC LIMIT 1;')

In [43]:
# Step 2: determine if the response from the model includes a tool call.   
tool_calls = response_message.tool_calls
if tool_calls:
    # If true the model will return the name of the tool / function to call and the argument(s)  
    tool_call_id = tool_calls[0].id
    tool_function_name = tool_calls[0].function.name
    tool_query_string = json.loads(tool_calls[0].function.arguments)['query']

    # Step 3: Call the function and retrieve results. Append the results to the messages list.      
    if tool_function_name == 'ask_database':
        results = ask_database(conn, tool_query_string)
        
        messages.append({
            "role":"tool", 
            "tool_call_id":tool_call_id, 
            "name": tool_function_name, 
            "content":results
        })
        
        # Step 4: Invoke the chat completions API with the function response appended to the messages list
        # Note that messages with role 'tool' must be a response to a preceding message with 'tool_calls'
        model_response_with_function_call = client.chat.completions.create(
            model="gpt-4o",
            messages=messages,
        )  # get a new response from the model where it can see the function response
        print(model_response_with_function_call.choices[0].message.content)
    else: 
        print(f"Error: function {tool_function_name} does not exist")
else: 
    # Model did not identify a function to call, result can be returned to the user 
    print(response_message.content) 

The album with the most tracks is "Greatest Hits," which features a total of 57 tracks.


In [44]:
messages

[{'role': 'user',
  'content': 'What is the name of the album with the most tracks?'},
 ChatCompletionMessage(content=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_fr3IUDJgvMxonRTweFWNPlH7', function=Function(arguments='{"query":"SELECT albums.Title, COUNT(tracks.TrackId) AS TrackCount FROM albums JOIN tracks ON albums.AlbumId = tracks.AlbumId GROUP BY albums.AlbumId ORDER BY TrackCount DESC LIMIT 1;"}', name='ask_database'), type='function')]),
 {'role': 'tool',
  'tool_call_id': 'call_fr3IUDJgvMxonRTweFWNPlH7',
  'name': 'ask_database',
  'content': "[('Greatest Hits', 57)]"}]

## RAG & function calling (functions with a knowledge base)

- https://cookbook.openai.com/examples/how_to_call_functions_for_knowledge_retrieval

In [46]:
# !pip install scipy --quiet
# !pip install tenacity --quiet
# !pip install tiktoken==0.3.3 --quiet
# !pip install termcolor --quiet
# !pip install openai --quiet
# !pip install arxiv --quiet
# !pip install pandas --quiet
# !pip install PyPDF2 --quiet
# !pip install tqdm --quiet

In [55]:
import os
import arxiv
import ast
import concurrent
import json
import os
import pandas as pd
import tiktoken
from csv import writer
from IPython.display import display, Markdown, Latex
from openai import OpenAI
from PyPDF2 import PdfReader
from scipy import spatial
from tenacity import retry, wait_random_exponential, stop_after_attempt
from tqdm import tqdm
from termcolor import colored

GPT_MODEL = "gpt-4o"
EMBEDDING_MODEL = "text-embedding-ada-002"
client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])

In [48]:
directory = './data/papers'

# Check if the directory already exists
if not os.path.exists(directory):
    # If the directory doesn't exist, create it and any necessary intermediate directories
    os.makedirs(directory)
    print(f"Directory '{directory}' created successfully.")
else:
    # If the directory already exists, print a message indicating it
    print(f"Directory '{directory}' already exists.")

Directory './data/papers' created successfully.


In [49]:
# Set a directory to store downloaded papers
data_dir = os.path.join(os.curdir, "data", "papers")
paper_dir_filepath = "./data/arxiv_library.csv"

# Generate a blank dataframe where we can store downloaded files
df = pd.DataFrame(list())
df.to_csv(paper_dir_filepath)

In [50]:
@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
def embedding_request(text):
    response = client.embeddings.create(input=text, model=EMBEDDING_MODEL)
    return response


@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
def get_articles(query, library=paper_dir_filepath, top_k=5):
    """This function gets the top_k articles based on a user's query, sorted by relevance.
    It also downloads the files and stores them in arxiv_library.csv to be retrieved by the read_article_and_summarize.
    """
    client = arxiv.Client()
    search = arxiv.Search(
        query=query,
        max_results=top_k,
        sort_by = arxiv.SortCriterion.SubmittedDate
    )
    result_list = []
    for result in client.results(search):
        result_dict = {}
        result_dict.update({"title": result.title})
        result_dict.update({"summary": result.summary})

        # Taking the first url provided
        result_dict.update({"article_url": [x.href for x in result.links][0]})
        result_dict.update({"pdf_url": [x.href for x in result.links][1]})
        result_list.append(result_dict)

        # Store references in library file
        response = embedding_request(text=result.title)
        file_reference = [
            result.title,
            result.download_pdf(data_dir),
            response.data[0].embedding,
        ]

        # Write to file
        with open(library, "a") as f_object:
            writer_object = writer(f_object)
            writer_object.writerow(file_reference)
            f_object.close()
    return result_list


In [51]:
# Test that the search is working
result_output = get_articles("ppo rlhf")
result_output

[{'title': 'Understanding Reinforcement Learning-Based Fine-Tuning of Diffusion Models: A Tutorial and Review',
  'summary': 'This tutorial provides a comprehensive survey of methods for fine-tuning\ndiffusion models to optimize downstream reward functions. While diffusion\nmodels are widely known to provide excellent generative modeling capability,\npractical applications in domains such as biology require generating samples\nthat maximize some desired metric (e.g., translation efficiency in RNA, docking\nscore in molecules, stability in protein). In these cases, the diffusion model\ncan be optimized not only to generate realistic samples but also to explicitly\nmaximize the measure of interest. Such methods are based on concepts from\nreinforcement learning (RL). We explain the application of various RL\nalgorithms, including PPO, differentiable optimization, reward-weighted MLE,\nvalue-weighted sampling, and path consistency learning, tailored specifically\nfor fine-tuning diffusion

In [52]:
def strings_ranked_by_relatedness(
    query: str,
    df: pd.DataFrame,
    relatedness_fn=lambda x, y: 1 - spatial.distance.cosine(x, y),
    top_n: int = 100,
) -> list[str]:
    """Returns a list of strings and relatednesses, sorted from most related to least."""
    query_embedding_response = embedding_request(query)
    query_embedding = query_embedding_response.data[0].embedding
    strings_and_relatednesses = [
        (row["filepath"], relatedness_fn(query_embedding, row["embedding"]))
        for i, row in df.iterrows()
    ]
    strings_and_relatednesses.sort(key=lambda x: x[1], reverse=True)
    strings, relatednesses = zip(*strings_and_relatednesses)
    return strings[:top_n]

In [58]:
def read_pdf(filepath):
    """Takes a filepath to a PDF and returns a string of the PDF's contents"""
    # creating a pdf reader object
    reader = PdfReader(filepath)
    pdf_text = ""
    page_number = 0
    for page in reader.pages:
        page_number += 1
        pdf_text += page.extract_text() + f"\nPage Number: {page_number}"
    return pdf_text


# Split a text into smaller chunks of size n, preferably ending at the end of a sentence
def create_chunks(text, n, tokenizer):
    """Returns successive n-sized chunks from provided text."""
    tokens = tokenizer.encode(text)
    i = 0
    while i < len(tokens):
        # Find the nearest end of sentence within a range of 0.5 * n and 1.5 * n tokens
        j = min(i + int(1.5 * n), len(tokens))
        while j > i + int(0.5 * n):
            # Decode the tokens and check for full stop or newline
            chunk = tokenizer.decode(tokens[i:j])
            if chunk.endswith(".") or chunk.endswith("\n"):
                break
            j -= 1
        # If no end of sentence found, use n tokens as the chunk size
        if j == i + int(0.5 * n):
            j = min(i + n, len(tokens))
        yield tokens[i:j]
        i = j


def extract_chunk(content, template_prompt):
    """This function applies a prompt to some input content. In this case it returns a summarized chunk of text"""
    prompt = template_prompt + content
    response = client.chat.completions.create(
        model=GPT_MODEL, messages=[{"role": "user", "content": prompt}], temperature=0
    )
    return response.choices[0].message.content


def summarize_text(query):
    """This function does the following:
    - Reads in the arxiv_library.csv file in including the embeddings
    - Finds the closest file to the user's query
    - Scrapes the text out of the file and chunks it
    - Summarizes each chunk in parallel
    - Does one final summary and returns this to the user"""

    # A prompt to dictate how the recursive summarizations should approach the input paper
    summary_prompt = """Summarize this text from an academic paper. Extract any key points with reasoning.\n\nContent:"""

    # If the library is empty (no searches have been performed yet), we perform one and download the results
    library_df = pd.read_csv(paper_dir_filepath).reset_index()
    if len(library_df) == 0:
        print("No papers searched yet, downloading first.")
        get_articles(query)
        print("Papers downloaded, continuing")
        library_df = pd.read_csv(paper_dir_filepath).reset_index()
    library_df.columns = ["title", "filepath", "embedding"]
    library_df["embedding"] = library_df["embedding"].apply(ast.literal_eval)
    strings = strings_ranked_by_relatedness(query, library_df, top_n=1)
    print(f"Chunking text from paper: {strings[0]}")
    pdf_text = read_pdf(strings[0])

    # Initialise tokenizer
    tokenizer = tiktoken.get_encoding("cl100k_base")
    results = ""

    # Chunk up the document into 1500 token chunks
    chunks = create_chunks(pdf_text, 1500, tokenizer)
    text_chunks = [tokenizer.decode(chunk) for chunk in chunks]
    print("Summarizing each chunk of text")

    # Parallel process the summaries
    with concurrent.futures.ThreadPoolExecutor(
        max_workers=len(text_chunks)
    ) as executor:
        futures = [
            executor.submit(extract_chunk, chunk, summary_prompt)
            for chunk in text_chunks
        ]
        with tqdm(total=len(text_chunks)) as pbar:
            for _ in concurrent.futures.as_completed(futures):
                pbar.update(1)
        for future in futures:
            data = future.result()
            results += data

    # Final summary
    print("Summarizing into overall summary")
    response = client.chat.completions.create(
        model=GPT_MODEL,
        messages=[
            {
                "role": "user",
                "content": f"""Write a summary collated from this collection of key points extracted from an academic paper.
                        The summary should highlight the core argument, conclusions and evidence, and answer the user's query.
                        User query: {query}
                        The summary should be structured in bulleted lists following the headings Core Argument, Evidence, and Conclusions.
                        Key points:\n{results}\nSummary:\n""",
            }
        ],
        temperature=0,
    )
    return response


In [59]:
chat_test_response = summarize_text("PPO reinforcement learning sequence generation")

Chunking text from paper: ./data/papers/2407.13734v1.Understanding_Reinforcement_Learning_Based_Fine_Tuning_of_Diffusion_Models__A_Tutorial_and_Review.pdf
Summarizing each chunk of text


100%|██████████| 16/16 [00:24<00:00,  1.54s/it]


Summarizing into overall summary
