# Knowledge Classification Result Check

In [None]:
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
# from matplotlib import cm

from tqdm import tqdm

from sources.utils import *
from sources.process import *

## Load model results

In [None]:
# model_name_list = ["Llama2_7B", "Llama3_8B", "Llama3.1_8B", "Mistral7B", "Phi3.5_Mini", "SOLAR_10.7B", "Gemma_7B", "Gemma2_9B"]
model_name_list = ["Llama3.1_8B", "Mistral7B", "Phi3.5_Mini", "SOLAR_10.7B", "Gemma2_9B", 'gpt-4o-mini']
# model_name_list = ['gpt-4o-mini']

official_model_name = {
    "Llama3.1_8B": "Llama-3.1-8B-Instruct", 
    "Mistral7B": "Mistral-7B-Instruct-v0.3",
    "Phi3.5_Mini": "Phi-3.5-mini-instruct",
    "SOLAR_10.7B": "SOLAR-10.7B-Instruct-v1.0",
    "Gemma2_9B": "gemma-2-9b-it",
    "gpt-4o-mini": "GPT-4o mini",
}

official_model_color = {
    "Llama3.1_8B": "blue",
    "Mistral7B": "orange",
    "Phi3.5_Mini": "red",
    "SOLAR_10.7B": "purple",
    "Gemma2_9B": "green",
    "gpt-4o-mini": "black",
}


official_model_checkpoint = {
    "Llama3.1_8B": 2023, # https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct
    "Mistral7B": 2024, # https://aimlapi.com/models/mistral-7b-instruct-v0-3 (unofficial)
    "Phi3.5_Mini": 2023, # https://huggingface.co/microsoft/Phi-3.5-mini-instruct
    "SOLAR_10.7B": "purple", 
    "Gemma2_9B": "green", 
    "gpt-4o-mini": 2023, # https://openai.com/index/gpt-4o-mini-advancing-cost-efficient-intelligence/
}


In [None]:
def get_results_list(model_name_list, domain, mode):
    results_list_temp_state = []

    for t_state in ["Dynamic", "Static"]:
        results_list = []

        for model_name in tqdm(model_name_list):
            bench, temp0_parsed_time, temp7_parsed_time = load_result(
                                                                    model_name=model_name,
                                                                    domain=domain,
                                                                    temp_state=t_state,   
                                                                    mode=mode
                                                                    )

            results, object_classification = classify_results_time(bench, temp0_parsed_time, temp7_parsed_time, model_name, domain, mode)
            fine_grained_results, classification_indices = classify_results_time_fine_graining(bench, temp0_parsed_time, temp7_parsed_time, model_name, domain, mode)
            
            results_list.append(results)

            sampling_results(fine_grained_results, classification_indices, model_name, domain, t_state)

        results_list_temp_state.append(results_list)

    return results_list_temp_state

def get_results_list_invariant(model_name_list, domain, mode):

    results_list = []

    t_state = ""

    for model_name in tqdm(model_name_list):
        bench, temp0_parsed_time, temp7_parsed_time = load_result(
                                                                model_name=model_name,
                                                                domain=domain,
                                                                temp_state=t_state, 
                                                                mode=mode, 
                                                                    )

        
        results, object_classification = classify_results_time(bench, temp0_parsed_time, temp7_parsed_time, model_name, domain, mode)
        results_list.append(results)

    return results_list

## Making Plot with Time Variant Results

In [None]:
# domains: General, Biomedical, Legal, CommonSense, Math
domain = "General"

mode ='generation'

results_list_generation = get_results_list(model_name_list=model_name_list,
                 domain=domain,
                 mode=mode)

mode ='QA'

results_list_mcqa = get_results_list(model_name_list=model_name_list,
                 domain=domain,
                 mode=mode)

mode ='TF'

results_list_tf = get_results_list(model_name_list=model_name_list,
                 domain=domain,
                 mode=mode)

