In [None]:
from datasets import load_dataset

## Load humaneval dataset

In [None]:
# Load original dataset
dataset = load_dataset("openai_humaneval",split="test")
dataset = [entry for entry in dataset] # features: ['task_id', 'prompt', 'canonical_solution', 'test', 'entry_point']

In [None]:
prompts = [entry["prompt"] for entry in dataset]

## Load model

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline

model_id = "mistralai/Mistral-7B-Instruct-v0.2"

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(
    model_id,
    use_fast=False,
    trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype="auto"
)

# Create pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

#### Prompt test

## Tuned prompt generation

In [None]:
tuned_prompts = []
for i, prompt in enumerate(prompts):
    messages = [
        {"role": "system", "content": (
            "You are an expert prompt rewriter. Your task is to improve the clarity and helpfulness of docstrings "
            "for code generation. Only rewrite the docstring (the text inside triple double quotes: \"\"\" ... \"\"\"). "
            "Do not modify the function signature or write any implementation code."
        )},
        {"role": "user",   "content": f"{prompt}"},
    ]

    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    out_ids = model.generate(**tokenizer(prompt, return_tensors="pt").to(model.device),
                            max_new_tokens=1024, do_sample=True)
    
    tuned_prompts.append(tokenizer.decode(out_ids[0], skip_special_tokens=True))

### Dump results into json files

In [None]:
import json

data = [{"id": i, "prompt": p} for i, p in enumerate(tuned_prompts)]

with open("tuned_prompts.json", "w", encoding="utf-8") as f:
    json.dump(data, f, ensure_ascii=False, indent=2)

### Extract response

In [None]:
import re
pattern = re.compile(r"\[/INST\](.+)", re.DOTALL)
docstrings = []
for i, prompt in enumerate(tuned_prompts):
    matches = pattern.findall(prompt)
    if matches:
        docstrings.append(matches[0].strip())
    else:
        print("No match found.")

### Standardize docstrings by removing quotes

In [None]:
docstring_pattern = re.compile(r"['\"]{3}(.+?)['\"]{3}", re.DOTALL)

docstrings_without_quotes = []
for i, prompt in enumerate(docstrings):
    if prompt.startswith('"""') and prompt.endswith('"""'): # Standard
        # print(i)
        docstrings_without_quotes.append(prompt[3:-3].strip())  # Remove the leading and trailing triple quotes and whitespace
    elif prompt.startswith('"""'): # Missing end quotes
        # print(i)
        docstrings_without_quotes.append(prompt[3:].strip())
    elif prompt.endswith('"""'): # Missing start quotes
        # print(i)
        docstrings_without_quotes.append(prompt[:-3].strip())
    elif prompt.startswith("'''") and prompt.endswith("'''"): # Single quotes instead of double quotes
        # print(i)
        docstrings_without_quotes.append(prompt[3:-3].strip())
    elif prompt.startswith("'''"): # Missing end quotes with single quotes
        # print(i)
        docstrings_without_quotes.append(prompt[3:].strip())
    elif prompt.endswith("'''"): # Missing start quotes with single quotes
        # print(i)
        docstrings_without_quotes.append(prompt[:-3].strip())
    elif i == 10:  # Special case for the 10th prompt
        docstrings_10 = []
        docstrings_10 = docstring_pattern.findall(prompt)
        docstrings_10 = [docstring.strip() for docstring in docstrings_10]
        if docstrings_10:
            docstrings_without_quotes.append(docstrings_10)
    elif prompt.startswith('```python'):
        # Find tripple single or double quotes and append the content inbetween
        match = docstring_pattern.search(prompt)
        if match:
            docstrings_without_quotes.append(match.group(1).strip())
    else:
        print(f"Unexpected format in response {i}: {prompt}")

In [None]:
len(docstrings), len(docstrings_without_quotes)

In [None]:
# Check that all prompts contain a docstring
for i, prompt in enumerate(prompts):
    if not '"""' in prompt and not "'''" in prompt:
        print(f"No docstring found in prompt {i}.")
        continue

In [None]:
# Create a hardcopy of the dataset to modify
import copy
new_dataset = copy.deepcopy(dataset)

In [None]:
docstrings_with_quotes = [f'"""\n{docstring}\n"""' for docstring in docstrings_without_quotes]

In [None]:
docstrings_with_quotes[10] = (f'{docstrings_without_quotes[10][0]}', f'"""{docstrings_without_quotes[10][1]}"""')

In [None]:
replace_pattern = re.compile(r"['\"]{3}(.+?)['\"]{3}", re.DOTALL)

for i, entry in enumerate(new_dataset):
    # Go through each entry and replace the docstring in the prompt
    if not '"""' in entry["prompt"] and not "'''" in entry["prompt"]:
        print(f"No docstring found in prompt {i}.")
        continue
    
    if entry["prompt"].count('"""') == 2 or entry["prompt"].count("'''") == 2:
        entry["prompt"] = re.sub(replace_pattern, docstrings_with_quotes[i], entry["prompt"], count=1)
    else:
        if i == 10: # Special case for the 10th prompt
            entry["prompt"] = re.sub(replace_pattern, f'"""{docstrings_with_quotes[i][1]}"""', entry["prompt"], count=1)
            entry["prompt"] = re.sub(replace_pattern, f'"""{docstrings_with_quotes[i][0]}"""', entry["prompt"], count=1)
            print(f"Verify special case for entry {i}:\nOG:{dataset[i]['prompt']}\nNew:{entry['prompt']}")
            continue
        if entry["prompt"].count('"""') == 4 or entry["prompt"].count("'''") == 4:
            entry["prompt"] = re.sub(replace_pattern, docstrings_with_quotes[i], entry["prompt"], count=1)
            print(f"Verify special case for entry {i}:\nOG:{dataset[i]['prompt']}\nNew:{entry['prompt']}")
        print(f"Unexpected format in prompt {i}: {entry['prompt']}")

In [None]:
len(new_dataset), len(docstrings_without_quotes)

In [None]:
# Export the modified dataset to a new JSON file
import json

with open("humaneval_tuned_prompts.json", "w", encoding="utf-8") as f:
    json.dump(new_dataset, f, ensure_ascii=False, indent=2)

In [None]:
# Export the modified dataset to disk using datasets library
from datasets import Dataset
modified_dataset = Dataset.from_list(new_dataset)
modified_dataset.save_to_disk("humaneval_tuned_prompts")

In [None]:
# Load the modified dataset from disk
from datasets import load_from_disk
loaded_dataset = load_from_disk("humaneval_tuned_prompts")

In [None]:
id = 0
print(dataset[id]["prompt"])
print(new_dataset[id]["prompt"])
print(loaded_dataset[id]["prompt"])