In [1]:
import os
import json
import torch
import wandb
import random
import pandas as pd

from datetime import datetime
from dotenv import load_dotenv
from itertools import combinations
from tqdm.autonotebook import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

load_dotenv()

  from tqdm.autonotebook import tqdm


In [22]:
JOB_TYPE = "Dataset_Generation"
CURR_DATE_TIME = datetime.now().strftime("%Y%m%d_%H%M%S")

# randomly sample instructions
random.seed(42)

NUM_INSTRUCTIONS = 20
NUM_ORIGINAL_TEXTS = 100

# TODO: No. of instructions to sample from the dataset ()

In [3]:
run = wandb.init(entity="jhu-llm-prompt-recovery", project="llm-prompt-recovery", job_type="upload-dataset", name=f"{JOB_TYPE}_{CURR_DATE_TIME}")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mnrishabh[0m ([33mjhu-llm-prompt-recovery[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [13]:
artifact = run.use_artifact("instruction-prompts-dataset:latest")
artifact_dir = artifact.download()
with open(os.path.join(artifact_dir, os.listdir(artifact_dir)[0]), "r") as f:
    prompts = json.load(f)

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [10]:
DATASET_NAME = "email-dataset"
artifact = run.use_artifact(f"{DATASET_NAME}:latest")
artifact_dir = artifact.download()
with open(os.path.join(artifact_dir, os.listdir(artifact_dir)[0]), "r") as f:
    og_text = json.load(f)

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [14]:
rewritten_text_dataset = list()

In [23]:
# sample instructions
sampled_instructions = random.sample(prompts, NUM_INSTRUCTIONS)

# sample original texts
sampled_original_texts = random.sample(og_text, NUM_ORIGINAL_TEXTS)

# generate all possible combinations of instructions and original texts
# each combination will have one instruction and one original text
instruction_original_text_combinations = list(combinations([(instruction, original_text) for instruction in sampled_instructions for original_text in sampled_original_texts], 1))

In [26]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-1.1-7b-it")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-1.1-7b-it", torch_dtype=torch.bfloat16
)

Downloading shards: 100%|██████████| 4/4 [06:43<00:00, 100.79s/it]
Loading checkpoint shards: 100%|██████████| 4/4 [00:42<00:00, 10.51s/it]


In [31]:
# change instruction_original_text_combinations from a 2d tuple to 1d tuple
instruction_original_text_combinations = [combination[0] for combination in instruction_original_text_combinations]

In [38]:
for instruction, original_text in tqdm(instruction_original_text_combinations):

    instruction_id = instruction['id']
    instruction_text = instruction['prompt']

    original_text_id = original_text['id']
    original_text_text = original_text['text']

    message = f"{instruction_text} {original_text_text}"

    input_ids = tokenizer(message, return_tensors="pt")

    outputs = model.generate(**input_ids, max_new_tokens=1000)

    output = tokenizer.decode(outputs[0], skip_special_tokens=True)

    rewritten_text_dataset.append({
        "instruction_id": instruction_id,
        "instruction_text": instruction_text,
        "original_text_id": original_text_id,
        "original_text_text": original_text_text,
        "rewritten_text": output
    })


  0%|          | 0/2000 [00:27<?, ?it/s]


KeyboardInterrupt: 

In [37]:
rewritten_text_dataset

[]

In [33]:
tokenizer.default_chat_template


No chat template is defined for this tokenizer - using a default chat template that implements the ChatML format (without BOS/EOS tokens!). If the default is not appropriate for your model, please set `tokenizer.chat_template` to an appropriate template. See https://huggingface.co/docs/transformers/main/chat_templating for more information.



"{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"

In [None]:
# write to json
with open(f"{DATASET_NAME}_{CURR_DATE_TIME}.json", "w") as f:
    json.dump(rewritten_text_dataset, f)


# upload to wandb
artifact = wandb.Artifact(f"{DATASET_NAME}_{CURR_DATE_TIME}", type="rewritten-texts-dataset")
artifact.add_file(f"{DATASET_NAME}_{CURR_DATE_TIME}.json")
run.log_artifact(artifact)

In [None]:
run.finish()
wandb.finish()