In [2]:
import yaml
import pandas as pd
import json
from word2number import w2n

# Evaluation

In [3]:

class EventExtractionEvaluation:
    def __init__(self, df):
        self.df = df
        self.synonyms = json.load(open("/home/ubuntu/devesh/Prod_changes/idsp_feb_5th/idsp-score/data/05_model_input/disease_synonyms.json"))
        
    def group_diseases(self, events):
        new_events = []
        for event in events:
            disease = event["disease"]
            for key in self.synonyms.keys():
                synonyms_for_disease = [x.lower() for x in self.synonyms[key]]
                if disease.lower().lstrip().rstrip() in synonyms_for_disease:
                    event["disease"] = key.lower()
                    event["original_disease"] = disease
                    break  
            new_events.append(event)
        return new_events

    def precision_recall_method_1(self, pred, gt):
        

        gt = set(tuple(sorted(d.items())) for d in gt)
        pred = set(tuple(sorted(d.items())) for d in pred)
        
        tp = len(pred.intersection(gt))
        fp = len(pred.difference(gt))
        fn = len(gt.difference(pred))

        precision = tp / (tp + fp) if tp+fp > 0 else 1.0
        recall = tp / (tp + fn) if (tp+fn) > 0 else 1.0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 1.0
        exact_match = 1.0 if pred == gt else 0.0
        
        return precision, recall, f1, exact_match
        
    def precision_recall_method_2(self, pred, gt):
         
        # strip spaces from the values
        # for d in gt:
        #     for k, v in d.items():
        #         d[k] = v.strip()
        # for d in pred:
        #     for k, v in d.items():
        #         d[k] = v.strip()
        try:
            pred = set(tuple(sorted(d.items())) for d in pred)    
        except:
            print('pred: ', pred) 
            pred = set([])
            
        gt = set(tuple(sorted(d.items())) for d in gt)
        tp = len(pred.intersection(gt))
        fp = len(pred.difference(gt))
        fn = len(gt.difference(pred))
        
        
            
        if tp == 0 and fp == 0 and fn == 0:
            precision = 1.0; recall = 1.0; f1 = 1.0
        elif tp == 0 and (fp > 0 or fn > 0):
            precision = 0.0; recall = 0.0; f1 = 0.0
        else:
            precision = tp / (tp + fp)
            recall = tp / (tp + fn)
            f1 = 2 * (precision * recall) / (precision + recall)
        exact_match = 1.0 if pred == gt else 0.0
        
        # write incorrect predictions to a file
        ptr = open('incorrect_predictions_gpt3.5.txt', 'a')
        if exact_match == 0:
            ptr.write('GT: ' + str(gt) + '\n')
            ptr.write('Pred: ' + str(pred) + '\n\n')
            
        
        return precision, recall, f1, exact_match


    def jaccard_index(self, pred, gt):
        
        try:
            pred = set(tuple(sorted(d.items())) for d in pred)    
        except:
            pred = set([])
            
        gt = set(tuple(sorted(d.items())) for d in gt)
        if len(pred) == 0 and len(gt) == 0: return 1
        intersection = len(pred.intersection(gt))
        union = len(pred.union(gt))
        return intersection / union
    

    def subset_accuracy(self, pred, gt):
        
        try:
            pred = set(tuple(sorted(d.items())) for d in pred)    
        except:
            print('pred: ', pred) 
            pred = set([])
            
        gt = set(tuple(sorted(d.items())) for d in gt)
        return float(pred.issubset(gt))

    def evaluate_event_extraction(self):
        # Group diseases using synonyms list
        self.df['GT_Events'] = self.df['GT_Events'].apply(lambda events: self.group_diseases(events))
        self.df['events'] = self.df['events'].apply(lambda events: self.group_diseases(events))
        
        # if 'number' not present as a key in the GT_events or events make 'number': ''
        self.df['GT_Events'] = self.df['GT_Events'].apply(lambda x: [{k: v for k, v in i.items()} if 'number' in i.keys() else {**i, 'number': ''} for i in x])
        self.df['events'] = self.df['events'].apply(lambda x: [{k: v for k, v in i.items()} if 'number' in i.keys() else {**i, 'number': ''} for i in x])
        
        
        def keys_to_keep(d):
            return {k: v for k, v in d.items() if k in ['disease', 'location', 'incident', 'incident_type', 'number']}
        self.df['GT_Events'] = self.df['GT_Events'].apply(lambda x: [keys_to_keep(i) for i in x])
        self.df['events'] = self.df['events'].apply(lambda x: [keys_to_keep(i) for i in x])


        
        # apply precision_recall_method_2 to each row
        self.df['precision'], self.df['recall'], self.df['f1'], self.df['exact_match'] = zip(*self.df.apply(lambda row: self.precision_recall_method_2(row['events'], row['GT_Events']), axis=1))
        self.df['jaccard'] = self.df.apply(lambda row: self.jaccard_index(row['events'], row['GT_Events']), axis=1)
        self.df['subset'] = self.df.apply(lambda row: self.subset_accuracy(row['events'], row['GT_Events']), axis=1)
        
        metrics = {
            "precision": self.df['precision'].mean(),
            "recall": self.df['recall'].mean(),
            "f1": self.df['f1'].mean(),
            "exact_match": self.df['exact_match'].mean(),
            "jaccard": self.df['jaccard'].mean(),
            "subset": self.df['subset'].mean()
        }
        
        return metrics

