In [1]:
import os
import sys
import re

sys.path.append('..')
sys.path.append('.')
from concurrent.futures import wait

from eval_fully_contrastive import *
from sklearn.metrics import confusion_matrix
import numpy as np
from few_shot import *
from generative_method import *
import matplotlib.pyplot as plt
from card import GenerativeCard

## Setting up

In [2]:
card_format = 'dict'

initial_criteria = []
baseline = 'card'

model = 'Llama-2-13b-chat-hf'
# model = 'Mixtral-8x7B-Instruct-v0.1'
topic = 'high_school_chemistry'
folder_path = '03-14_21-59-48_Llama-2-13b-chat-hf_main'
if card_format == 'dict':
    initial_criteria = get_initial_criteria(topic)

assert model in folder_path

In [3]:
cm = CostManager()
exp = 'web_eval' + topic

hp = {
            # general
            'experiment'      : exp,
            'name'            : f'{model}_main',
            'method'          : 'generative',
            'load_from'       : None,
            # dataset
            'dataset_folder'  : 'datasets/mmlu',
            'topic'           : topic,
            'shuffle'         : False,
            'seed'            : 311,
            'batch_nums'      : [8] * 5 + [60],
            'model'           : model,
            # training
            'epoch'           : None,  # None: use the number of training batches as epoch
            'use_refine'      : False,
            'initial_criteria': initial_criteria,
            'CoT'             : False,
            'card_format'     : card_format,
            # eval
            'async_eval'      : True,
        }
m =  GenerativeMethod(hp, cm)

In [4]:
print(m.model)

In [5]:
no_to_letter = {0: 'A', 1: 'B', 2: 'C', 3: 'D'} #  0-3: A-D, ALL OTHERS: N

def format_single_entry(entry):
    q, choice, gt, pred, completion = entry

    choice_str = '\n'.join([f'{no_to_letter[i]}: {c}' for i, c in enumerate(choice)])
    gt_str = no_to_letter[gt] if gt in no_to_letter else 'N'

    return f'{q}\nChoices:\n{choice_str}\nCorrect answer: {gt_str}'

In [6]:
def format_few_shot_entry(entry):
    q, choice, gt, pred, completion = entry

    choice_str = '\n'.join([f'{no_to_letter[i]}: {c}' for i, c in enumerate(choice)])
    gt_str = no_to_letter[gt] if gt in no_to_letter else 'N'
    student_correctness = gt == pred

    return f'{q}\nChoices:\n{choice_str}\nStudent Completion: {completion}\nGround Truth Answer: {gt_str}\nStudents correctness: {student_correctness}'

In [7]:
def format_full_eval_str(method):
    str_all = ''
    cnt = 1
    for entry in method.testing_batch:
        entry_formatted = (format_single_entry(entry))
        str_all += f'Question {cnt}. {entry_formatted}\n\n'
        cnt += 1
    
    return str_all


In [8]:
def format_few_shot_string(method):
    str_all = ''
    cnt = 1
    for i in range(len(method.training_batches)):
        for entry in method.training_batches[i]:
            entry_formatted = (format_few_shot_entry(entry))
            str_all += f'Few shot sample question {cnt}. {entry_formatted}\n\n'
            cnt += 1
    
    return str_all


In [9]:
def get_training_accuracy(method):
    correct_cnt = 0
    tot_cnt = 0
    for i in range(len(method.training_batches)):
        for entry in method.training_batches[i]:
            q, choices, gt, pred, completion = entry
            if gt == pred:
                correct_cnt += 1
            tot_cnt += 1
    
    # 2 Decimal 
    return round(correct_cnt / tot_cnt, 2)

In [10]:
if card_format != 'str':
    card_path = f'./outputs/generative_{topic}/{folder_path}/cards/epoch_4_card.json'
else:
    card_path = f'./outputs/generative_{card_format}_{topic}/{folder_path}/cards/epoch_4_card.json'

In [11]:
card_path = f'./outputs/generative_{card_format}_{topic}/{folder_path}/cards/epoch_4_card.json'

In [12]:
with open(card_path) as f:
    card_dict = json.load(f)

In [13]:
if card_format == 'dict':
    card = GenerativeCard(d=card_dict)
elif card_format == 'bullet_point':
    card = '\n -'.join([card_dict['card'][str(i)] for i in range(len(card_dict['card']))])
else:
    card = card_dict['card']

In [14]:
if baseline == 'few_shot':
    few_shot_str = format_few_shot_string(m)
    card = few_shot_str

In [15]:
str(card)[:50]

In [16]:
card = str(card)
# card += f"\n# LLM Student's Past Performance:\nThe LLM student's accuracy on previous questions under this topic: {get_training_accuracy(m)}.\n"