In [None]:
def plot_histogram_time_model_total(model_name_list, results_list_change, bench_name, mode, changed_title=True):

    years = sorted(results_list_change[0][0].keys())
    # extract the year part
    year_labels = [int(year.split('_')[-1]) for year in years]

    categories = ['correct', 'partial_correct1', 'partial_correct2', 'incorrect']
    colors = ['#3C6495', '#7E9BC7', '#B5C3DD', '#FDE58E']

    # Create the plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 4.5))  # 두 개의 그래프를 한 줄에 배치

    for ax, results_list in zip([ax1, ax2], results_list_change):

        max_val = 0
        min_val = 100
        for model_name, results in zip(model_name_list, results_list):

            data = []
            for year in years:
                counts = [results[year][category] for category in categories]
                total = sum(counts)

                knowns = 0
                for category in categories:
                    if "correct" in category:
                        knowns += results[year][category]

                knowns_percentage = knowns / total * 100
                data.append(knowns_percentage)

                if knowns_percentage > max_val:
                    max_val = knowns_percentage
                if knowns_percentage < min_val:
                    min_val = knowns_percentage

            # pointplot
            ax.plot(year_labels, data, linestyle='-', linewidth=5,
                    label=official_model_name[model_name] if ax == ax2 else "",
                    color=official_model_color[model_name])

        min_space = 35
        if max_val - min_val <= min_space:
            stride = min_space - (max_val-min_val)
            min_val = int(min_val - stride / 2)
            max_val = int(max_val + stride / 2)

        min_val = 3 if min_val < 3 else min_val

        ax.set_xticks(year_labels)
        ax.set_xticklabels(year_labels, fontsize=15)

        ax.set_ylim(min_val - 3, max_val + 3)
        ax.set_yticklabels(ax.get_yticks(), fontsize=15)
        ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{int(x)}'))
        ax.grid(True)
        ax.margins(x=0.03)
        ax.grid(False, axis='x')
        ax.patch.set_facecolor('#f1f1f0')

    if not changed_title:
        # legend_properties = {'weight':'bold', 'size': 16}
        legend_properties = {'size': 16}
        fig.legend(loc='lower center', ncol=6, prop=legend_properties, bbox_to_anchor=(0.5,-0.1))

    plt.tight_layout()
    plt.savefig(f'./Results/known_percentage_{bench_name}_{mode}.png', bbox_inches='tight')
    plt.show()

In [None]:
plot_histogram_time_model_total(model_name_list, results_list_generation, "General", 'Generation')
plot_histogram_time_model_total(model_name_list, results_list_mcqa, "General", 'Multi-choice QA')
plot_histogram_time_model_total(model_name_list, results_list_mcqa, "General", 'True/False', changed_title=False)

## Making Plots with Time Invariant Results

In [None]:
domain = "CommonSense"

mode = "generation"
results_list_generation_cs = get_results_list_invariant(model_name_list=model_name_list,
                 domain=domain,
                 mode=mode)

mode = 'QA'
results_list_mcqa_cs = get_results_list_invariant(model_name_list=model_name_list,
                 domain=domain,
                 mode=mode)

mode = 'TF'
results_list_tf_cs = get_results_list_invariant(model_name_list=model_name_list,
                 domain=domain,
                 mode=mode)

In [None]:
def plot_histogram_time_model_invariant(model_name_list, results_list_generation, results_list_mcqa, results_list_tf, bench_name, changed_title=True):

    years = sorted(results_list_generation[0].keys())
    # extract the year part
    year_labels = [int(year.split('_')[-1]) for year in years]

    categories = ['highly_known', 'maybe_known', 'weakly_known', 'unknown']
    colors = ['#3C6495', '#7E9BC7', '#B5C3DD', '#FDE58E']

    # Create the plot
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(22, 4.5))  # 두 개의 그래프를 한 줄에 배치


    for ax, results_list in zip([ax1, ax2, ax3], [results_list_generation, results_list_mcqa, results_list_tf]):
        
        max_val = 0
        min_val = 100
        for model_name, results in zip(model_name_list, results_list):

            data = []
            for year in years:
                counts = [results[year][category] for category in categories]
                total = sum(counts)

                knowns = 0
                for category in categories:
                    if "highly_known" in category:
                        knowns += results[year][category]

                knowns_percentage = knowns / total * 100
                data.append(knowns_percentage)

                if knowns_percentage > max_val:
                    max_val = knowns_percentage
                if knowns_percentage < min_val:
                    min_val = knowns_percentage

            # pointplot
            ax.plot(year_labels, data, linestyle='-', linewidth=3,
                    label=official_model_name[model_name] if ax == ax1 else "",
                    color=official_model_color[model_name],
                    alpha=1)

            min_space = 35
            if max_val - min_val <= min_space:
                stride = min_space - (max_val-min_val)
                min_val = int(min_val - stride / 2)
                max_val = int(max_val + stride / 2)

            ax.set_xticks(year_labels)
            ax.set_xticklabels(year_labels, fontsize=15)

            min_val = 3 if min_val < 3 else min_val

            ax.set_ylim(min_val - 3, max_val + 3)
            ax.set_yticklabels(ax.get_yticks(), fontsize=15)
            ax.yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f'{int(x)}'))
            ax.grid(True)
            ax.margins(x=0.03)
            ax.grid(False, axis='x')
            ax.patch.set_facecolor('#f1f1f0')

        fig.legend(loc='lower center', ncol=6, fontsize=16, bbox_to_anchor=(0.5,-0.1))

    
    plt.tight_layout()
    plt.savefig(f'./Results/known_percentage_{bench_name}.pdf', bbox_inches='tight')
    plt.show()

In [None]:
plot_histogram_time_model_invariant(model_name_list, results_list_generation_cs, results_list_mcqa_cs, results_list_tf_cs, "CommonSense")