In [45]:
import time
import pandas as pd

from tqdm import tqdm
from google.cloud import aiplatform
from sklearn.metrics import f1_score
from vertexai.language_models import ChatModel, InputOutputTextPair
from utils import html_parsing_ncbi, html_parsing_n2c2, get_classification_report, get_macro_average_f1

aiplatform.init(project='xxx-xxxxxxx-xxxxxx')
chat_model = ChatModel.from_pretrained("chat-bison@002")

**Note**: You will have to setup a project on Google Cloud that enables Vertex AI API, and replace `xxx-xxxxxxx-xxxxxx` with your own project ID. The free trial period of Google Cloud has limited quota for Vertex AI API for PaLM2 model Bison per minute (~60 calls per minute). If you encounter quota exceeded error, please try again after waiting for a minute and continue from where you left off.

# 1. NER (Named Entity Recognition)

## 1.1 NCBI-Disease Dataset

### 1.1.1 Inference

In [3]:
ncbi_df = pd.read_csv('data/NER/NCBI-disease/test_200.csv')
ncbi_example_df = pd.read_csv('data/NER/NCBI-disease/examples.csv')

In [10]:
system_message = """You are a helpful assistant to perform the following task.
"TASK: the task is to extract disease entities in a sentence."
"INPUT: the input is a sentence."
"OUTPUT: the output is an HTML that highlights all the disease entities in the sentence. The highlighting should only use HTML tags <span style=\"background-color: #FFFF00\"> and </span> and no other tags."
"""
def get_ner_ncbi_disease(sentence: str, shot: int = 0) -> str:
    """
    Get NER result from NCBI-disease dataset given a sentence and number of examples
    Args:
        sentence: input sentence
        shot: number of examples
    Returns:
        response: NER result
    """

    parameters = {
        "temperature": 0.0,
        "max_output_tokens": 2048,
        "top_p": 0.95,
        "top_k": 1,
    }
    
    examples = []
    for i in range(shot):
        examples.append(
            InputOutputTextPair(
                input_text=ncbi_example_df['text'][i],
                output_text=ncbi_example_df['label_text'][i],
            )
        )

    chat = chat_model.start_chat(
        context=system_message,
        examples=examples,
    )

    time_start = time.time()
    response = chat.send_message(
        sentence, **parameters
    )
    time_end = time.time()

    return response.text, time_end - time_start

In [None]:
for i in tqdm(range(0, len(ncbi_df), 1)):
    if (i + 1) % 15 == 0: # PaLM2 has a quota limit
        time.sleep(65)
    ncbi_df.loc[i, 'html_palm2_one_shot'], ncbi_df.loc[i, 'palm2_one_shot_time'] = get_ner_ncbi_disease(ncbi_df.loc[i, 'text'], 1)
    ncbi_df.loc[i, 'html_palm2_five_shot'], ncbi_df.loc[i, 'palm2_five_shot_time'] = get_ner_ncbi_disease(ncbi_df.loc[i, 'text'], 5)
    ncbi_df.loc[i, 'html_palm2_ten_shot'], ncbi_df.loc[i, 'palm2_ten_shot_time'] = get_ner_ncbi_disease(ncbi_df.loc[i, 'text'], 10)
    ncbi_df.loc[i, 'html_palm2_twenty_shot'], ncbi_df.loc[i, 'palm2_twenty_shot_time'] = get_ner_ncbi_disease(ncbi_df.loc[i, 'text'], 20)

In [15]:
# drop 89th prediction because Gemini is not able to predict it due to safety filter
ncbi_df.drop([89], inplace=True)

### 1.1.2 Evaluation

In [None]:
# Optional: you can just load the llm output from the csv file instead of running the above code
# ncbi_df = pd.read_csv("data/NER/NCBI-disease/test_200_palm2_bison_results.csv")

In [16]:
ncbi_df['gt_labels'], ncbi_df['palm2_one_shot_labels'] = html_parsing_ncbi(ncbi_df, 'html_palm2_one_shot')
_, ncbi_df['palm2_five_shot_labels'] = html_parsing_ncbi(ncbi_df, 'html_palm2_five_shot')
_, ncbi_df['palm2_ten_shot_labels'] = html_parsing_ncbi(ncbi_df, 'html_palm2_ten_shot')
_, ncbi_df['palm2_twenty_shot_labels'] = html_parsing_ncbi(ncbi_df, 'html_palm2_twenty_shot')

