In [1]:
import os
import json
import wandb
import torch
import transformers

from datasets import load_dataset
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [9]:
DATASETS = [
            # "minute-dataset",
            # "mini-dataset", 
            # "medium-dataset",
            "large-dataset",
            # "huge-dataset"
            ]

In [10]:
run = wandb.init(entity="jhu-llm-prompt-recovery", project="llm-prompt-recovery", job_type="hf_dataset_upload")

In [11]:
model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
tokenizer = transformers.AutoTokenizer.from_pretrained(model_id)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [12]:
for dataset in DATASETS:
    artifact = run.use_artifact(f"{dataset}:latest")

    dataset_dir = Path(artifact.download())

    simplified_dir = dataset_dir.joinpath("simplified")
    os.makedirs(simplified_dir, exist_ok=True)
    
    original_test = json.load(open(f"{str(dataset_dir)}/test.json"))
    original_train = json.load(open(f"{str(dataset_dir)}/train.json"))

    simplified_test = [{
        "original_text": data['original_text']['text'],
        "prompt": data['instruction']['prompt'],
        "rewritten_text": data['rewritten_text']
    } for data in original_test]
    json.dump(simplified_test, open(f"{str(simplified_dir)}/test.json", "w"))
    
    simplified_train = [{
        "original_text": data['original_text']['text'],
        "prompt": data['instruction']['prompt'],
        "rewritten_text": data['rewritten_text']
    } for data in original_train]

    len_validation = 0.125*len(simplified_train)
    len_train = len(simplified_train) - len_validation

    simplified_validation = simplified_train[:int(len_validation)]
    simplified_train = simplified_train[int(len_validation):]

    json.dump(simplified_train, open(f"{str(simplified_dir)}/train.json", "w"))
    json.dump(simplified_validation, open(f"{str(simplified_dir)}/validation.json", "w"))

    dname = dataset_dir.stem.split(":")[0].split("-")[0]

    dataset = load_dataset(str(simplified_dir))
    dataset.push_to_hub(f"prompt-recovery", dname)

    llama_dir = dataset_dir.joinpath("llama")
    os.makedirs(llama_dir, exist_ok=True)

    llama_train = list()

    for sample in simplified_train:
        messages = [
            {"role": "system", "content": "Find the AI prompt used to rewrite the old text into the new text."},
            {"role": "user", "content": f"Old Text: {sample['original_text']} New Text: {sample['rewritten_text']}"},
        ]

        prompt = tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
        )

        llama_train.append({
            "prompt": prompt,
            "completion": sample['prompt']
        })

    llama_validation = list()
    for sample in simplified_validation:
        messages = [
            {"role": "system", "content": "Find the AI prompt used to rewrite the old text into the new text."},
            {"role": "user", "content": f"Old Text: {sample['original_text']} New Text: {sample['rewritten_text']}"},
        ]

        prompt = tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
        )

        llama_validation.append({
            "prompt": prompt,
            "completion": sample['prompt']
        })

    llama_test = list()
    for sample in simplified_test:
        messages = [
            {"role": "system", "content": "Find the AI prompt used to rewrite the old text into the new text."},
            {"role": "user", "content": f"Old Text: {sample['original_text']} New Text: {sample['rewritten_text']}"},
        ]

        prompt = tokenizer.apply_chat_template(
                messages, 
                tokenize=False, 
                add_generation_prompt=True
        )

        llama_test.append({
            "prompt": prompt,
            "completion": sample['prompt']
        })

    json.dump(llama_train, open(f"{str(llama_dir)}/train.json", "w"))
    json.dump(llama_validation, open(f"{str(llama_dir)}/validation.json", "w"))
    json.dump(llama_test, open(f"{str(llama_dir)}/test.json", "w"))
    
    dname = dname + "-llama"

    dataset = load_dataset(str(llama_dir))
    dataset.push_to_hub(f"prompt-recovery", dname)

    instr_dir = dataset_dir.joinpath("llama")
    os.makedirs(instr_dir, exist_ok=True)

    instr_train = list()

    for sample in simplified_train:

        instr_train.append({
            "prompt": f"Find the AI prompt used to rewrite the old text into the new text.\nOld Text: {sample['original_text']}\nNew Text: {sample['rewritten_text']}",
            "completion": sample['prompt']
        })

    instr_validation = list()
    for sample in simplified_validation:
        

        instr_validation.append({
            "prompt": f"Find the AI prompt used to rewrite the old text into the new text.\nOld Text: {sample['original_text']}\nNew Text: {sample['rewritten_text']}",
            "completion": sample['prompt']
        })

    instr_test = list()
    for sample in simplified_test:

        instr_test.append({
            "prompt": f"Find the AI prompt used to rewrite the old text into the new text.\nOld Text: {sample['original_text']}\nNew Text: {sample['rewritten_text']}",
            "completion": sample['prompt']
        })

    json.dump(instr_train, open(f"{str(instr_dir)}/train.json", "w"))
    json.dump(instr_validation, open(f"{str(instr_dir)}/validation.json", "w"))
    json.dump(instr_test, open(f"{str(instr_dir)}/test.json", "w"))

    dname = dname + "-instr"

    dataset = load_dataset(str(instr_dir))
    dataset.push_to_hub(f"prompt-recovery", dname)



[34m[1mwandb[0m:   2 of 2 files downloaded.  
Generating train split: 1260 examples [00:00, 4815.22 examples/s]
Generating validation split: 180 examples [00:00, 52479.82 examples/s]
Generating test split: 360 examples [00:00, 4121.72 examples/s]
Creating parquet from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 85.04ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  1.73it/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 813.01ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  3.80it/s]
Creating parquet from Arrow format: 100%|██████████| 1/1 [00:00<00:00, 227.73ba/s]
Uploading the dataset shards: 100%|██████████| 1/1 [00:00<00:00,  2.85it/s]
Generating train split: 1260 examples [00:00, 4929.82 examples/s]
Generating validation split: 180 examples [00:00, 57724.19 examples/s]
Generating test split: 360 examples [00:00, 4080.00 examples/s]
Creating parquet from Arrow format: 100%|██████████| 2/2 [00:00<00:00, 55

In [13]:
run.finish()