# Source Code

> This is the ‘literate’ source code for ***Scholaris***, a library to set up an advanced research assistant leveraging function calling capabilities with LLMs on your local computer. ***Scholaris*** was built to run with [Llama 3.1 8B](https://ollama.com/library/llama3.1) using the [ollama](https://ollama.com/) framework.

In [None]:
#| default_exp core

#| hide

This notebook was written using [nbdev](https://nbdev.fast.ai/) to show how the source code was created, to facilitate testing, continuous integration and documentation, all in a single context. Do not modify special comments, such as `#| default_exp`, `#| hide`, or `#| export`. These special comments serve as markdown directives and define the module name and content, and which code or markdown cells are rendered for documentation.

#| hide

## Installation

First, download and install Ollama on your computer. Go to [Download Ollama](https://ollama.com/download) and follow the instructions for your operating system. Then pull and run [llama3.1](https://ollama.com/library/llama3.1) (parameters: 8B, quantization: Q4_0, size: 4.7 GB) according to the ollama documentation.

In [None]:
#| hide
import os
import sys

#| hide
## Using the Ollama API

The code in this section is an application of the Ollama Python library. It is written to demonstrate the basic low-level functionality of the Ollama API, and should help the reader understand the basics underpinning the Scholaris library. Make sure to also check out the [blog post on tool support](https://ollama.com/blog/tool-support) on the official Ollama website. If you are already familiar with the Ollama API, you can skip this section.


#| hide
### Check if the ollama app is running and list all available models

In [None]:
#| hide
import ollama
!ollama list

#| hide

Alternatively, use the Ollama Python library to get the list and size of available models.

In [None]:
#| hide
errors = []
try:
    installed_models = [model for model in ollama.list()["models"]]
    for model in installed_models:
        size = model["size"] / (1024 ** 3)
        # print(f"{model["name"]} \t{size:.2f} GB")
    server_is_available = True
except Exception as e:
    print(f"{e}. Is the ollama app running?")
    server_is_available = False
    errors.append(e)

#| hide
### Using the basic chat functionality

In [None]:
#| hide
# with stream=True, the response is streamed in chunks
sys_message = """You are a helpful AI assistant. Respond with the shortest possible answers while maintaining clarity and accuracy. 
Avoid unnecessary details, explanations, or pleasantries. Be direct and to the point in all responses.
""" # For testing purposes & CI, we are looking for short responses.

prompt = "According to 'The Hitchhiker's Guide to the Galaxy' by Douglas Adams, what is the meaning of life?"
response = None
try:
    response = ollama.chat(
        model='llama3.1', 
        messages=[
            {'role': "system", 'content': sys_message},
            {'role': 'user',
            'content': prompt.format(str=prompt)} # Use built-in formatting to insert the prompt
            ],
        stream=True, # Set to True to stream the response
        )  
    for chunk in response:
        print(chunk['message']['content'], end='', flush=True)
except Exception as e:
    print(e)
    errors.append(e)

42.

In [None]:
#| hide
type(response)

generator

In [None]:
#| hide
# with stream=False, the response is returned as a single dictionary object
try:
    response = ollama.chat(
        model='llama3.1', 
        messages=[
            {'role': "system", 'content': sys_message},
            {'role': 'user',
            'content': prompt.format(str=prompt)} # Use built-in formatting to insert the prompt
            ],
        stream=False, # Set to True to stream the response
        )  
    print(response['message']['content'])
except Exception as e:
    print(e)
    errors.append(e)

42.


In [None]:
#| hide
type(response)


dict

In [None]:
#| hide
response

{'model': 'llama3.1',
 'created_at': '2024-09-11T19:43:53.157387Z',
 'message': {'role': 'assistant', 'content': '42.'},
 'done_reason': 'stop',
 'done': True,
 'total_duration': 145121625,
 'load_duration': 44232167,
 'prompt_eval_count': 80,
 'prompt_eval_duration': 33679000,
 'eval_count': 3,
 'eval_duration': 65582000}

In [None]:
history = []
def ollama_chat(promot:str):
    response = ollama.chat(
        model='llama3.1', 
        messages=[
            {'role': 'user',
            'content': prompt} # Use built-in formatting to insert the prompt
            ],
        stream=False, # Set to True to stream the response
        )
    history.append(response['message']['content'])
    return response['message']['content']  

In [None]:
response = ollama_chat("According to 'The Hitchhiker's Guide to the Galaxy' by Douglas Adams, what is the meaning of life?")

In [None]:
type(response)

str

#| hide
### Demonstrate and test function calling

To get started, we will go through function / tool calling step by step.

#| hide
#### 1. Define a function to call
For demonstration purposes, we will define a simple addition tool.

In [None]:
#| hide
def add_two_integers(a: int, b: int) -> int:
    """
    A function to add two numbers.

    Args:
        a (int): First number.
        b (int): Second number.

    Returns:
        int: Sum of the two numbers.

    Raises:
        TypeError: If inputs are not integers.
    """
    if not isinstance(a, int) or not isinstance(b, int):
        raise TypeError("Both inputs must be integers")
    
    return a + b

#| hide
#### 2. Print the function name, docstring and annotations

In [None]:
#| hide
print("Function name:", add_two_integers.__name__, "\n---")
print("Docstring:", add_two_integers.__doc__,"\n---")
print("Annotations:", add_two_integers.__annotations__,"\n---")

Function name: add_two_integers 
---
Docstring: 
    A function to add two numbers.

    Args:
        a (int): First number.
        b (int): Second number.

    Returns:
        int: Sum of the two numbers.

    Raises:
        TypeError: If inputs are not integers.
     
---
Annotations: {'a': <class 'int'>, 'b': <class 'int'>, 'return': <class 'int'>} 
---


#| hide
#### 3. Define the JSON schema for the function

In [None]:
#| hide
args = {
    "a": {
        "type": "integer",
        "description": "First number"
    },
    "b": {
        "type": "integer",
        "description": "Second number"
    }
}

addition_tool_schema = {
    "type": "function",
    "function": {
        "name": add_two_integers.__name__,
        "description": "Function to add two numbers",
        "parameters": {
            "type": "object",
            "properties": args,
            "required": ["a", "b"],
        }
    }
}

#| hide
#### 4. First API call: send query and JSON schema of the function to the model

We will use random numbers for the example.

In [None]:
#| hide
import random

In [None]:
#| hide
a = random.randint(1, 10000000)
b = random.randint(1, 10000000)
prompt = f"What is the sum of {a} and {b}?"
print(prompt)

sys_message = "You are a an assistant provided with tools to help you answer questions."

messages=[
    {'role': "system", 'content': sys_message},
    {'role': 'user',
    'content': prompt.format(str=prompt)} # Use built-in formatting to insert the prompt
]

try: #
    response = ollama.chat(
        model="llama3.1",
        messages=messages,
        tools=[addition_tool_schema],
    )
except Exception as e:
    print(e)
    errors.append(e)

What is the sum of 2372167 and 7765676?


#| hide
#### 5. Get the function return value with the arguments from the LLM response

In [None]:
#| hide
try:
    for tool in response['message'].get('tool_calls'):
        print("tool names:", tool['function'].get('name'))
        print("args:", tool['function'].get('arguments'))
        # print first arg
        a = tool['function'].get('arguments').get('a')
        b = tool['function'].get('arguments').get('b')
        output = f"return: {add_two_integers(a, b)}"
        print(output)
except Exception as e:
    print(f"{e}. If you get an error message, try re-running the cell above.")
    errors.append(e)

tool names: add_two_integers
args: {'a': 2372167, 'b': 7765676}
return: 10137843


#| hide
#### 6. Add the output to the messages list

In [None]:
#| hide
try:
    messages.append(
        {
            'role': 'tool',
            'content': output,
        }
    )
    messages
except Exception as e:
    print(f"Error: {e}. Try fixing the error by re-running the two cells above.")
    errors.append(e)


#| hide
#### 7. Second API call: get the final reponse from the model

In [None]:
#| hide
try:
    response = ollama.chat(
        model="llama3.1",
        messages=messages,
        stream=True,
    )
    for chunk in response:
        print(chunk['message']['content'], end='', flush=True)
except Exception as e:
    print(e)
    errors.append(e)

The sum of 2372167 and 7765676 is 10137843.

#| hide

Next, we will define helper functions to siplify the steps above. Documentation continues below.

## Helper functions
The Ollama framework supports [tool calling](https://ollama.com/blog/tool-support) (also referred to as function calling) with models such as Llama 3.1. To leverage function calling, we need to pass the JSON schema for any given function as an argument to the LLM. This is the information based on which the LLM infers the most appropriate tool to use given a prompt, and which parameters/arguments to pass to a function. To simplify the process of generating JSON schemas, use the helper and decorator functions defined below.

In [None]:
#| export
import inspect
from typing import Callable, Dict, List, Tuple, Optional, Any, Union 
import json
import os
import random
import ollama

In [None]:
#| export
def generate_json_schema(func: Callable) -> Dict[str, Any]:
    """
    Generate a JSON schema for the given function based on its annotations and docstring.
    
    Args:
        func (Callable): The function to generate a schema for.
    
    Returns:
        Dict[str, Any]: A JSON schema for the function.
    """
    annotations = func.__annotations__
    doc = inspect.getdoc(func)
    
    schema = {
        "type": "function",
        "function": {
            "name": func.__name__,
            "description": "",
            "parameters": {
                "type": "object",
                "properties": {},
                "required": []
            }
        }
    }
    
    if doc:
        lines = doc.split('\n')
        description = []
        arg_descriptions = {}
        in_args_section = False
        
        for line in lines:
            line = line.strip()
            if line.lower().startswith('args:'):
                in_args_section = True
                continue
            elif line.lower().startswith('returns:') or line.lower().startswith('raises:'):
                break
            
            if not in_args_section:
                description.append(line)
            else:
                parts = line.split(':')
                if len(parts) >= 2:
                    # print(parts) # Uncomment for debugging
                    arg_name = parts[0].split(' ')[0].strip()
                    arg_desc = ':'.join(parts[1:]).strip()
                    # print(arg_name, arg_desc) # Uncomment for debugging
                    arg_descriptions[arg_name] = arg_desc
                    # print(arg_descriptions) # Uncomment for debugging
                    if 'optional' not in parts[0].lower():
                        schema['function']['parameters']['required'].append(arg_name)

        
        schema['function']['description'] = ' '.join(description).strip()
    
    for arg, arg_type in annotations.items():
        if arg != 'return':
            schema['function']['parameters']['properties'][arg] = {
                "type": _get_type(arg_type),
                "description": arg_descriptions.get(arg, "")
            }
    
    return schema

def _get_type(arg_type):
    if 'List' in str(arg_type):
        return "list"
    elif 'Optional' in str(arg_type):
        return "optional"
    elif arg_type == int:
        return "integer"
    elif arg_type == str:
        return "string"
    elif arg_type == float:
        return "number"
    elif arg_type == bool:
        return "boolean"
    else:
        return "object"

In [None]:
#| export
import functools
from typing import TypeVar

In [None]:
#| export
T = TypeVar('T', bound=Callable)

def json_schema_decorator(func: T) -> T:
    """
    Decorator to generate and attach a JSON schema to a function.
    
    Args:
        func (Callable): The function to decorate.
    
    Returns:
        Callable: The decorated function with an attached JSON schema.
    """
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    
    schema = generate_json_schema(func)
    wrapper.json_schema = schema  # Attach the schema dictionary directly
    return wrapper

#| hide

The following code cells are for testing purposes only. They are not part of the final library.

In [None]:
#| hide
@json_schema_decorator
def add_two_integers(a: int, b: int) -> int:
    """
    A function to add two numbers.

    Args:
        a (int): First number.
        b (int): Second number.

    Returns:
        int: Sum of the two numbers.

    Raises:
        TypeError: If inputs are not integers.
    """
    if not isinstance(a, int) or not isinstance(b, int):
        raise TypeError("Both inputs must be integers")
    
    return a + b

assert generate_json_schema(add_two_integers) == add_two_integers.json_schema
assert generate_json_schema(add_two_integers)['function']['name'] == "add_two_integers"
assert generate_json_schema(add_two_integers)['function']['description'] == "A function to add two numbers."
assert generate_json_schema(add_two_integers)['function']['parameters']['properties']['a']['description'] == "First number."
assert generate_json_schema(add_two_integers)['function']['parameters']['properties']['b']['description'] == "Second number."
assert generate_json_schema(add_two_integers)['function']['parameters']['properties']['a']['type'] == "integer"
assert generate_json_schema(add_two_integers)['function']['parameters']['properties']['b']['type'] == "integer"
assert generate_json_schema(add_two_integers)['function']['parameters']['required'] == ['a', 'b']
add_two_integers.json_schema

{'type': 'function',
 'function': {'name': 'add_two_integers',
  'description': 'A function to add two numbers.',
  'parameters': {'type': 'object',
   'properties': {'a': {'type': 'integer', 'description': 'First number.'},
    'b': {'type': 'integer', 'description': 'Second number.'}},
   'required': ['a', 'b']}}}

In [None]:
#| hide
# Generate another JSON schema for a different function for testing purposes
@json_schema_decorator
def multiply_two_integers(a: int, b: int) -> int:
    """
    A function to multiply two numbers.

    Args:
        a (int): First number.
        b (int): Second number.

    Returns:
        int: Product of the two numbers.

    Raises:
        TypeError: If inputs are not integers.
    """

    if not isinstance(a, int) or not isinstance(b, int):
        raise TypeError("Both inputs must be integers")
    
    return a * b

assert generate_json_schema(multiply_two_integers) == multiply_two_integers.json_schema
assert generate_json_schema(multiply_two_integers)['function']['parameters']['properties']['a']['description'] == "First number."
assert generate_json_schema(multiply_two_integers)['function']['parameters']['properties']['b']['description'] == "Second number."
assert generate_json_schema(multiply_two_integers)['function']['parameters']['required'] == ['a', 'b']
multiply_two_integers.json_schema

{'type': 'function',
 'function': {'name': 'multiply_two_integers',
  'description': 'A function to multiply two numbers.',
  'parameters': {'type': 'object',
   'properties': {'a': {'type': 'integer', 'description': 'First number.'},
    'b': {'type': 'integer', 'description': 'Second number.'}},
   'required': ['a', 'b']}}}

In [None]:
#| hide
# Test decorator with an optional argument
@json_schema_decorator
def process_data(data: List[int], threshold: Optional[int] = None) -> List[int]:
    """
    Process a list of integers based on an optional threshold.

    This function filters and modifies the input list of integers.
    If a threshold is provided, it keeps only the numbers above the threshold.
    If no threshold is given, it returns the original list.

    Args:
        data (List[int]): A list of integers to process.
        threshold (Optional[int], optional): The minimum value to keep. Defaults to None.

    Returns:
        List[int]: A list of processed integers.

    Raises:
        ValueError: If the input data is empty.

    Example:
        >>> process_data([1, 5, 3, 7, 2], threshold=3)
        [5, 7]
    """
    if not data:
        raise ValueError("Input data cannot be empty")

    if threshold is None:
        return data
    else:
        return [num for num in data if num > threshold]

assert generate_json_schema(process_data) == process_data.json_schema
assert generate_json_schema(process_data)['function']['name'] == "process_data"
assert "Process a list of integers based" in generate_json_schema(process_data)['function']['description']
assert "If no threshold is given, it returns the original list." in generate_json_schema(process_data)['function']['description']
assert generate_json_schema(process_data)['function']['parameters']['properties']['data']['description'] == "A list of integers to process."
assert generate_json_schema(process_data)['function']['parameters']['properties']['threshold']['description'] == "The minimum value to keep. Defaults to None."
assert generate_json_schema(process_data)['function']['parameters']['required'] == ['data']
process_data.json_schema

{'type': 'function',
 'function': {'name': 'process_data',
  'description': 'Process a list of integers based on an optional threshold.  This function filters and modifies the input list of integers. If a threshold is provided, it keeps only the numbers above the threshold. If no threshold is given, it returns the original list.',
  'parameters': {'type': 'object',
   'properties': {'data': {'type': 'list',
     'description': 'A list of integers to process.'},
    'threshold': {'type': 'optional',
     'description': 'The minimum value to keep. Defaults to None.'}},
   'required': ['data']}}}

In [None]:
#| hide
# In this code block, we will test tool use. We provide the LLM with multiple tools and let it decide which tool to use.
a = random.randint(1, 10000000)
b = random.randint(1, 10000000)
prompts = [f"What is the sum of {a} and {b}?", f"What is the product of {a} and {b}?"]

sys_message = "You are a an assistant provided with tools to help you answer questions."

for prompt in prompts:
    messages=[
        {'role': "system", 'content': sys_message},
        {'role': 'user', 'content': prompt},
    ]
    try:
        response = ollama.chat(
            model="llama3.1",
            messages=messages,
            tools=[add_two_integers.json_schema, multiply_two_integers.json_schema],
        )

        print(prompt)
        print(response.get('message').get('tool_calls'))
    except Exception as e:
        print(e)
        errors.append(e)

What is the sum of 1303847 and 8919027?
[{'function': {'name': 'add_two_integers', 'arguments': {'a': 1303847, 'b': 8919027}}}]
What is the product of 1303847 and 8919027?
[{'function': {'name': 'multiply_two_integers', 'arguments': {'a': 1303847, 'b': 8919027}}}]


## Local file processing: listing, content extraction, and summarization

:::{.callout-note}
Below are the core functions the assistant can call. With these functions, the assistant will be able to get a list of file names in a specific data directory, can extract content from these files, and summarize them.
:::

#| hide

Note that for the json_schema_decorator() function to work properly, the function definitions require type hints and docstrings as shown below. Intentionally, docstrings are verbose, function names are descriptive, and type hints are explicitly set. This is because the LLM will make function calling decisions based on the function name, type annotations, and information in the docstring. 

It's crucial to understand that this metadata (function name, type hints, and docstring) is all the information the LLM has access to when deciding which function to call and how to use it. The LLM does not have access to or information about the actual source code or implementation of the functions (unless explicitly provided). Therefore, the metadata must be comprehensive and accurate to ensure proper function selection and usage by the LLM.

We start by defining a tool that retrieves a list of file names with specified extensions in a specific data directory the assistant has access to.

In [None]:
#| export
@json_schema_decorator
def get_file_names(ext: str = "pdf, txt") -> str:
    """Retrieves a list of file names with specified extensions in a local data directory the assistant has access to on the user's computer.

    Args:
        ext: A comma-separated string of file extensions to filter the files by. Options are: pdf, txt, md, markdown, csv, and py. Defaults to "pdf, txt".

    Returns:
        str: A comma-separated string of file names with the specified extensions. If no files are found, a message is returned.

    Example:
        >>> get_file_names(ext="pdf, txt")
        
        "List of file names with the specified extensions in the local data directory: file1.pdf, file2.txt"
    """

    if 'DIR_PATH' not in globals():
        return "Error: The local data directory path is not defined."
    
    if not os.path.exists(DIR_PATH):
        return f"Error: The local data directory does not exist."

    valid_extensions = ["pdf", "txt", "md", "markdown", "csv", "py"]

    # Process the input extensions
    selected_extensions = []
    for e in ext.split(','):
        e = e.lower().strip().lstrip('.').strip('{').strip('}').strip('[').strip(']').strip('(').strip(')').strip('"').strip("'") # Clean up the extension string
        if e not in valid_extensions:
            return f"Error: Invalid file extension '{e}'. Please choose from: pdf, txt, md, markdown, csv, py." # Instead of raising an error, we return a message to the LLM to avoid stopping a conversation
        selected_extensions.append(e)

    # List all files with the specified extensions
    file_names = [file for file in os.listdir(DIR_PATH) if any(file.endswith(f".{e}") for e in selected_extensions)]
    # print(file_names) # Uncomment for debugging

    if len(file_names) == 0:
         return "Access of local data directory successful but no files found with the specified extensions."

    file_names_json = ', '.join(file_names)

    # Convert list to a comma-separated string. This is because the object is returned to the LLM and the API accepts str only
    return f"List of file names with the specified extensions in the local data directory: {file_names_json}"

In [None]:
#| export
assert type(get_file_names.json_schema) == dict
assert get_file_names.json_schema['function']['name'] == "get_file_names"

#| hide

To test this and other functions defined below with an actual file, download a sample article from the internet:

In [None]:
#| hide
if "CONDA_PREFIX" in os.environ:
    !mkdir -p ../data
    pdf_urls = [
        "https://df6sxcketz7bb.cloudfront.net/manuscripts/144000/144499/jci.insight.144499.v2.pdf",
    ]
    for url in pdf_urls:
        !curl -o ../data/$(basename {url}) {url}

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1722k  100 1722k    0     0  6237k      0 --:--:-- --:--:-- --:--:-- 6242k


#| hide

Below are several test cases for the tool.

In [None]:
#| hide
import copy
import tempfile
import shutil

In [None]:
#| hide
global DIR_PATH
DIR_PATH = "../data"

In [None]:
#| hide
# Test case for get_file_names function
if "DIR_PATH" in globals():
    original_DIR_PATH = copy.deepcopy(DIR_PATH)
    del globals()["DIR_PATH"] # Remove the global variable for testing purposes
with tempfile.TemporaryDirectory() as temp_dir: 
    assert get_file_names() == "Error: The local data directory path is not defined.", "Test failed: get_file_names() should return an error message if DIR_PATH is not defined."
    global DIR_PATH
    DIR_PATH = temp_dir
    assert get_file_names() == "Access of local data directory successful but no files found with the specified extensions.", "Test failed: get_file_names() should return a message if no files are found in the directory."
    test_files = ["test_file1.txt", "test_file2.py"] # Create some test files
    for file in test_files:
        with open(os.path.join(DIR_PATH, file), 'w') as f:
            f.write("Test content")
    assert "test_file1.txt" in get_file_names(ext="txt"), "Test failed: get_file_names() should return a string with the test file name."
    assert "test_file2.py" not in get_file_names(ext="txt"), "Test failed: get_file_names() should not return a file name with a different extension."
    assert "test_file2.py" in get_file_names(ext="txt, py"), "Test failed: get_file_names() should return a string with the python test file name."
    print("All tests passed successfully!")

global DIR_PATH
DIR_PATH = original_DIR_PATH

All tests passed successfully!


In [None]:
#| hide
# Test function calling with the LLM using the get_file_names tool, if no files are found
# Store the original DIR_PATH
if "DIR_PATH" in globals():
    original_DIR_PATH = copy.deepcopy(DIR_PATH)

# Create a temporary, empty directory 
with tempfile.TemporaryDirectory() as temp_dir:
    # Set DIR_PATH to the temporary directory
    global DIR_PATH
    DIR_PATH = temp_dir

    prompt = f"Can you provide me with a list of PDF files you have access to on the local computer?"

    sys_message = """You are a helpful AI assistant. Respond with the shortest possible answers while maintaining clarity and accuracy. 
    Avoid unnecessary details, explanations, or pleasantries. Be direct and to the point in all responses.

    Key instructions:
    1. Use the provided tool to gather information before answering.
    2. Interpret tool results and provide clear, concise answers in natural language.
    3. If you can't answer with available tools, state this clearly.
    4. Don't provide information if tool content is empty.
    """

    messages=[
        {'role': "system", 'content': sys_message},
        {'role': 'user', 'content': prompt},
    ]

    try:
        # Make a request to the LLM to select a tool
        response = ollama.chat(
            model="llama3.1",
            messages=messages,
            tools=[get_file_names.json_schema],
        )
    
        if response.get('message').get('tool_calls'):
            for tool in response['message']['tool_calls']:
                function_to_call = tool['function']['name']
                print(f"Calling {function_to_call}()...\n")

        # Call the function
        for tool in response['message'].get('tool_calls'):
            output = f"return: {get_file_names()}"
            messages.append(
                {
                    'role': 'tool',
                    'content': output,
                }
            )

        # Make a second request to the LLM with the tool output to generate a final response
        response = ollama.chat(
            model="llama3.1",
            messages=messages,
            stream=False,
        )

        print(response['message']['content'])
    except Exception as e:
        print(e)

# Reset DIR_PATH to its original value
if "DIR_PATH" in globals():
    DIR_PATH = original_DIR_PATH

Calling get_file_names()...

No PDF files available.


In [None]:
#| hide
import uuid

In [None]:
#| hide
# Test function calling; create a test file and check if it is included in the response
if "CONDA_PREFIX" in os.environ: # For local testing only 
    unique_suffix = uuid.uuid4().hex[:8]  # Use first 8 characters of a UUID
    test_file_name = f"test_{unique_suffix}.pdf"
    test_file_path = os.path.join(DIR_PATH, test_file_name)

    print(f"Creating test file: {test_file_name}\n")
    open(test_file_path, 'a').close()  # Create an empty file

    prompt = f"Can you provide me with a list of PDF files you have access to?"

    sys_message = """You are a helpful AI assistant. Respond with the shortest possible answers while maintaining clarity and accuracy. 
    Avoid unnecessary details, explanations, or pleasantries. Be direct and to the point in all responses.

    Key instructions:
    1. Use the provided tool to gather information before answering.
    2. Interpret tool results and provide clear, concise answers in natural language.
    3. If you can't answer with available tools, state this clearly.
    4. Don't provide information if tool content is empty.
    """

    messages=[
        {'role': "system", 'content': sys_message},
        {'role': 'user', 'content': prompt},
    ]

    try:
        # Make a request to the LLM to select a tool
        response = ollama.chat(
            model="llama3.1",
            messages=messages,
            tools=[get_file_names.json_schema],
        )

        if response.get('message').get('tool_calls'):
            for tool in response['message']['tool_calls']:
                function_to_call = tool['function']['name']
                print(f"Calling {function_to_call}()...\n")

        # Call the function
        for tool in response['message'].get('tool_calls'):
            output = f"return: {get_file_names()}"
            # print(output)
            messages.append(
                {
                    'role': 'tool',
                    'content': output,
                }
            )

        # Make a second request to the LLM with the tool output to generate a final response
        response = ollama.chat(
            model="llama3.1",
            messages=messages,
            stream=False,
        )
        print(f"LLM response:\n{response['message']['content']}")

    except Exception as e:
        print(e)
        errors.append(e)

    print(f"\n\nRemoving test file...")
    os.remove(test_file_path)
    print("Test completed.")

Creating test file: test_750525c0.pdf

Calling get_file_names()...

LLM response:
Here is the list of available PDF files:

1. test_750525c0.pdf
2. jci.insight.144499.v2.pdf


Removing test file...
Test completed.


:::{.callout-note}
The next two functions are executed when calling the tool `get_titles_and_first_authors`, which is defined below, to extract the title and first author from PDF files in the local data directory. The first function, `extract_text_from_pdf`, extracts text from a PDF file using the **PyPDF2** library. The second function, `extract_title_and_first_author`, will then use the LLM to extract the title and first author from the text extracted from the PDF file.

In addition, the `extract_text_from_pdf` function can also be called directly by the assistant to extract user-specified pages or sections from a PDF file, to respond to user queries or to extract specific information from a PDF file specified by the user.
:::

In [None]:
#| export
import PyPDF2


In [None]:
#| export
@json_schema_decorator
def extract_text_from_pdf(file_name: str, page_range: Optional[str] = None) -> str:
    """A function that extracts text from a PDF file.
    Use this tool to extract specific details from a PDF document, such as abstract, authors, conclusions, or user-specified content of other sections.
    If the user specifies a page rage, use the optional page_range parameter to extract text from specific pages.
    If the user uses words such as beginning, middle, or end, to descripe the section, infer the page range based on the total number of 15 pages in a document.
    Do not use this tool to summarize an entire PDF document. Only use this tool for documents with extensions .pdf, or .PDF.

    Args:
        file_name (str): The file name of the PDF document in the local data directory.
        page_range (Optional[str]): A string with page numbers and/or page ranges separated by commas (e.g., "1" or "1, 2", or "5-7"). Default is None to extract all pages.

    Returns:
        str: Extracted text from the PDF.

    Example:
    >>> text = extract_text_from_pdf("./test.pdf", page_range="1")
    """
    verbose = False  # Set to True for debugging

    # Check if the directory exists
    if not os.path.exists(DIR_PATH):
        return "Local data directory not found."
    
    # Construct the full file path; this will be handled by a try-except block because this function is alos used as a tool
    try:
        pdf_path = os.path.join(DIR_PATH, file_name)
    except Exception as e:
        return f"Error: {e}"
    
    # Validate input
    if not os.path.exists(pdf_path):
        raise FileNotFoundError(f"File not found: {pdf_path}")
    assert pdf_path.endswith('.pdf'), "The file must be a PDF."
    assert isinstance(page_range, (str, type(None))), "The page_range must be a string or None."

    if page_range:
        # Clean up page range input in case of LLM formatting errors; must be a string, e.g., "1" or "1, 2", or "5-7", with no quotes or brackets   
        page_range = page_range.strip().replace('"', '').replace("'", "").replace("[", "").replace("]", "")

        # Parse the page range string
        page_numbers = []
        for part in page_range.split(','):
            if '-' in part:
                a, b = part.split('-')
                page_numbers.extend(range(int(a), int(b) + 1))
            else:
                page_numbers.append(int(part))
        if verbose:
            print(f"Extracting text from pages: {page_numbers}")

    # Validate page numbers
    if page_range:
        start_page, end_page = min(page_numbers), max(page_numbers)
        if start_page < 1 or end_page < start_page:
            raise ValueError("Invalid page range. Please provide a valid range of pages to extract text from.")

    # Extract text from the PDF
    text = ""
    try:
        with open(pdf_path, 'rb') as file:
            pdf = PyPDF2.PdfReader(file)
            
            # Check for invalid page numbers
            if page_range:
                max_page = len(pdf.pages)
                invalid_pages = [p for p in page_numbers if p < 1 or p > max_page]
                if invalid_pages:
                    page_range = None # Reset page range to extract all pages
            
            if page_range:
                for page_num in page_numbers:
                    if verbose:
                        print(f"Extracting text from page {page_num}...")
                    page = pdf.pages[page_num - 1]  # Adjust for 0-based index
                    text += f"page {page_num} of {len(pdf.pages)}\n{page.extract_text()}"
            else:
                for page_num, page in enumerate(pdf.pages):
                    if verbose:
                        print(f"Extracting text from page {page_num + 1}...")
                    text += f"page {page_num} of {len(pdf.pages)}\n{page.extract_text()}"
            
            if verbose:
                word_count = len(text.split())
                total_pages = len(page_numbers) if page_range else len(pdf.pages)
                print(f"Text extraction completed.\nTotal pages extracted: {total_pages}\nWord count: {word_count}\nNo. of characters (with spaces): {len(text)}")
    except PyPDF2.errors.PdfReadError as e:
        return f"Error reading PDF: {e}"

    return text

In [None]:
#| hide
# Testing the function with a sample PDF file
# For this test to work, the specified PDF file must exist in the data directory
try: 
    extract_text_from_pdf("../data/jci.insight.144499.v2.pdf", page_range="1")
    extract_text_from_pdf("../data/jci.insight.144499.v2.pdf", page_range="1, 2")
    extract_text_from_pdf("../data/jci.insight.144499.v2.pdf", page_range="1-3") # Should handle page ranges with hyphens
    extract_text_from_pdf("../data/jci.insight.144499.v2.pdf", page_range=" 1-3 ") # Should handle extra spaces
    extract_text_from_pdf("../data/jci.insight.144499.v2.pdf", page_range="1, 3") 
    extract_text_from_pdf("../data/jci.insight.144499.v2.pdf", page_range="1-3, 5") # Should handle multiple ranges
    extract_text_from_pdf("../data/jci.insight.144499.v2.pdf", page_range="1-1000") # Invalid page range will return all pages
    extract_text_from_pdf("../data/jci.insight.144499.v2.pdf", page_range="[1]") # Should handle invalid formatting; brackets will be removed

except FileNotFoundError as e:
    print(e)
    errors.append(e)

#| hide

Note: The following function will be defined but not be callable directly by the LLM. 
Hence, we will not generate a JSON schema for this function.

In [None]:
#| export
from tqdm import tqdm


In [None]:
#| export
def extract_title_and_first_author(contents: List[Dict[str, str]], model: str='llama3.1', verbose: Optional[bool] = False, show_progress: Optional[bool] = False) -> List[Dict[str, str]]:
    """
    A function that extracts the titles and the first author's names from the text of one or more research articles.

    Args:
        contents (List[Dict[str, str]]): A list of dictionaries containing the file name and extracted text.
        model (str): The model to use for the extraction. Default is 'llama3.1'.
        verbose (Optional[bool]): Whether to print additional information. Default is False.
        show_progress (Optional[bool]): Whether to show a progress bar. Default is False.

    Returns:
        contents (List[Dict[str, str]]): The input list of dictionaries with the extracted title and first author added.

    Raises:
        JSONDecodeError: If the JSON response is invalid.

    Example:
    >>> contents = extract_title_and_first_author(contents)
    Extracting titles and first authors: 100%|██████████| 3/3 [00:22<00:00,  7.35s/it]
    """

    prompt = """
    The text below between the <text> XML like tags is extracted from the first page of a research article. 
    Your task is to identify the title of the research article and the first author's name.
    The title is typically located immediately before the authors' names and the abstract.

    <text>
    {text}
    </text>

    The output must be provided in JSON format shown in the following example.

    Example output:
    {{
        "title": "<title>",
        "first_author": "<first_author>"
    }}
    Write the JSON output and nothing more. Do not include degree titles or affiliations in the author's name.

    Here is the JSON output:
    """

    pdf_iterator = tqdm(contents, desc="Extracting titles and first authors") if show_progress else contents
    
    for pdf_item in pdf_iterator:
        text = pdf_item['extracted_text']
        if verbose:
            tqdm.write(pdf_item['file_path'])
            tqdm.write("---")
            tqdm.write(text[:500])
            tqdm.write("---")

        try:
            response = ollama.chat(model=model, messages=[
                {
                    'role': 'user',
                    'content': prompt.format(text=text),
                },
            ])
        except Exception as e:
            tqdm.write(f"Error: {e}")
            if "Connection refused" in str(e):
                tqdm.write("Make sure Ollama is running and the correct model is available.")
            continue

        if verbose:
            tqdm.write(response['message']['content'])

        try:
            extracted_info = json.loads(response['message']['content'])
        except json.JSONDecodeError:
            tqdm.write(f"Error: Invalid JSON response for {pdf_item['file_path']}")
            extracted_info = {"title": "", "first_author": ""}

        pdf_item.update(extracted_info)

    print("\n") if show_progress else None # Add a newline if showing progress bar
    return contents

In [None]:
#| hide
contents = [{
    "file_name": "de_medicina.pdf",
    "extracted_text": "Titulus: De Medicina\n\nAuctor: Aulus Cornelius Celsus\n\nLiber I\n\nUt alimenta sanis corporibus agricultura, sic sanitatem aegris Medicina promittit. Haec nusquam quidem non est; siquidem etiam imperitissimae gentes herbas aliaque prompta in auxilium vulnerum morborumque noverunt. Verumtamen apud Graecos aliquanto magis quam in ceteris nationibus exculta est, ac ne apud hos quidem a prima origine, sed paucis ante nos saeculis; utpote cum vetustissimus auctor Aesculapius celebretur. Qui quoniam adhuc rudem et vulgarem hanc scientiam paulo subtilius excoluit, in deorum numerum receptus est."
}]
assert extract_title_and_first_author(contents)

#| hide

Run the following cells to test the `extract_text_from_pdf` & `extract_title_and_first_author` functions separately on actual PDF files. Remember, both of these functions are executed when the LLM calls the `get_titles_and_first_authors` function. These tests will be executed from the module as they require PDF files to be present in the data directory.

In [None]:
#| hide
# First, get the file paths from the local data directory
# Make sure to have some PDF files in the data directory for this test
DIR_PATH = "../data/"
file_names = [file for file in os.listdir(DIR_PATH) if file.endswith('.pdf')]
file_paths = [os.path.join(DIR_PATH, file) for file in file_names]
print("File paths:", file_paths)

File paths: ['../data/jci.insight.144499.v2.pdf']


In [None]:
#| hide
# Next, extract the text from the first page of each PDF file
pdf_contents = []
for file_path in file_paths:
    pdf_contents.append({
        "file_name": file_path.split("/")[-1],
        "extracted_text": extract_text_from_pdf(file_path, page_range="1")
        })
pdf_contents

[{'file_name': 'jci.insight.144499.v2.pdf',
  'extracted_text': 'page 1 of 15\n1\nRESEARCH ARTICLEConflict of interest: The authors have \ndeclared that no conflict of interest \nexists.\nCopyright: © 2021, Khan et al. This is \nan open access article published under \nthe terms of the Creative Commons \nAttribution 4.0 International License.\nSubmitted: September 21, 2020 \nAccepted: January 13, 2021 \nPublished: January 26, 2021\nReference information: JCI Insight. \n2021;6(4):e144499. \nhttps://doi.org/10.1172/jci.\ninsight.144499.Distinct antibody repertoires against \nendemic human coronaviruses in  \nchildren and adults\nTaushif Khan,1 Mahbuba Rahman,1 Fatima Al Ali,1 Susie S. Y . Huang,1 Manar Ata,1 Qian Zhang,2  \nPaul Bastard,2,3,4 Zhiyong Liu,2 Emmanuelle Jouanguy,2,3,4 Vivien Béziat,2,3,4 Aurélie Cobat,2,3,4  \nGheyath K. Nasrallah,5,6 Hadi M. Yassine,5,6 Maria K. Smatti,6 Amira Saeed,7 Isabelle Vandernoot,8 \n Jean-Christophe Goffard,9 Guillaume Smits,8 Isabelle Migeotte,10

#| hide

Now, we iterate through the contents and run the `extract_title_and_first_author` function to extract the title and first author from the text extracted from the PDF files.

In [None]:
#| hide
if len(pdf_contents) > 0:
    response = extract_title_and_first_author(pdf_contents, show_progress=True)
    for article in response:
        print(f"{article.get('file_name')}: '{article.get('title')}'")
else: 
    print("Add PDF files to the data directory to test the functions.")

Extracting titles and first authors: 100%|██████████| 1/1 [00:07<00:00,  7.35s/it]



jci.insight.144499.v2.pdf: 'Distinct antibody repertoires against endemic human coronaviruses in children and adults'





:::{.callout-note}
The next function combines the two functions `extract_text_from_pdf` and `extract_title_and_first_author` to extract the title and first author from PDF files in the local data directory. This function will be callable by the LLM.
:::

In [None]:
#| export
@json_schema_decorator
def get_titles_and_first_authors() -> str:
    """
    A function that retrieves the titles of research articles from a directory of PDF files.

    Returns:
        str: A JSON-formatted string containing the titles, first authors and file names of the research articles.

    Raises:
        FileNotFoundError: If the specified directory does not exist.

    Example:
    >>> get_titles_and_first_authors()
    """

    # Check if the directory exists
    if not os.path.exists(DIR_PATH):
        return "Local data directory not found."

    # Initialize an empty list to store the file names and extracted text
    pdf_contents = []

    # Initialize an empty list of dictionaries to store file names, extracted titles and first authors
    titles_and_authors = []

    # Get the file paths of all PDF files in the local data directory
    file_paths = [os.path.join(DIR_PATH, file) for file in os.listdir(DIR_PATH) if file.endswith('.pdf')]

    # Extract text from the first page of each PDF file
    for file_path in file_paths:
        pdf_contents.append({
            "file_name": file_path.split("/")[-1],
            "extracted_text": extract_text_from_pdf(file_path, page_range="1")
            })

    # Extract titles and first authors from the extracted text of each PDF file
    response = extract_title_and_first_author(pdf_contents, show_progress=True)
    for article in response:
        titles_and_authors.append({
            "title": article.get('title'), 
            "first_author": article.get('first_author'),
            "file_name": article.get('file_name'),
            })
    
    if not titles_and_authors:
        return "No titles found."
        
    return json.dumps(titles_and_authors, indent=2) # Return as a JSON-formatted string

In [None]:
#| hide
get_titles_and_first_authors.json_schema

{'type': 'function',
 'function': {'name': 'get_titles_and_first_authors',
  'description': 'A function that retrieves the titles of research articles from a directory of PDF files.',
  'parameters': {'type': 'object', 'properties': {}, 'required': []}}}

In [None]:
#| export
assert type(get_titles_and_first_authors.json_schema) == dict
assert get_titles_and_first_authors.json_schema['function']['name'] == "get_titles_and_first_authors"

In [None]:
#| hide
print(get_titles_and_first_authors())
print(type(get_titles_and_first_authors()))

Extracting titles and first authors: 100%|██████████| 1/1 [00:01<00:00,  1.38s/it]




[
  {
    "title": "Distinct antibody repertoires against endemic human coronaviruses in children and adults",
    "first_author": "Taushif Khan",
    "file_name": "jci.insight.144499.v2.pdf"
  }
]


Extracting titles and first authors: 100%|██████████| 1/1 [00:01<00:00,  1.40s/it]



<class 'str'>





:::{.callout-note}
The functions below are executed to summarize the content of files in the local data directory.
:::

In [None]:
#| export
@json_schema_decorator
def summarize_local_document(file_name: str, ext: str = "pdf") -> str:
    """Summarize the content of a single PDF, markdown, or text document from the local data directory.

    Args:
        file_name (str): The file name of the local document to summarize.
        ext (str): The extension of the local document. Options are: pdf, txt, md, and markdown. Defaults to "pdf".

    Returns:
        str: The summary of the content of the local document.

    Example:
        >>> summarize_local_document("research_paper", ext="pdf")
    """
    # Check if the directory exists
    if not os.path.exists(DIR_PATH):
        return "Local data directory not found."

    # Ensure the extension is valid: delete spaces, hyphens, and quotation marks, and convert to lowercase in case of LLM input errors
    ext = ext.lower().replace('"', '').replace("'", "").replace(" ", "").replace(".", "")
    if ext not in ["pdf", "txt", "md", "markdown"]:
        return f"Invalid file extension '{ext}'. Please choose from: pdf, txt, md, markdown."

    # Get the file paths of all files in the local data directory
    file_paths = [os.path.join(DIR_PATH, file) for file in os.listdir(DIR_PATH) if file.endswith(ext)]
    # print(file_paths) # Uncomment for debugging
    
    # Find the file path that matches the specified file name
    file_path = [path for path in file_paths if file_name in path]
    # print(file_path) # Uncomment for debugging

    if not file_path:
        return f"No file found with the name '{file_name}' and extension '{ext}'."
    elif len(file_path) > 1:
        return f"Multiple files found with the name '{file_name}' and extension '{ext}'. Please specify a unique file name."
    else:
        file_path = file_path[0] # Convert the file path list to a string

    if ext == "pdf":
        # Extract text from a PDF file
        try:
            full_text = extract_text_from_pdf(file_path, page_range=None)
        except Exception as e:
            return f"Error while extracting text from PDF file {file_name}: {e}"
    
    if ext in ["txt", "md", "markdown"]:
        # Read the full text content from a text file
        try:
            with open(file_path, 'r') as file:
                full_text = file.read()
        except Exception as e:
            return f"Error while reading the file {file_name}: {e}"
    # print(full_text[:500]) # Uncomment for debugging

    # Remove references section from the full text content, if present
    patterns = ["References", "REFERENCES", "references", "Bibliography", "BIBLIOGRAPHY", "bibliography"]
    for pattern in patterns:
        if pattern in full_text:
            full_text = full_text.split(pattern)[0]
            break

    # Summarize the full text content of the document
    prompt = f"""
    The text below is the full text content from single document with the file name '{file_name}'.

    <text>
    {full_text}
    </text>

    If the document has an abstract, use the abstract for the summary. The abstract is typically located at the beginning of the document (page 1) and provides a concise summary of the research.
    If there is no abstract, generate a concise summary (approx. 200 words) that captures the main points of the document, including key findings and conclusions.
    Remember that your task is to summarize the content of the main text accurately and concisely, ignoring acknowledgements and references that are typically listed at the end of the document, after the conclusion section.
    Start the summary with the title of the document, typically found on the first page before the author names.
    """

    sys_message = """You are a scientific summarization assistant for health and life sciences research. 
    Your task is to condense the contents of a complex research document with mutiple pages into a concise, accurate summary.
    If the document describes a research study, highlight the main findings, methodologies, and conclusions of the study.
    Start the summary with the title of the document found on the first page, followed by a brief summary of the content.
    Ignore references typically listed at the end of a document after the conclusion section, as well as acknowledgements, and other non-content sections.
    """

    # Set the model from the global variables
    if 'MODEL' in globals():
        model = MODEL
    else:
        model = "llama3.1" # Default model

    try:
        response = ollama.chat(
            model=model,
            messages=[
                {'role': "system", 'content': sys_message},
                {'role': 'user', 'content': prompt},
            ]
        )
        # TODO: Try out different optional parameters for the ollama.chat function, such as temperature, max_tokens, etc. to improve the quality of the summary
        # For details, see the Ollama API documentation:
        # https://github.com/ollama/ollama/blob/main/docs/api.md
        # https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values

        summary = response['message']['content']
        return summary

    except Exception as e:
        return f"Error while summarizing the content of the document '{file_name}': {e}" 

In [None]:
#| hide
summarize_local_document.json_schema

{'type': 'function',
 'function': {'name': 'summarize_local_document',
  'description': 'Summarize the content of a single PDF, markdown, or text document from the local data directory.',
  'parameters': {'type': 'object',
   'properties': {'file_name': {'type': 'string',
     'description': 'The file name of the local document to summarize.'},
    'ext': {'type': 'string',
     'description': 'The extension of the local document. Options are: pdf, txt, md, and markdown. Defaults to "pdf".'}},
   'required': ['file_name', 'ext']}}}

In [None]:
#| export
assert type(summarize_local_document.json_schema) == dict
assert summarize_local_document.json_schema['function']['name'] == "summarize_local_document"
assert 'file_name' in summarize_local_document.json_schema['function']['parameters']['properties'].keys()
assert 'ext' in summarize_local_document.json_schema['function']['parameters']['properties'].keys()

In [None]:
#| hide
if "CONDA_PREFIX" in os.environ: # For local testing only
    !wget -O ../data/modelfile.md https://raw.githubusercontent.com/ollama/ollama/main/docs/modelfile.md

--2024-09-11 21:44:25--  https://raw.githubusercontent.com/ollama/ollama/main/docs/modelfile.md
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12893 (13K) [text/plain]
Saving to: ‘../data/modelfile.md’


2024-09-11 21:44:26 (43,1 MB/s) - ‘../data/modelfile.md’ saved [12893/12893]



In [None]:
#| hide
if "CONDA_PREFIX" in os.environ: # For local testing only
    DIR_PATH = "../data" # for testing purposes; will be defined when the Assistant class is initialized
    print(summarize_local_document("modelfile.md", ext="md"))

I can’t provide a Modelfile as it is specific to a particular model. However, I can help you create a custom Modelfile for your needs.

To get started, what type of text would you like the AI to generate? For example, should it be:

1. **Answering questions**: Provide answers to user queries.
2. **Writing articles**: Generate articles on a given topic.
3. **Conversational dialogue**: Engage in natural-sounding conversations.
4. **Summarizing content**: Summarize long pieces of text into concise versions.

Please let me know your desired application, and I'll help you create a basic Modelfile with the necessary instructions to get started.


In [None]:
#| hide
if "CONDA_PREFIX" in os.environ: # For local testing only
    print(summarize_local_document("jci.insight.144499.v2.pdf"))

Here is a concise summary (approx. 200 words) based on the provided text:

Title: Deployment of Convalescent Plasma for the Prevention and Treatment of COVID-19.

The deployment of convalescent plasma has emerged as a potential strategy for the prevention and treatment of COVID-19. Studies have shown that pre-existing immunity to SARS-CoV-2 can modulate and impair antibody responses to subsequent infections, indicating the potential for convalescent plasma to be effective. Fc-mediated antibody effector functions during respiratory syncytial virus infection and disease also suggest a role for convalescent plasma.

Research has identified several key factors influencing the efficacy of convalescent plasma, including the presence of neutralizing antibodies, avidity responses in COVID-19 patients and donors, and inborn errors of type I IFN immunity. Studies have demonstrated that convalescent plasma can contain antibodies reacting against SARS-CoV-2 antigens, and pre-existing yellow fever 

In [None]:
#| export
@json_schema_decorator
def describe_python_code(file_name: str) -> str:
    """Describe the purpose of the Python code in a local Python file.
    This may involve summarizing the entire code, extracting key functions, or providing an overview of the code structure.

    Args:
        file_name (str): The file name of the local Python code file document to describe.

    Returns:
        str: A description of the purpose of the Python code in the local file.

    Example:
        >>> describe_python_code("main.py", ext="py")
    """
    # Check if the directory exists
    if not os.path.exists(DIR_PATH):
        return "Local data directory not found."

    # Get the file paths of all files in the local data directory
    file_paths = [os.path.join(DIR_PATH, file) for file in os.listdir(DIR_PATH) if file.endswith(".py")]
    # print(file_paths) # Uncomment for debugging
    
    # Find the file path that matches the specified file name
    file_path = [path for path in file_paths if file_name in path]
    # print(file_path) # Uncomment for debugging

    if not file_path:
        return f"No file found with the name '{file_name}' and extension '.py'."
    elif len(file_path) > 1:
        return f"Multiple files found with the name '{file_name}' and extension '.py'. Please specify a unique file name."
    else:
        file_path = file_path[0] # Convert the file path list to a string

    try:
        with open(file_path, 'r') as file:
            full_text = file.read()
    except Exception as e:
        return f"Error while reading the file {file_name}: {e}"
    # print(full_text[:500]) # Uncomment for debugging

    # Summarize the full text content of the document
    prompt = f"""
    The text below is the full Python code content from the file '{file_name}'.

    <code>
    {full_text}
    </code>

    Your task is to describe the purpose of the Python code in the file. This may involve summarizing the entire code, extracting key functions, or providing an overview of the code structure.
    """

    sys_message = """You are a programming assistant for Python code. Your task is to describe the purpose of Python code in a local file for the user who may not be familiar with the code,
    and may not know how to interpret the code. If not ask for specific content, provide a high-level overview of the code's functionality, key functions, and structure.
    """

    # Set the model from the global variables
    if 'MODEL' in globals():
        model = MODEL
    else:
        model = "llama3.1" # Default model

    try:
        response = ollama.chat(
            model=model,
            messages=[
                {'role': "system", 'content': sys_message},
                {'role': 'user', 'content': prompt},
            ]
        )
        # TODO: Try out different optional parameters for the ollama.chat function, such as temperature, max_tokens, etc. to improve the quality of the summary
        # For details, see the Ollama API documentation:
        # https://github.com/ollama/ollama/blob/main/docs/api.md
        # https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values

        summary = response['message']['content']
        return summary

    except Exception as e:
        return f"Error while describing the Python code in the file '{file_name}': {e}"

In [None]:
#| hide
describe_python_code.json_schema

{'type': 'function',
 'function': {'name': 'describe_python_code',
  'description': 'Describe the purpose of the Python code in a local Python file. This may involve summarizing the entire code, extracting key functions, or providing an overview of the code structure.',
  'parameters': {'type': 'object',
   'properties': {'file_name': {'type': 'string',
     'description': 'The file name of the local Python code file document to describe.'}},
   'required': ['file_name']}}}

In [None]:
#| export
assert type(describe_python_code.json_schema) == dict
assert describe_python_code.json_schema['function']['name'] == "describe_python_code"
assert 'file_name' in describe_python_code.json_schema['function']['parameters']['properties'].keys()

In [None]:
#| hide
if "CONDA_PREFIX" in os.environ: # For local testing only
    !cp ../scholaris/core.py ../data/core.py
    print(describe_python_code("core.py"))

The Python code appears to be part of a larger project that provides an interface for interacting with AI models through the ollama library. The code defines several classes and functions for managing conversations, displaying conversation history, clearing conversation history, and printing tool information.

Here's a high-level overview of the purpose of the code:

1. **Conversation Management**: The `Assistant` class manages conversations by storing messages in a list (`self.messages`). It provides methods to show conversion history, clear conversion history, and print tool information.
2. **Function Wrapping**: The `add_to_class` function is used to register new functions as methods of the `Assistant` class. This allows for easy extension of the conversation management functionality.
3. **Conversation History Display**: The `show_conversion_history` method displays the conversation history, including user input, assistant responses, and tool output (if applicable).
4. **Tool Inform

In [None]:
#| hide
!rm ../data/modelfile.md
!rm ../data/core.py
!rm ../data/.*

zsh:1: no matches found: ../data/.*


## External data retrieval from NCBI, OpenAlex and Semantic Scholar

:::{.callout-note}
The following functions are used to convert article IDs between different formats and detect the type of an article ID based on its format.

- `convert_id`: Converts article IDs between PubMed Central, PubMed, DOI, and manuscript ID formats using the NCBI ID Converter API.

- `detect_id_type`: Analyzes a given string to determine if it's a PMID, PMCID, DOI, OpenAlex ID, Semantic Scholar ID, potential article title, or an unknown format.

- `id_converter_tool`: Combines the functionality of the above two functions to process a list of IDs, detecting their types and converting them using the NCBI API. This is callable by the LLM assistant.
:::

In [None]:
#| export
import re
import requests

In [None]:
#| export
def convert_id(ids: List[str]) -> str:
    """
    For any article(s) in PubMed Central, find all the corresponding PubMed IDs (PMIDs), digital object identifiers (DOIs), and manuscript IDs (MIDs).
    
    Args:
    ids (List[str]): A list of IDs to convert (max 200 per request).
    
    Returns:
    Str: A JSON-formatted string containing the conversion results.
    """
    
    if 'EMAIL' in globals():
        email = EMAIL
    else:
        try:
            email = os.environ['EMAIL']
        except KeyError:
            return {"error": "Please provide an email address"}

    # API endpoint
    base_url = "https://www.ncbi.nlm.nih.gov/pmc/utils/idconv/v1.0/"
    
    # Prepare the IDs string
    ids_string = ",".join(ids)
    
    # Prepare the parameters
    params = {
        "tool": "scholaris",
        "email": email,
        "ids": ids_string,
        "format": "json"
    }
    
    try:
        # Make the API request
        response = requests.get(base_url, params=params)
        response.raise_for_status()  # Raise an exception for bad status codes
        
        return response.json()
    
    except requests.exceptions.RequestException as e:
        return {"error": str(e)}

def detect_id_type(id_string: str) -> str:
    """
    Detect the type of the given ID or title.
    
    Args:
    id_string (str): The ID or title to detect.
    
    Returns:
    str: The detected type ('pmid', 'pmcid', 'doi', 'openalex', 'semantic_scholar', 'potential_title', or 'unknown').
    """
    if re.match(r'^\d{1,8}$', id_string):
        return 'pmid'
    elif re.match(r'^PMC\d+$', id_string):
        return 'pmcid'
    elif re.match(r'^10\.\d{4,9}/[-._;()/:A-Z0-9]+$', id_string, re.I): # Detect DOIs, case-insensitive
        return 'doi'
    elif re.match(r'^[WAIC]\d{2,}$', id_string):
        return 'openalex'
    elif re.match(r'^[0-9a-f]{40}$', id_string): 
        return 'semantic_scholar'
    elif re.match(r'^[A-Z][\w\s:,\-()]{10,150}[.?!]?$', id_string, re.I): # A simple heuristic for detecting titles
        return 'potential_title'
    else:
        return 'unknown'


@json_schema_decorator
def id_converter_tool(ids: List[str]) -> str:
    """
    For any article(s) in PubMed Central, find all the corresponding PubMed IDs (PMIDs), digital object identifiers (DOIs), and manuscript IDs (MIDs).
    Use this tool to convert a list of IDs, such as PMIDs, PMCIDs, or DOIs, and find the corresponding IDs for the same articles.
    
    Args:
    ids (str): A string with a comma-separated list of IDs to convert. Must be PMIDs, PMCIDs, or DOIs. The maximum number of IDs per request is 200.
    
    Returns:
    str: A JSON-formatted string containing the conversion results and the detected ID types.
    """

    if isinstance(ids, str):
        ids = ids.replace("https://doi.org/", "").replace("doi.org/", "") # Remove DOI URL prefixes, if present
    
    # Convert the input string to a list of IDs
    try:
        ids = ids.split(",")
    except AttributeError:
        return json.dumps({"error": "Input must be a string of comma-separated IDs"})

    if len(ids) > 200:
        return json.dumps({"error": "Input IDs must be no more than 200"})
    
    # Detect ID types
    id_types = [detect_id_type(id_string) for id_string in ids]
    
    # Convert IDs
    conversion_result = convert_id(ids)
    
    # Check if conversion_result is already a dictionary (error case)
    if isinstance(conversion_result, dict):
        parsed_result = conversion_result
    else:
        # Parse the JSON string result
        try:
            parsed_result = json.loads(conversion_result)
        except json.JSONDecodeError:
            return json.dumps({"error": "Failed to parse conversion result"})
    
    # Prepare the result
    result = {
        "conversion_result": parsed_result,
        "detected_id_types": dict(zip(ids, id_types))
    }
    
    return json.dumps(result, indent=2)

#| hide
Check the JSON-schema for id_converter_tool:

In [None]:
#| hide
id_converter_tool.json_schema

{'type': 'function',
 'function': {'name': 'id_converter_tool',
  'description': 'For any article(s) in PubMed Central, find all the corresponding PubMed IDs (PMIDs), digital object identifiers (DOIs), and manuscript IDs (MIDs). Use this tool to convert a list of IDs, such as PMIDs, PMCIDs, or DOIs, and find the corresponding IDs for the same articles.',
  'parameters': {'type': 'object',
   'properties': {'ids': {'type': 'list',
     'description': 'A string with a comma-separated list of IDs to convert. Must be PMIDs, PMCIDs, or DOIs. The maximum number of IDs per request is 200.'}},
   'required': ['ids']}}}

#| hide

Test cases for `id_converter_tool`, `detect_id_types` and `convert_id` functions are provided below.

In [None]:
#| hide
# Test the id_converter_tool function with a list of mock IDs
test_cases = [
    ("12345678", "pmid"),
    ("PMC1234567", "pmcid"),
    ("10.1234/abcd.efg", "doi"),
    ("W1234567", "openalex"),
    ("1234567890abcdef1234567890abcdef12345678", "semantic_scholar"),
    ("This is a potential title", "potential_title"),
    ("abc123", "unknown")
]

# Run tests
for input_string, expected_output in test_cases:
    result = detect_id_type(input_string)
    # print(f"Input: {input_string}")
    # print(f"Expected: {expected_output}")
    # print(f"Result: {result}")
    # print("Pass" if result == expected_output else "Fail")
    # print()

# Count passed tests
passed_tests = sum(1 for input_string, expected_output in test_cases if detect_id_type(input_string) == expected_output)
print(f"Passed {passed_tests} out of {len(test_cases)} tests.")

Passed 7 out of 7 tests.


In [None]:
#| hide
if "CONDA_PREFIX" in os.environ: # For local testing only
    # Test with real IDs and by making a request to the API
    ids = ['38776920', '38701783', '38557723', '38422122', '38363432', '38175961', '38157855', 
    '38048195', '37875108', '37779520', '37448622', '37083451', '37047709', '36763636', 
    '36515678', '36736301', '36326697', '36342405', '36425144', '36094518', '36003377', 
    '35670811', '35091979', '35090163', '34427831', '34623332', '34183838', '34413140', 
    '34137790', '34214472', '34183371', '33876776', '33529170', '33497357', '33510449', 
    '32960813', '33296702', '32972995', '32163377', '31995689', '31784499', '31270247', 
    '31346092', '31046570', '30578925', '31231515', '31559014', '30578352', '30143481', 
    '29907691', '29537367', '28437470', '28069966', '27347375', '26720836', '25819983', 
    '24949794', '24332264', '24603545', '24391215', '24119913', '23543769', '23467413', 
    '23487427', '22535679', '21695123', '21050116', '20176798', '18582518', '18424515']

    ids_string = ",".join(ids)

    detected_id_types = json.loads(id_converter_tool(ids_string))["detected_id_types"]
    assert detected_id_types.values(), "No ID types were detected for any of the input IDs."

    data = convert_id(ids)
    assert data.get("status") == "ok", f"Conversion failed with status: {data.get('status', 'unknown')}"

    for record in data.get("records", []):
        assert record.get("pmid") or record.get("doi"), f"No PMID or DOI found for record: {record}"
        # print(f"PMID: {record.get('pmid', 'N/A')}, DOI: {record.get('doi', 'N/A')}, PMCID: {record.get('pmcid', 'N/A')}")
    print("All tests passed successfully!")

All tests passed successfully!


In [None]:
#| hide
if "CONDA_PREFIX" in os.environ: # For local testing only
    # Generate a list of actual DOIs and PMCIDs from the API response
    doi_list = []
    pmcid_list = []
    for record in data.get("records", []):
        doi_list.append(record.get("doi"))
        pmcid_list.append(record.get("pmcid"))
    doi_list = list(filter(None, doi_list)) # filter to remore None values
    pmcid_list = list(filter(None, pmcid_list)) # filter to remore None values

    for doi in doi_list:
        assert detect_id_type(doi) == "doi", f"Failed to detect DOI for {doi}"
    print("All DOIs detected successfully!")
    for pmcid in pmcid_list:
        assert detect_id_type(pmcid) == "pmcid", f"Failed to detect PMCID for {pmcid}"
    print("All PMCIDs detected successfully!")

All DOIs detected successfully!
All PMCIDs detected successfully!


:::{.callout-note}
The function `query_openalex_api` below is executed to query the OpenAlex database for additional information about a given a article, either by title, PubMed ID, PMC ID, or DOI.
:::

In [None]:
#| export
@json_schema_decorator
def query_openalex_api(query_param: str) -> str:
    """
    Retrieve metadata for a given article from OpenAlex, a comprehensive open-access catalog of global research papers.
    Use this tool to search the OpenAlex API by using the article title, the PubMed ID (PMID), the PubMed Central ID (PMCID) or the digital object identifier (DOI) of an article as the query parameter. 
    This tool returns the following metadata:
    - the OpenAlex ID
    - the digital object identifier (DOI) URL
    - Citation count
    - The open access status
    - URL to the open-access location for the work
    - Publication year
    - A URL to a website listing works that have cite the article
    - The type of the article
    Use this tool only if an article title, PubMed ID or DOI is provided by the user or was extracted from a local PDF file and is present in the conversation history.
    
    Args:
        query_param (str): The article title, the PubMed ID (PMID), the PubMed Central ID (PMCID) or the digital object identifier (DOI) of the article to retrieve metadata for. May be provided by the user or extracted from a local PDF file and present in the conversation history.

    Returns:
        str: A JSON-formatted string including the search results from the OpenAlex database. If no results are found or the API query fails, an appropriate message is returned.
    """

    if 'EMAIL' not in globals():
        try:
            EMAIL = os.getenv("EMAIL")
        except KeyError:
            EMAIL = ""

    # Validate the input
    if query_param is None or query_param == "":
        return json.dumps({"error": "The query parameter must be a non-empty string."})
    elif not isinstance(query_param, str):
        query_param = str(query_param)
    
    if query_param.startswith("https://doi.org/"):
        query_param = query_param.replace("https://doi.org/", "")
    elif query_param.startswith("doi.org/"):
        query_param = query_param.replace("doi.org/", "")

    # Constants
    base_url = "https://api.openalex.org/" # Define the base URL for the OpenAlex API
    
    # Initialize variables
    filter = ""

    if detect_id_type(query_param) == "potential_title":
        url = f"{base_url}works?"
        filter = f"title.search:{query_param}"
    elif detect_id_type(query_param) == "pmid":
        url = f"{base_url}works/pmid:{query_param}"
    elif detect_id_type(query_param) == "doi":
        url = f"{base_url}works/https://doi.org/{query_param}"
    elif detect_id_type(query_param) == "pmcid":
        url = f"{base_url}works/pmcid:{query_param}"
    else:
        return json.dumps({"error": "The query parameter must be a valid title, PMID, PMCID, or DOI."})
        
    # Set the query parameters
    params = {
        "mailto": EMAIL,
        "page": 1,
        "per-page": 5,
        "select": "id,doi,title,publication_year,cited_by_count,cited_by_api_url,open_access,type,type_crossref",
    }
    if filter:
        params["filter"] = filter # Add the filter parameter for title search

    response = requests.get(url, params=params)

    if response.status_code != 200:
        return {"error": f"Failed to query OpenAlex API. Status code: {response.status_code}"}

    raw_search_results = response.json()

    if filter:
        number_of_search_matches = raw_search_results['meta']['count']
        if len(raw_search_results['results']) == 0:
            return "No results found for the provided title."
        elif number_of_search_matches > 5:
            return "Error: The search results are more than 5. Please provide a correct title."
        else:
            formatted_results = []
            for result in raw_search_results['results']:
                formatted_results.append(result)
            return json.dumps(formatted_results, indent=2)

    
    if len(raw_search_results) == 0:
        return {"search result": "None"}
    else:
        return json.dumps(raw_search_results, indent=2)

#| hide

Assert the correct format of the JSON schema...

In [None]:
#| hide
assert type(query_openalex_api.json_schema) == dict
assert query_openalex_api.json_schema['function']['name'] == "query_openalex_api"
assert 'query_param' in query_openalex_api.json_schema['function']['parameters']['properties'].keys()
print("All tests passed!")

All tests passed!


#| hide

A test case for the `query_openalex_api()` function is shown below.

In [None]:
#| hide
import json
from unittest.mock import patch

# Mock responses for different scenarios
mock_title_response = {
    "meta": {"count": 1},
    "results": [{
        "id": "https://openalex.org/W1234567890",
        "doi": "https://doi.org/10.1234/example",
        "title": "Example Article",
        "publication_year": 2023,
        "cited_by_count": 42,
        "cited_by_api_url": "https://api.openalex.org/works?filter=cites:W1234567890",
        "open_access": {"is_oa": True, "oa_status": "gold"},
        "type": "journal-article",
        "type_crossref": "journal-article"
    }]
}

mock_pmid_response = {
    "id": "https://openalex.org/W9876543210",
    "doi": "https://doi.org/10.5678/example",
    "title": "Example Article with a PMID",
    "publication_year": 2022,
    "cited_by_count": 10,
    "cited_by_api_url": "https://api.openalex.org/works?filter=cites:W9876543210",
    "open_access": {"is_oa": False, "oa_status": "closed"},
    "type": "journal-article",
    "type_crossref": "journal-article"
}

mock_doi_response = {
    "id": "https://openalex.org/W1357924680",
    "doi": "https://doi.org/10.9876/example",
    "title": "Example Book Chapter",
    "publication_year": 2021,
    "cited_by_count": 100,
    "cited_by_api_url": "https://api.openalex.org/works?filter=cites:W1357924680",
    "open_access": {"is_oa": True, "oa_status": "bronze"},
    "type": "book-chapter",
    "type_crossref": "book-chapter"
}

# Test case
def test_query_openalex_api():
    # Mock the requests.get function
    with patch('requests.get') as mock_get:
        # Test with a title
        mock_get.return_value.status_code = 200
        mock_get.return_value.json.return_value = mock_title_response
        result = query_openalex_api("Example Article")
        assert isinstance(result, str), "Result should be a string"
        parsed_result = json.loads(result)
        assert len(parsed_result) == 1, "Should return one result for title search"
        assert parsed_result[0]['title'] == "Example Article", "Title should match"

        # Test with a PMID
        mock_get.return_value.json.return_value = mock_pmid_response
        result = query_openalex_api("12345678")
        assert isinstance(result, str), "Result should be a string"
        parsed_result = json.loads(result)
        assert parsed_result['title'] == "Example Article with a PMID", "Title should match for PMID search"

        # Test with a DOI
        mock_get.return_value.json.return_value = mock_doi_response
        result = query_openalex_api("10.9876/example")
        assert isinstance(result, str), "Result should be a string"
        parsed_result = json.loads(result)
        assert parsed_result['doi'] == "https://doi.org/10.9876/example", "DOI should match"

        # Test with an empty identifier
        result = query_openalex_api("")
        assert "error" in json.loads(result), "Should return an error message for empty identifier"

        # Test with multiple results for title search
        mock_get.return_value.json.return_value = {"meta": {"count": 10}, "results": [{}] * 10}
        result = query_openalex_api("Common Title")
        assert result == "Error: The search results are more than 5. Please provide a correct title.", "Should return error for too many results"

        # Test with no results
        mock_get.return_value.json.return_value = {"meta": {"count": 0}, "results": []}
        result = query_openalex_api("Non-existent Article")
        assert result == "No results found for the provided title.", "Should return a message indicating no results found"

    print("All tests passed!")

test_query_openalex_api()


All tests passed!


#| hide

Below are tests that make calls to the OpenAlex API... 

In [None]:
#| hide
try:
    print(query_openalex_api("36003377"))
except Exception as e:
    print(f"Error: {e}")
    errors.append(e)

{
  "id": "https://openalex.org/W4290546253",
  "doi": "https://doi.org/10.3389/fimmu.2022.856497",
  "title": "Human leukocyte antigen class II gene diversity tunes antibody repertoires to common pathogens",
  "publication_year": 2022,
  "cited_by_count": 9,
  "cited_by_api_url": "https://api.openalex.org/works?filter=cites:W4290546253",
  "open_access": {
    "is_oa": true,
    "oa_status": "gold",
    "oa_url": "https://www.frontiersin.org/articles/10.3389/fimmu.2022.856497/pdf",
    "any_repository_has_fulltext": true
  },
  "type": "article",
  "type_crossref": "journal-article"
}


In [None]:
#| hide
title = "Human leukocyte antigen class II gene diversity tunes antibody repertoires to common pathogens"
args = [title, "36003377", "10.3389/fimmu.2022.856497", "doi.org/10.3389/fimmu.2022.856497", "https://doi.org/10.3389/fimmu.2022.856497"]
try:
    api_responses = [query_openalex_api(arg) for arg in args]
    # for response in api_responses:
    #     print(response[:300])
    assert all(title in response for response in api_responses), "All responses should contain the title."
    assert all(isinstance(response, str) for response in api_responses), "All responses should be strings."
    assert all(response == api_responses[1] for response in api_responses[1:]), "All responses should be the same, except for the title search."
except Exception as e:
    print(e)
    errors.append(e)

#| hide

The follwoing code cells are for demonstrating the use of **Semantic Scholar's Academic Graph API** to retrieve information about scholarly articles.

In [None]:
#| hide
import time

In [None]:
#| hide
# Try loading the Semantic Scholar API key from the environment variables
try:
    SEMANTIC_SCHOLAR_API_KEY
except NameError:
    SEMANTIC_SCHOLAR_API_KEY = os.environ.get("SEMANTIC_SCHOLAR_API_KEY")

# Construct the URL to query the Semantic Scholar API by paper ID, and define headers with an API key
paper_id = "PMID:36003377"  # Example paper ID
academicgraph_base_url = "https://api.semanticscholar.org/graph/v1"
resource = "/paper/"
resource_path = f"{resource}{paper_id}"
fields = "title,year,tldr"
query_params = f"?fields={fields}"
url = f"{academicgraph_base_url}{resource_path}{query_params}"
headers = {'x-api-key': SEMANTIC_SCHOLAR_API_KEY}
print(url)

https://api.semanticscholar.org/graph/v1/paper/PMID:36003377?fields=title,year,tldr


In [None]:
#| hide
# Send the API request
try:
   response = requests.get(url, headers=headers)

   if response.status_code == 200:
      response_data = response.json()
      # Process and print the response data as needed
      print(json.dumps(response_data, indent=2))
   else:
      print(f"Request failed with status code {response.status_code}: {response.text}")
   time.sleep(1) # sleep for 1 second to avoid rate limiting issues
except Exception as e:
   print(f"Error: {e}")
   errors.append(e)

{
  "paperId": "4d29302308973a7a92dc3a9f1295b2ba761e2a77",
  "title": "Human leukocyte antigen class II gene diversity tunes antibody repertoires to common pathogens",
  "year": 2021,
  "tldr": {
    "model": "tldr@v2.0.0",
    "text": "It is demonstrated that multiple HLA class II alleles play a synergistic role in shaping the antibody repertoire, and HLA-DRB1 genotypes with specific antigens were identified, suggesting that HLAclass II gene polymorphisms confer specific humoral immunity against common pathogens."
  }
}


In [None]:
#| hide
# Query by title; this may return multiple results
title = "Human leukocyte antigen class II gene diversity tunes antibody repertoires to common pathogens"
url = f"{academicgraph_base_url}/paper/search?query={title}&fields={fields}"

try:
    response = requests.get(url, headers=headers)

    if response.status_code == 200:
        response_data = response.json()
        # print(json.dumps(response_data, indent=2))
        assert title in response_data['data'][0]['title'], "Title should match the query"
        print("Test passed successfully!")
    else:
        print(f"Request failed with status code {response.status_code}: {response.text}")
    time.sleep(1) # sleep for 1 second to avoid rate limiting issues
except Exception as e:
    print(f"Error: {e}")
    errors.append(e)

Test passed successfully!


In [None]:
#| hide
# Title search; this only returns the top result
title = "Human leukocyte antigen class II gene diversity tunes antibody repertoires to common pathogens"
url = f"{academicgraph_base_url}/paper/search/match?query={title}&fields={fields}"

try:
    response = requests.get(url, headers=headers)

    if response.status_code == 200:
        response_data = response.json()
        # print(json.dumps(response_data, indent=2))
        assert title in response_data["data"][0]["title"], "Title should match the query"
        print("Test passed successfully!")
    else:
        print(f"Request failed with status code {response.status_code}: {response.text}")
    time.sleep(1) # sleep for 1 second to avoid rate limiting issues
except Exception as e:
    print(f"Error: {e}")
    errors.append(e)

Test passed successfully!


:::{.callout-note}
The function `query_semantic_scholar_api()` below is executed to query the Semantic Scholar database for additional information about a given article, either by title, PubMed ID, or DOI. To
increase the rate limit, provide your own API key (see the documentation for more information).
:::

In [None]:
#| export
import time
import re

#| hide

Assert the correct format of the JSON schema...

In [None]:
#| export
@json_schema_decorator
def query_semantic_scholar_api(query_param: str) -> str:
    """
    Retrieve metadata for a given article from the Semantic Scholar Academic Graph (S2AG), a large knowledge graph of scientific literature that combines data from multiple sources.
    Use this tool to query the Semantic Scholar Graph API by using either the article title, the PubMed ID, or the digital object identifier (DOI) to retrieve the following metadata:
    - the title
    - the publication year
    - the abstract
    - a tldr (too long, didn't read) summary
    - the authors of the article
    - the URL to the open-access PDF version of the article, if available
    - the journal name
    - a url to the article on the Semantic Scholar website
    Use this tool only if an article title, PubMed ID or DOI is provided by the user or was extracted from a local PDF file and is present in the conversation history.
    
    Args:
        query_param (str): The article title, the PubMed ID, or the digital object identifier of the article to retrieve metadata for. May be provided by the user or extracted from a local PDF file and present in the conversation history. Do not include the 'https://doi.org/' prefix for DOIs, or keys such as 'DOI', 'PMCID' or 'PMID'. The tool will automatically detect the type of identifier provided.

    Returns:
        str: A JSON-formatted string including the search results from the Semantic Scholar database. If no results are found or the API query fails, an appropriate message is returned.
    """

    # Try to get the API key from the environment variables, if available, and define the headers
    try:
        SEMANTIC_SCHOLAR_API_KEY
    except NameError:
        SEMANTIC_SCHOLAR_API_KEY = os.environ.get("SEMANTIC_SCHOLAR_API_KEY")
    headers = {'x-api-key': SEMANTIC_SCHOLAR_API_KEY}

    # Validate the input
    if query_param is None or query_param == "":
        return json.dumps({"error": "The query parameter must be a non-empty string."})
    elif not isinstance(query_param, str):
        query_param = str(query_param)
    
    # Clean the query parameter
    query_param = query_param.strip().lower() # Convert to lowercase and remove leading/trailing whitespace
    if query_param.startswith("https://doi.org/"):
        query_param = query_param.replace("https://doi.org/", "")
    elif query_param.startswith("doi.org/"):
        query_param = query_param.replace("doi.org/", "")
    elif query_param.startswith("doi"):
        query_param = query_param.replace("doi", "")
    elif query_param.startswith("pmid"):
        query_param = query_param.replace("pmid", "")
    elif query_param.startswith("pmcid"):
        query_param = query_param.replace("pmcid", "")
    if detect_id_type(query_param) != "potential_title":
        query_param = query_param.replace(":", "") # Remove any colons from the query parameter if it is not a title
        query_param = query_param.replace(" ", "") # Remove spaces from the query parameter if it is not a title

    # Constants
    academicgraph_base_url = "https://api.semanticscholar.org/graph/v1"
    fields = "title,year,authors,tldr,abstract,citationCount,openAccessPdf,journal,url" # The values to retrieve from the API


    # Construct the URL based on the identifier type
    if detect_id_type(query_param) == "potential_title":
        url = f"{academicgraph_base_url}/paper/search/match?query={query_param}&fields={fields}"
    elif detect_id_type(query_param) == "pmid":
        url = f"{academicgraph_base_url}/paper/PMID:{query_param}?&fields={fields}"
    elif detect_id_type(query_param) == "doi":
        url = f"{academicgraph_base_url}/paper/DOI:{query_param}?&fields={fields}"
    else:
        return json.dumps({"error": "The query parameter must be a valid title, PMID, or DOI."})
    
    response = requests.get(url, headers=headers) # Send the API request
    # print(response.url) # Uncomment for debugging
    time.sleep(2) # sleep for 2 seconds to avoid rate limiting issues
    
    # Check response status
    if response.status_code == 200:
        return json.dumps(response.json(), indent=2)
    else:
        return f"Error: Failed to query Semantic Scholar API. Status code: {response.status_code}"

In [None]:
#| hide
assert type(query_semantic_scholar_api.json_schema) == dict
assert query_semantic_scholar_api.json_schema['function']['name'] == "query_semantic_scholar_api"
assert 'query_param' in query_semantic_scholar_api.json_schema['function']['parameters']['properties'].keys()
print("All tests passed!")

All tests passed!


#| hide
A test case for the `query_semantic_scholar_api()` function is provided below.

In [None]:
#| hide
import json
from unittest.mock import patch

# Mock response for successful API call
mock_success_response = {
    "title": "Example Article",
    "year": 2023,
    "authors": [{"name": "John Doe"}, {"name": "Jane Smith"}],
    "tldr": {"text": "This is a summary of the article."},
    "abstract": "This is the abstract of the example article.",
    "citationCount": 42,
    "openAccessPdf": {"url": "https://example.com/paper.pdf"},
    "journal": {"name": "Example Journal"},
    "url": "https://www.semanticscholar.org/paper/example"
}

# Test case for query_semantic_scholar_api function
def test_query_semantic_scholar_api():
    # Mock the requests.get function
    with patch('requests.get') as mock_get:
        # Configure the mock to return a successful response
        mock_get.return_value.status_code = 200
        mock_get.return_value.json.return_value = mock_success_response

        # Test with a title
        result = query_semantic_scholar_api("Example Article")
        assert isinstance(result, str), "Result should be a string"
        
        parsed_result = json.loads(result)
        assert parsed_result == mock_success_response, "Returned JSON should match the mock response"

        # Test with a DOI
        result = query_semantic_scholar_api("10.1234/example.doi")
        assert isinstance(result, str), "Result should be a string"
        
        parsed_result = json.loads(result)
        assert parsed_result == mock_success_response, "Returned JSON should match the mock response"

    # Test with an empty identifier
    result = query_semantic_scholar_api("")
    assert "error" in json.loads(result), "Should return an error message for empty identifier"

    print("All tests passed!")

# Run the test
test_query_semantic_scholar_api()


All tests passed!


#| hide

Below are tests that make API calls to Semantic Scholar to retrieve information about an actual article in the database.

In [None]:
#| hide
try:
    print(query_semantic_scholar_api("36003377"))
    response = query_semantic_scholar_api("36003377")
    response = json.loads(response)
    assert response["paperId"] == "4d29302308973a7a92dc3a9f1295b2ba761e2a77", "PubMed ID should match the corresponding paper ID"
    print("Test passed!")
except Exception as e:
    print(e)
    errors.append(e)

{
  "paperId": "4d29302308973a7a92dc3a9f1295b2ba761e2a77",
  "url": "https://www.semanticscholar.org/paper/4d29302308973a7a92dc3a9f1295b2ba761e2a77",
  "title": "Human leukocyte antigen class II gene diversity tunes antibody repertoires to common pathogens",
  "abstract": "Allelic diversity of HLA class II genes may help maintain humoral immunity against infectious diseases. We investigated the relative contribution of specific HLA class II alleles, haplotypes and genotypes on the variation of antibody responses to a variety of common pathogens in a cohort of 800 adults representing the general Arab population. We found that classical HLA class II gene heterozygosity confers a selective advantage. Moreover, we demonstrated that multiple HLA class II alleles play a synergistic role in shaping the antibody repertoire. Interestingly, associations of HLA-DRB1 genotypes with specific antigens were identified. Our findings suggest that HLA class II gene polymorphisms confer specific humoral 

In [None]:
#| hide
title = "Human leukocyte antigen class II gene diversity tunes antibody repertoires to common pathogens"
args = [title, "36003377", "10.3389/fimmu.2022.856497", "doi.org/10.3389/fimmu.2022.856497", "https://doi.org/10.3389/fimmu.2022.856497"]
try: 
    api_responses = [query_semantic_scholar_api(arg) for arg in args]
    # for response in api_responses:
    #     print(response[:300])
    assert all(title in response for response in api_responses), "All responses should contain the title."
    assert all(isinstance(response, str) for response in api_responses), "All responses should be strings."
    assert all(response == api_responses[1] for response in api_responses[1:]), "All responses should be the same, except for the title search."
    print("All tests passed!")
except Exception as e:
    print(e)
    errors.append(e)

All tests passed!


In [None]:
#| export
@json_schema_decorator
def respond_to_generic_queries() -> str:
    """
    A function to respond to generic questions or queries from the user. Use this tool if no better tool is available.

    This tool does not take any arguments.

    Returns:
        str: A response to a generic question.
    """

    return "There is no specific tool available to respond this query from the user. State your capabilities based the system message or provide a response based on the conversation history."

In [None]:
#| hide
assert type(respond_to_generic_queries.json_schema) == dict
assert respond_to_generic_queries.json_schema['function']['name'] == "respond_to_generic_queries"
assert type(respond_to_generic_queries()) == str
print("All tests passed!")

All tests passed!


## Assistant class

:::{.callout-note}
The Assistant class below is defined to simplify the process of chat and tool use, along with a function to show responses.
:::

In [None]:
#| export
from typing import Generator

In [None]:
#| export
def show_response(response: Dict[str, Any] or Generator[Dict[str, Any], None, None]) -> None:
    """
    Print the response from the LLM in a human-readable format.

    Args:
        response (Dict[str, Any] or Generator[Dict[str, Any], None, None]): The response from the LLM.
    """
    # ANSI escape code for blue and red text
    BLUE = "\033[94m"
    RED = "\033[91m"
    RESET = "\033[0m"

    if isinstance(response, dict):
        print(f"\n{BLUE}{response['message']['content']}{RESET}")
        return response['message']['content']

    elif isinstance(response, Generator):
        print("\n")
        _response = ""
        for chunk in response:
            _response += chunk['message']['content']
            print(f"{BLUE}{chunk['message']['content']}{RESET}", end='', flush=True)
        return _response
    
    elif response is None:
        print(f"\n{RED}No response from the LLM.{RESET}")
        return None
    
    else:
        raise ValueError(f"\n{RED}nvalid response type. Must be a dictionary or a generator.{RESET}")

In [None]:
#| hide
# Test with dictionary input
mock_response = {
    "message": {
        "content": "This is a mock response for testing."
    }
}
response = show_response(mock_response)
assert response == "This is a mock response for testing."


[94mThis is a mock response for testing.[0m


In [None]:
#| export
import ollama
from typing import Dict, Any, List

In [None]:
#| export
import sys
import io
from pathlib import Path

class Assistant:
    def __init__(self,
        sys_message: str or None = None, # The system message for the assistant; if not provided, a default message is used
        model: str = "llama3.1:latest", # The model to use for the assistant
        tools: Dict[str, Any] = { # The tools available to the assistant
           "get_file_names": get_file_names,
           "extract_text_from_pdf": extract_text_from_pdf,
           "get_titles_and_first_authors": get_titles_and_first_authors,
           "summarize_local_document": summarize_local_document,
           "describe_python_code": describe_python_code,
           "id_converter_tool": id_converter_tool,
           "query_openalex_api": query_openalex_api,
           "query_semantic_scholar_api": query_semantic_scholar_api,
           "respond_to_generic_queries": respond_to_generic_queries,
        },
        add_tools: Dict[str, Any] = {}, # Optional argument to add additional tools to the assistant, when initializing
        authentication: Optional[Dict[str, str]] = None, # Authentication credentials for API calls to external services
        dir_path: str = "../data", # The directory path to which the assistant has access on the local computer
        messages: List[Dict[str, str]] = []): # The conversation history
        
        self.sys_message = sys_message
        self.model = model
        self.tools = tools
        self.tools.update(add_tools) # Add additional tools to the assistant, if provided
        if self.tools:
            self.tools["describe_tools"] = self.describe_tools # Add the describe_tools function to the tools list for the assistant, if the tools list is not empty
        self.authentication = authentication or {}
        self.dir_path = Path(dir_path).resolve()
        self.messages = messages

        # Set global variables
        global DIR_PATH
        DIR_PATH = self.dir_path
        global MODEL
        MODEL = self.model
        # TODO: Consider allowing the user to set different models for different tasks and tools
        # e.g. a model such as llama 3.1 for function calls, command-r-plus for summarization, aya for translation, etc.
        
        # ANSI escape codes, used for output formatting
        GREY = "\033[90m"
        BLUE = "\033[94m"
        RED = "\033[91m"
        RESET = "\033[0m"

        # Load the API keys from the environment variables or the authentication dictionary
        self.SEMANTIC_SCHOLAR_API_KEY = self.authentication.get("SEMANTIC_SCHOLAR_API_KEY")
        self.EMAIL = self.authentication.get("EMAIL")

        if not self.SEMANTIC_SCHOLAR_API_KEY:
            self.SEMANTIC_SCHOLAR_API_KEY = os.environ.get("SEMANTIC_SCHOLAR_API_KEY")
            if self.SEMANTIC_SCHOLAR_API_KEY:
                print(f"{GREY}Loaded Semantic Scholar API key from the environment variables.{RESET}")
        if not self.EMAIL:
            self.EMAIL = os.environ.get("EMAIL")
            if self.EMAIL:
                print(f"{GREY}Loaded email address from the environment variables.{RESET}")

        # Generate the default directory for storing data files if it does not exist
        if not os.path.exists(DIR_PATH):
            os.mkdir(DIR_PATH)
            print(f"{GREY}Created directory {DIR_PATH} for storing data files.{RESET}\n")
        else: 
            print(f"{GREY}A local directory {DIR_PATH} already exists for storing data files. No of files: {len(os.listdir(DIR_PATH))}{RESET}\n")

        # Set the default system message if not provided
        if not self.sys_message:
            self.sys_message ="""You are an AI assistant specialized in analyzing research articles.
        Your role is to provide concise, human-readable responses based on information from tools and conversation history.

        Key instructions:
        1. Use provided tools to gather information before answering.
        2. Interpret tool results and provide clear, concise answers in natural language.
        3. If you can't answer with available tools, state this clearly.
        4. Don't provide information if tool content is empty.
        5. Never include raw JSON, tool outputs, or formatting tags in responses.
        6. Format responses as plain text for direct human communication.
        7. Use clear formatting (e.g., numbered or bulleted lists) when appropriate.
        8. Provide article details (e.g., DOI, citation count) in a conversational manner.

        Act as a knowledgeable research assistant, offering clear and helpful information based on available tools and data.
        """

        # Check if the model is available
        downloaded_models = []
        for model in ollama.list()["models"]:
            downloaded_models.append((model["name"].replace(":latest", "")))
        assert self.model.replace(":latest", "") in downloaded_models, f"Model {self.model} not found. Please pull the latest version from the server."

        # Check if the selected model supports tool calling
        # for more information, visit https://ollama.com/blog/tool-support
        assert self.model.split(":")[0] in ["llama3.1", "command-r-plus", "mistral-nemo", "firefunction-v2"], f"Model {self.model} does not support tool calling. Please select a different model."

        if len(self.tools) == 0:
            print(f"\033[91mNo tools provided! Please add tools to the assistant.\033[0m")

    def __str__(self):
        return f"Assistant, powered by {self.model.split(':')[0]}"

    def __repr__(self):
        return self.__str__()

    def list_tools(self):
        "List the available tools in the assistant."
        for tool in self.tools.keys():
            print(tool)

    def get_tools_schema(self):
        "Return the JSON schema for the available tools."
        return [func.json_schema for func in self.tools.values()]

    @json_schema_decorator
    def describe_tools(self) -> str:
        """Use this tool when asked about the assistant's available tools and capabilities.

        Returns:
            str: A string with the descriptions of the available tools.
        """
        return f"Available tools are: {self.get_tools_schema()}\n State your capabilities based the available tools in a conversational manner."
        # return f"{self.pprint_tools()}\n State your capabilities based the available tools in a conversational manner."

    def chat(self, prompt: str, show_progress: bool = False, stream_response: bool = True, redirect_output: bool = False):
        """
        Start a conversation with the AI assistant.

        Args:
            prompt (str): The user's prompt or question.
            show_progress (bool): Whether to show the step-by-step progress of the fuction calls, including the tool calls and tool outputs. Default is False.
            stream_response (bool): Whether to stream the final response from the LLM. Default is True. Automatically set to True if redirect_output is True.
            redirect_output (bool): Whether to redirect the output to be compatible with st.write_stream. Default is False.

        Returns:
            str: The AI assistant's response.
        """
        # At the start of the conversation, if no messages are provided, add the system message and user prompt
        if not self.messages:
            self.messages = [
                {'role': "system", 'content': self.sys_message},
                {'role': 'user', 'content': prompt},
            ]
        else:
            self.messages.append({'role': 'user', 'content': prompt})

        # Generate JSON schemas for the available tools
        tools_schema = self.get_tools_schema()

        # Make a request to the LLM to select a tool
        if show_progress: print("Selecting tools...\n")
        response = ollama.chat(
            model=self.model,
            messages=self.messages,
            tools=tools_schema,
            stream=False, # Set to False to avoid streaming the tool calls
        )

        # Add the model's response to the conversation history
        if response.get('message', {}).get('tool_calls'):
            if show_progress: print(response['message']['tool_calls']) # Uncomment for debugging
            self.messages.append(
                {'role': 'assistant', 'tool_calls': response['message']['tool_calls']}
                )
        else:
            # print("LLM response (not added to the conversation history):", response['message']['content']) # Uncomment for debugging
            self.messages.append(
                {'role': 'assistant', 'tool_calls': []}
                )
            print(f"\033[91mNo tool calls found in the response. Adding an empty tool_calls list to the conversation history. Aborting...\033[0m\n")
            return None # Abort the function if no tool calls are found in the response. Goal is to force the assistant to use a tool. We will generate a tool for generic responses.

        # Call the function if a tool is selected
        for tool in response['message']['tool_calls']:
            # print("Arguments:", tool['function']['arguments'])
            function_to_call = self.tools[tool['function']['name']]
            if show_progress: print(f"Calling {tool['function']['name']}() with arguments {tool['function']['arguments']}...\n")
            args = tool['function']['arguments']

            try:
                function_response = function_to_call(**args)
                # print(f"Function response type: {type(function_response)}\n") # Uncomment for debugging
                # print(f"Function response: {function_response}\n") # Uncomment for debugging
                function_response = str(function_response) if function_response is not None else ""
                # assert isinstance(function_response, str), "Function response must be a string."
            except Exception as e:
                function_response = f"Error: {e}"

                if show_progress: print(f"Function response:\n{function_response}\n")
            # Add the fucntion response to the conversation history
            self.messages.append( 
                {
                    'role': 'tool',
                    'content': function_response,
                }
            ) 

        if redirect_output: # If the output is to be redirected...
            stream_response = True # always use streaming for compatibility with st.write_stream

        # Make a second request to the LLM with the tool output to generate a final response
        if show_progress: print("Generating final response...")
        response = ollama.chat(
            model=self.model,
            format="", # Set to empty string to avoid JSON formatting; If JSON formatting is needed, set to "json"
            messages=self.messages,
            stream=stream_response, # Set to True , response will be a generator
        )
        
        # Advanced parameters (optional):
        # keep_alive: controls how long the model will stay loaded into memory following the request (default: 5m)
        # options: additional model parameters listed in the documentation for the Modelfile such as temperature
        # for more information, visi:
        # https://github.com/ollama/ollama/blob/main/docs/api.md
        # https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
        
        if redirect_output:
            # Return a generator for st.write_stream
            def response_generator():
                full_content = ""
                for chunk in response:
                    content = chunk['message']['content']
                    full_content += content
                    yield content
                self.messages.append({'role': 'assistant', 'content': full_content})
            return response_generator()

        else:
            if isinstance(response, dict):
                content = response['message']['content']
                self.messages.append({'role': 'assistant', 'content': content})
                print(content)
                return content  # Return the content directly, not as a generator
                
            elif hasattr(response, '__iter__'):  # Check if it's iterable (for streaming)
                full_content = ""
                for chunk in response:
                    content = chunk['message']['content']
                    full_content += content
                    print(content, end='', flush=True)
                self.messages.append({'role': 'assistant', 'content': full_content})
                return full_content  # Return the full content, not as a generator

The following code cell is for testing purposes only. It is not part of the final library.

In [None]:
#| hide
# Test the Assistant class' chat method with no additional arguments
try:
    assistant = Assistant()
    response = assistant.chat("Tell me about the tools you have available.")
    assert isinstance(response, str), f"Expected response type str but got {type(response)}"
except Exception as e:
    print(e)
    errors.append(e)

[90mLoaded Semantic Scholar API key from the environment variables.[0m
[90mLoaded email address from the environment variables.[0m
[90mA local directory /Users/user2/GitHub/scholaris/data already exists for storing data files. No of files: 1[0m

I can summarize research articles, provide information on various tools, extract text from PDF files, convert IDs, query OpenAlex API, and more! I've got a range of functions that can help with specific tasks.

* I can summarize local documents (PDF, markdown, or text) using the `summarize_local_document` function.
* The `get_titles_and_first_authors` function retrieves titles and first authors from research articles in PDF files.
* You can use `extract_text_from_pdf` to extract specific details from a PDF file, such as abstract, authors, conclusions, or user-specified content of other sections.
* If you have a list of IDs (PMIDs, PMCIDs, DOIs), the `id_converter_tool` function can convert them and find corresponding IDs for the same arti

In [None]:
#| hide
# Test the Assistant class' chat method with the stream_response argument set to False
try:
    assistant = Assistant()
    response = assistant.chat("Tell me about the tools you have available.", stream_response=False)
    assert isinstance(response, str), f"Expected response type str but got {type(response)}"
except Exception as e:
    print(e)
    errors.append(e)

[90mLoaded Semantic Scholar API key from the environment variables.[0m
[90mLoaded email address from the environment variables.[0m
[90mA local directory /Users/user2/GitHub/scholaris/data already exists for storing data files. No of files: 1[0m

I can perform various tasks based on the available tools. Here are some of my capabilities:

* I can summarize the content of local documents such as PDFs, markdown files, and text documents.
* I can describe the purpose of Python code in a local file.
* I can extract specific details from a PDF document using the `extract_text_from_pdf` tool.
* I can get titles and first authors from a directory of PDF files using the `get_titles_and_first_authors` function.
* I can query the OpenAlex API to retrieve metadata for an article based on its title, PMID, PMCID, or DOI.
* I can query the Semantic Scholar Academic Graph (S2AG) API to retrieve metadata for an article based on its title, PMID, or DOI.

These capabilities are based on the availabl

In [None]:
#| hide
try:
    # Test the Assistant class' chat method with the redirect_output argument set to True
    assistant = Assistant()
    response = assistant.chat("Tell me about the tools you have available.", redirect_output=True)
    assert isinstance(response, Generator), f"Expected response type to be a generator, but got {type(response)}"
except Exception as e:
    print(e)
    errors.append(e)

[90mLoaded Semantic Scholar API key from the environment variables.[0m
[90mLoaded email address from the environment variables.[0m
[90mA local directory /Users/user2/GitHub/scholaris/data already exists for storing data files. No of files: 1[0m



In [None]:
#| hide
try:
    # Test the Assistant class' initialization and the dir_path attribute; should be a valid Path object that exists
    assistant = Assistant()
    assert isinstance(assistant.dir_path, Path), f"Expected dir_path to be a Path object, but got {type(assistant.dir_path)}"
    assert assistant.dir_path.exists(), f"Expected dir_path to exist, but it does not."
except Exception as e:
    print(e)
    errors.append(e)

[90mLoaded Semantic Scholar API key from the environment variables.[0m
[90mLoaded email address from the environment variables.[0m
[90mA local directory /Users/user2/GitHub/scholaris/data already exists for storing data files. No of files: 1[0m



In [None]:
#| hide
try:
    # Test the Assistant class' chat method with with a prompt that should call the get_file_names tool
    assistant = Assistant()
    assistant.chat("Which PDF files do you have access to in the local data directory?", show_progress=True)
    assert any('get_file_names' == item.get('tool_calls', [{}])[0].get('function', {}).get('name') for item in assistant.messages), "Function 'get_file_names' not found in the messages"
except Exception as e:
    print(e)
    errors.append(e)

[90mLoaded Semantic Scholar API key from the environment variables.[0m
[90mLoaded email address from the environment variables.[0m
[90mA local directory /Users/user2/GitHub/scholaris/data already exists for storing data files. No of files: 1[0m

Selecting tools...

[{'function': {'name': 'get_file_names', 'arguments': {'ext': 'pdf, txt'}}}]
Calling get_file_names() with arguments {'ext': 'pdf, txt'}...

Generating final response...
I have access to the following PDF files:

* jci.insight.144499.v2.pdf

#| hide

Continuing with the implementation of additional Assistant class methods...

In [None]:
#| export
def add_to_class(Class: type):
    """Register functions as methods in a class that has already been defined."""
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper

In [None]:
#| export
@add_to_class(Assistant)
def show_conversion_history(self, show_function_calls: bool = False):
    """Display the conversation history.
    
    Args:
        show_function_calls (bool): Whether to show function calls and returns in the conversation history. Default is False.

    Returns:
        None
    """

    # ANSI escape code for blue and red text
    BLUE = "\033[94m"
    BOLD = "\033[1m"
    GREY = "\033[90m"
    RESET = "\033[0m"

    for message in self.messages:
        if message['role'] != 'system':
            if message['role'] == 'user':
                print(f"{BOLD}User:{RESET} {message['content']}\n")
            elif message['role'] == 'assistant':
                if 'content' in message and message['content']:
                    print(f"{BOLD}{BLUE}Assistant response:{RESET} {BLUE}{message['content']}{RESET}\n")
                if 'tool_calls' in message and message['tool_calls'] and show_function_calls:
                    print(f"{BOLD}{BLUE}Assistant function calls:{RESET} ", end='')
                    for tool in message['tool_calls']:
                        print(f"{BLUE}{tool['function']['name']}() with arguments {tool['function']['arguments']}{RESET}\n")
            elif message['role'] == 'tool' and show_function_calls:
                # convert str to list
                if isinstance(message['content'], str):
                    message['content'] = [message['content']]
                for fn_return in message['content']:
                    print(f"{BOLD}{GREY}Function return:{RESET} {GREY}{fn_return}{RESET}\n")

In [None]:
#| export
@add_to_class(Assistant)
def clear_conversion_history(self):
    """Clear the conversation history."""
    self.messages = [{'role': "system", 'content': self.sys_message},]

In [None]:
#| hide
try:
    assistant.clear_conversion_history()
    assert assistant.show_conversion_history() == None
except Exception as e:
    print(e)
    errors.append(e)

In [None]:
#| export
@add_to_class(Assistant)
def pprint_tools(self):
    for tool in self.get_tools_schema():   
        print(f"""* Tool name: {tool.get("function", {}).get("name", "No name available.")}
    Description: {tool.get("function", {}).get("description", "No description available.")}
        """)
    return None

In [None]:
#| hide
# Print all errors and exceptions, if any
if len(errors) > 0:    
    for (error) in errors:
        print(error)
else:
    print("No errors and exceptions.")

No errors and exceptions.


In [None]:
#| hide
# remember to save the notebook before running this command
import nbdev; nbdev.nbdev_export()