In [21]:
print(f"F1 Score One Shot (Strict): {get_classification_report(ncbi_df, 'gt_labels', 'palm2_one_shot_labels', 'strict')['default']['f1-score']}")
print(f"F1 Score Five Shot (Strict): {get_classification_report(ncbi_df, 'gt_labels', 'palm2_five_shot_labels', 'strict')['default']['f1-score']}")
print(f"F1 Score Ten Shot (Strict): {get_classification_report(ncbi_df, 'gt_labels', 'palm2_ten_shot_labels', 'strict')['default']['f1-score']}")
print(f"F1 Score Twenty Shot (Strict): {get_classification_report(ncbi_df, 'gt_labels', 'palm2_twenty_shot_labels', 'strict')['default']['f1-score']}")

F1 Score One Shot (Strict): 0.5778275475923852
F1 Score Five Shot (Strict): 0.596888260254597
F1 Score Ten Shot (Strict): 0.601156069364162
F1 Score Twenty Shot (Strict): 0.6402116402116402


In [22]:
print(f"F1 Score One Shot (Lenient): {get_classification_report(ncbi_df, 'gt_labels', 'palm2_one_shot_labels', 'lenient')['default']['f1-score']}")
print(f"F1 Score Five Shot (Lenient): {get_classification_report(ncbi_df, 'gt_labels', 'palm2_five_shot_labels', 'lenient')['default']['f1-score']}")
print(f"F1 Score Ten Shot (Lenient): {get_classification_report(ncbi_df, 'gt_labels', 'palm2_ten_shot_labels', 'lenient')['default']['f1-score']}")
print(f"F1 Score Twenty Shot (Lenient): {get_classification_report(ncbi_df, 'gt_labels', 'palm2_twenty_shot_labels', 'lenient')['default']['f1-score']}")

F1 Score One Shot (Lenient): 0.6920492721164613
F1 Score Five Shot (Lenient): 0.6789250353606789
F1 Score Ten Shot (Lenient): 0.7023121387283238
F1 Score Twenty Shot (Lenient): 0.7566137566137566


In [23]:
print(f"Average PaLM 2 one-shot prediction time: {ncbi_df['palm2_one_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 five-shot prediction time: {ncbi_df['palm2_five_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 ten-shot prediction time: {ncbi_df['palm2_ten_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 twenty-shot prediction time: {ncbi_df['palm2_twenty_shot_time'].mean():.2f} seconds")

Average PaLM 2 one-shot prediction time: 1.27 seconds
Average PaLM 2 five-shot prediction time: 1.14 seconds
Average PaLM 2 ten-shot prediction time: 1.23 seconds
Average PaLM 2 twenty-shot prediction time: 1.54 seconds


In [24]:
# save the inference results
ncbi_df.to_csv('data/NER/NCBI-disease/test_200_palm2_bison_results.csv', index=False)

# 1.2 2018 n2c2 Dataset

### 1.2.1 Inference

In [26]:
n2c2_df = pd.read_csv('data/NER/2018_n2c2/test_200.csv')
n2c2_example_df = pd.read_csv('data/NER/2018_n2c2/examples.csv')

In [27]:
system_message = """You are a helpful assistant to perform the following task.
"TASK: the task is to extract disease entities in a sentence. The entity type includes Form, Route, Frequency, Dosage, Strength, Duration, Reason, Ade, Drug."
"INPUT: the input is a sentence."
"OUTPUT: the output is an HTML that highlights all the disease entities in the sentence in different colors: Form(#FF0000), Route(#FFA500), Frequency(#FFFF00), Dosage(#00FF00), Strength(#0000FF), Duration(#800080), Reason(#FFC0CB), Ade(#964B00), Drug(#808080) in hex code. \
         The highlighting should only use HTML tags <span style=\"background-color: #XXXXXX\"> and </span> and no other tags."
"""
def get_ner_2018_n2c2(sentence: str, shot: int = 0) -> str:
    """
    Get NER result from 2018 n2c2 dataset given a sentence and number of examples
    Args:
        sentence: input sentence
        shot: number of examples
    Returns:
        response: NER result
    """

    parameters = {
        "temperature": 0.0,
        "max_output_tokens": 2048,
        "top_p": 0.95,
        "top_k": 1,
    }
    
    examples = []
    for i in range(shot):
        examples.append(
            InputOutputTextPair(
                input_text=n2c2_example_df['text'][i],
                output_text=n2c2_example_df['label_text'][i],
            )
        )

    chat = chat_model.start_chat(
        context=system_message,
        examples=examples,
    )

    time_start = time.time()
    response = chat.send_message(
        sentence, **parameters
    )
    time_end = time.time()

    return response.text, time_end - time_start