In [4]:

class LLMEventFilterer:
    def __init__(self, disease_syn_path):
        self.disease_synonyms = json.load(open(disease_syn_path))
        # Define the keys to filter
        self. target_keys = {'Disease', 'Location', 'Incident (case or death)', 'Incident Type (new or total)', 'Number'}

    def lower_case_keys(self, d):
        return {k.lower(): v.lower() if isinstance(v, str) else v for k, v in d.items()}
    
    # Function to filter rows based on keys
    def filter_rows(self, row, target_keys):
        target_keys = sorted(target_keys)
        try:
            events = []
            for event in row:
                keys = sorted(event.keys())
                if keys == target_keys:
                    events.append(event)
                else:
                    continue
            return events
        except Exception as e:
            print(f"Error processing row: {e}")
            return []

    def map_keys(self, events):
        new_events = []
        for event in events:
            new_event = {
                "Incident Type (new or total)": event.get('Incident Type (total or new)', event.get('Incident Type (new or total)', event.get('Incident Type (total)', event.get('Incident (new)', event.get('Incident', ''))))),
                "Incident (case or death)": event.get('Incident (cases or deaths)', event.get("Incident Type (case or death)", event.get('Incident (case or death)', ''))),
                "Location": event.get('Location', ''),
                "Disease": event.get('Disease', '')
            }
            if 'Number' in event.keys():
                new_event["Number"] = event["Number"]
            new_events.append(new_event)
        return new_events 
    
    def group_diseases(self, events):
        try:
            new_events = []
            for event in events:
                disease = event["disease"]
                for key in self.disease_synonyms.keys():
                    synonyms_for_disease = [x.lower() for x in self.disease_synonyms[key]]
                    if disease.lower().lstrip().rstrip() in synonyms_for_disease:
                        event["disease"] = key
                        event["original_disease"] = disease  # storing original disease to be used after post-processing
                if "original_disease" not in event.keys():  # keep the original string.
                    event["original_disease"] = ""
                
                new_events.append(event)
                
            return new_events
        except:
            return events
    
    def filter_events(self, response):
        try: 
            if response:
                start = response.find("[{")
                end = response.find("}]")
                if start != -1 and end != -1:
                    events = response[start : end+2]
                    return events
                else:
                    return '[]'
            else:
                return '[]'
        except:
            return '[]'
        

    def safe_eval(self, x):
        try:
            return eval(x)
        except SyntaxError:
            return []  # Return None for rows with invalid syntax

            
    def filter_number(self, events):
        new_events = []
       # conver string to number
        try:
            for event in events:
                if event["Number"]:
                    event["Number"] = str(w2n.word_to_num(event["Number"]))
                else:
                    event["Number"] = ''
                new_events.append(event)
            return new_events
        except Exception as e:
            return events
        
    def filter_unknown_na(self, events_str):
        # Replacing 'N/A', 'null', 'Unknown', 'unknown' with empty string
        events_str = events_str.replace('N/A', '').replace('null', '').replace('Unknown', '').replace('unknown', '')
        return events_str

