# GeoLlama: Fine-Tuning on Topoym Extraction
This notebook wlaks through the process of producing the `GeoLlama_toponym` model used to extract toponyms form text. We will be using the LGL and GeoVirus datasets to fine-tune the model. These have been pre-processed and combined into the `llama3_toponym_extraction_ft.json` dataset.

We will use Unsloth to optimize the fine tuning process. Thi will make for fater, more memory efficient training, alhtough it does mean the training can only be done on a Linux machine. A Google Colab GPU instance is optimal for this.

Fine tuning require ~20GB of GPU RAM, depending on GPU model.

## 1. Setting up the Unsloth package

In [31]:
# standard library imports
import random
import json
import os
# third party imports
from google.colab import drive
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import SFTTrainer
from transformers import TrainingArguments
from huggingface_hub import login
from datasets import Dataset

import pandas as pd
import numpy as np
import torch

#drive.mount('/content/drive').

In [3]:
os.environ["WANDB_DISABLED"] = "true"

In [None]:
max_seq_length = 2048
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/mistral-7b-bnb-4bit",
    "unsloth/mistral-7b-instruct-v0.2-bnb-4bit",
    "unsloth/llama-2-7b-bnb-4bit",
    "unsloth/llama-2-13b-bnb-4bit",
    "unsloth/codellama-34b-bnb-4bit",
    "unsloth/tinyllama-bnb-4bit",
    "unsloth/llama-3-8b-bnb-4bit",
    "unsloth/llama-3-70b-bnb-4bit",
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-bnb-4bit", # Choose ANY! eg teknium/OpenHermes-2.5-Mistral-7B
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)



config.json:   0%|          | 0.00/1.20k [00:00<?, ?B/s]

==((====))==  Unsloth: Fast Llama patching release 2024.7
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2+cu121. CUDA = 7.5. CUDA Toolkit = 12.1.
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.25.post1. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/5.70G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/172 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/464 [00:00<?, ?B/s]

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    use_gradient_checkpointing = "unsloth", # 4x longer contexts auto supported!
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)



Unsloth 2024.7 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


## 2. Data Pre-processing
We will use the `llama3_toponym_extraction_ft.json` dataset to fine-tune and test the model. The dataset has 817 examples, which isn't a huge number, but the relative simplicty of the task and the requirement for a highly structured output means we don't need an extremely large training set.

We'll use 750 samples from the dataset as training data and use the remaining 67 samples to test the model.

In [6]:
# load the data
with open('../data/fine_tuning_data/llama3_toponym_extraction_ft.json') as f:
    data = json.load(f)

In [7]:
# Set up the prompt template for model training

RAG_prompt = """Below is an instruction that describes a task, paired with an input that provides a specfic example which the task should be applied to. Write a response that appropriately completes the request.

### Instruction:
{}

### Input:
{}

### Response:
{}
"""

EOS_TOKEN = tokenizer.eos_token # Must add EOS_TOKEN
def formatting_prompts_func(examples):
    instructions = examples["instruction"]
    inputs       = examples["input"]
    outputs      = examples["response"]
    texts = []
    for instruction, input, output in zip(instructions, inputs, outputs):
        # Must add EOS_TOKEN, otherwise your generation will go on forever!
        text = RAG_prompt.format(instruction, input, output) + EOS_TOKEN
        texts.append(text)
    return { "text" : texts, }

ft_data = {"items":data}

In [10]:
# Set up the test/train split
random.seed(7723)
trn_idxs = random.sample(range(len(data)), 750)
val_idxs = [x for x in range(len(data)) if x not in trn_idxs]
trn_data = [data[i] for i in trn_idxs]
val_data = [data[i] for i in val_idxs]

trn_dataset = Dataset.from_pandas(pd.DataFrame(trn_data), split="train")

In [None]:
# format data for input into the model
trn_dataset = trn_dataset.map(formatting_prompts_func, batched = True,)