In [None]:
for i in tqdm(range(0, len(n2c2_df), 1)):
    if (i + 1) % 15 == 0: # PaLM2 has a quota limit
        time.sleep(65)
    n2c2_df.loc[i, 'html_palm2_one_shot'], n2c2_df.loc[i, 'palm2_one_shot_time'] = get_ner_2018_n2c2(n2c2_df.loc[i, 'text'], 1)
    n2c2_df.loc[i, 'html_palm2_five_shot'], n2c2_df.loc[i, 'palm2_five_shot_time'] = get_ner_2018_n2c2(n2c2_df.loc[i, 'text'], 5)
    n2c2_df.loc[i, 'html_palm2_ten_shot'], n2c2_df.loc[i, 'palm2_ten_shot_time'] = get_ner_2018_n2c2(n2c2_df.loc[i, 'text'], 10)
    n2c2_df.loc[i, 'html_palm2_twenty_shot'], n2c2_df.loc[i, 'palm2_twenty_shot_time'] = get_ner_2018_n2c2(n2c2_df.loc[i, 'text'], 20)

### 1.2.2 Evaluation

In [29]:
# Optional: you can just load the llm output from the csv file instead of running the above code
# n2c2_df = pd.read_csv("data/NER/2018_n2c2/test_200_palm2_bison_results.csv")

In [30]:
n2c2_df['gt_labels'], n2c2_df['palm2_one_shot_labels'] = html_parsing_n2c2(n2c2_df, 'html_palm2_one_shot')
_, n2c2_df['palm2_five_shot_labels'] = html_parsing_n2c2(n2c2_df, 'html_palm2_five_shot')
_, n2c2_df['palm2_ten_shot_labels'] = html_parsing_n2c2(n2c2_df, 'html_palm2_ten_shot')
_, n2c2_df['palm2_twenty_shot_labels'] = html_parsing_n2c2(n2c2_df, 'html_palm2_twenty_shot')

In [33]:
print(f"F1 Score One Shot (Strict): {get_macro_average_f1(get_classification_report(n2c2_df, 'gt_labels', 'palm2_one_shot_labels', 'strict'))}")
print(f"F1 Score Five Shot (Strict): {get_macro_average_f1(get_classification_report(n2c2_df, 'gt_labels', 'palm2_five_shot_labels', 'strict'))}")
print(f"F1 Score Ten Shot (Strict): {get_macro_average_f1(get_classification_report(n2c2_df, 'gt_labels', 'palm2_ten_shot_labels', 'strict'))}")
print(f"F1 Score Twenty Shot (Strict): {get_macro_average_f1(get_classification_report(n2c2_df, 'gt_labels', 'palm2_twenty_shot_labels', 'strict'))}")

F1 Score One Shot (Strict): 0.2483551418438149
F1 Score Five Shot (Strict): 0.4415615290491005
F1 Score Ten Shot (Strict): 0.5191967371610239
F1 Score Twenty Shot (Strict): 0.5447600480768185


In [34]:
print(f"F1 Score One Shot (Lenient): {get_macro_average_f1(get_classification_report(n2c2_df, 'gt_labels', 'palm2_one_shot_labels', 'lenient'))}")
print(f"F1 Score Five Shot (Lenient): {get_macro_average_f1(get_classification_report(n2c2_df, 'gt_labels', 'palm2_five_shot_labels', 'lenient'))}")
print(f"F1 Score Ten Shot (Lenient): {get_macro_average_f1(get_classification_report(n2c2_df, 'gt_labels', 'palm2_ten_shot_labels', 'lenient'))}")
print(f"F1 Score Twenty Shot (Lenient): {get_macro_average_f1(get_classification_report(n2c2_df, 'gt_labels', 'palm2_twenty_shot_labels', 'lenient'))}")

