In [1]:
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
import os
os.environ['CUDA_VISIBLE_DEVICES']='0,1,2,3,4,5,6,7'
import torch.nn as nn
import inspect
import torch.multiprocessing as mp
# mp.set_start_method('spawn')
import json, pickle, copy
from peft import PeftModel
from tqdm import tqdm
import re
import concurrent.futures

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# model_pth = '../Llama-2-7b-chat-hf'
# peft_model_pth = './llama_2_7b_lora_3/5_epoch_finetuning'
model_pth = '../Meta-Llama-3-8B-Instruct'

tokenizer = AutoTokenizer.from_pretrained(model_pth,torch_dtype=torch.float16,load_in_4bit=False,load_in_8bit=False)
model = AutoModelForCausalLM.from_pretrained(model_pth,torch_dtype=torch.float16,load_in_4bit=False,load_in_8bit=False)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.73it/s]


In [3]:
class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [2]  # IDs of tokens where the generation should stop.
        for stop_id in stop_ids:
            if input_ids[0][-1] == stop_id:  # Checking if the last generated token is a stop token.
                return True
        return False
    
class BatchedStopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        stop_ids = [2]  # IDs of tokens where the generation should stop.
        for stop_id in stop_ids:
            # checking if stop token appears in every batch (not just the last token)
            if (input_ids == stop_id).any(dim=-1).all():
                return True
            # check if stop token is generated in all batches
            # if all([input_id[-1] == stop_id for input_id in input_ids]):
            #     return True
        return False

In [4]:
with open('../data/grade_school_math/data/test.jsonl', 'r') as f:
    data_test = f.readlines()
    data_test = [json.loads(d) for d in data_test]

with open('../data/grade_school_math/data/train.jsonl', 'r') as f:
    data_train = f.readlines()
    data_train = [json.loads(d) for d in data_train]

with open('./generated_data/actor_response_data_01.jsonl', 'r') as f:
    actor_response_data = f.readlines()
    actor_response_data = [json.loads(d) for d in actor_response_data]

with open('./generated_data/critic_response_data_01.jsonl', 'r') as f:
    critic_response_data = f.readlines()
    critic_response_data = [json.loads(d) for d in critic_response_data]

In [5]:
models = [copy.deepcopy(model).to(torch.device(f'cuda:{_}')) for _ in range(8)]
tokenizers = [AutoTokenizer.from_pretrained(model_pth,torch_dtype=torch.float16,load_in_4bit=False,load_in_8bit=False) for _ in range(8)]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [30]:
# actor_models = [PeftModel.from_pretrained(models[i], peft_model_pth, torch_dtype=torch.float16) for i in range(8)]
# critic_models = [PeftModel.from_pretrained(models[i], peft_model_pth, torch_dtype=torch.float16) for i in range(8)]
# summarizer_models = [PeftModel.from_pretrained(models[i], peft_model_pth, torch_dtype=torch.float16) for i in range(8)]

# actor_models = [PeftModel.from_pretrained(models[i], peft_model_pth, torch_dtype=torch.float16) for i in range(4)]
critic_models = models
summarizer_models = models

In [6]:
# id = 0
# for data in data_train + data_test:
#     data['id'] = id
#     id += 1
print(len(data_train), len(data_test), len(actor_response_data), len(critic_response_data))

7473 1319 133632 133632


In [32]:
ac, wa, iv = 0, 0, 0
for critic_response in critic_response_data:
    if critic_response['judge_critic'] == 'Accepted':
        ac += 1
    elif critic_response['judge_critic'] == 'Wrong Answer':
        wa += 1
    else:
        iv += 1
print(ac, wa, iv)

71477 0 62155


In [35]:
for critic_response in critic_response_data[:16]:
    print(critic_response['critic_response'])
    print(critic_response['judge_critic'])

#### The answer is: Accepted.
Accepted
#### Wrong Answer: The actor added the extra 2 gallons for the bleach cycle to the total amount of water already calculated, which is incorrect. The correct approach would be to add the extra 2 gallons to the total amount of water needed, which is already 72 gallons, resulting in a total of 74 gallons.
Invalid
#### The answer is: Accepted.
Accepted
#### Wrong Answer: The actor's answer does not take into account the correct calculation of the amount Ian paid to Helen and Benedict.
Invalid
#### The answer is: Accepted.
Accepted
#### The answer is: Accepted.
Accepted
#### The answer is: Accepted.
Accepted
#### The answer is: Accepted.
Accepted
#### Wrong Answer: The calculation is correct, but the answer is wrong because the actor's answer is not in line with the problem. The problem asks for the age of the fourth child, not the year of birth.
Invalid
#### Wrong Answer: The actor's calculation for the total number of push-ups is incorrect.
Invalid
#