Map:   0%|          | 0/3600 [00:00<?, ? examples/s]

## 3. Training
This section shows the process of model training. Training takes ~45 minutes on a T4 GPU. We will use the `SFTTrainer` class to train the model, setting the parameters as shown below.

In [None]:
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = trn_dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        num_train_epochs = 1,
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
    ),
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Map (num_proc=2):   0%|          | 0/3600 [00:00<?, ? examples/s]

In [None]:
#@title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = NVIDIA L4. Max memory = 22.168 GB.
5.594 GB of memory reserved.


In [None]:
# trian the model
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 3,600 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 4
\        /    Total batch size = 8 | Total steps = 450
 "-____-"     Number of trainable parameters = 41,943,040


Step,Training Loss
1,1.9599
2,1.7759
3,1.7483
4,1.8338
5,1.6697
6,1.6293
7,1.7065
8,1.5009
9,1.46
10,1.4191


In [None]:

#@title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory         /max_memory*100, 3)
lora_percentage = round(used_memory_for_lora/max_memory*100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.")
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")



2730.1956 seconds used for training.
45.5 minutes used for training.
Peak reserved memory = 8.967 GB.
Peak reserved memory for training = 3.373 GB.
Peak reserved memory % of max memory = 40.45 %.
Peak reserved memory for training % of max memory = 15.216 %.


In [None]:
# log in to huggingface so the model can be saved
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
# save to huggingface
model.push_to_hub("JoeShingleton/GeoLlama-3.1-8b-toponym")

  0%|          | 0/1 [00:00<?, ?it/s]

adapter_model.safetensors:   0%|          | 0.00/168M [00:00<?, ?B/s]

Saved model to https://huggingface.co/JoeShingleton/GeoLlama


## 4. Testing the model
We will now test the model with our 67 testing samples. We'll use the `TopoModel` class form the `geo_llama` package as a wrapper for the fine-tuned model. This module will handle cleaning up any responses and processing inputs.

We'll assess the accuracy as the proportion of unique toponyms form the text which have been found by the model. Note: we are not expecting the model to identify the location of the toponym within the text, so the overlap will not be assessed.

In [21]:
from geo_llama.model import TopoModel

# set up the model
model = TopoModel(model_name='JoeShingleton/GeoLlama-3.1-8b-toponym',
                  prompt_path='../data/prompt_templates/prompt_template.txt',
                  instruct_path='../data/prompt_templates/topo_instruction.txt',
                  input_path=None,
                  config_path='../data/config_files/model_config.json')

==((====))==  Unsloth 2024.8: Fast Llama patching. Transformers = 4.43.3.
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2+cu121. CUDA = 7.5. CUDA Toolkit = 12.1.
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.25.post1. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [53]:
from tqdm import tqdm

outputs = []
for d in tqdm(val_data):
  prompt = model.toponym_prompt(d['input'])
  tst_outputs = model.get_output(prompt['text'], d['input'])
  outputs.append(tst_outputs)


  1%|▏         | 1/67 [00:02<02:18,  2.10s/it]


{'toponyms': ['Pointe Coupee', 'St. James']}



  3%|▎         | 2/67 [00:08<04:51,  4.49s/it]


{'toponyms': ['Minnesota', 'Fargo', 'Moorhead', 'Nebraska', 'Fargo-Moorhead', 'Red River Valley', 'Oakport', 'North Dakota', 'Alexandria']}



  4%|▍         | 3/67 [00:12<04:40,  4.39s/it]


{'toponyms': ['Augusta', 'Athens-Clarke', 'Clarke County']}



  6%|▌         | 4/67 [00:14<03:46,  3.59s/it]


{'toponyms': ['Athens', 'McMinn County', 'Etowah']}



  7%|▋         | 5/67 [00:16<02:58,  2.88s/it]


{'toponyms': ['Madison County']}



  9%|▉         | 6/67 [01:16<22:36, 22.24s/it]


{'toponyms': ['Arizona', 'Georgia', 'ATLANTA', 'Oregon', 'Washington', 'Chicago', 'British', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'American', 'Ameri

 10%|█         | 7/67 [01:18<15:32, 15.53s/it]


{'toponyms': ['Putnam County', 'Hurricane', 'W.Va.']}



 12%|█▏        | 8/67 [01:20<11:12, 11.39s/it]


{'toponyms': ['Gaston', 'Columbia', 'Lexington County', 'W. Columbia', 'Parkwood Drive', 'West Columbia']}



 13%|█▎        | 9/67 [01:21<07:54,  8.18s/it]


{'toponyms': ['Battle Ground']}



 15%|█▍        | 10/67 [01:26<06:41,  7.04s/it]


{'toponyms': ['Georgia', 'Nigeria', 'Clayton County', 'Dunwoody', 'Clayton', 'Ohio', 'Jonesboro', 'Lovejoy']}



 16%|█▋        | 11/67 [01:28<05:05,  5.45s/it]


{'toponyms': ['Kirkland', 'Route 72']}



 18%|█▊        | 12/67 [01:29<03:55,  4.28s/it]


{'toponyms': ['Newberry', 'Alachua County']}



 19%|█▉        | 13/67 [01:31<03:08,  3.49s/it]


{'toponyms': ['Gilchrist', 'Alachua County', 'Alachua']}



 21%|██        | 14/67 [01:34<02:53,  3.28s/it]


{'toponyms': ['Batumi', 'Georgian', 'Georgia', 'Tbilisi']}



 22%|██▏       | 15/67 [01:37<02:52,  3.31s/it]


{'toponyms': ['Ushba', 'Svaneti', 'Caucasus', 'Georgian', 'Latvian', 'Latvians']}



 24%|██▍       | 16/67 [01:39<02:28,  2.91s/it]


{'toponyms': ['Bloomfield', 'BLOOMFIELD']}



 25%|██▌       | 17/67 [01:41<02:07,  2.54s/it]


{'toponyms': ['Huntsville', 'Walker County']}



 27%|██▋       | 18/67 [01:44<02:17,  2.81s/it]


{'toponyms': ['Spencer Street', 'Pahrump', 'Silverado Ranch', 'Eldorado Lane', 'Clark County', 'U.S.', 'Las Vegas', 'Nevada']}



 28%|██▊       | 19/67 [01:46<02:08,  2.67s/it]


{'toponyms': ['Jackson Township', 'Grantville', 'Myerstown', 'Cumberland St.', 'Lebanon']}



 30%|██▉       | 20/67 [01:47<01:41,  2.15s/it]


{'toponyms': ['London']}



 31%|███▏      | 21/67 [01:49<01:29,  1.95s/it]


{'toponyms': ['Los Angeles', 'San Fernando']}



 33%|███▎      | 22/67 [01:50<01:19,  1.77s/it]


{'toponyms': ['Kansas', 'Manhattan']}



 34%|███▍      | 23/67 [01:52<01:14,  1.68s/it]


{'toponyms': ['Chile', 'Miami']}



 36%|███▌      | 24/67 [01:54<01:16,  1.77s/it]


{'toponyms': ['Winfield', 'New London']}



 37%|███▋      | 25/67 [01:56<01:16,  1.81s/it]


{'toponyms': ['Texas', 'Canton', 'Mount Pleasant', 'North East Texas']}



 39%|███▉      | 26/67 [01:57<01:10,  1.73s/it]


{'toponyms': ['Iowa', 'Norfolk', 'Denison']}



 40%|████      | 27/67 [01:59<01:08,  1.70s/it]


{'toponyms': ['Oxford', 'East Oxford', 'Divinity Road']}



 42%|████▏     | 28/67 [02:00<01:06,  1.71s/it]


{'toponyms': ['HAMILTON', 'Butler County']}



 43%|████▎     | 29/67 [02:03<01:13,  1.93s/it]


{'toponyms': ['Ontario', 'Oxford', 'Woodstock']}



 45%|████▍     | 30/67 [02:07<01:35,  2.59s/it]


{'toponyms': ['Tennessee', 'Carroll County', 'Benton County', 'Stewart County', 'Perry County', 'Henry County', 'Weakley County', 'Williamson County']}



 46%|████▋     | 31/67 [02:08<01:19,  2.21s/it]


{'toponyms': ['Phila.', 'Lock Haven']}



 48%|████▊     | 32/67 [02:09<01:03,  1.81s/it]


{'toponyms': ['Virginia']}



 49%|████▉     | 33/67 [02:11<00:58,  1.72s/it]


{'toponyms': ['Berea', 'Candlewood Drive']}



 51%|█████     | 34/67 [02:13<00:57,  1.76s/it]


{'toponyms': ['Kew', 'Richmond', 'Teddington']}



 52%|█████▏    | 35/67 [02:14<00:55,  1.74s/it]


{'toponyms': ['London', 'Richmond', 'Twickenham']}



 54%|█████▎    | 36/67 [02:17<00:59,  1.93s/it]


{'toponyms': ['St. Petersburg', 'Nevsky Prospekt', 'Sochi']}



 55%|█████▌    | 37/67 [02:19<01:02,  2.07s/it]


{'toponyms': ['Sochi', 'Britain', 'London', 'Russia', 'Nice']}



 57%|█████▋    | 38/67 [02:21<01:02,  2.16s/it]


{'toponyms': ['Gilroy', 'U.S.', 'Santa Clara County', 'San Jose']}



 58%|█████▊    | 39/67 [02:25<01:10,  2.51s/it]


{'toponyms': ['Los Altos', 'Santa Clara', 'Sunnyvale', 'Morgan Hill', 'Gilroy', 'San Jose', 'Silicon Valley', 'Santa Clara County']}



 60%|█████▉    | 40/67 [02:27<01:04,  2.37s/it]


{'toponyms': ['San Jose', 'Santa Clara County']}



 61%|██████    | 41/67 [02:30<01:07,  2.60s/it]


{'toponyms': ['Columbus', 'Georgia', 'Alabama', 'Blackmon Road', 'Phenix City', 'Chattahoochee River']}



 63%|██████▎   | 42/67 [02:32<01:00,  2.42s/it]


{'toponyms': ['Chattahoochee', 'Harris County', 'Columbus']}



 64%|██████▍   | 43/67 [02:34<00:52,  2.19s/it]


{'toponyms': ['Syria', 'Lebanon']}



 66%|██████▌   | 44/67 [02:36<00:51,  2.22s/it]


{'toponyms': ['Minnesota', 'Lake Harriet', 'Rochester']}



 67%|██████▋   | 45/67 [02:40<01:00,  2.74s/it]


{'toponyms': ['Golan Heights', 'Iran', 'U.S.', 'Palestine', 'Gaza', 'Israel', 'Syria', 'Lebanon', 'Middle East', 'Syrians', 'United States']}



 69%|██████▊   | 46/67 [02:42<00:50,  2.41s/it]


{'toponyms': ['COLUMBUS', 'Lincoln']}



 70%|███████   | 47/67 [02:47<01:08,  3.44s/it]


{'toponyms': ['Nebraska', 'Lancaster', 'Hayes', 'Sarpy', 'Platte County', 'Washington', 'Lincoln County', 'Dakota', 'Sherman County', 'Otoe', 'Douglas', 'North Platte', 'Madison County', 'Platte']}



 72%|███████▏  | 48/67 [02:50<01:03,  3.35s/it]


{'toponyms': ['Columbus', 'NEBRASKA', 'Wayne State College', 'COLUMBUS']}



 73%|███████▎  | 49/67 [02:52<00:51,  2.86s/it]


{'toponyms': ['Omaha', 'COLUMBUS']}



 75%|███████▍  | 50/67 [02:56<00:53,  3.17s/it]


{'toponyms': ['Ukraine', 'Greece', 'Denmark', 'Austria', 'Europe', 'Slovenia', 'Italy', 'Croatia', 'Russia', 'Black Sea', 'Germany', 'Bulgaria']}



 76%|███████▌  | 51/67 [03:02<01:02,  3.94s/it]


{'toponyms': ['Maliuc', 'Ukraine', 'Bucharest', 'Brăila', 'Russia', 'Ceamurlia de Jos', 'Ciocile', 'Romania', 'Crimea', 'Ukrainian', 'Dudescu', 'Bumbacari', 'Caraorman']}



 78%|███████▊  | 52/67 [03:06<01:00,  4.02s/it]


{'toponyms': ['Kisumu', 'Nyanza', 'Nairobi', 'Kogelo', 'Kenya', 'London', 'Sarit', 'US', 'Britain', 'India']}



 79%|███████▉  | 53/67 [03:08<00:48,  3.50s/it]


{'toponyms': ['Glasgow', 'UK', 'Paisley', 'Aberdeen', 'Scotland']}



 81%|████████  | 54/67 [03:11<00:42,  3.25s/it]


{'toponyms': ['Port-Au-Prince', 'Haiti', 'St. Nicolas', 'St. Marc', 'Artibonite']}



 82%|████████▏ | 55/67 [03:13<00:35,  2.97s/it]


{'toponyms': ['Australia', 'U.K.', 'United Kingdom', 'Canada', 'UK', 'U.S.']}



 84%|████████▎ | 56/67 [03:16<00:32,  2.91s/it]


{'toponyms': ['U.S.', 'Japan', 'Europe', 'Hong Kong', 'Canada', 'United States']}



 85%|████████▌ | 57/67 [03:19<00:29,  2.96s/it]


{'toponyms': ['Burma', 'Cambodia', 'Laos', 'Thailand', 'Vietnam', 'Asia']}



 87%|████████▋ | 58/67 [03:24<00:32,  3.57s/it]


{'toponyms': ['Telemark', 'Oslo', 'Buskerud', 'Aust-Agder', 'Vestfold', 'Hedmark', 'Vest-Agder', 'Rogaland', 'Oppland', 'Østfold', 'Akershus', 'Russia', 'Norway']}



 88%|████████▊ | 59/67 [03:30<00:33,  4.15s/it]


{'toponyms': ['Shwe Kyin', 'Binh Phuoc', 'Cambodia', 'Pingilikani', 'DR Congo', 'Attapeu', 'Vietnam', 'Asia', 'Africa', 'Pailin', 'Europe', 'Laos', 'UK', 'Kenya', 'Myanmar']}



 90%|████████▉ | 60/67 [03:32<00:24,  3.56s/it]


{'toponyms': ['Greenlane', 'New Zealand', 'Auckland']}



 91%|█████████ | 61/67 [03:34<00:19,  3.19s/it]


{'toponyms': ['England', 'Westminster', 'United Kingdom', 'London', 'UK']}



 93%|█████████▎| 62/67 [03:38<00:16,  3.32s/it]


{'toponyms': ['Valsad', 'Teethal', 'Mumbai', 'India', 'Mahim Creek', 'Maharashtra', 'Gujarat', 'Mahim Bay', 'Mithi River']}



 94%|█████████▍| 63/67 [03:42<00:14,  3.56s/it]


{'toponyms': ['Suriname', 'US', 'French Guiana', 'Haiti', 'Europe', 'Americas', 'Canada', 'Middle East', 'Nicaragua', 'Africa', 'Dominican Republic', 'Europeans']}



 96%|█████████▌| 64/67 [03:47<00:12,  4.12s/it]


{'toponyms': ['Cameroon', 'Kinshasa', 'Democratic Republic of the Congo', 'U.S.', 'Léopoldville', 'Australia', 'Denmark', 'France', 'Arizona', 'United States', 'Belgium']}



 97%|█████████▋| 65/67 [04:05<00:16,  8.20s/it]


{'toponyms': ['São Paulo', 'Japan', 'U.S.', 'Minas Gerais', 'New South Wales', 'Spain', 'Toronto', 'New York', 'Tocantins', 'U.S. Centers for Disease Control and Prevention', 'Ho Chi Minh City', 'Victoria', 'Moscow', 'Glasgow', 'Santa Catarina', 'Paisley', 'Istanbul', 'United Kingdom', 'Santiago', 'Canad', 'China', 'England', 'Lebanon', 'United States', 'New Zealand', 'Vietnam', 'Nicaragua', 'American', 'Australia', 'Russia', 'Philippines', 'Iraq', 'Iran', 'Turkey', 'Egypt', 'Bulgaria', 'Scotland', 'Brazil', 'Canada', 'Singapore', 'Vietname', 'US', 'Pensylvania', 'Vietnams', 'California', 'U.S.', 'Australia', 'Vietnam', 'Lebanon', 'U.S.']}



 99%|█████████▊| 66/67 [04:17<00:09,  9.19s/it]


{'toponyms': ['New Zealand', 'Tennessee', 'Delaware', 'Kentucky', 'Cairo', 'Illinois', 'Hong Kong', 'China', 'Kentucky', 'Alabama', 'Nigeria', 'Shanghai', 'New Brunswick', 'United States', 'South Carolina', 'Mexico', 'Fort Worth', 'US', 'British Columbia', 'Houston', 'Texas', 'African', 'Bavaria', 'Canada', 'Egypt', 'Spain', 'German', 'American', 'Alberta', 'North America']}



100%|██████████| 67/67 [04:19<00:00,  3.88s/it]


{'toponyms': ['San Jose', 'Santa Clara County', 'Monterey Highway', 'California', 'Las Vegas', 'Monterey']}






In [55]:
# Get the precision, recall and F1 on each article.
topo_precision = []
topo_recall = []
topo_f1 = []
topo_distance = []

def get_toponym_metrics(true_toponyms:list, pred_toponyms:list)->dict[str,int]:
    true_positives = len([t for t in pred_toponyms if t in true_toponyms])
    false_positives = len([t for t in pred_toponyms if t not in true_toponyms])
    false_negatives = len([t for t in true_toponyms if t not in pred_toponyms])
    return {'TP':true_positives, 'FP':false_positives, 'FN':false_negatives}

def get_accuracy_metrics(topo_metrics: dict[str, int]) -> dict[str, float]:
    tp = topo_metrics['TP']
    fp = topo_metrics['FP']
    fn = topo_metrics['FN']

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0

    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    return {'precision': precision, 'recall': recall, 'F1': f1}

for pred, true in zip(outputs, [x['response'] for x in val_data]):
    pred_toponyms = pred['toponyms']
    true_toponyms = json.loads(true.replace("'", '"'))['toponyms']
   # get true/false positives/negatives
    toponym_metrics = get_toponym_metrics(true_toponyms, pred_toponyms)
  # get prec, rec, f1
    accuracy_metrics = get_accuracy_metrics(toponym_metrics)
    topo_precision.append(accuracy_metrics['precision'])
    topo_recall.append(accuracy_metrics['recall'])
    topo_f1.append(accuracy_metrics['F1'])

In [56]:
# print the results
macro_precision = np.mean(topo_precision)
macro_recall = np.mean(topo_recall)
macro_f1 = np.mean(topo_f1)

print(f'Macro precision: {macro_precision:.3f}')
print(f'Macro recall: {macro_recall:.3f}')
print(f'Macro F1: {macro_f1:.3f}')

Macro precision: 0.892
Macro recall: 0.859
Macro F1: 0.863