F1 Score One Shot (Lenient): 0.38963457779104854
F1 Score Five Shot (Lenient): 0.5801350723986359
F1 Score Ten Shot (Lenient): 0.6524081629707195
F1 Score Twenty Shot (Lenient): 0.6531610078079909


In [35]:
print(f"Average PaLM 2 one-shot prediction time: {n2c2_df['palm2_one_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 five-shot prediction time: {n2c2_df['palm2_five_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 ten-shot prediction time: {n2c2_df['palm2_ten_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 twenty-shot prediction time: {n2c2_df['palm2_twenty_shot_time'].mean():.2f} seconds")

Average PaLM 2 one-shot prediction time: 1.65 seconds
Average PaLM 2 five-shot prediction time: 2.20 seconds
Average PaLM 2 ten-shot prediction time: 2.06 seconds
Average PaLM 2 twenty-shot prediction time: 2.91 seconds


In [36]:
# save the inference results
n2c2_df.to_csv('data/NER/2018_n2c2/test_200_palm2_bison_results.csv', index=False)

# 2. RE (Relation Extraction)

## 2.1 2018 n2c2 Dataset

### 2.1.1 Infernece

In [38]:
n2c2_df = pd.read_csv('data/RE/2018_n2c2/test_200.csv')
n2c2_example_df = pd.read_csv('data/RE/2018_n2c2/examples.csv')

In [39]:
system_message = """You are a helpful assistant to perform the following task.
"TASK: the task is to classify relations for a sentence."
"INPUT: the input is a sentence where the entities are labeled within [E${X}] and [E${X}/] in a sentence, where X is an integer representing an unique entity."
"OUTPUT: your task is to select one out of the nine types of relations ('STRENGTH-DRUG', 'ROUTE-DRUG', 'FREQUENCY-DRUG', 'FORM-DRUG', 'DOSAGE-DRUG', 'REASON-DRUG', 'DURATION-DRUG', 'ADE-DRUG', and 'No relation')."
"""
def get_re_2018_n2c2(sentence: str, shot: int = 0) -> str:
    """
    Get RE result from 2018 n2c2 dataset given a sentence and number of examples
    Args:
        sentence: input sentence
        shot: number of examples
    Returns:
        response: RE result
    """

    parameters = {
        "temperature": 0.0,
        "max_output_tokens": 2048,
        "top_p": 0.95,
        "top_k": 1,
    }
    
    examples = []
    for i in range(shot):
        examples.append(
            InputOutputTextPair(
                input_text=n2c2_example_df['text'][i],
                output_text=n2c2_example_df['labels'][i],
            )
        )

    chat = chat_model.start_chat(
        context=system_message,
        examples=examples,
    )

    time_start = time.time()
    response = chat.send_message(
        sentence, **parameters
    )
    time_end = time.time()

    return response.text, time_end - time_start

In [None]:
for i in tqdm(range(0, len(n2c2_df), 1)):
    if (i + 1) % 15 == 0: # PaLM2 has a quota limit
        time.sleep(65)
    n2c2_df.loc[i, 'palm2_one_shot'], n2c2_df.loc[i, 'palm2_one_shot_time'] = get_re_2018_n2c2(n2c2_df.iloc[i]['text'], 1)
    n2c2_df.loc[i, 'palm2_five_shot'], n2c2_df.loc[i, 'palm2_five_shot_time'] = get_re_2018_n2c2(n2c2_df.iloc[i]['text'], 5)
    n2c2_df.loc[i, 'palm2_ten_shot'], n2c2_df.loc[i, 'palm2_ten_shot_time'] = get_re_2018_n2c2(n2c2_df.iloc[i]['text'], 10)
    n2c2_df.loc[i, 'palm2_twenty_shot'], n2c2_df.loc[i, 'palm2_twenty_shot_time'] = get_re_2018_n2c2(n2c2_df.iloc[i]['text'], 20)

### 2.1.2 Evaluation

In [43]:
# get rid of ' ' if any
n2c2_df['palm2_one_shot'] = n2c2_df['palm2_one_shot'].apply(lambda x: x[1:-1] if "'" in x else x)
n2c2_df['palm2_five_shot'] = n2c2_df['palm2_five_shot'].apply(lambda x: x[1:-1] if "'" in x else x)
n2c2_df['palm2_ten_shot'] = n2c2_df['palm2_ten_shot'].apply(lambda x: x[1:-1] if "'" in x else x)
n2c2_df['palm2_twenty_shot'] = n2c2_df['palm2_twenty_shot'].apply(lambda x: x[1:-1] if "'" in x else x)

