In [1]:
!pip install bitsandbytes
!pip install datasets
!pip install accelerate
# !pip install transformers -q

import pandas as pd
from datasets import load_dataset
import random
import gc
import torch
from tqdm import tqdm
from accelerate import Accelerator
import warnings
warnings.filterwarnings('ignore')



  from .autonotebook import tqdm as notebook_tqdm


This should be able to generate about 1000 texts in 6 hours on the P100 GPU. For some reason it's far slower (about 6x slower, doesn't make sense) on the 2xT4 GPUs, have to look into that later. The catch being the text lengths are capped at 200 words right now, will increase it later at the cost of speed later. Thanks to [nbroad](https://www.kaggle.com/nbroad) for introducing me to the writingprompts [here](https://www.kaggle.com/competitions/llm-prompt-recovery/discussion/480262). 

In [2]:
from datasets import load_dataset

dataset = load_dataset("euclaise/WritingPrompts")
dataset = dataset.remove_columns(["prompt"])

In [3]:
df = pd.DataFrame(dataset["train"])

del dataset
gc.collect()

# Output token size can be very expensive, so for now we stick to small sizes. 
df = df.loc[(df["story"].apply(lambda x: len(x.split())) < 150), :]
df = df.sample(1000).reset_index(drop=True)
df["original_text"] = df["story"]
df = df.drop("story", axis=1)
display(df)

Unnamed: 0,original_text
0,As sure as the sun rises and falls into the su...
1,"Hi. my name is Luke, I should introduce to you..."
2,Wake up and turn the alarm clock off; rub my e...
3,"Your hair is brown while her hair is red, and ..."
4,"October 17, 2015 / main house \n \n I really d..."
...,...
995,She does n't -- ca n't -- see herself the way ...
996,"Jeff slammed the door on his Ford Vista shut, ..."
997,"By Jove!'T is ſuch a horrid, evil day! \n The ..."
998,Sonnet Number Seventy-Six \n \n Our expedition...


In [4]:
# Generating Random Prompts
import json

# Load prompts from JSON file
with open('data/raw/prompts.json', 'r') as file:
    prompts_data = json.load(file)

def generate_random_prompt():
    prompt = random.choice(prompts_data)
    return prompt
    

df["prompt"] = pd.Series([generate_random_prompt() for _ in range(df.shape[0])], dtype=str)
df["rewrite_prompt"] = df["prompt"] + ": \"\"\" " + df["original_text"] + "\"\"\""
display(df)


Unnamed: 0,original_text,prompt,rewrite_prompt
0,As sure as the sun rises and falls into the su...,Rewrite the opening paragraph of a mystery nov...,Rewrite the opening paragraph of a mystery nov...
1,"Hi. my name is Luke, I should introduce to you...","Rewrite the passage from George Orwell's ""1984...","Rewrite the passage from George Orwell's ""1984..."
2,Wake up and turn the alarm clock off; rub my e...,"Rewrite the passage from Shakespeare's ""Hamlet...","Rewrite the passage from Shakespeare's ""Hamlet..."
3,"Your hair is brown while her hair is red, and ...",Rewrite the classic fairy tale of Cinderella i...,Rewrite the classic fairy tale of Cinderella i...
4,"October 17, 2015 / main house \n \n I really d...",Rewrite the passage from a Victorian Gothic pe...,Rewrite the passage from a Victorian Gothic pe...
...,...,...,...
995,She does n't -- ca n't -- see herself the way ...,Rewrite this passage in the style of a Gothic ...,Rewrite this passage in the style of a Gothic ...
996,"Jeff slammed the door on his Ford Vista shut, ...",Rewrite the following passage from a gothic ho...,Rewrite the following passage from a gothic ho...
997,"By Jove!'T is ſuch a horrid, evil day! \n The ...",Rewrite the following passage in a whimsical a...,Rewrite the following passage in a whimsical a...
998,Sonnet Number Seventy-Six \n \n Our expedition...,Rewrite the passage from a modern gothic persp...,Rewrite the passage from a modern gothic persp...


In [5]:
# Set Max New Tokens to the Longest text in our data * 1.3
CMAX_NEW_TOKENS = max(df["original_text"].apply(lambda x: len(x.split()))) * 1.3
CMAX_NEW_TOKENS

193.70000000000002

In [6]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch

# Assuming you want to use a model from the HuggingFace Hub, replace the local MODEL_PATH with the model's name on the hub.
MODEL_NAME = "google/gemma-7b-it"  # Replace 'gpt2' with the actual model name you want to use from the HuggingFace Hub

quantization_config = BitsAndBytesConfig(
    load_in_4bit = True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    device_map = "auto",
    trust_remote_code = True,
    quantization_config=quantization_config,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


Loading checkpoint shards: 100%|██████████| 4/4 [00:32<00:00,  8.03s/it]


In [7]:
df['rewritten_text'] = ""

it = iter(df.iterrows())
idx, row = next(it, (None, None))

device = 'cuda'

MAX_NEW_TOKENS = CMAX_NEW_TOKENS

pbar = tqdm(total=df.shape[0])

while idx is not None:
    torch.cuda.empty_cache()
    gc.collect()
    
    try:
        
        encoded_input = tokenizer(row["rewrite_prompt"], return_tensors="pt").to(device)

        with torch.no_grad():
            encoded_output = model.generate(**encoded_input, max_new_tokens=MAX_NEW_TOKENS, do_sample=True)

        decoded_output = tokenizer.decode(encoded_output[0], skip_special_tokens=True).replace(row["rewrite_prompt"], '')
    
        df.loc[idx, "rewritten_text"] = decoded_output
        
        idx, row = next(it, (None, None))
        pbar.update(1)
        
        # If working again, reset back to max
        MAX_NEW_TOKENS = CMAX_NEW_TOKENS
        
        del encoded_input, decoded_output, encoded_output
        
        # Backing up so we don't lose current progress
        df.to_csv("data.csv", index=False)
        
    except Exception as e:
        print("ERROR: ", e)
        # If we are failing due to memory isses, halve the output size and retry
        MAX_NEW_TOKENS /= 2
    

pbar.close()

  1%|          | 6/1000 [02:14<6:10:04, 22.34s/it]

In [None]:
df.to_csv("data.csv", index=False)
df

Unnamed: 0,original_text,prompt,rewrite_prompt,rewritten_text
0,I get it now. You have to give to get. Where's...,Convey the same message as this text but throu...,Convey the same message as this text but throu...,\n\n**Rewritten from the perspective of an ali...
1,Atlas was not seen to be special by any means....,Restyle this text as if it were written by a A...,Restyle this text as if it were written by a A...,"\n\nSure, here is the text rewritten as if it ..."
2,Gleaming eyes shining in the dark. \n Visions ...,Adapt this text as a script for a wizard in a ...,Adapt this text as a script for a wizard in a ...,\n\n**Script:**\n\n(A smoky bar in the heart o...
3,"Remember, there's security cameras at every co...",Convey the same message as this text but throu...,Convey the same message as this text but throu...,"\n\nNow, rewrite the text as if it is being na..."
4,Waves \n 17 Septillion ships per cloud \n Like...,Imagine this text was a villain in the world o...,Imagine this text was a villain in the world o...,\n\nIf written as a villain in the world of vi...
...,...,...,...,...
995,I knew I only had 10 seconds to change things....,Rewrite this text in the style of a philosophi...,Rewrite this text in the style of a philosophi...,\n\n## The Knight's Tale of Temporal Flux:\n\n...
996,The rebels ran into the building lining up in ...,Restyle this text as if it were written by a p...,Restyle this text as if it were written by a p...,"\n\nSure, here is the text rewritten as if it ..."
997,"Jeanne loomed over the coffin, a great nagging...",Restyle this text as if it were written by a p...,Restyle this text as if it were written by a p...,"\n\nSure, here is the text rewritten as if it ..."
998,The officers examined the graphic scene. The v...,Translate the essence of this text into a pira...,Translate the essence of this text into a pira...,"\n\n**Pirate Narrative:**\n\nAvast, me heartie..."