In [7]:
for tokenizer in tokenizers:
    tokenizer.pad_token = tokenizer.bos_token

In [8]:
embedding_dim = 4096

In [33]:
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
def extract_answer(completion):
    match = ANS_RE.search(completion)
    if match:
        match_str = match.group(1).strip()
        match_str = match_str.replace(",", "")
        return match_str
    else:
        return "[invalid]"    

def extract_judgement(output):
    if output.startswith("#### The answer is: Accepted."):
        return 'Accepted'
    if output.startswith("#### The answer is: Wrong Answer."):
        return 'Wrong Answer'
    return 'Invalid'


In [10]:
data_use = data_train[:16]
def data_generate_e2e(actor, critic, summarizer, tokenizer, sys_prompts, data, batch_size, device, use_tqdm=False):
    n_data = len(data)
    n_batches = n_data // batch_size
    results = []
    stop = BatchedStopOnTokens()
    generate_kwargs = dict(
        inputs_embeds=None,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        top_k=50,
        temperature=0.7,
        num_beams=1,
        stopping_criteria=StoppingCriteriaList([stop]),
        attention_mask=None
    )
    if use_tqdm:
        pbar = tqdm(total=n_data, desc=f'Data generating for actor response on {device}', ncols=100)
    for batch_idx in range(n_batches):
        batch_data = data[batch_idx*batch_size:(batch_idx+1)*batch_size]
        questions = [d['question'] for d in batch_data]
        answers = [d['answer'] for d in batch_data]
        ids = [d['id'] for d in batch_data]

        # generate batched input embeddings, attention mask for actor and apply to generate_kwargs
        input_prompts = ['<|system|>:' + sys_prompts['actor'] + '</s>\n<|user|>:' + question + '</s>\n<|assistant|>:' for question in questions]
        input_embeds = [actor.get_input_embeddings()(tokenizer(input_prompt, return_tensors='pt', padding=False, truncation=True, max_length=512).input_ids.to(device)) for input_prompt in input_prompts]
        max_len = max([input_embed.size(1) for input_embed in input_embeds])
        attention_mask = torch.concatenate([torch.cat([torch.zeros(max_len - input_embed.size(1), device=device), torch.ones(input_embed.size(1), device=device)]).unsqueeze(0) for input_embed in input_embeds], dim=0).to(dtype=torch.float16)
        input_embeds = torch.concatenate([torch.cat([torch.zeros(1, max_len - input_embed.size(1), embedding_dim, device=device), input_embed], dim=1) for input_embed in input_embeds], dim=0).to(dtype=torch.float16)
        generate_kwargs['inputs_embeds'] = input_embeds
        generate_kwargs['attention_mask'] = attention_mask

        # generate actor responses
        outputs = actor.generate(**generate_kwargs)
        for actor_output in outputs:
            if tokenizer.eos_token_id in actor_output:
                eos_idx = (actor_output == tokenizer.eos_token_id).nonzero()[0].item()
                actor_output[eos_idx+1:] = tokenizer.pad_token_id
        actor_responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        # # generate batched input embeddings, attention mask for critic and apply to generate_kwargs
        # input_prompts = ['<|system|>:' + sys_prompts['critic'] + '</s>\n<|question|>:' + question + '</s>\n<|correct answer|>:'+ answer + '</s>\n<|student|>:' + actor_response + '</s>\n<|assistant|>:' for question, answer, actor_response in zip(questions, answers, actor_responses)]
        # input_embeds = [critic.get_input_embeddings()(tokenizer(input_prompt, return_tensors='pt', padding=False, truncation=True, max_length=512).input_ids.to(device)) for input_prompt in input_prompts]
        # max_len = max([input_embed.size(1) for input_embed in input_embeds])
        # attention_mask = torch.concatenate([torch.cat([torch.zeros(max_len - input_embed.size(1), device=device), torch.ones(input_embed.size(1), device=device)]).unsqueeze(0) for input_embed in input_embeds], dim=0).to(dtype=torch.float16)
        # input_embeds = torch.concatenate([torch.cat([torch.zeros(1, max_len - input_embed.size(1), embedding_dim, device=device), input_embed], dim=1) for input_embed in input_embeds], dim=0).to(dtype=torch.float16)
        # generate_kwargs['inputs_embeds'] = input_embeds
        # generate_kwargs['attention_mask'] = attention_mask

        # # generate critic responses
        # outputs = critic.generate(**generate_kwargs)
        # for critic_output in outputs:
        #     if tokenizer.eos_token_id in critic_output:
        #         eos_idx = (critic_output == tokenizer.eos_token_id).nonzero()[0].item()
        #         critic_output[eos_idx+1:] = tokenizer.pad_token_id
        # critic_responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        # test if responses include padding tokens
        for actor_response in actor_responses:
            print(f'actor_response: {actor_response}\n')
        # for critic_response in critic_responses:
        #     print(f'critic_response: {critic_response}\n')
            # print(f'last 5 charactor: {actor_response[-5:]}\n')

