In [1]:
%reload_ext autoreload
%autoreload 2
import os
os.chdir('../')
from src.utils import find_text_parts, split_text
from datasets import load_dataset
import pandas as pd
import numpy as np
from evaluator.gpt_evaluator import GPT4Semantic, GPT4Accuracy
import time

[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /home/ubuntu/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /home/ubuntu/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


# Processing the data

In [20]:
def processData(window_size, unit, dataset, model_type, text_key_name, num_key_name):
    output_dir = f"/home/ubuntu/multimodal/Data/{dataset}-GPT4-Evaluation/{model_type}/{window_size}{unit}"
    filename = "processed.csv"
    num_pattern = fr"{unit}_\d+_{num_key_name}: '([\d.]+)'"
    text_pattern =fr'({unit}_\d+_date:\s*\S+\s+{unit}_\d+_{text_key_name}:.*?)(?=\s{unit}_\d+_date|\Z)'
    hf_dataset = f"Howard881010/{dataset}-{window_size}{unit}" + ("-mixed" if model_type.startswith("textTime2textTime") else "") +  "-inContext"

    data_all = load_dataset(hf_dataset)
    data = pd.DataFrame(data_all['test'])

    output_texts = data['output'].apply(lambda x: find_text_parts(x, num_pattern)).apply(lambda x: split_text(x, text_pattern)).to_list()
    pred_texts = data['pred_output'].apply(lambda x: find_text_parts(x, num_pattern)).apply(lambda x: split_text(x, text_pattern)).to_list()
    for idx, pred_text in enumerate(pred_texts):
        if len(pred_text) > window_size:
            pred_texts[idx] = pred_text[:window_size]
        while len(pred_text) < window_size:
            pred_texts[idx].append(None)

    output_texts = np.reshape(output_texts, -1)
    pred_texts = np.reshape(pred_texts, -1)

    results = pd.DataFrame({"output_text": output_texts, "pred_text": pred_texts})

    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, filename)
    results.to_csv(output_path, index=False)

# Wait function

In [21]:
def wait_for_completion(job_id, processor, poll_interval=100):
    status = processor.check_status(job_id)
    while status.status not in ["completed", "failed"]:
        print(f"Current status: {status}. Waiting for {poll_interval} seconds...")
        time.sleep(poll_interval)
        status = processor.check_status(job_id)
    return status.status

# GPT4 semantic & GPT4 accuracy

In [22]:
def calculate_metrics(window_size, unit, dataset, model_type, text_key_name, num_key_name):
    processData(window_size, unit, dataset, model_type, text_key_name, num_key_name)

    gpt4semantic = GPT4Semantic()

    results_dir = f"/home/ubuntu/multimodal/Data/{dataset}-GPT4-Evaluation/{model_type}/{window_size}{unit}"
    data = pd.read_csv(f"/home/ubuntu/multimodal/Data/{dataset}-GPT4-Evaluation/{model_type}/{window_size}{unit}/processed.csv")

    jsonl_path = os.path.join(results_dir, "batch.jsonl")
    semantic_output_path = os.path.join(results_dir, "semantic.txt")

    semantic_batch_object_id = gpt4semantic.create_and_run_batch_job(data, jsonl_path, output_text_column="output_text",
                                    pred_text_column="pred_text")

    job_status = wait_for_completion(semantic_batch_object_id, gpt4semantic)

    if job_status == "completed":
        print("Batch job completed successfully!")
        semantic_outputs = gpt4semantic.check_status_and_parse(semantic_batch_object_id , semantic_output_path)
        semantic_score, count_none = gpt4semantic.calculate_metrics(semantic_outputs)


    gpt4accuracy = GPT4Accuracy()
    accuracy_output_path = os.path.join(results_dir, "accuracy.txt")
    accuracy_batch_object_id = gpt4accuracy.create_and_run_batch_job(data, jsonl_path, output_text_column="output_text",
                                    pred_text_column="pred_text")

    job_status = wait_for_completion(accuracy_batch_object_id, gpt4accuracy)
    if job_status == "completed":
        print("Batch job completed successfully!")
        accuracy_outputs = gpt4accuracy.check_status_and_parse(accuracy_batch_object_id, accuracy_output_path)
        precisions, recalls, f1_scores = gpt4accuracy.calculate_metrics(accuracy_outputs)

    results = {"semantic_score": semantic_score, "count_none": count_none, "precisions": precisions, "recalls": recalls, "f1_scores": f1_scores}
    results = pd.DataFrame.from_dict(results, orient="index").T
    results.to_csv(os.path.join(results_dir, "results.csv"))






In [24]:
unit = "day"
dataset = "climate"
text_key_name = "weather_forecast"
num_key_name = "temp"
model_type = "textTime2textTime-inContext"
for window_size in [6]:
    calculate_metrics(window_size, unit, dataset, model_type, text_key_name, num_key_name)

Downloading readme:   0%|          | 0.00/710 [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/6.10M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/821k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/1.17M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2888 [00:00<?, ? examples/s]

Generating valid split:   0%|          | 0/361 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/362 [00:00<?, ? examples/s]

batch job created with batch_object_id 
 batch_kVyaTz2SrCumFQluM8gLqotu
Current status: Batch(id='batch_kVyaTz2SrCumFQluM8gLqotu', completion_window='24h', created_at=1725034288, endpoint='/v1/chat/completions', input_file_id='file-KUlXTpoMIw0Iu5aPlsnpeSIt', object='batch', status='validating', cancelled_at=None, cancelling_at=None, completed_at=None, error_file_id=None, errors=None, expired_at=None, expires_at=1725120688, failed_at=None, finalizing_at=None, in_progress_at=None, metadata={'description': 'Multimodal Forecasting'}, output_file_id=None, request_counts=BatchRequestCounts(completed=0, failed=0, total=0)). Waiting for 100 seconds...
Current status: Batch(id='batch_kVyaTz2SrCumFQluM8gLqotu', completion_window='24h', created_at=1725034288, endpoint='/v1/chat/completions', input_file_id='file-KUlXTpoMIw0Iu5aPlsnpeSIt', object='batch', status='in_progress', cancelled_at=None, cancelling_at=None, completed_at=None, error_file_id=None, errors=None, expired_at=None, expires_at=1725