# Definitions

In [1]:
# imports
from typing import Tuple, List

In [2]:
# constants
DATASET_NAMES = ["evalplus/humanevalplus", "dz1/CodeScore-HumanEval-ET"]
MODEL_NAMES = ["Qwen/Qwen2.5-1.5B-Instruct", "lmsys/vicuna-7b-v1.5"]
NUM_PROMPTS = 10
SEED = 42
INDICES = [124,  31,  59,  76,  81,   7,  89,  25, 121, 112]

In [3]:
# set environmental variables to suppress warnings
import os

os.environ["HF_ALLOW_CODE_EVAL"] = "1"
os.environ["TOKENIZERS_PARALLELISM"]= "false"

In [4]:
import re

def extract_python_code(cot_response: str) -> list[str]:
    """
    Extracts all Python code blocks from a string.

    Args:
        cot_response: The string containing the LLM response.

    Returns:
        A list of strings, where each string is a Python code block.
    """
    # The `re.DOTALL` flag is crucial for matching across multiple lines.
    pattern = r"```python\n(.*?)```"
    matches = re.findall(pattern, cot_response, re.DOTALL)

    # Clean up the captured code by removing leading/trailing whitespace
    return [code.strip() for code in matches]

In [5]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from evaluate import load
import torch

def create_cot_prompt(prompt: str) -> str:
  cot_prompt = f"Implement the python function given by this definition:\n\n{prompt}\n\nLet's think step by step.\n"

  return cot_prompt

class PassKEvaluator:
  def __init__(
    self,
    model_name: str,
    dtype='auto',
    device_map='auto',
    quantization_config=None
  ):
    # instance variables
    self.model_name = model_name
    self.metric = load("code_eval")

    # create model
    self.model = AutoModelForCausalLM.from_pretrained(
      self.model_name,
      dtype=dtype,
      device_map=device_map,
      quantization_config=quantization_config
    )

    # set model to eval mode
    self.model.eval()

    # create tokenizer
    self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)

  def evaluate(
    self,
    test_cases: List[str],
    prompt: str,
    k: List[int]=[1, 10, 20]
  ):
    # get largest (usually last) k value
    largest_k = max(k)

    # query model wth prompts and output responses
    generated_prompts = self.generate(prompt, largest_k)

    # parse python code
    candidates = []

    for p in generated_prompts:
        extracted = extract_python_code(p)
        if len(extracted) > 0:
            candidates.append(extracted[0])
        else:
            candidates.append('')

    # compute pass@k metric
    results = self.metric.compute(
      references=test_cases,
      predictions=[candidates],
      k=k
    )

    return results, candidates

  def generate(self, prompt: str, num_samples: int) -> List[str]:
    # tokenize prompt/input
    inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

    # Generate multiple samples using a high temperature for diversity
    outputs = self.model.generate(
        **inputs,
        do_sample=True,
        temperature=0.8,
        num_return_sequences=num_samples,
        max_new_tokens=300
    ).to('cpu')

    # return code examples/responses as list of strings
    return [
        self.tokenizer.decode(
          out,
          skip_special_tokens=True
        )
      for out in outputs
    ]

  def __del__(self):
    del self.model
    del self.tokenizer
    del self.metric
    torch.cuda.empty_cache()

In [6]:
import multiprocessing as mp
from tqdm import trange
import json

# evaluation process
def run_evals(ev, ds, prefix):
    for i in trange(len(ds)):
        try:
            res = {}
    
            sample = ds[i]
    
            prompt = create_cot_prompt(sample['prompt'])
            test = sample['test']
            results, candidates = ev.evaluate(
                test,
                prompt,
                [1,5]
            )
    
            print(results)
    
            res['prompt'] = prompt
            res['test'] = test
            res['results'] = results
            res['candidates'] = candidates
    
            with open(f'{prefix}_{i}.json', "w") as f:
                json.dump(res, f, indent=4)
        except Exception as e:
            print(e)
        finally:
            continue

# Loading the Dataset

In [7]:
from datasets import load_dataset

In [8]:
# load humanevalplus dataset
humanevalplus_ds = load_dataset(DATASET_NAMES[0])['test']

# Finding Test Cases

In [9]:
# create 1.5B parameter qwen model
ev = PassKEvaluator(MODEL_NAMES[0])

In [None]:
# run qwen on dataset

# file to log to
file_prefix = "cot_results"