In [27]:
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
def extract_answer(completion):
    match = ANS_RE.search(completion)
    if match:
        match_str = match.group(1).strip()
        match_str = match_str.replace(",", "")
        return match_str
    else:
        return "[invalid]"

def generate_actor_response(actor, tokenizer, sys_prompt, data, batch_size, device, use_tqdm=False):
    n_data = len(data)
    n_batches = n_data // batch_size
    results = []
    stop = BatchedStopOnTokens()
    generate_kwargs = dict(
        inputs_embeds=None,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        top_k=50,
        temperature=0.7,
        num_beams=1,
        stopping_criteria=StoppingCriteriaList([stop]),
        attention_mask=None
    )
    if use_tqdm:
        pbar = tqdm(total=n_data, desc=f'Data generating for critic response on {device}', ncols=100)
    for batch_idx in range(n_batches):
        batch_data = data[batch_idx*batch_size:(batch_idx+1)*batch_size]
        questions = [d['question'] for d in batch_data]
        answers = [d['answer'] for d in batch_data]
        ids = [d['id'] for d in batch_data]

        # generate batched input embeddings, attention mask for actor and apply to generate_kwargs
        input_prompts = ['<|system|>:' + sys_prompt + '</s>\n<|user|>:' + question + '</s>\n<|assistant|>:' for question in questions]
        input_embeds = [actor.get_input_embeddings()(tokenizer(input_prompt, return_tensors='pt', padding=False, truncation=True, max_length=512).input_ids.to(device)) for input_prompt in input_prompts]
        max_len = max([input_embed.size(1) for input_embed in input_embeds])
        attention_mask = torch.concatenate([torch.cat([torch.zeros(max_len - input_embed.size(1), device=device), torch.ones(input_embed.size(1), device=device)]).unsqueeze(0) for input_embed in input_embeds], dim=0).to(dtype=torch.float16)
        input_embeds = torch.concatenate([torch.cat([torch.zeros(1, max_len - input_embed.size(1), embedding_dim, device=device), input_embed], dim=1) for input_embed in input_embeds], dim=0).to(dtype=torch.float16)
        generate_kwargs['inputs_embeds'] = input_embeds
        generate_kwargs['attention_mask'] = attention_mask

        # generate actor responses
        outputs = actor.generate(**generate_kwargs)
        for actor_output in outputs:
            if tokenizer.eos_token_id in actor_output:
                eos_idx = (actor_output == tokenizer.eos_token_id).nonzero()[0].item()
                actor_output[eos_idx+1:] = tokenizer.pad_token_id
        actor_responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        results.extend(actor_responses)
        if use_tqdm:
            pbar.update(batch_size)
    if use_tqdm:
        pbar.close()
    return results