In [70]:
label2digit = {
    'No relation': 0,
    'STRENGTH-DRUG': 1,
    'ROUTE-DRUG': 2,
    'FREQUENCY-DRUG': 3,
    'FORM-DRUG': 4,
    'DOSAGE-DRUG': 5,
    'REASON-DRUG': 6,
    'DURATION-DRUG': 7,
    'ADE-DRUG': 8
}
def get_digit(x):
    for k, v in label2digit.items():
        if k in x:
            return v
    return 0

In [71]:
# get digit label while considering failed LLM outputs as 'No relation'
n2c2_df['labels'] = n2c2_df['labels'].apply(get_digit)
n2c2_df['palm2_one_shot_labels'] = n2c2_df['palm2_one_shot'].apply(get_digit)
n2c2_df['palm2_five_shot_labels'] = n2c2_df['palm2_five_shot'].apply(get_digit)
n2c2_df['palm2_ten_shot_labels'] = n2c2_df['palm2_ten_shot'].apply(get_digit)
n2c2_df['palm2_twenty_shot_labels'] = n2c2_df['palm2_twenty_shot'].apply(get_digit)

In [None]:
# Optional: you can just load the llm output from the csv file instead of running the above code
# n2c2_df = pd.read_csv("data/RE/2018_n2c2/test_200_palm2_bison_results.csv")

In [73]:
y_true = n2c2_df['labels'].tolist()
y_pred = n2c2_df['palm2_one_shot_labels'].tolist()
print(f"F1 Score One Shot: {f1_score(y_true, y_pred, average='macro')}")
y_pred = n2c2_df['palm2_five_shot_labels'].tolist()
print(f"F1 Score Five Shot: {f1_score(y_true, y_pred, average='macro')}")
y_pred = n2c2_df['palm2_ten_shot_labels'].tolist()
print(f"F1 Score Ten Shot: {f1_score(y_true, y_pred, average='macro')}")
y_pred = n2c2_df['palm2_twenty_shot_labels'].tolist()
print(f"F1 Score Twenty Shot: {f1_score(y_true, y_pred, average='macro')}")

F1 Score One Shot: 0.30078351625254
F1 Score Five Shot: 0.3186707818930041
F1 Score Ten Shot: 0.40771004467502525
F1 Score Twenty Shot: 0.3189666977902272


In [74]:
print(f"Average PaLM 2 one-shot prediction time: {n2c2_df['palm2_one_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 five-shot prediction time: {n2c2_df['palm2_five_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 ten-shot prediction time: {n2c2_df['palm2_ten_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 twenty-shot prediction time: {n2c2_df['palm2_twenty_shot_time'].mean():.2f} seconds")

Average PaLM 2 one-shot prediction time: 0.41 seconds
Average PaLM 2 five-shot prediction time: 0.46 seconds
Average PaLM 2 ten-shot prediction time: 0.48 seconds
Average PaLM 2 twenty-shot prediction time: 0.56 seconds


In [75]:
# save the inference results
n2c2_df.to_csv('data/RE/2018_n2c2/test_200_palm2_bison_results.csv', index=False)

## 2.2 GAD

### 2.2.1 Inference

In [88]:
gad_df = pd.read_csv('data/RE/GAD/test_200.csv')
gad_example_df = pd.read_csv('data/RE/GAD/examples.csv')

In [104]:
system_message = """You are a helpful assistant to perform the following task.
"TASK: the task is to classify relations between a disease and a gene for a sentence."
"INPUT: the input is a sentence where the disease is labeled as @DISEASE$ and the gene is labeled as @GENE$ accordingly in a sentence. "
"OUTPUT: your task is to select one out of the two types of relations (0 and 1) for the gene and disease without any explanation or other characters: 
        0, no relations 
        1, has relations"
"""
def get_re_gad(sentence: str, shot: int = 0) -> str:
    """
    Get RE result from GAD dataset given a sentence and number of examples
    Args:
        sentence: input sentence
        shot: number of examples
    Returns:
        response: RE result
    """

    parameters = {
        "temperature": 0.0,
        "max_output_tokens": 2048,
        "top_p": 0.95,
        "top_k": 1,
    }
    
    examples = []
    for i in range(shot):
        examples.append(
            InputOutputTextPair(
                input_text=gad_example_df['text'][i],
                output_text=str(gad_example_df['labels'][i]),
            )
        )

    chat = chat_model.start_chat(
        context=system_message,
        examples=examples,
    )

    time_start = time.time()
    response = chat.send_message(
        sentence, **parameters
    )
    time_end = time.time()

    return response.text, time_end - time_start