run_evals(ev, humanevalplus_ds, file_prefix)

  1%|▎                                        | 1/164 [00:30<1:23:09, 30.61s/it]

list index out of range


  1%|▌                                        | 2/164 [01:00<1:21:44, 30.28s/it]

list index out of range


  2%|▊                                        | 3/164 [01:15<1:02:54, 23.44s/it]

In [9]:
# create vicuna model to evluate and use quantization so can run on 8gb VRAM hardware
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16  # Recommended for speed
)

ev = PassKEvaluator(MODEL_NAMES[1], quantization_config=quantization_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


In [None]:
# run vicuna on dataset
from tqdm import trange
import json

# file to log to
file_prefix = "cot_results_llama"

# run on llama model
run_evals(ev, humanevalplus_ds, file_prefix)

  1%|▎                                        | 1/164 [00:29<1:21:23, 29.96s/it]

list index out of range


  1%|▌                                        | 2/164 [00:59<1:19:37, 29.49s/it]

list index out of range


  2%|▊                                        | 3/164 [01:27<1:18:22, 29.21s/it]

list index out of range


  2%|█                                        | 4/164 [01:58<1:18:54, 29.59s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


  3%|█▎                                       | 5/164 [02:27<1:18:15, 29.53s/it]

list index out of range


  4%|█▌                                       | 6/164 [02:56<1:17:19, 29.37s/it]

list index out of range


  4%|█▊                                       | 7/164 [03:25<1:16:42, 29.31s/it]

list index out of range


  5%|██                                       | 8/164 [03:55<1:16:19, 29.36s/it]

{'pass@1': np.float64(0.8), 'pass@5': np.float64(1.0)}


  5%|██▎                                      | 9/164 [04:24<1:16:05, 29.46s/it]

{'pass@1': np.float64(0.8), 'pass@5': np.float64(1.0)}


  6%|██▍                                     | 10/164 [04:54<1:15:44, 29.51s/it]

{'pass@1': np.float64(0.8), 'pass@5': np.float64(1.0)}


  7%|██▋                                     | 11/164 [05:24<1:15:23, 29.56s/it]

list index out of range


  7%|██▉                                     | 12/164 [05:53<1:14:44, 29.50s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


  8%|███▏                                    | 13/164 [06:23<1:14:19, 29.53s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


  9%|███▍                                    | 14/164 [06:52<1:13:19, 29.33s/it]

list index out of range


  9%|███▋                                    | 15/164 [07:20<1:12:22, 29.14s/it]

list index out of range


 10%|███▉                                    | 16/164 [07:50<1:12:01, 29.20s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 10%|████▏                                   | 17/164 [08:16<1:09:04, 28.19s/it]

list index out of range


 11%|████▍                                   | 18/164 [08:46<1:10:06, 28.82s/it]

{'pass@1': np.float64(0.6), 'pass@5': np.float64(1.0)}


 12%|████▋                                   | 19/164 [09:15<1:10:04, 29.00s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 12%|████▉                                   | 20/164 [09:44<1:09:48, 29.08s/it]

list index out of range


 13%|█████                                   | 21/164 [10:14<1:09:46, 29.27s/it]

list index out of range


 13%|█████▎                                  | 22/164 [10:44<1:09:20, 29.30s/it]

list index out of range


 14%|█████▌                                  | 23/164 [11:12<1:08:22, 29.10s/it]

list index out of range


 15%|█████▊                                  | 24/164 [11:41<1:07:30, 28.93s/it]

list index out of range


 15%|██████                                  | 25/164 [12:09<1:06:50, 28.85s/it]

list index out of range


 16%|██████▎                                 | 26/164 [12:39<1:06:43, 29.01s/it]

list index out of range


 16%|██████▌                                 | 27/164 [13:08<1:06:33, 29.15s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 17%|██████▊                                 | 28/164 [13:37<1:05:45, 29.01s/it]

list index out of range


 18%|███████                                 | 29/164 [14:06<1:05:19, 29.04s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 18%|███████▎                                | 30/164 [14:32<1:02:49, 28.13s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 19%|███████▌                                | 31/164 [15:02<1:03:26, 28.62s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 20%|███████▊                                | 32/164 [15:31<1:03:22, 28.81s/it]

list index out of range


 20%|████████                                | 33/164 [16:02<1:04:07, 29.37s/it]

list index out of range


 21%|████████▎                               | 34/164 [16:31<1:03:43, 29.41s/it]

list index out of range


 21%|████████▌                               | 35/164 [17:00<1:02:48, 29.21s/it]

list index out of range


 22%|████████▊                               | 36/164 [17:27<1:00:41, 28.45s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 23%|█████████                               | 37/164 [17:56<1:00:34, 28.62s/it]

list index out of range


 23%|█████████▎                              | 38/164 [18:25<1:00:38, 28.87s/it]

list index out of range


 24%|█████████▌                              | 39/164 [18:56<1:01:15, 29.40s/it]

list index out of range


 24%|█████████▊                              | 40/164 [19:26<1:01:07, 29.58s/it]

list index out of range


 25%|██████████                              | 41/164 [19:56<1:01:15, 29.88s/it]

list index out of range


 26%|██████████▏                             | 42/164 [20:27<1:01:12, 30.10s/it]

list index out of range


 26%|███████████                               | 43/164 [20:54<59:01, 29.27s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 27%|███████████▎                              | 44/164 [21:24<58:55, 29.46s/it]

list index out of range


 27%|███████████▌                              | 45/164 [21:53<58:15, 29.37s/it]

list index out of range


 28%|███████████▊                              | 46/164 [22:22<57:17, 29.13s/it]

list index out of range


 29%|████████████                              | 47/164 [22:52<57:12, 29.33s/it]

list index out of range


 29%|████████████▎                             | 48/164 [23:21<56:28, 29.21s/it]

list index out of range


 30%|████████████▌                             | 49/164 [23:50<55:51, 29.14s/it]

list index out of range


 30%|████████████▊                             | 50/164 [24:19<55:25, 29.17s/it]

list index out of range


 31%|█████████████                             | 51/164 [24:48<54:55, 29.16s/it]

list index out of range


 32%|█████████████▎                            | 52/164 [25:18<54:52, 29.40s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 32%|█████████████▌                            | 53/164 [25:48<54:29, 29.45s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 33%|█████████████▊                            | 54/164 [26:16<53:36, 29.24s/it]

list index out of range


 34%|██████████████                            | 55/164 [26:46<53:23, 29.39s/it]

list index out of range


 34%|██████████████▎                           | 56/164 [27:15<52:31, 29.18s/it]

list index out of range


 35%|██████████████▌                           | 57/164 [27:44<52:01, 29.17s/it]

list index out of range


 35%|██████████████▊                           | 58/164 [28:13<51:27, 29.13s/it]

list index out of range


 36%|███████████████                           | 59/164 [28:42<51:03, 29.18s/it]

list index out of range


 37%|███████████████▎                          | 60/164 [29:11<50:22, 29.07s/it]

list index out of range


 37%|███████████████▌                          | 61/164 [29:41<50:12, 29.25s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 38%|███████████████▉                          | 62/164 [30:10<49:43, 29.25s/it]

list index out of range


 38%|████████████████▏                         | 63/164 [30:39<49:12, 29.23s/it]

list index out of range


 39%|████████████████▍                         | 64/164 [31:09<48:57, 29.37s/it]

list index out of range


 40%|████████████████▋                         | 65/164 [31:38<48:31, 29.41s/it]

list index out of range


 40%|████████████████▉                         | 66/164 [32:07<47:50, 29.29s/it]

list index out of range


 41%|█████████████████▏                        | 67/164 [32:37<47:23, 29.31s/it]

list index out of range


 41%|█████████████████▍                        | 68/164 [33:07<47:28, 29.67s/it]

list index out of range


 42%|█████████████████▋                        | 69/164 [33:39<47:49, 30.21s/it]

list index out of range


 43%|█████████████████▉                        | 70/164 [34:08<47:07, 30.08s/it]

list index out of range


 43%|██████████████████▏                       | 71/164 [34:38<46:22, 29.92s/it]

list index out of range


 44%|██████████████████▍                       | 72/164 [35:07<45:37, 29.76s/it]

list index out of range


 45%|██████████████████▋                       | 73/164 [35:38<45:31, 30.02s/it]

list index out of range


 45%|██████████████████▉                       | 74/164 [36:08<44:54, 29.93s/it]

list index out of range


 46%|███████████████████▏                      | 75/164 [36:38<44:32, 30.02s/it]

list index out of range


 46%|███████████████████▍                      | 76/164 [37:07<43:36, 29.73s/it]

list index out of range


 47%|███████████████████▋                      | 77/164 [37:37<43:05, 29.72s/it]

list index out of range


 48%|███████████████████▉                      | 78/164 [38:06<42:26, 29.62s/it]

list index out of range


 48%|████████████████████▏                     | 79/164 [38:37<42:43, 30.15s/it]

list index out of range


 49%|████████████████████▍                     | 80/164 [39:07<42:01, 30.02s/it]

list index out of range


 49%|████████████████████▋                     | 81/164 [39:37<41:29, 30.00s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 50%|█████████████████████                     | 82/164 [40:08<41:25, 30.31s/it]

list index out of range


 51%|█████████████████████▎                    | 83/164 [40:37<40:24, 29.93s/it]

list index out of range


 51%|█████████████████████▌                    | 84/164 [41:06<39:23, 29.55s/it]

list index out of range


 52%|█████████████████████▊                    | 85/164 [41:36<39:09, 29.74s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 52%|██████████████████████                    | 86/164 [42:05<38:18, 29.46s/it]

list index out of range


 53%|██████████████████████▎                   | 87/164 [42:34<37:49, 29.47s/it]

list index out of range


 54%|██████████████████████▌                   | 88/164 [43:05<37:48, 29.84s/it]

list index out of range


 54%|██████████████████████▊                   | 89/164 [43:35<37:25, 29.94s/it]

list index out of range


 55%|███████████████████████                   | 90/164 [44:05<36:40, 29.73s/it]

list index out of range


 55%|███████████████████████▎                  | 91/164 [44:34<36:02, 29.62s/it]

list index out of range


 56%|███████████████████████▌                  | 92/164 [45:03<35:24, 29.51s/it]

list index out of range


 57%|███████████████████████▊                  | 93/164 [45:33<34:56, 29.53s/it]

list index out of range


 57%|████████████████████████                  | 94/164 [46:02<34:23, 29.47s/it]

list index out of range


 58%|████████████████████████▎                 | 95/164 [46:33<34:27, 29.96s/it]

list index out of range


 59%|████████████████████████▌                 | 96/164 [47:00<32:51, 29.00s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 59%|████████████████████████▊                 | 97/164 [47:30<32:38, 29.24s/it]

list index out of range


 60%|█████████████████████████                 | 98/164 [47:59<32:09, 29.24s/it]

list index out of range


 60%|█████████████████████████▎                | 99/164 [48:28<31:34, 29.14s/it]

list index out of range


 61%|█████████████████████████                | 100/164 [48:58<31:15, 29.30s/it]

list index out of range


 62%|█████████████████████████▎               | 101/164 [49:27<30:57, 29.48s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 62%|█████████████████████████▌               | 102/164 [49:57<30:25, 29.44s/it]

list index out of range


 63%|█████████████████████████▊               | 103/164 [50:26<29:57, 29.47s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 63%|██████████████████████████               | 104/164 [50:56<29:32, 29.55s/it]

list index out of range


 64%|██████████████████████████▎              | 105/164 [51:26<29:07, 29.63s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 65%|██████████████████████████▌              | 106/164 [51:57<29:02, 30.04s/it]

list index out of range


 65%|██████████████████████████▊              | 107/164 [52:26<28:20, 29.83s/it]

list index out of range


 66%|███████████████████████████              | 108/164 [52:56<27:58, 29.97s/it]

list index out of range


 66%|███████████████████████████▎             | 109/164 [53:26<27:18, 29.80s/it]

list index out of range


 67%|███████████████████████████▌             | 110/164 [53:57<27:11, 30.21s/it]

list index out of range


 68%|███████████████████████████▊             | 111/164 [54:27<26:38, 30.15s/it]

list index out of range


 68%|████████████████████████████             | 112/164 [54:57<25:59, 30.00s/it]

list index out of range


 69%|████████████████████████████▏            | 113/164 [55:26<25:25, 29.91s/it]

list index out of range


 70%|████████████████████████████▍            | 114/164 [55:56<24:55, 29.90s/it]

list index out of range


 70%|████████████████████████████▊            | 115/164 [56:26<24:18, 29.77s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 71%|█████████████████████████████            | 116/164 [56:57<24:13, 30.29s/it]

list index out of range


 71%|█████████████████████████████▎           | 117/164 [57:27<23:36, 30.14s/it]

list index out of range


 72%|█████████████████████████████▌           | 118/164 [57:57<23:03, 30.09s/it]

list index out of range


 73%|█████████████████████████████▊           | 119/164 [58:27<22:34, 30.10s/it]

list index out of range


 73%|██████████████████████████████           | 120/164 [58:58<22:17, 30.40s/it]

list index out of range


 74%|██████████████████████████████▎          | 121/164 [59:29<21:49, 30.47s/it]

list index out of range


 74%|██████████████████████████████▌          | 122/164 [59:58<21:08, 30.21s/it]

list index out of range


 75%|█████████████████████████████▎         | 123/164 [1:00:28<20:28, 29.95s/it]

list index out of range


 76%|█████████████████████████████▍         | 124/164 [1:00:59<20:14, 30.37s/it]

list index out of range


 76%|█████████████████████████████▋         | 125/164 [1:01:31<19:57, 30.70s/it]

list index out of range


 77%|█████████████████████████████▉         | 126/164 [1:02:01<19:21, 30.57s/it]

list index out of range


 77%|██████████████████████████████▏        | 127/164 [1:02:32<19:01, 30.84s/it]

list index out of range


 78%|██████████████████████████████▍        | 128/164 [1:03:04<18:34, 30.97s/it]

list index out of range


 79%|██████████████████████████████▋        | 129/164 [1:03:35<18:09, 31.12s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 79%|██████████████████████████████▉        | 130/164 [1:04:08<17:56, 31.67s/it]

list index out of range


 80%|███████████████████████████████▏       | 131/164 [1:04:39<17:18, 31.48s/it]

list index out of range


 80%|███████████████████████████████▍       | 132/164 [1:05:09<16:29, 30.94s/it]

list index out of range


 81%|███████████████████████████████▋       | 133/164 [1:05:38<15:45, 30.50s/it]

list index out of range


 82%|███████████████████████████████▊       | 134/164 [1:06:08<15:06, 30.21s/it]

list index out of range


 82%|████████████████████████████████       | 135/164 [1:06:37<14:31, 30.04s/it]

list index out of range


 83%|████████████████████████████████▎      | 136/164 [1:07:07<13:56, 29.86s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 84%|████████████████████████████████▌      | 137/164 [1:07:37<13:28, 29.93s/it]

list index out of range


 84%|████████████████████████████████▊      | 138/164 [1:08:07<12:55, 29.81s/it]

list index out of range


 85%|█████████████████████████████████      | 139/164 [1:08:36<12:23, 29.75s/it]

list index out of range


 85%|█████████████████████████████████▎     | 140/164 [1:09:06<11:52, 29.70s/it]

list index out of range


 86%|█████████████████████████████████▌     | 141/164 [1:09:35<11:19, 29.56s/it]

list index out of range


 87%|█████████████████████████████████▊     | 142/164 [1:10:05<10:55, 29.80s/it]

list index out of range


 87%|██████████████████████████████████     | 143/164 [1:10:35<10:24, 29.74s/it]

list index out of range


 88%|██████████████████████████████████▏    | 144/164 [1:11:04<09:53, 29.68s/it]

list index out of range


 88%|██████████████████████████████████▍    | 145/164 [1:11:34<09:23, 29.64s/it]

list index out of range


 89%|██████████████████████████████████▋    | 146/164 [1:12:04<08:56, 29.82s/it]

list index out of range


 90%|██████████████████████████████████▉    | 147/164 [1:12:35<08:30, 30.02s/it]

list index out of range


 90%|███████████████████████████████████▏   | 148/164 [1:13:06<08:08, 30.51s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 91%|███████████████████████████████████▍   | 149/164 [1:13:37<07:39, 30.66s/it]

list index out of range


 91%|███████████████████████████████████▋   | 150/164 [1:14:07<07:04, 30.33s/it]

list index out of range


 92%|███████████████████████████████████▉   | 151/164 [1:14:38<06:37, 30.59s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 93%|████████████████████████████████████▏  | 152/164 [1:15:10<06:10, 30.87s/it]

{'pass@1': np.float64(1.0), 'pass@5': np.float64(1.0)}


 93%|████████████████████████████████████▍  | 153/164 [1:15:41<05:40, 30.93s/it]

list index out of range


 94%|████████████████████████████████████▌  | 154/164 [1:16:13<05:12, 31.27s/it]

list index out of range


# Question 2 Experiments

In [7]:
# find results where there were failures and success cases
from glob import glob
import json

# placehold for cases
cases = []

for p in glob("test_data/*.json"):
    with open(p, 'r') as f:
        # load json file
        payload = json.load(f)
        res = payload['results']

        # create relevant info dict
        info = {}
        info['prompt'] = payload['prompt']
        info['pass@1'] = res[0]['pass@1']
        info['pass@5'] = res[0]['pass@5']
        info['result'] = res[1]
        info['test'] = payload['test']

        # append model information
        if (f.name.split('_')[3] == 'llama'):
            info['model'] = 'vicuna'
        else:
            info['model'] = 'qwen'

        # append pass/fail information
        if (res[0]['pass@1'] < 1.0):
            info['status'] = 'fail'
        else:
            info['status'] = 'pass'

        # append to all cases
        cases.append(info)

In [8]:
import pandas as pd

# create results for creating tables
tabular_results = [{k:v for k,v in info.items() if k != 'result'} for info in cases]

# separate results for each model
tab_results_qwen = pd.DataFrame([res for res in tabular_results if res['model'] == 'qwen'])
tab_results_vicuna = pd.DataFrame([res for res in tabular_results if res['model'] == 'vicuna'])

# pick two failures and eight successes from for each model's results
tab_results_qwen = pd.concat([tab_results_qwen[tab_results_qwen['pass@5'] < 1][:2], tab_results_qwen[tab_results_qwen['status'] == 'pass'][:8]], axis=0)
tab_results_vicuna = pd.concat([tab_results_vicuna[tab_results_vicuna['pass@1'] < 1][:2], tab_results_vicuna[tab_results_vicuna['status'] == 'pass'][:8]], axis=0)

# reset indices
tab_results_qwen.reset_index(inplace=True, drop=True)
tab_results_vicuna.reset_index(inplace=True, drop=True)

tab_results_qwen.to_excel('qwen_cot.xlsx', index=False)
tab_results_vicuna.to_excel('vicuna_cot.xlsx', index=False)

In [9]:
tab_results_qwen

Unnamed: 0,prompt,pass@1,pass@5,test,model,status
0,Implement the python function given by this de...,0.4,0.97619,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,qwen,fail
1,Implement the python function given by this de...,0.1,0.5,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,qwen,fail
2,Implement the python function given by this de...,1.0,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,qwen,pass
3,Implement the python function given by this de...,1.0,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,qwen,pass
4,Implement the python function given by this de...,1.0,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,qwen,pass
5,Implement the python function given by this de...,1.0,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,qwen,pass
6,Implement the python function given by this de...,1.0,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,qwen,pass
7,Implement the python function given by this de...,1.0,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,qwen,pass
8,Implement the python function given by this de...,1.0,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,qwen,pass
9,Implement the python function given by this de...,1.0,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,qwen,pass


In [10]:
tab_results_vicuna

Unnamed: 0,prompt,pass@1,pass@5,test,model,status
0,Implement the python function given by this de...,0.8,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,vicuna,fail
1,Implement the python function given by this de...,0.8,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,vicuna,fail
2,Implement the python function given by this de...,1.0,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,vicuna,pass
3,Implement the python function given by this de...,1.0,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,vicuna,pass
4,Implement the python function given by this de...,1.0,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,vicuna,pass
5,Implement the python function given by this de...,1.0,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,vicuna,pass
6,Implement the python function given by this de...,1.0,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,vicuna,pass
7,Implement the python function given by this de...,1.0,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,vicuna,pass
8,Implement the python function given by this de...,1.0,1.0,import numpy as np\n\ndef is_floats(x) -> bool...,vicuna,pass
9,Implement the python function given by this de...,1.0,1.0,\n\nimport numpy as np\n\ndef is_floats(x) -> ...,vicuna,pass


In [11]:
failed_prompts_qwen = tab_results_qwen[tab_results_qwen['status'] == 'fail']['prompt']
qwen_failures = [(info['result'], info['prompt'], info['pass@1'], info['pass@5'], info['test'])
 for info in cases if info['prompt'] == failed_prompts_qwen[0] or
 info['prompt'] == failed_prompts_qwen[1]
][:2]

In [12]:
failed_prompts_vicuna = tab_results_vicuna[tab_results_vicuna['status'] == 'fail']['prompt']
vicuna_failures = [(info['result'], info['prompt'], info['pass@1'], info['pass@5'], info['test'])
 for info in cases if info['prompt'] == failed_prompts_vicuna[0] or
 info['prompt'] == failed_prompts_vicuna[1]
][:2]

In [13]:
for fail, prompt, k1, k5, _ in qwen_failures:
    for f in fail:
        print(prompt)
        print(k1)
        print(k5)
        for results in fail['0']:
            result = results[-1]
            if result['result'].startswith('failed'):
                print(result['result'])
    print()

Implement the python function given by this definition:

from typing import List


def sort_numbers(numbers: str) -> str:
    """ Input is a space-delimited string of numberals from 'zero' to 'nine'.
    Valid choices are 'zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight' and 'nine'.
    Return the string with numbers sorted from smallest to largest
    >>> sort_numbers('three one five')
    'one three five'
    """


Let's think step by step.

0.4
0.9761904761904762
failed: invalid literal for int() with base 10: 'three'
failed: invalid literal for int() with base 10: 'three'
failed: '1'
failed: invalid literal for int() with base 10: 'three'
failed: 
failed: list indices must be integers or slices, not list

Implement the python function given by this definition:

from typing import List


def parse_music(music_string: str) -> List[int]:
    """ Input to this function is a string representing musical notes in a special ASCII format.
    Your task is to parse this 

In [14]:
for fail, prompt, k1, k5, _ in vicuna_failures:
    for f in fail:
        print(prompt)
        print(k1)
        print(k5)
        for results in fail['0']:
            result = results[-1]
            if result['result'].startswith('failed'):
                print(result['result'])
    print()

Implement the python function given by this definition:

from typing import List


def filter_by_substring(strings: List[str], substring: str) -> List[str]:
    """ Filter an input list of strings only for ones that contain given substring
    >>> filter_by_substring([], 'a')
    []
    >>> filter_by_substring(['abc', 'bacd', 'cde', 'array'], 'a')
    ['abc', 'bacd', 'array']
    """


Let's think step by step.

0.5999999999999999
1.0
failed: name 'List' is not defined
failed: name 'List' is not defined
failed: name 'List' is not defined
failed: name 'List' is not defined

Implement the python function given by this definition:

from typing import List, Tuple


def rolling_max(numbers: List[int]) -> List[int]:
    """ From a given list of integers, generate a list of rolling maximum element found until given moment
    in the sequence.
    >>> rolling_max([1, 2, 3, 2, 3, 4, 2])
    [1, 2, 3, 3, 3, 4, 4]
    """


Let's think step by step.

0.8
1.0
failed: name 'List' is not defined



In [17]:
# load qwen model
ev = PassKEvaluator(MODEL_NAMES[0])

In [34]:
# first qwen prompt
results, candidates = ev.evaluate(
    [qwen_failures[0][-1]],
    qwen_failures[0][1] + ' convert the input to a list of numbers before sorting them.',
    [1,5]
)

In [23]:
print('original: ', qwen_failures[0][1])
print()
print('new: ', qwen_failures[0][1] + ' convert the input to a list of numbers before sorting them.')

original:  Implement the python function given by this definition:

from typing import List


def sort_numbers(numbers: str) -> str:
    """ Input is a space-delimited string of numberals from 'zero' to 'nine'.
    Valid choices are 'zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight' and 'nine'.
    Return the string with numbers sorted from smallest to largest
    >>> sort_numbers('three one five')
    'one three five'
    """


Let's think step by step.


new:  Implement the python function given by this definition:

from typing import List


def sort_numbers(numbers: str) -> str:
    """ Input is a space-delimited string of numberals from 'zero' to 'nine'.
    Valid choices are 'zero', 'one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight' and 'nine'.
    Return the string with numbers sorted from smallest to largest
    >>> sort_numbers('three one five')
    'one three five'
    """


Let's think step by step.
 convert the input to a list of numbers befor

In [39]:
print("new results: ", results[0]['pass@1'], results[0]['pass@5'])
print("old results: ", qwen_failures[0][2], qwen_failures[0][3])

new results:  1.0 1.0
old results:  0.4 0.9761904761904762


In [40]:
# second qwen prompt
results, candidates = ev.evaluate(
    [qwen_failures[1][-1]],
    qwen_failures[1][1] + ' Include necessary imports in the generated python code',
    [1,5]
)

In [24]:
print('original: ', qwen_failures[1][1])
print()
print('new: ', qwen_failures[1][1] + ' convert the input to a list of numbers before sorting them.')

original:  Implement the python function given by this definition:

from typing import List


def parse_music(music_string: str) -> List[int]:
    """ Input to this function is a string representing musical notes in a special ASCII format.
    Your task is to parse this string and return list of integers corresponding to how many beats does each
    not last.

    Here is a legend:
    'o' - whole note, lasts four beats
    'o|' - half note, lasts two beats
    '.|' - quater note, lasts one beat

    >>> parse_music('o o| .| o| o| .| .| .| .| o o')
    [4, 2, 1, 2, 2, 1, 1, 1, 1, 4, 4]
    """


Let's think step by step.


new:  Implement the python function given by this definition:

from typing import List


def parse_music(music_string: str) -> List[int]:
    """ Input to this function is a string representing musical notes in a special ASCII format.
    Your task is to parse this string and return list of integers corresponding to how many beats does each
    not last.

    Here is

In [42]:
print("new results: ", results[0]['pass@1'], results[0]['pass@5'])
print("old results: ", qwen_failures[1][2], qwen_failures[1][3])

new results:  0.8 1.0
old results:  0.09999999999999998 0.5


In [15]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16  # Recommended for speed
)

ev = PassKEvaluator(MODEL_NAMES[1], quantization_config=quantization_config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


In [21]:
# first vicuna prompt
results, candidates = ev.evaluate(
    [vicuna_failures[0][-1]],
    vicuna_failures[0][1] + ' Include necessary imports in the generated python code before the python function definition',
    [1,5]
)

In [25]:
print('original: ', vicuna_failures[0][1])
print()
print('new: ', vicuna_failures[0][1] + ' convert the input to a list of numbers before sorting them.')

original:  Implement the python function given by this definition:

from typing import List


def filter_by_substring(strings: List[str], substring: str) -> List[str]:
    """ Filter an input list of strings only for ones that contain given substring
    >>> filter_by_substring([], 'a')
    []
    >>> filter_by_substring(['abc', 'bacd', 'cde', 'array'], 'a')
    ['abc', 'bacd', 'array']
    """


Let's think step by step.


new:  Implement the python function given by this definition:

from typing import List


def filter_by_substring(strings: List[str], substring: str) -> List[str]:
    """ Filter an input list of strings only for ones that contain given substring
    >>> filter_by_substring([], 'a')
    []
    >>> filter_by_substring(['abc', 'bacd', 'cde', 'array'], 'a')
    ['abc', 'bacd', 'array']
    """


Let's think step by step.
 convert the input to a list of numbers before sorting them.


In [22]:
print("new results: ", results[0]['pass@1'], results[0]['pass@5'])
print("old results: ", vicuna_failures[0][2], vicuna_failures[0][3])

new results:  1.0 1.0
old results:  0.5999999999999999 1.0


In [19]:
# second vicuna prompt
results, candidates = ev.evaluate(
    [vicuna_failures[1][-1]],
    vicuna_failures[1][1] + ' Include necessary imports in the generated python code before the python function definition',
    [1,5]
)

In [26]:
print('original: ', vicuna_failures[1][1])
print()
print('new: ', vicuna_failures[1][1] + ' convert the input to a list of numbers before sorting them.')

original:  Implement the python function given by this definition:

from typing import List, Tuple


def rolling_max(numbers: List[int]) -> List[int]:
    """ From a given list of integers, generate a list of rolling maximum element found until given moment
    in the sequence.
    >>> rolling_max([1, 2, 3, 2, 3, 4, 2])
    [1, 2, 3, 3, 3, 4, 4]
    """


Let's think step by step.


new:  Implement the python function given by this definition:

from typing import List, Tuple


def rolling_max(numbers: List[int]) -> List[int]:
    """ From a given list of integers, generate a list of rolling maximum element found until given moment
    in the sequence.
    >>> rolling_max([1, 2, 3, 2, 3, 4, 2])
    [1, 2, 3, 3, 3, 4, 4]
    """


Let's think step by step.
 convert the input to a list of numbers before sorting them.


In [20]:
print("new results: ", results[0]['pass@1'], results[0]['pass@5'])
print("old results: ", vicuna_failures[1][2], vicuna_failures[1][3])

new results:  1.0 1.0
old results:  0.8 1.0