def generate_critic_response(critic, tokenizer, actor_data, question_data, batch_size, device, use_tqdm=False):
    n_data = len(actor_data)
    n_batches = n_data // batch_size
    results = []
    # stop = BatchedStopOnTokens()
    generate_kwargs = dict(
        inputs_embeds=None,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        top_k=50,
        temperature=0.7,
        num_beams=1,
        # stopping_criteria=StoppingCriteriaList([stop]),
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=[tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids("<|eot_id|>")],
        attention_mask=None
    )
    if use_tqdm:
        pbar = tqdm(total=n_data, desc=f'Data generating end to end on {device}', ncols=100)
    for batch_idx in range(n_batches):
        batch_data = actor_data[batch_idx*batch_size:(batch_idx+1)*batch_size]
        questions = [question_data[d['question_id']]['question'] for d in batch_data]
        answers = [extract_answer(question_data[d['question_id']]['answer']) for d in batch_data]
        ids = [d['id'] for d in batch_data]
        actor_responses = [d['actor_response'] for d in batch_data]

        # generate batched input embeddings, attention mask for critic and apply to generate_kwargs
        messages = [[{'role': 'system', 'content': 'You are a critic who is responsible for judging the correctness of the actor\'s answer. Provided with the math problem, correct answer and the student\'s answer, you need to judge whether the actor\'s answer is correct. If the actor\'s answer is right, respond with "#### The answer is: Accepted." Otherwise, analyze the reason why the actor arrived at the wrong answer and respond with "#### The answer is: Wrong Answer. [Reason for the wrong answer, without displaying the correct number to the question]".'},
                     {'role': 'question', 'content': question},
                     {'role': 'correct answer', 'content': answer},
                     {'role': 'actor\'s answer', 'content': actor_response}] for question, answer, actor_response in zip(questions, answers, actor_responses)]
        # input_prompts = ['<|system|>:' + sys_prompt + '</s>\n<|question|>:' + question + '</s>\n<|correct answer|>:'+ answer + '</s>\n<|student|>:' + actor_response + '</s>\n<|assistant|>:' for question, answer, actor_response in zip(questions, answers, actor_responses)]
        input_prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages]
        input_embeds = [critic.get_input_embeddings()(tokenizer(input_prompt, return_tensors='pt', padding=False, truncation=True, max_length=1024).input_ids.to(device)) for input_prompt in input_prompts]
        max_len = max([input_embed.size(1) for input_embed in input_embeds])
        attention_mask = torch.concatenate([torch.cat([torch.zeros(max_len - input_embed.size(1), device=device), torch.ones(input_embed.size(1), device=device)]).unsqueeze(0) for input_embed in input_embeds], dim=0).to(dtype=torch.float16)
        input_embeds = torch.concatenate([torch.cat([torch.zeros(1, max_len - input_embed.size(1), embedding_dim, device=device), input_embed], dim=1) for input_embed in input_embeds], dim=0).to(dtype=torch.float16)
        generate_kwargs['inputs_embeds'] = input_embeds
        generate_kwargs['attention_mask'] = attention_mask

        # generate critic responses
        outputs = critic.generate(**generate_kwargs)
        for critic_output in outputs:
            if tokenizer.eos_token_id in critic_output:
                eos_idx = (critic_output == tokenizer.eos_token_id).nonzero()[0].item()
                critic_output[eos_idx+1:] = tokenizer.pad_token_id
        critic_responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        results.extend(critic_responses)
        if use_tqdm:
            pbar.update(batch_size)
    if use_tqdm:
        pbar.close()
    return results

def generate_summarizer_response(critic, tokenizer, critic_data, actor_data, question_data, batch_size, device, use_tqdm=False):
    n_data = len(critic_data)
    n_batches = n_data // batch_size
    results = []
    # stop = BatchedStopOnTokens()
    generate_kwargs = dict(
        inputs_embeds=None,
        max_new_tokens=1024,
        do_sample=True,
        top_p=0.95,
        top_k=50,
        temperature=0.7,
        num_beams=1,
        # stopping_criteria=StoppingCriteriaList([stop]),
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=[tokenizer.eos_token_id,tokenizer.convert_tokens_to_ids("<|eot_id|>")],
        attention_mask=None
    )
    if use_tqdm:
        pbar = tqdm(total=n_data, desc=f'Generating summarizer response data on {device}', ncols=100)
    for batch_idx in range(n_batches):
        batch_data = critic_data[batch_idx*batch_size:(batch_idx+1)*batch_size]
        actor_data_dicts = [actor_data[d['actor_response_id']] for d in batch_data]
        question_data_dicts = [question_data[d['question_id']] for d in actor_data_dicts]
        questions = [d['question'] for d in question_data_dicts] 
        answers = [extract_answer(d['answer']) for d in question_data_dicts]
        ids = [d['id'] for d in batch_data]
        actor_responses = [d['actor_response'] for d in actor_data_dicts]
        critic_responses = [d['critic_response'] for d in batch_data]

        # generate batched input embeddings, attention mask for critic and apply to generate_kwargs
        messages = [[{'role': 'system', 'content': 'You are a summarizer who is responsible for deciding the final answer to a given math problem, with the help of an actor\'s solution and a critic\'s judgement of whether the actor\'s answer is correct or not. If the actor\'s answer is correct, summarize the analysis. Otherwise, fix the actor\'s answer according to the critic\'s feedback. Only the correct analysis is allowed to be presented. Do not include statements about whether the actor or critic is correct. Finally, add "\n\n#### [Answer to the question with digits only]" as a summarization.'},
                     {'role': 'question', 'content': question},
                     {'role': 'actor\'s answer', 'content': actor_response},
                     {'role': 'critic\'s judgement', 'contenet': critic_response}] for question, actor_response, critic_response in zip(questions, actor_responses, critic_responses)]
        # input_prompts = ['<|system|>:' + sys_prompt + '</s>\n<|question|>:' + question + '</s>\n<|correct answer|>:'+ answer + '</s>\n<|student|>:' + actor_response + '</s>\n<|assistant|>:' for question, answer, actor_response in zip(questions, answers, actor_responses)]
        input_prompts = [tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True) for message in messages]
        input_embeds = [critic.get_input_embeddings()(tokenizer(input_prompt, return_tensors='pt', padding=False, truncation=True, max_length=1024).input_ids.to(device)) for input_prompt in input_prompts]
        max_len = max([input_embed.size(1) for input_embed in input_embeds])
        attention_mask = torch.concatenate([torch.cat([torch.zeros(max_len - input_embed.size(1), device=device), torch.ones(input_embed.size(1), device=device)]).unsqueeze(0) for input_embed in input_embeds], dim=0).to(dtype=torch.float16)
        input_embeds = torch.concatenate([torch.cat([torch.zeros(1, max_len - input_embed.size(1), embedding_dim, device=device), input_embed], dim=1) for input_embed in input_embeds], dim=0).to(dtype=torch.float16)
        generate_kwargs['inputs_embeds'] = input_embeds
        generate_kwargs['attention_mask'] = attention_mask

        # generate critic responses
        outputs = critic.generate(**generate_kwargs)
        for critic_output in outputs:
            if tokenizer.eos_token_id in critic_output:
                eos_idx = (critic_output == tokenizer.eos_token_id).nonzero()[0].item()
                critic_output[eos_idx+1:] = tokenizer.pad_token_id
        critic_responses = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        results.extend(critic_responses)
        if use_tqdm:
            pbar.update(batch_size)
    if use_tqdm:
        pbar.close()
    return results


