In [None]:
import pandas as pd
import re
import copy

In [None]:
forum_messages_df = pd.read_csv('../data/ForumMessages.csv')
forum_messages_df.head()
#print(forum_messages_df['Message'][:2][0])
#print(forum_messages_df['Message'])

In [None]:
wikipedia_movie_plots_df = pd.read_csv('../data/wiki_movie_plots_deduped.csv')
wikipedia_movie_plots_df.head()
#print(wikipedia_movie_plots_df['Plot'])

In [None]:
# Function to remove HTML tags from a string
def remove_html_tags(text):
    if isinstance(text, str):
        # Regular expression for finding HTML tags
        clean = re.compile('<.*?>')
        # Replacing the tags with an empty string
        return re.sub(clean, '', text)
    else:
        return ''

forum_messages_df['Message'] = forum_messages_df['Message'].apply(lambda x: remove_html_tags(x))

forum_messages_df.head()

In [None]:
original_texts = pd.concat([forum_messages_df['Message'],wikipedia_movie_plots_df['Plot']], ignore_index=True)
original_texts_wiki_plot = copy.deepcopy(wikipedia_movie_plots_df['Plot'])
original_texts_wiki_plot

In [None]:
rewrite_prompts = [
    'Explain this to me like I\'m five.',
    'Convert this into a sea shanty.',
    'Make this rhyme.',
    'Make this shorter.',
    'Make this longer.',
    'Make this more detailed.',
    'Rewrite this essay but do it using the writing style of Dr. Seuss',
    'Rewrite this essay but do it using the writing style of William Shakespeare',
    'Rewrite this essay but do it using the writing style of Tupac Shakur',
    'Make this a haiku.',
    'Make this into a poem.',
    'Turn this into a sonnet.',
    'Summarize this.',
    'Give me the highlights.',
    'Rewrite this essay using the writing style of Jane Austen.',
    'Rewrite this essay with the terse, straightforward prose and understated tone characteristic of Ernest Hemingway\'s works, focusing on clarity and the power of simple statements.',
    'Transform this piece to reflect Virginia Woolf\'s stream-of-consciousness style, focusing on the psychological depth and introspective nature of her characters.',
    'Rephrase this essay using Mark Twain\'s distinctive humor and satirical edge, capturing his unique perspective on American society and culture.',
    'Revise this text to mirror J.K. Rowling\'s engaging narrative style, blending magical elements with a touch of mystery and a strong sense of moral integrity.',
    'Recreate this content with George Orwell\'s clear, direct language and his propensity for exploring themes of social injustice and authoritarianism.',
    'Convert this text into a form that evokes Edgar Allan Poe\'s gothic style, focusing on the macabre, the mysterious, and the psychological depth of his narratives.'
]

In [None]:
!pip install -q -U immutabledict sentencepiece
!git clone https://github.com/google/gemma_pytorch.git

In [None]:
import sys
sys.path.append("gemma_pytorch")
from gemma_pytorch.gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM
from gemma_pytorch.gemma.tokenizer import Tokenizer
import contextlib
import os
import torch

VARIANT = "7b-it-quant"
MACHINE_TYPE = "cuda"
weights_dir = '../models/gemma-7b-it'

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)
    
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")
model_config.quant = "quant" in VARIANT

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
  model = GemmaForCausalLM(model_config)
  ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
  model.load_weights(ckpt_path)
  model = model.to(device).eval()

In [None]:
from tqdm import tqdm
import random
random.seed(0)
USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"

rewrite_data = []

for original_text in tqdm(original_texts_wiki_plot):
    rewrite_prompt = random.choice(rewrite_prompts)
    prompt = f'{rewrite_prompt}\n{original_text}'
    rewritten_text = model.generate(
        USER_CHAT_TEMPLATE.format(prompt=prompt),
        device=device,
        output_len=100,
    )
    rewrite_data.append({
        'original_text': original_text,
        'rewrite_prompt': rewrite_prompt,
        'rewritten_text': rewritten_text,
    })

In [None]:
rewrite_data_df = pd.DataFrame(rewrite_data)
#rewrite_data_df[:20].values

In [None]:
rewrite_data_df.to_csv('prompts_and_rewrites.csv', index=False)