<a href="https://colab.research.google.com/github/datakind/hxl-metadata-prediction/blob/main/openai-hxl-prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

A data standard on platforms such as the [Humanitarian Data Exchange (HDX)](https://data.humdata.org/) is the [Humanitarian Exchange Language (HXL)](https://hxlstandard.org/), a column level set of attributes and tags and attributes which improve data interoperability and discovery. These tags and attributes are typically set by hand by data owners, which being a manual process can result in poor dataset coverage. Improving coverage through ML and AI techniques is desirable for faster and more efficient use of data in responding to Humanitarian disasters.

Previous work has focussed on fine tuning LLMs to complete tags and attrubutes, starting with the study [Predicting Metadata on Humanitarian Datasets with GPT 3](https://medium.com/towards-data-science/predicting-metadata-for-humanitarian-datasets-using-gpt-3-b104be17716d). This has yielded promosing results, but is constrained by the quality of training data and the HDX team have confirmed that basic tags related to location and dates are popular, more esoteric tags defined in [the standard](https://hxlstandard.org/standard/1-1final/tagging/) are not well represented.

This notebook fine-tunes an OpenAI model to test performance.

# Setup

1. Run notebook [generate-test-train-data.ipynb]([generate-test-train-data.ipynb]) to generate test and train data files for use in fine-tuning
2. Set `OPENAI_API_KEY` in file `.env` or in Colab secrets

In [None]:
!pip install pandas==2.2.2
!pip install openai==1.35.3
!pip install python-dotenv==1.0.1

In [27]:
import openai
import os
import time
import openai
from openai import OpenAI
import pandas as pd
import json
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import sys

from dotenv import load_dotenv
load_dotenv()

if os.getenv("OPENAI_API_KEY") is None:
  from google.colab import userdata
  OPENAI_API_KEY =  userdata.get('OPENAI_API_KEY')
else:
  OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

# If using Colab, this is where Google drive gets mounted. Otherwise leave blank
GOOGLE_BASE_DIR = "/content/drive/MyDrive/Colab"

# Where to save local data files
LOCAL_DATA_DIR = f"{GOOGLE_BASE_DIR}/hxl-metadata-prediction/data/"

# As generated by generate-test-train-data.ipynb
TRAINING_FILE = f"{LOCAL_DATA_DIR}/hxl_chat_prompts_train.jsonl"
TEST_FILE = f"{LOCAL_DATA_DIR}/hxl_chat_prompts_test.jsonl"

# Base model to fine-tune
MODEL = "gpt-4o-mini-2024-07-18"

pd.set_option('display.max_colwidth', 900)
pd.set_option('display.max_rows', 200)

client = OpenAI(
    api_key=OPENAI_API_KEY
)

In [7]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Analysis

## Check test/train

Let's do a sanity check to ensure the test set doesn't include data from orgnaizations in the traning set.

In [50]:
def read_prompts_file(filename):
  results = []
  with open(filename) as f:
    prompts = [json.loads(line) for line in f]
    for p in prompts:
      exp = p["messages"][-1]["content"]
      p["prompt"] = p["messages"][0:2]
      p["expected"] = exp
      results.append(p)
    results = pd.DataFrame(results)

    print(f"\nFound {len(results)} prompts")
    print(f"\nData providers {results['Data provider'].unique()}")

    results['tag'] = results['expected'].apply(lambda x: x.split('+')[0])
    tag_counts = results['tag'].value_counts()
    print("\n",tag_counts)

  return results

X_train = read_prompts_file(TRAINING_FILE)
X_test = read_prompts_file(TEST_FILE)

# Print data providers in X_test that are in X_train
common_providers = list(set(X_train["Data provider"]).intersection(set(X_test["Data provider"])))
if len(common_providers) == 0:
  print("No common Data providers")
else:
  print(f"Common providers: {common_providers} found in both Train and test sets!!!")
  sys.exit()

display(X_train)


Found 2919 prompts

Data providers ['international-organization-for-migration'
 'eth-zurich-weather-and-climate-risks' 'ifrc' 'ocha-fts' 'cerf' 'awsd'
 'insecurity-insight' 'ocha-sudan' 'ocha-niger' 'wfp' 'ocha-car' 'cred'
 'fao' 'water-point-data-exchange' 'ipc' 'interaction' 'ocha-somalia'
 'hdx' 'ocha-yemen' 'ocha-afghanistan' 'ourairports' 'hxl'
 'world-bank-group' 'unrwa-for-palestine-refugees-in-the-near-east'
 'ocha-fiss' 'ocha-ukraine' 'unhcr' 'ocha-ethiopia' 'ocha-haiti'
 'ocha-colombia' 'ocha-chad' 'ocha-nigeria' 'ocha-myanmar'
 'ocha-south-sudan' 'ocha-mali' 'ocha-dr-congo'
 'blavatnik-school-of-government-university-of-oxford' 'ocha-burkina'
 'un-ocha' 'ocha-ds' 'reliefweb' 'ocha-rosc' 'ocha-cameroon' 'unicef-rdc'
 'ocha-rosea' 'ocha-rolac' 'ocha-burundi' 'world-health-organization'
 'jcc' 'international-displacement-monitoring-centre-idmc' 'ocha-iraq'
 'ocha-opt' 'qcri' 'health-cluster' 'ocha-mozambique-hat' 'unicef-data'
 'unesco' 'ocha-libya' 'ocha-rowca' 'iati' 'clear'

Unnamed: 0,messages,HDX resource id,HDX dataset id,Data provider,Date created,Locations,URL,prompt,expected,tag
0,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/DRC - Baseline Assessment - M23 Crisis 13 - February 2024.xlsx'; column_name:'Total IDP HH'; examples:[319283]'}, {'role': 'assistant', 'content': '#affected+hh'}]",26ecc26f-74e7-46af-b450-8872dca0b63b,drc-displacement-idps-returnees-m23-crisis-north-kivu-province-baseline-assessment-iom-dtm,international-organization-for-migration,2023-10-16,COD,https://data.humdata.org/dataset/3554c498-660a-45cb-ada5-86a1fbcd6056/resource/26ecc26f-74e7-46af-b450-8872dca0b63b/download/adc_27jan-12_feb_update_public_v2.xlsx,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/DRC - Baseline Assessment - M23 Crisis 13 - February 2024.xlsx'; column_name:'Total IDP HH'; examples:[319283]'}]",#affected+hh,#affected
1,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/DRC - Baseline Assessment - M23 Crisis 13 - February 2024.xlsx'; column_name:'Total Returnees'; examples:[587705]'}, {'role': 'assistant', 'content': '#affected+ind+returnees'}]",26ecc26f-74e7-46af-b450-8872dca0b63b,drc-displacement-idps-returnees-m23-crisis-north-kivu-province-baseline-assessment-iom-dtm,international-organization-for-migration,2023-10-16,COD,https://data.humdata.org/dataset/3554c498-660a-45cb-ada5-86a1fbcd6056/resource/26ecc26f-74e7-46af-b450-8872dca0b63b/download/adc_27jan-12_feb_update_public_v2.xlsx,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/DRC - Baseline Assessment - M23 Crisis 13 - February 2024.xlsx'; column_name:'Total Returnees'; examples:[587705]'}]",#affected+ind+returnees,#affected
2,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/admin1-summaries-earthquake.csv'; column_name:'country_name'; examples:['Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan']'}, {'role': 'assistant', 'content': '#country'}]",dbf9b4bd-1321-4846-b6f0-4654509d3626,climada-earthquake-dataset,eth-zurich-weather-and-climate-risks,2024-02-23,AFG BFA BDI CMR CAF TCD COL COD ETH HTI MLI MOZ MMR NER NGA SOM SSD PSE SDN SYR UKR VEN YEM,https://data.humdata.org/dataset/744f4f0b-3172-4397-9609-5ec0b9d34fcb/resource/dbf9b4bd-1321-4846-b6f0-4654509d3626/download/admin1-summaries-earthquake.csv,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/admin1-summaries-earthquake.csv'; column_name:'country_name'; examples:['Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan', 'Afghanistan']'}]",#country,#country
3,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/admin1-summaries-earthquake.csv'; column_name:'latitude'; examples:['34.5527', '34.9568', '34.9619', '34.3033', '34.0121', '34.2743', '34.7693', '35.4474', '35.8025', '34.8046', '33.3211']'}, {'role': 'assistant', 'content': '#geo+lat'}]",dbf9b4bd-1321-4846-b6f0-4654509d3626,climada-earthquake-dataset,eth-zurich-weather-and-climate-risks,2024-02-23,AFG BFA BDI CMR CAF TCD COL COD ETH HTI MLI MOZ MMR NER NGA SOM SSD PSE SDN SYR UKR VEN YEM,https://data.humdata.org/dataset/744f4f0b-3172-4397-9609-5ec0b9d34fcb/resource/dbf9b4bd-1321-4846-b6f0-4654509d3626/download/admin1-summaries-earthquake.csv,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/admin1-summaries-earthquake.csv'; column_name:'latitude'; examples:['34.5527', '34.9568', '34.9619', '34.3033', '34.0121', '34.2743', '34.7693', '35.4474', '35.8025', '34.8046', '33.3211']'}]",#geo+lat,#geo
4,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/admin1-summaries-earthquake.csv'; column_name:'longitude'; examples:['69.3376', '69.6258', '68.887', '68.2174', '69.1631', '70.4529', '70.1638', '69.798', '68.9114', '67.2373', '67.812']'}, {'role': 'assistant', 'content': '#geo+lon'}]",dbf9b4bd-1321-4846-b6f0-4654509d3626,climada-earthquake-dataset,eth-zurich-weather-and-climate-risks,2024-02-23,AFG BFA BDI CMR CAF TCD COL COD ETH HTI MLI MOZ MMR NER NGA SOM SSD PSE SDN SYR UKR VEN YEM,https://data.humdata.org/dataset/744f4f0b-3172-4397-9609-5ec0b9d34fcb/resource/dbf9b4bd-1321-4846-b6f0-4654509d3626/download/admin1-summaries-earthquake.csv,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/admin1-summaries-earthquake.csv'; column_name:'longitude'; examples:['69.3376', '69.6258', '68.887', '68.2174', '69.1631', '70.4529', '70.1638', '69.798', '68.9114', '67.2373', '67.812']'}]",#geo+lon,#geo
...,...,...,...,...,...,...,...,...,...,...
2914,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/1501 Sierra Leone Health Centers.xlsx'; column_name:'Province'; examples:['Eastern', 'Eastern', 'Eastern', 'Eastern', 'Eastern', 'Eastern', 'Eastern', 'Eastern', 'Eastern', 'Eastern', 'Eastern']'}, {'role': 'assistant', 'content': '#adm1'}]",f78dc606-04e2-4fb6-a7eb-9eb995c33f76,141121-sierra-leone-health-facilities,standby-task-force,2014-11-01,SLE,https://data.humdata.org/dataset/7453fb80-752b-4078-a892-d936f9846dab/resource/f78dc606-04e2-4fb6-a7eb-9eb995c33f76/download/1501-sierra-leone-health-centers.xlsx,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/1501 Sierra Leone Health Centers.xlsx'; column_name:'Province'; examples:['Eastern', 'Eastern', 'Eastern', 'Eastern', 'Eastern', 'Eastern', 'Eastern', 'Eastern', 'Eastern', 'Eastern', 'Eastern']'}]",#adm1,#adm1
2915,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/1501 Sierra Leone Health Centers.xlsx'; column_name:'District'; examples:['Kenema', 'Kenema', 'Kenema', 'Kenema', 'Kenema', 'Kenema', 'Kenema', 'Kenema', 'Kenema', 'Kenema', 'Kenema']'}, {'role': 'assistant', 'content': '#adm2'}]",f78dc606-04e2-4fb6-a7eb-9eb995c33f76,141121-sierra-leone-health-facilities,standby-task-force,2014-11-01,SLE,https://data.humdata.org/dataset/7453fb80-752b-4078-a892-d936f9846dab/resource/f78dc606-04e2-4fb6-a7eb-9eb995c33f76/download/1501-sierra-leone-health-centers.xlsx,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/1501 Sierra Leone Health Centers.xlsx'; column_name:'District'; examples:['Kenema', 'Kenema', 'Kenema', 'Kenema', 'Kenema', 'Kenema', 'Kenema', 'Kenema', 'Kenema', 'Kenema', 'Kenema']'}]",#adm2,#adm2
2916,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/1501 Sierra Leone Health Centers.xlsx'; column_name:'Chiefdom'; examples:['Dama', 'Dama', 'Dama', 'Dama', 'Dama', 'Dama', 'Dama', 'Dama', 'Dama', 'Dama', 'Dama']'}, {'role': 'assistant', 'content': '#adm3'}]",f78dc606-04e2-4fb6-a7eb-9eb995c33f76,141121-sierra-leone-health-facilities,standby-task-force,2014-11-01,SLE,https://data.humdata.org/dataset/7453fb80-752b-4078-a892-d936f9846dab/resource/f78dc606-04e2-4fb6-a7eb-9eb995c33f76/download/1501-sierra-leone-health-centers.xlsx,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/1501 Sierra Leone Health Centers.xlsx'; column_name:'Chiefdom'; examples:['Dama', 'Dama', 'Dama', 'Dama', 'Dama', 'Dama', 'Dama', 'Dama', 'Dama', 'Dama', 'Dama']'}]",#adm3,#adm3
2917,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/Guinea health-facility master data.google sheet'; column_name:'Nom de la région'; examples:['Boke', 'Conakry', 'Faranah', 'Kankan', 'Kindia', 'Labe', 'Mamou', 'Nzerekore']'}, {'role': 'assistant', 'content': '#adm1+name'}]",5d2531d6-c03a-449b-afdd-52c07d687679,guinea-healthcare-master-data,ipc-cluster-guinea,2015-09-03,GIN,https://docs.google.com/spreadsheets/d/1x0MgLKLG3fxWBJ200VV5Fr67GqgSvYISefO-EYEp2wg/edit#gid=0,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with these details? resource_name='/content/drive/MyDrive/Colab/hxl-metadata-prediction/data/Guinea health-facility master data.google sheet'; column_name:'Nom de la région'; examples:['Boke', 'Conakry', 'Faranah', 'Kankan', 'Kindia', 'Labe', 'Mamou', 'Nzerekore']'}]",#adm1+name,#adm1


