In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import os    
import sys

APP_ROOT = os.path.abspath('..')
sys.path.append(os.path.abspath(APP_ROOT))

from chart_reasoning.pipeline_oss import ChartReasoningOSSPipeline


vistext_data_dir = os.path.join(APP_ROOT, 'data', 'vistext-data')
output_dir = os.path.join(APP_ROOT, 'output', 'chart-reasoning-trail-run')

# need to move the outputs to chart-reasoning-trail-run/03-oss-chart-reasoning-output, .e.g, phi-3.5-vision
model_name = 'phi-3.5-vision'
pipeline = ChartReasoningOSSPipeline(vistext_data_dir, output_dir, model_name=model_name)

# ensure the Pipeline arguments are the same with the ones in the chart-reasoning-pipeline.ipynb

In [3]:
# pipeline.task_generation(sample_size=100)

In [4]:
# pipeline.graded_reasoning_output_dir

In [6]:
for i in range(10):  # only grading is needed for open-source models
    pipeline.grade_with_text_agent()

100%|██████████| 15302/15302 [02:32<00:00, 100.15it/s]
100%|██████████| 15302/15302 [02:10<00:00, 116.99it/s] 
100%|██████████| 15302/15302 [02:43<00:00, 93.86it/s]  
100%|██████████| 15302/15302 [03:05<00:00, 82.37it/s]  
100%|██████████| 15302/15302 [02:01<00:00, 126.15it/s] 
100%|██████████| 15302/15302 [01:53<00:00, 135.33it/s] 
100%|██████████| 15302/15302 [02:13<00:00, 114.55it/s] 
100%|██████████| 15302/15302 [02:02<00:00, 125.07it/s] 
100%|██████████| 15302/15302 [02:05<00:00, 121.69it/s] 
100%|██████████| 15302/15302 [01:43<00:00, 148.24it/s] 


In [7]:
pipeline.graded_output_to_md()


  0%|          | 0/8741 [00:00<?, ?it/s]

100%|██████████| 8741/8741 [47:48<00:00,  3.05it/s]  


# Evaluate

In [None]:
from tqdm import tqdm
import json
from copy import deepcopy

# code_graded_dir = os.path.join(output_dir, '06-code-assistant-grading-output')
text_graded_dir = os.path.join(output_dir, '04-oss-text-grading-output', model_name)

# RQ2 exp1
# chart_types = ['line', 'scatter', 'bar']
# RQ2 exp2
# chart_types = ['pie', 'table', 'bar_anno', 'line_anno', 'scatter_anno']

# RQ1 exp1
# chart_types = ['unaligned_rule', 'color', 'size', 'scatter']
# RQ1 exp2
chart_types = ['rule', 'scatter_size', 'bar', 'bar_color']

text_student_answer_correctness_dict_by_chart_type = dict()
text_student_judgement_list = dict()
for chart_type in chart_types:
    text_student_answer_correctness_dict_by_chart_type[chart_type] = dict()
    text_student_judgement_list[chart_type] = []

graded_task_ids = open("reported_ids_1000.txt", "r").readlines()
graded_task_ids = [one_id.strip() for one_id in graded_task_ids]

valid_task_ids = list(set(graded_task_ids))
for task_id in tqdm(valid_task_ids):
    # text_graded_task_file = os.path.join(text_graded_dir, f'{task_id}.grade.json')
    for chart_type in chart_types:
        text_graded_task_file = os.path.join(text_graded_dir, f'{task_id}.{chart_type}.grade.json')

        if os.path.exists(text_graded_task_file):
            with open(text_graded_task_file, 'r') as f:
                text_graded_task = json.load(f)
                # print("len(text_graded_task): ", len(text_graded_task))
                for question_index, question_dict in enumerate(text_graded_task):
                    # print(question_dict)
                    # break
                    try:
                        text_student_judgement_list[chart_type].append(question_dict['student_answer_correctness'].lower())
                    except:
                        question_dict['student_answer_correctness'] = 'skipped'
                    if question_dict['student_answer_correctness'] not in text_student_answer_correctness_dict_by_chart_type[chart_type].keys():
                        text_student_answer_correctness_dict_by_chart_type[chart_type][question_dict['student_answer_correctness'].lower()] = []
                    else:
                        text_student_answer_correctness_dict_by_chart_type[chart_type][question_dict['student_answer_correctness'].lower()].append((question_dict['task_id']+"_"+str(question_index), question_dict['task_type']))
        else:
            print(f"File not found: {text_graded_task_file}")
            continue

In [None]:
# sort valid_task_ids
valid_task_ids = sorted([int(task_id) for task_id in valid_task_ids])
len(valid_task_ids)

In [None]:
correct_list = []
for chart_type in chart_types:
    print(f"Chart Type: {chart_type}")
    print("Correctness Distribution:")
    total = sum([len(id_list) for id_list in text_student_answer_correctness_dict_by_chart_type[chart_type].values()])
    for key, id_list in text_student_answer_correctness_dict_by_chart_type[chart_type].items():
        print(key, ":", len(id_list), f"({round(len(id_list)/total * 100, 2)}%)")
        if key == 'correct':
            correct_list.append(round(len(id_list)/total * 100, 2))
    print("\n\n")

print("Correctness:", correct_list)

# total = sum([len(id_list) for id_list in text_student_answer_correctness.values()])
# # get the distribution of correctness
# for key, id_list in text_student_answer_correctness.items():
#     print(key, ":", len(id_list), f"({round(len(id_list)/total * 100, 2)}%)")

In [None]:
# analysis in terms of task type

task_types = ['Find Anomalies', 'Find Correlation', 'Determine Range', 'Order', 'Filter', 'Compute Derived Value', 'Find Extremum', 'Retrieve Value', 'Find Clusters', 'Characterize Distribution']


for task_type in task_types:
    print(f"Task Type: {task_type}")
    # only consider the correctness of the task type
    for chart_type in chart_types:
        # print(f"Chart Type: {chart_type}")
        # get all examples for this task type for this chart and calculate the correctness
        # total = len(text_student_answer_correctness_dict_by_chart_type[chart_type][task_type])
        correct_cnt = 0
        for example in text_student_answer_correctness_dict_by_chart_type[chart_type]['correct']:
            if example[1] == task_type:
                correct_cnt += 1
        # count all
        all_example_cnt = 0
        for score_type, example_list in text_student_answer_correctness_dict_by_chart_type[chart_type].items():
            for example in example_list:
                if example[1] == task_type:
                    all_example_cnt += 1

        # print("Correctness:", correct_cnt, "All examples:", all_example_cnt, f"({round(correct_cnt/all_example_cnt * 100, 2)}%)")
        print(f"{round(correct_cnt/all_example_cnt * 100, 2)}", end=';')
    print("\n\n")