In [6]:
disease_syn_path = "/home/ubuntu/devesh/Prod_changes/idsp_feb_5th/idsp-score/data/05_model_input/disease_synonyms.json"
event_filterer = LLMEventFilterer(disease_syn_path)

numbered_df = pd.read_csv('./output_llama_finetuned_numbered_lora.csv')
numberless_df = pd.read_csv('./llama_empty_numberless_lora.csv')

# remove irrelevant rows
irrelevant_df = pd.read_csv('articles_irrelevant.csv')
irrelevant_articles = irrelevant_df['Article'].tolist()
numbered_df = numbered_df[~numbered_df['Article'].isin(irrelevant_articles)]
numberless_df = numberless_df[~numberless_df['Article'].isin(irrelevant_articles)]


def preprocess_data(df, numberless=False):
    df['events'] = df['Generated_Events']
    if numberless:
        df['events'] = df['Generated_Events']

    df['events'] = df['events'].fillna('')
    df['events'] = df['events'].apply(event_filterer.filter_events)
    df['events'] = df['events'].apply(event_filterer.filter_unknown_na)
    df['events'] = df['events'].apply(event_filterer.safe_eval)
    df['events'] = df['events'].apply(event_filterer.map_keys)

    if numberless is False:
        df['events'] = df['events'].apply(lambda x: event_filterer.filter_rows(x, {'Disease', 'Location', 'Incident (case or death)', 'Incident Type (new or total)', 'Number'}))
    else:
        df['events'] = df['events'].apply(lambda x: event_filterer.filter_rows(x, {'Disease', 'Location', 'Incident (case or death)', 'Incident Type (new or total)'}))

    df['events'] = df['events'].apply(event_filterer.filter_number)
    df['GT_Events'] = df['GT_Events'].apply(event_filterer.safe_eval)

    # convert keys to lower case 
    df['events'] = df['events'].apply(lambda x: [event_filterer.lower_case_keys(i) for i in x])
    df['GT_Events'] = df['GT_Events'].apply(lambda x: [event_filterer.lower_case_keys(i) for i in x])

    print('Total no of non empty events:', df['events'].apply(len).sum())
    print('Total no of non empty GT events:', df['GT_Events'].apply(len).sum())


    return df

numbered_df = preprocess_data(numbered_df)
numberless_df = preprocess_data(numberless_df, numberless=True)
# articles for which numbered events are empty, check if they are present in numberless_df
empty_events = numbered_df[numbered_df['events'].apply(len) == 0]
empty_articles = empty_events['Article'].tolist()
# write empty_events df to csv
empty_events.to_csv('empty_events-llama_numbered_lora.csv', index=False)

numberless_df = numberless_df[numberless_df['Article'].isin(empty_articles)]
def update_events(row, numberless_df):
    if len(row['events']) == 0:
        filtered_df = numberless_df[numberless_df['Article'] == row['Article']]
        if not filtered_df.empty:
            return filtered_df['events'].values[0]
    return row['events']

numbered_df['events'] = numbered_df.apply(lambda row: update_events(row, numberless_df), axis=1)

# write filtered events to csv as 'Filtered_llama_finetuned.csv'


thresholds = [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
qa_only_results = {}
df_copy = numbered_df.copy()
event_eval = EventExtractionEvaluation(df_copy.copy())
metrics = event_eval.evaluate_event_extraction()

# keep only relevant columns: Article, GT_Events, events

numbered_df.to_csv('Filtered_llama_finetuned.csv', index=False)


metrics

Total no of non empty events: 684
Total no of non empty GT events: 849
Total no of non empty events: 10
Total no of non empty GT events: 133


{'precision': 0.5014965986394558,
 'recall': 0.5044444444444445,
 'f1': 0.4977196650666038,
 'exact_match': 0.4340136054421769,
 'jaccard': 0.4796016039893591,
 'subset': 0.5401360544217687}