## Fine-tune

In [61]:
def fine_tune_model(train_file, model_name="gpt-4o-mini"):
    """
    Fine-tune an OpenAI model using training data.

    Args:
        prompt_file (str): The file containing the prompts to use for fine-tuning.
        model_name (str): The name of the model to fine-tune. Default is "davinci-002".

    Returns:
        str: The ID of the fine-tuned model.
    """

    # Create a version of the train_file jsonl which only has "messages"
    train_file_short = train_file.replace(".jsonl", "_short.jsonl")
    with open(train_file) as f:
        prompts = [json.loads(line) for line in f]
        prompts = [p["messages"] for p in prompts]
        with open(train_file_short, "w") as f:
            for p in prompts:
              row = {}
              row["messages"] = p
              f.write(json.dumps(row) + "\n")

    # Create a file on OpenAI for fine-tuning
    file = client.files.create(
        file=open(train_file_short, "rb"),
        purpose="fine-tune"
    )
    file_id = file.id
    print(f"Uploaded training file with ID: {file_id}")

    # Start the fine-tuning job
    ft = client.fine_tuning.jobs.create(
        training_file=file_id,
        model=model_name
    )
    ft_id = ft.id
    print(f"Fine-tuning job started with ID: {ft_id}")

    # Monitor the status of the fine-tuning job
    ft_result = client.fine_tuning.jobs.retrieve(ft_id)
    while ft_result.status != 'succeeded':
        print(f"Current status: {ft_result.status}")
        time.sleep(120)  # Wait for 60 seconds before checking again
        ft_result = client.fine_tuning.jobs.retrieve(ft_id)
        if 'failed' in ft_result.status.lower():
            sys.exit()

    print(f"Fine-tuning job {ft_id} succeeded!")

    # Retrieve the fine-tuned model
    fine_tuned_model = ft_result.fine_tuned_model
    print(f"Fine-tuned model: {fine_tuned_model}")

    return fine_tuned_model