In [12]:
# data_generate_e2e(models[0], models[0], summarizer_models[0], tokenizers[0], {'actor': "Solving the following math problem and respond with '\n#### <answer>' with <answer> substituted by the correct number in the very end:\n", 'critic': "Judge the student's answer (not solving the problem by yourself) to the following math question with the given correct answer. Respond with '\n#### accepted' if the student is correct, or '\n#### wrong answer: <why the answer is wrong>' if the student's answer is wrong.\n"}, data_use[4:6], 2, torch.device('cuda:0'), use_tqdm=True)

In [34]:
# actor_responses = generate_actor_response(actor_models[0], tokenizers[0], "Solving the following math problem and respond with '\n#### <answer>' with <answer> substituted by the correct number in the very end:\n", [data_use[0]] * 4, 2, torch.device('cuda:0'), use_tqdm=True)
# for actor_response in actor_responses:
#     print(actor_response)
# critic_responses = generate_critic_response(critic_models[0], tokenizers[4], "Judge the student's answer (not solving the problem by yourself) to the following math question using the <correct answer> provided as reference. If the student is correct, respond with '\n#### accepted.' Otherwise, analyse why the student is wrong and respond with '\n#### wrong answer: <why the student is wrong>'.\n", data_use[0:2], actor_responses, 2, torch.device('cuda:4'), use_tqdm=True)
# print(critic_responses)
critic_responses = generate_critic_response(critic_models[0], tokenizers[0], actor_response_data[:4], data_train, 4, torch.device('cuda:0'), use_tqdm=True) 
for critic_response in critic_responses:
    print(critic_response)
# summarizer_responses = generate_summarizer_response(summarizer_models[0], tokenizers[0], critic_response_data[:4], actor_response_data, data_train, 4, torch.device('cuda:0'), use_tqdm=True)
# for summarizer_response in summarizer_responses:
#     print(summarizer_response)
# print(len(actor_responses), len(critic_responses), len(summarizer_responses))
    print(extract_judgement(critic_response))

Data generating end to end on cuda:0: 100%|███████████████████████████| 4/4 [00:01<00:00,  2.46it/s]

#### The answer is: Wrong Answer. Jackson's calculation for the total number of calories in the salad and pizza is correct, but then he adds the number of calories in the salad and pizza to get 132.5, which is incorrect.
Wrong Answer
#### The answer is: Accepted.
Accepted
#### The answer is: Accepted.
Accepted
#### The answer is: Accepted.
Accepted





In [28]:
def extract_judgement(output):
    if output.startswith("#### The answer is: Accepted."):
        return 'Accepted'
    if output.startswith("#### The answer is: Wrong Answer."):
        return 'Wrong Answer'
    return 'Invalid'

In [29]:
print([extract_judgement(critic_response) for critic_response in critic_responses])

['Wrong Answer', 'Accepted', 'Accepted', 'Accepted']