In [17]:
get_training_accuracy(m)

In [18]:
str_all = format_full_eval_str(m)

In [19]:
with open('prompts/eval/predictive/system_mmlu.txt', 'r') as f:
   sys_prompt = f.read() 

with open('prompts/eval/predictive/user.txt', 'r') as f:
   user_prompt = f.read() 

In [20]:
sys_prompt

In [21]:
user_prompt = """
Predict if the Large Language Model (LLM) student will answer The Question correctly based on the Student Evaluation. You will do this by reasoning step-by-step:

# Instruction:

1. Analyze how the student may answer the questions based on the criteria.
2. Based on your analysis, predict if the student will answer The Question correctly (true) or not (false).
3. Note that the student is an LLM, so their performance would not change over time.

Requirement: Don't make any assumptions about the student. Your prediction should be solely grounded on the Student Evaluation below.

## The Questions

{qa}

## Student Evaluation

{card}
# Formatting:
You should do one single aggregated analysis, and then output verdict one by one. 
Your task is not to establish expectations for the LLM student but to predict the student's performance based on the evaluation. So you should be objective and unbiased.

Provide an analysis using the format:
# Aggregated analysis: <student's capability>
...

Make predictions one by one, using the format:
Question index: Verdict

<1-60>: <T/F>
"""

if baseline == 'few_shot':
    user_prompt += "Note that questions under student evaluation are examples for you to check, the first 60 questions are what you are going to predict!"

In [22]:
formatted_str = f"""{sys_prompt.format(topic=topic)}\n\n{user_prompt.format(qa=str_all, card=str(card))}"""

In [23]:
if baseline != 'few_shot':
    with open('Web/eval_input.txt', 'w') as f:
        f.write(formatted_str)  
else:
    with open('Web/eval_few_shot.txt', 'w') as f:
        f.write(formatted_str)  

In [24]:
ground_truths = [entry[2] == entry[3] for entry in m.testing_batch]

gt_str = ''
cnt = 1
for g in ground_truths:
    gt_str += f'{cnt}. {g}\n'
    cnt += 1

with open('Web/eval_gt.txt', 'w') as f:
    f.write(gt_str)

In [25]:
print(f'Accuracy is {sum(ground_truths) / len(ground_truths)}')

# Output

In [26]:
def count_words(txt):
    return len(re.findall(r'\w+', txt))

In [27]:
with open('Web/eval_output.txt', 'r') as f:
    lines = f.readlines()[-60:]

outputs = []
for l in lines:
    match = re.search(r': (.*)', l)
    # match = re.search(r'Verdict: (.*)', l)
    if match:
        outputs.append(match.group(1))

def tf_extractor(outputs):
    return [o[0] for o in outputs]

def true_false_extractor(outputs):
    return [bool(o) for o in outputs]

outputs = tf_extractor(outputs)

correct_cnt = 0
for i in range(len(outputs)):
   correct = ground_truths[i] == (outputs[i] == 'T')
   correct_cnt += correct

assert len(outputs) == len(ground_truths)

print(f'Claude 3 accuracy: {correct_cnt / len(outputs)}')
oracle = sum(ground_truths) / len(ground_truths)
print(f'Oracle accuracy is {max(oracle, 1-oracle)}')
if card_format != 'str':
    mixtral_path = f'./outputs/generative_{topic}/{folder_path}/eval_test_epoch_4_predictive.json'
else:
    mixtral_path = f'./outputs/generative_{card_format}_{topic}/{folder_path}/eval_test_epoch_4_predictive.json'

mixtral_path = f'./outputs/generative_{card_format}_{topic}/{folder_path}/eval_test_epoch_4_predictive.json'
with open(mixtral_path) as f:
    mixtral_outputs = json.load(f)
mixtral_accuracy = mixtral_outputs['metrics']['accuracies']
print(f'Mixtral accuracy: {mixtral_accuracy[0]}')

print(f'Word count: {count_words(str(card))}')


In [28]:
# Do a confusion matrix


outputs = [outputs[i] == 'T' for i in range(len(outputs))]

cm = confusion_matrix(ground_truths, true_false_extractor(outputs))
# plot with legend, label, for each block, label with the number
fig, ax = plt.subplots()
cax = ax.matshow(cm, cmap='Blues')
fig.colorbar(cax)
ax.set_xticklabels([''] + ['F', 'T'])
ax.set_yticklabels([''] + ['F', 'T'])
for (i, j), val in np.ndenumerate(cm):
    ax.text(j, i, f'{val}', ha='center', va='center')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title(f'Confusion Matrix on model {model}\n for topic {topic}')
plt.show()
