In [72]:
from main_workflow import *
import pandas as pd
from tqdm import tqdm
import evaluate
import re

### Reading in data files

In [74]:
bs = pd.read_csv('baseline_responses.csv')
gc = pd.read_csv('golden_copy.csv')

In [None]:
merged_df = pd.merge(bs, gc, on='user_query', how='outer')
df = merged_df[['user_query', 'response', 'golden_response']]

### Functions for Critical Fields Analysis

In [113]:
def normalize_datetime(date_str, time_str):
    for fmt in ["%Y-%m-%d", "%d %b %Y", "%d %B %Y"]:
        try:
            d = datetime.strptime(date_str, fmt)
            return d.strftime("%Y-%m-%d") + " " + time_str
        except:
            pass
    return f"{date_str} {time_str}"

def extract_fields(text):
    pattern_flight = r'''
    (?P<airline>[\w\s,&]+?)\s*[—\-–]+\s*
    (?P<dep_airport>[A-Za-z]{3,})\s+
    (?P<dep_date>\d{4}-\d{2}-\d{2}|\d{1,2}\s+\w+\s+\d{4})\s+
    (?P<dep_time>\d{2}:\d{2})\s*
    →\s*
    (?P<arr_airport>[A-Za-z]{3,})\s+
    (?P<arr_date>\d{4}-\d{2}-\d{2}|\d{1,2}\s+\w+\s+\d{4})\s+
    (?P<arr_time>\d{2}:\d{2})
    '''

    pattern_hotel = r'(?mi)^\*\s*Hotel(?: Name)?:\s*(.+)$'

    # hotel name
    hotel_match = re.search(pattern_hotel, text)
    hotel_name = hotel_match.group(1).strip() if hotel_match else None

    # flights
    flight_matches = list(re.finditer(pattern_flight, text, re.X))

    if len(flight_matches) == 0:
        return {
            'Outbound Airline': None,
            'Departure Datetime': None,
            'Inbound Airline': None,
            'Arrival Datetime': None,
            'Hotel Name': hotel_name
        }

    # outbound flight is first flight found
    first = flight_matches[0].groupdict()

    outbound_airline = first["airline"].strip()
    dep_datetime = normalize_datetime(first["dep_date"], first["dep_time"])

    # inbound flight is second flight (if found)
    if len(flight_matches) > 1:
        second = flight_matches[1].groupdict()
        inbound_airline = second["airline"].strip()
        arr_datetime = normalize_datetime(second["arr_date"], second["arr_time"])
    else:
        inbound_airline = None
        arr_datetime = None

    def airline_set(s):
        if s is None:
            return None
        parts = [a.strip() for a in s.split(",") if a.strip()]
        return set(parts)

    outbound_airline_set = airline_set(outbound_airline)
    inbound_airline_set = airline_set(inbound_airline)

    return {
        'Outbound Airline': outbound_airline_set,
        'Departure Datetime': dep_datetime,
        'Inbound Airline': inbound_airline_set,
        'Arrival Datetime': arr_datetime,
        'Hotel Name': hotel_name
    }

def calculate_score(fields_1, fields_2):
    score = 0
    for field, value in fields_1.items():
        if value == fields_2[field]:
            score += 1
    return score


### Critical Fields Score

In [122]:
total = 0
mismatches = {
    'Outbound Airline': 0,
    'Departure Datetime': 0,
    'Inbound Airline': 0,
    'Arrival Datetime': 0,
    'Hotel Name': 0
}

for _, row in df.iterrows():
    fields1 = extract_fields(row['response'])
    fields2 = extract_fields(row['golden_response'])
    for field, value in fields1.items():
        if value == fields2[field]:
            total += 1
        else:
            mismatches[field] += 1

total/39

2.717948717948718

In [123]:
mismatches

{'Outbound Airline': 13,
 'Departure Datetime': 10,
 'Inbound Airline': 15,
 'Arrival Datetime': 15,
 'Hotel Name': 36}

### BERTScore, METEOR

In [91]:
from evals import ModelEvaluator

In [92]:
evaluator = ModelEvaluator(df)

In [93]:
bertscore_df, bertscore_f1 = evaluator.bertEval()

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [95]:
bertscore_f1

np.float64(0.8550686163780017)

In [96]:
meteor_df, meteor_score = evaluator.meteorEval()

In [98]:
meteor_score

np.float64(0.425310123487645)