In [None]:
model = fine_tune_model(TRAINING_FILE, model_name=MODEL)

Uploaded training file with ID: file-uTRS3liJunTqLytyxIbpFTc1
Fine-tuning job started with ID: ftjob-cnmkffixfVvVCefjnxT8fH2h
Current status: validating_files


In [21]:
model = fine_tune_model(TRAINING_FILE, model_name=MODEL)

Uploaded training file with ID: file-aceqVIqxkn1PIYnhKHnBKAok
Fine-tuning job started with ID: ftjob-ilc329fOocaog6LjcXfabKO7
Current status: validating_files
Current status: validating_files
Current status: queued
Current status: queued
Current status: queued
Current status: queued
Current status: queued
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Current status: running
Fine-tuning job ftjob-ilc329fOocaog6LjcXfabKO

'ft:gpt-4o-mini-2024-07-18:datakind::9oJXzcfa'

In [20]:
# model = "ft:gpt-4o-mini-2024-07-18:datakind::9oJXzcfa"
print(f"Fine-tuned model: {model}")

Fine-tuned model: ft:gpt-4o-mini-2024-07-18:datakind::9oJXzcfa


## Prediction Test

In [21]:
def make_chat_predictions(prompts, model, temperature=0.1, max_tokens=13):
  results = []
  for p in prompts:
    exp = p["messages"][-1]["content"]
    p["messages"] = p["messages"][0:2]
    completion = client.chat.completions.create(
      model=model,
      messages=p["messages"],
      temperature=temperature,
      max_tokens=max_tokens
    )
    actual = completion.choices[0].message.content
    res = {
        "prompt": p["messages"],
        "expected": exp,
        "actual": actual
    }
    results.append(res)
    #print(json.dumps(res, indent=4))

  results = pd.DataFrame(results)

  return results


