In [1]:
import json
import os 
import pandas as pd
import numpy as np
import ast
import time
import re

In [2]:
pred_aeg = pd.read_csv("./../data/output_aegypti.csv")
pred_albo = pd.read_csv("./data/output_albopictus.csv")

val_aeg = pd.read_csv("./../data/validation_aegypti.csv")
val_albo = pd.read_csv("./../data/validation_albopictus.csv")

In [3]:
pred_aeg.drop_duplicates(inplace=True)
pred_albo.drop_duplicates(inplace=True)

val_aeg.drop_duplicates(inplace=True)
val_albo.drop_duplicates(inplace=True)

In [4]:
pred_aeg.reset_index(drop=True,inplace=True)
pred_albo.reset_index(drop=True,inplace=True)


val_aeg.reset_index(drop=True,inplace=True)
val_albo.reset_index(drop=True,inplace=True)

In [5]:
pred_aeg["source_type"] = pred_aeg["source_type"].astype(int)
pred_aeg["year"] = pred_aeg["year"].astype(int)
pred_aeg["country"] = pred_aeg["country"].astype(str)
pred_aeg["y"] = pred_aeg["y"].astype(float)
pred_aeg["x"] = pred_aeg["x"].astype(float)

pred_albo["source_type"] = pred_albo["source_type"].astype(int)
pred_albo["year"] = pred_albo["year"].astype(int)
pred_albo["country"] = pred_albo["country"].astype(str)
pred_albo["y"] = pred_albo["y"].astype(float)
pred_albo["x"] = pred_albo["x"].astype(float)

In [6]:
val_aeg["source_type"] = val_aeg["source_type"].astype(int)
val_aeg["year"] = val_aeg["year"].astype(int)
val_aeg["country"] = val_aeg["country"].astype(str)
val_aeg["y"] = val_aeg["y"].astype(float)
val_aeg["x"] = val_aeg["x"].astype(float)

val_albo["source_type"] = val_albo["source_type"].astype(int)
val_albo["year"] = val_albo["year"].astype(int)
val_albo["country"] = val_albo["country"].astype(str)
val_albo["y"] = val_albo["y"].astype(float)
val_albo["x"] = val_albo["x"].astype(float)

In [7]:
def compare(df_gt,df_llm):

    # Initialize counters
    TP, FP, FN = 0, 0, 0

    fp_df = pd.DataFrame()
    fn_df = pd.DataFrame()
    list_tp_doc = []
    list_tp_idx = []
    list_fn_doc = []
    list_fn_idx = []
    list_fp_idx = []
    counter = 0
    for doc_id in df_llm["source_type"].unique():
    
        gt_sample = df_gt[df_gt["source_type"]==doc_id]
        llm_sample = df_llm[df_llm["source_type"]==doc_id]
        
    
    # Compare each ground truth row to LLM rows
        for llm_idx, llm_row in llm_sample.iterrows():
            counter += 1
            matched = False
            for gt_idx, gt_row in gt_sample.iterrows():
                if (
                    gt_row["country"] == llm_row["country"]  
                  and ( gt_row["year"] == llm_row["year"] or gt_row["year"]==0)
                  and  np.abs(gt_row["y"]-llm_row["y"])<=0.2 and np.abs(gt_row["x"]-llm_row["x"])<=0.2
                   ):
                    TP += 1
                    matched = True
                    list_tp_doc.append(doc_id)
                    list_tp_idx.append(llm_idx)
                    break
            if not matched:
                FP += 1  # LLM missed this ground truth row
                #list_fn_doc.append(doc_id)
                list_fp_idx.append(llm_idx)


    for doc_id in df_gt["source_type"].unique():
        
        gt_sample = df_gt[df_gt["source_type"]==doc_id]
        llm_sample = df_llm[df_llm["source_type"]==doc_id]            

        for gt_idx, gt_row in gt_sample.iterrows():
            matched = False
            for llm_idx,llm_row in llm_sample.iterrows():
                if (
                    gt_row["country"] == llm_row["country"]  
                  and ( gt_row["year"] == llm_row["year"] or gt_row["year"]==0)
                  and  np.abs(gt_row["y"]-llm_row["y"])<=0.2 and np.abs(gt_row["x"]-llm_row["x"])<=0.2
                   ):
                    matched = True
                    break

            if not matched:
                FN += 1
                list_fn_idx.append(gt_idx)
        
    # FP = LLM rows not matched to any ground truth
    FP = len(df_llm) - TP
    list_fp = list(set(df_llm.index)-set(list_tp_idx))

    fp_df = df_llm.iloc[list_fp_idx]
    fn_df = df_gt.iloc[list_fn_idx]
    
    Precision = TP / (TP + FP)
    Recall = TP / (TP + FN)
    F1 = 2*( (Precision * Recall)/(Precision + Recall))
    
    print(f"counter {counter}")
    print(f"TP: {TP}, FP: {FP}, FN: {FN}")
    print(f"Precision: {Precision:.3f}, Recall: {Recall:.3f}, F1: {F1:.3f}")


    return F1, Precision, Recall, list_fp_idx, list_fn_idx,fp_df,fn_df

In [None]:
F1_aeg,Precision_aeg,Recall_aeg,list_fp_aeg,list_fn_aeg,fp_aeg,fn_aeg = compare(val_aeg,pred_aeg)

In [None]:
F1_albo,Precision_albo,Recall_albo,list_fp_albo,list_fn_albo,fp_albo,fn_albo = compare(val_albo,pred_albo)