A simple implementation of gsm8k data set with options for cot and few shot examples

In [6]:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import re

In [7]:
class GSM8KEval:
  def __init__(self, model_name: str):
    self.tokenizer = AutoTokenizer.from_pretrained(model_name)
    self.model = AutoModelForCausalLM.from_pretrained(model_name)
    self.model.to('cuda')

    if self.tokenizer.pad_token is None:
          self.tokenizer.pad_token = self.tokenizer.eos_token

  def load_dataset(self, path: str):
    dataset = load_dataset(path, 'main')
    self.train_dataset = dataset['train']
    self.test_dataset = dataset['test']

  def get_few_shot_data(self, num_few_shot: int = 0):
    few_shot_examples = ''
    train_data = self.train_dataset.select(range(num_few_shot))
    for example in train_data:
      few_shot_examples += (f'Question: {example["question"]}\nAnswer: {example["answer"]}\n')
    self.few_shot_examples = few_shot_examples

  def build_request(self, question: str, use_cot: bool = False):
    if use_cot:
      return f'{self.few_shot_examples}Question: {question} Solve this step by step.\nAnswer:'
    else:
      return f'{self.few_shot_examples}Question: {question}\nAnswer:'

  def generate_output(self, request: str):
    input = self.tokenizer(request, return_tensors="pt", padding=True)
    # Move input tensor to the same device as the model
    input = {k: v.to(self.model.device) for k, v in input.items()}

    # Define stop tokens
    stop_tokens = ['Question:', '</s>', '<|im_end|>']
    stop_token_ids = [self.tokenizer.encode(token, add_special_tokens=False)[0] for token in stop_tokens if self.tokenizer.encode(token, add_special_tokens=False)]

    with torch.no_grad():
      outputs = self.model.generate(
          **input,
          max_new_tokens=150,
          do_sample=False,
          pad_token_id=self.tokenizer.pad_token_id,
          eos_token_id=stop_token_ids,
      )
    # Decode output for just what is generated by the model
    input_sequence_length = input['input_ids'].shape[1]
    generated_sequence = outputs[0][input_sequence_length:]
    generated_text = self.tokenizer.decode(
            generated_sequence,
            skip_special_tokens=True
        )
    return generated_text

  def extract_answer(self, output: str):
    answer = output.split('####')[-1].strip().replace(',', '')

    numbers_in_answer = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', answer)
    if numbers_in_answer:
      return float(numbers_in_answer[-1].replace(',', ''))

    numbers_in_output = re.findall(r'-?\d+(?:,\d{3})*(?:\.\d+)?', output)
    if numbers_in_output:
      # Remove commas and convert to float
      return float(numbers_in_output[-1].replace(',', ''))

    return None


  def isExactmatch(self, output: str, answer: str):
    # Format of final answer is ####answer, so get that number from string
    final_answer = float(answer.split('####')[-1].strip().replace(',', ''))
    if final_answer == output:
      return True
    return False

In [8]:
sample_eval = GSM8KEval('gpt2')

In [9]:
sample_eval.load_dataset('openai/gsm8k')

In [10]:
LIMIT = 5
correct_answers = 0
for i in tqdm(range(LIMIT)):
  sample_eval.get_few_shot_data(1)
  request = sample_eval.build_request(sample_eval.test_dataset[i]['question'], True)
  output = sample_eval.generate_output(request)

  answer = sample_eval.extract_answer(output)
  is_correct = sample_eval.isExactmatch(answer, sample_eval.test_dataset[i]['answer'])
  if is_correct:
    correct_answers += 1
print(f'Accuracy: {correct_answers}/{LIMIT}')

100%|██████████| 5/5 [00:03<00:00,  1.58it/s]

Accuracy: 1/5