In [None]:
for i in tqdm(range(0, len(gad_df), 1)):
    if (i + 1) % 15 == 0: # PaLM2 has a quota limit
        time.sleep(65)
    gad_df.loc[i, 'palm2_one_shot'], gad_df.loc[i, 'palm2_one_shot_time'] = get_re_gad(gad_df.iloc[i]['text'], 1)
    gad_df.loc[i, 'palm2_five_shot'], gad_df.loc[i, 'palm2_five_shot_time'] = get_re_gad(gad_df.iloc[i]['text'], 5)
    gad_df.loc[i, 'palm2_ten_shot'], gad_df.loc[i, 'palm2_ten_shot_time'] = get_re_gad(gad_df.iloc[i]['text'], 10)
    gad_df.loc[i, 'palm2_twenty_shot'], gad_df.loc[i, 'palm2_twenty_shot_time'] = get_re_gad(gad_df.iloc[i]['text'], 20)

### 2.2.2 Evaluation

In [125]:
# convert some strings to int while considering failed LLM outputs as 'No relation (0)'
gad_df['palm2_one_shot_label'] = gad_df['palm2_one_shot'].apply(lambda x: int(x[1:]) if x[1:].isdigit() else 0)
gad_df['palm2_five_shot_label'] = gad_df['palm2_five_shot'].apply(lambda x: int(x[1:]) if x[1:].isdigit() else 0)
gad_df['palm2_ten_shot_label'] = gad_df['palm2_ten_shot'].apply(lambda x: int(x[1:]) if x[1:].isdigit() else 0)
gad_df['palm2_twenty_shot_label'] = gad_df['palm2_twenty_shot'].apply(lambda x: int(x[1:]) if x[1:].isdigit() else 0)

In [None]:
# Optional: you can just load the llm output from the csv file instead of running the above code
# gad_df = pd.read_csv("data/RE/GAD/test_200_palm2_bison_results.csv")

In [127]:
y_true = gad_df['labels'].tolist()
y_pred = gad_df['palm2_one_shot_label'].tolist()
print(f"F1 Score One Shot: {f1_score(y_true, y_pred, average='macro')}")
y_pred = gad_df['palm2_five_shot_label'].tolist()
print(f"F1 Score Five Shot: {f1_score(y_true, y_pred, average='macro')}")
y_pred = gad_df['palm2_ten_shot_label'].tolist()
print(f"F1 Score Ten Shot: {f1_score(y_true, y_pred, average='macro')}")
y_pred = gad_df['palm2_twenty_shot_label'].tolist()
print(f"F1 Score Twenty Shot: {f1_score(y_true, y_pred, average='macro')}")

F1 Score One Shot: 0.4542477715117337
F1 Score Five Shot: 0.4480980012894906
F1 Score Ten Shot: 0.4563951230876757
F1 Score Twenty Shot: 0.46872985170857506


In [129]:
print(f"Average PaLM 2 one-shot prediction time: {gad_df['palm2_one_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 five-shot prediction time: {gad_df['palm2_five_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 ten-shot prediction time: {gad_df['palm2_ten_shot_time'].mean():.2f} seconds")
print(f"Average PaLM 2 twenty-shot prediction time: {gad_df['palm2_twenty_shot_time'].mean():.2f} seconds")

Average PaLM 2 one-shot prediction time: 0.41 seconds
Average PaLM 2 five-shot prediction time: 0.40 seconds
Average PaLM 2 ten-shot prediction time: 0.46 seconds
Average PaLM 2 twenty-shot prediction time: 0.54 seconds


In [130]:
# save the inference results
gad_df.to_csv('data/RE/GAD/test_200_palm2_bison_results.csv', index=False)