In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [None]:
plt.style.use('fivethirtyeight')

In [None]:
mmlu_base = pd.read_csv("mmlu_base.csv")
mmlu_finetuned = pd.read_csv("mmlu_finetuned.csv")

In [None]:
base_correct = {}
subject_totals = {}
for i in range(len(mmlu_base)):
    subject = mmlu_base["subject"][i]
    subject_totals[subject] = subject_totals.get(subject, 0) + 1
    if mmlu_base.loc[i, 'base_answer'] == mmlu_base.loc[i, 'answer']:
        base_correct[subject] = base_correct.get(subject, 0) + 1

subjects = sorted(list(subject_totals.keys()))
base_scores = [base_correct[s] / subject_totals[s] for s in subjects]
base_scores = [s * 100 for s in base_scores]
base_correct = [base_correct[s] for s in subjects]
plt.bar(range(len(subjects)), base_scores, width=0.6, color="blue")
plt.title("Base Model MMLU Scores")
plt.ylabel("Percentage Correct")
plt.xlabel("Subjects")
plt.xticks([])
plt.ylim(0, 100)
plt.savefig("mmlu_base.png", bbox_inches='tight', dpi=500)

In [None]:
finetuned_correct = {}
subject_totals = {}
for i in range(len(mmlu_finetuned)):
    subject = mmlu_finetuned["subject"][i]
    subject_totals[subject] = subject_totals.get(subject, 0) + 1
    if mmlu_finetuned.loc[i, 'finetuned_answer'] == mmlu_finetuned.loc[i, 'answer']:
        finetuned_correct[subject] = finetuned_correct.get(subject, 0) + 1

subjects = sorted(list(subject_totals.keys()))
finetuned_scores = [finetuned_correct[s] / subject_totals[s] for s in subjects]
finetuned_scores = [s * 100 for s in finetuned_scores]
finetuned_correct = [finetuned_correct[s] for s in subjects]
plt.bar(range(len(subjects)), finetuned_scores, width=0.6, color='red')
plt.title("CogCotroLM MMLU Scores")
plt.ylabel("Percentage Correct")
plt.xlabel("Subjects")
plt.xticks([])
plt.ylim(0, 100)
plt.savefig("mmlu_ft.png", bbox_inches='tight', dpi=500)

In [None]:
differences = np.subtract(finetuned_correct, base_correct)
width = 0.4
x = np.arange(len(subjects))
plt.bar(x - width / 2, differences, width=width, label='CogControLM - Base', color='purple')
plt.bar(x + width / 2, [subject_totals[s] for s in subjects], label='Total', width=width, color='grey')
plt.title("MMLU Correct Answer Differences")
plt.ylabel("# of Questions")
plt.xlabel("Subjects")
plt.xticks([])
plt.legend()
plt.savefig("mmlu_diff.png", bbox_inches='tight', dpi=500)