def output_prediction_metrics(results, prediction_field="expected", actual_field="actual"):
    """
    Prints out model performance report if provided results as a dataframe, eg record ...

    'prompt': ' \'ISO3\' | "[\'RWA\', \'RWA\', \'RWA\', \'RWA\', \'RWA\', \'RWA\', \'RWA\', \'RWA\']"',
    'predicted': ' #country+code+iso3+v_iso3+',
    'predicted_post_processed': '#country+code',
    'actual_tags': '#country+code'

    Parameters
    ----------
    results : dataframe
        Dataframe of results
    prediction_field : str
        Field name of element with prediction. Handy for comparing raw and post-processed predictions.
    actual_field: str
        Field name of the actual result for comparison with prediction
    """
    y_test = []
    y_pred = []
    y_justtag_test = []
    y_justtag_pred = []
    for index, r in results.iterrows():
        if actual_field not in r and predicted_field not in r:
            print("Provided results do not contain expected values.")
            sys.exit()
        y_pred.append(r[prediction_field])
        y_test.append(r[actual_field])
        expected_tag = r[actual_field].split("+")[0]
        predicted_tag = r[prediction_field].split("+")[0]
        y_justtag_test.append(expected_tag)
        y_justtag_pred.append(predicted_tag)

    print(f"LLM results for {prediction_field}, {len(results)} predictions ...")
    print("\nJust HXL tags ...\n")
    print(f"Accuracy: {round(accuracy_score(y_justtag_test, y_justtag_pred),2)}")
    print(
        f"Precision: {round(precision_score(y_justtag_test, y_justtag_pred, average='weighted', zero_division=0),2)}"
    )
    print(
        f"Recall: {round(recall_score(y_justtag_test, y_justtag_pred, average='weighted', zero_division=0),2)}"
    )
    print(
        f"F1: {round(f1_score(y_justtag_test, y_justtag_pred, average='weighted', zero_division=0),2)}"
    )

    print(f"\nTags and attributes with {prediction_field} ...\n")
    print(f"Accuracy: {round(accuracy_score(y_test, y_pred),2)}")
    print(
        f"Precision: {round(precision_score(y_test, y_pred, average='weighted', zero_division=0),2)}"
    )
    print(
        f"Recall: {round(recall_score(y_test, y_pred, average='weighted', zero_division=0),2)}"
    )
    print(
        f"F1: {round(f1_score(y_test, y_pred, average='weighted', zero_division=0),2)}"
    )

    return

In [13]:
with open(TEST_FILE) as f:
    X_test = [json.loads(line) for line in f]

# Subsample
size = 50
X_test2 = X_test[-size:]

results = make_chat_predictions(X_test2, model)
#output_prediction_metrics(results)

display(results)

print("Done")

Unnamed: 0,prompt,expected,actual
0,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with...",#date,#date
1,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with...",#country+code,#country+code
2,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with...",#country+name,#country+name
3,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with...",#region+name,#adm1+name
4,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with...",#affected+infected,#affected+infected
5,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with...",#affected+killed,#affected+killed
6,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with...",#date,#date
7,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with...",#country+code,#country+code
8,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with...",#country+name,#country+name
9,"[{'role': 'system', 'content': '  You are an assistant that replies with HXL tags and attributes""  '}, {'role': 'user', 'content': 'What are the HXL tags and attributes for a column with...",#region+name,#adm1+name


Done


LLM results for expected, 50 predictions ...

Just HXL tags ...

Accuracy: 0.9
Precision: 0.95
Recall: 0.9
F1: 0.91

Tags and attributes with expected ...

Accuracy: 0.7
Precision: 0.85
Recall: 0.7
F1: 0.73
