In [1]:
#Notebook for assembling different training/val sets of crowdsourced and self-made rewrites from crowdsourced and self/ChatGPT-made prompts. 

import numpy as np 
import pandas as pd 
import os
import gc
import re
import time
import random
from tqdm.auto import tqdm

tqdm.pandas()
pd.set_option('display.max_rows',30)
pd.set_option('display.max_columns',5)
pd.set_option('display.max_colwidth',None)

In [2]:
#Make train/test splits

from sklearn.model_selection import train_test_split
def clean_and_split_data(df, prefix, maxpromptlen, maxtextlen, maxsize=1000, test_size=0.2, val_frac=0.0):
    df = df[~df['rewrite_prompt'].str.contains('[^\x00-\x7F]+', na=False)]
    df['rewrite_prompt'] = df['rewrite_prompt'].apply(lambda x: re.sub(r'[\.\?\!]+\s*$', '', re.sub(r'\s+([?.!,\':;])', r'\1', re.sub(r'\s+', ' ', x.strip()))))
    df = df[df['rewrite_prompt'].str.len()<=maxpromptlen]
    df = df[df['original_text'].str.len()<=maxtextlen]

    df = df.sample(frac=1).reset_index(drop=True)#randomly reorder rows
    df = df[:min(maxsize,len(df))]
    if test_size == 0.0: #no test set
        df.to_csv(prefix+'.csv', index=False)
        return (len(df),0)
    if val_frac > 0.0: #make separate val and test sets
        df_train, df_test1 = train_test_split(df, test_size=test_size, random_state=42)
        df_val, df_test = train_test_split(df, test_size=test_size, random_state=42)
        df_test.to_csv(prefix+'_test.csv', index=False)
    else:
        df_train, df_val = train_test_split(df, test_size=test_size, random_state=42)
    df_train.to_csv(prefix+'_train.csv', index=False)
    df_val.to_csv(prefix+'_val.csv', index=False)
    return (len(df_train),len(df_val))

In [None]:
#Make crowdsourced dataset 1

df_nbv2 = pd.read_csv("/kaggle/input/gemma-rewrite-nbroad/nbroad-v2.csv")
print(df_nbv2.head(1))
print("Length of dataset:",len(df_nbv2))
#remove rows with non-ascii characters
df_nbv2 = df_nbv2[~df_nbv2['rewrite_prompt'].str.contains('[^\x00-\x7F]+', na=False)]
df_nbv2['rewrite_prompt'] = df_nbv2['rewrite_prompt'].apply(lambda x: re.sub(r'[\.\?\!]+\s*$', '', re.sub(r'\s+([?.!,\':;])', r'\1', re.sub(r'\s+', ' ', x.strip()))))
arr=sorted(df_nbv2["rewrite_prompt"].unique())
print("Number of unique prompts:",len(arr))
maxpromptlen=80
maxtextlen=4000
arr = [x for x in arr if len(x)<=maxpromptlen]
print("Number of unique prompts with len<=maxpromptlen:",len(arr))
print("Number of orignal texts with len<=maxtextlen:",len(df_nbv2[df_nbv2['original_text'].str.len()<=3000]))
print("Number of orignal texts with len<=maxtextlen and len<=80:",len(df_nbv2[(df_nbv2['original_text'].str.len()<=maxtextlen) & (df_nbv2['rewrite_prompt'].str.len()<=maxpromptlen)]))
#for l in arr: print(l)
#print(df_nbv2.iloc[5:10,:])
(n_train,n_val) = clean_and_split_data(df_nbv2,"crowdsourced_dataset_1",maxpromptlen,maxtextlen,maxsize=1000)
print(n_train,n_val)

In [None]:
#Make crowdsourced dataset 2

df_gem70k = pd.read_csv("/kaggle/input/70k-prompt-rewrite-triples/70k_gemma_template_built.csv")
print(df_gem70k.head(1))
#rename generated_text column to 'rewritten_text', 'prompt' to 'rewrite_prompt', drop 'input' column, and add id column
df_gem70k = df_gem70k.rename(columns={'generated_text':'rewritten_text','prompt':'rewrite_prompt'})
df_gem70k = df_gem70k.drop(columns=['input'])
df_gem70k['id'] = df_gem70k.index
print("Length of dataset:",len(df_gem70k))
print("Number of unique prompts:",len(df_gem70k['rewrite_prompt'].unique()))
maxpromptlen=80
maxtextlen=4000
print("Number of unique prompts with len<=maxpromptlen:",len(df_gem70k[df_gem70k['rewrite_prompt'].str.len()<=maxpromptlen]['rewrite_prompt'].unique()))
print("Number of orignal texts with len<=maxtextlen:",len(df_gem70k[df_gem70k['original_text'].str.len()<=3000]))
print("Number of orignal texts with len<=maxtextlen and maxpromptlen<=80:",len(df_gem70k[(df_gem70k['original_text'].str.len()<=maxtextlen) & (df_gem70k['rewrite_prompt'].str.len()<=maxpromptlen)]))
#for l in arr: print(l)
#print(df_nbv2.iloc[5:10,:])
(n_train,n_val) = clean_and_split_data(df_gem70k,"crowdsourced_dataset_2",maxpromptlen,maxtextlen,maxsize=1000)
print(n_train,n_val)
(n_train,n_val) = clean_and_split_data(df_gem70k,"crowdsourced_dataset_2_big",maxpromptlen,maxtextlen,maxsize=10000,test_size=0.02)
print(n_train,n_val)

In [None]:
#Generate custom dataset: 1) Come up with some prompt prefixes

prefix_verbs = ['Adapt',
                'Change',
                'Convert',
                'Craft',
                'Express',
                'Frame',
                'Present',
                'Recast',
                'Recreate',
                'Reformulate',
                'Rephrase',
                'Rewrite',
                'Style',
                'Transform',
                'Write']
prefix_nouns = ['it', 
                'this', 
                'this passage', 
                'this story' ,
                'this text', 
                'the passage', 
                'the story' ,
                'the text', 
                'the following passage', 
                'the following story' ,
                'the following text', 
                'the following']
prefix_preps = ['as', 'into', 'to be', 'to be like', 'to be more like']


In [None]:
#Generate custom dataset: 2a) Come up with some prompt rewrite styles (with ChatGPT's help)

suffix_vars = ['a late-night infomercial script',
               'a dramatic play',
               'a series of haikus',
               'a business proposal',
               'a sci fi setting',
               'a case study for a successful project',
               'a sitcom script scene',
               'a 1920s jazz song',
               'a Shakespearean sonnet',
               'a Ted Talk',
               "a cheery children's book",
               'a classic rock anthem',
               'a reality TV show plot',
               'a recipe',
               'a heartfelt eulogy',
               'a romance',
               'a poem',
               'a noir detective story',
               'an ancient prophecy',
               'an exchange between genie and its master',
               'an online dating profile',
               'an interview'
               "an old sailor's sea shanty",
               'a philosophical debate',
                'a personal diary entry',
                'a political speech',
                'a wartime correspondence',
                'a gothic novel',
                'a travel guide description',
                'a superhero comic book',
                'a silent film screenplay',
                'a tech startup pitch',
                'a meditation guide',
                'a sports commentary',
                'a high fantasy tale',
                'a detective noir monologue',
                'an epic poem',
                'a steamy love letter',
                'a horror movie script',
                'a mockumentary scene',
                'a self-help book snippet',
                'a dystopian short story',
                'an opera libretto',
                'a slapstick comedy script',
                'a cyberpunk narrative',
                'a classical myth retelling',
                'a spy thriller',
                'an academic lecture',
                'a ceremonial speech',
                'a radio drama script',
                'a fast-paced thriller',
                'a historical biography',
                'a philosophical essay',
                'a cold war espionage novel',
                'a documentary script',
                'a medieval ballad',
                'an absurdist play',
                'a technical manual',
                'a beat poetry reading',
                'an architectural critique',
                'a culinary review',
                'a young adult dystopia',
                'a rags-to-riches story',
                'a motivational speech',
                'a minimalist short story',
                'a magical realism narrative',
                'a courtroom drama',
                'a pulpy adventure novel',
                'a satirical article',
                'a fashion magazine feature',
                'a tragic love story',
                'a survivalist’s handbook',
                'a cryptic crossword puzzle',
                'a Victorian ghost story',
                'a celebrity interview',
                'a surrealist painting description',
                'an urban legend recount',
                'a public service announcement',
                'a speculative science essay',
                'a slapstick comedy routine',
                'a post-apocalyptic journal',
                'a cosmic horror story',
                'a puppet show script',
                'a sports play-by-play',
                'a caper story outline',
                'a confessional poem',
                'a folk tale retelling',
                'a limerick sequence',
                'a viral marketing campaign',
                'a board game rulebook',
                'a military strategy analysis',
                'a religious sermon',
                'an auctioneer’s chant',
                'a wilderness survival guide',
                'a jazz improvisation description',
                'a graphic novel panel',
                'a fantasy epic prologue',
                'a nature documentary narration',
                'a video game storyline',
                'an art exhibition review',
                'a celebrity roast script',
                'a cyber security briefing']

In [None]:
#Generate custom dataset: 2b) Filter out too semantically similar prompt rewrite styles

#https://colab.research.google.com/drive/1GH8PW9-zAe4cXEZyOIE-T9uHXblIldAg?usp=sharing
!pip install -Uq transformers sentence_transformers faiss-gpu
from sentence_transformers import SentenceTransformer
import faiss
from tqdm.autonotebook import tqdm

def deduplicate_prompts(prompts: list, model: str, threshold: float):
    sentence_model = SentenceTransformer(model)

    print("Converting text to embeddings...")
    embeddings = sentence_model.encode(prompts, show_progress_bar=True)
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatIP(dimension)
    normalized_embeddings = embeddings / np.linalg.norm(embeddings, axis=1, keepdims=True)
    index.add(normalized_embeddings)

    print("Filtering out near-duplicates...")
    D, I = index.search(normalized_embeddings, k=2)
    to_keep = []

    for i in tqdm(range(len(embeddings)), desc="Filtering"):
        # If the second closest vector (D[i, 1]) has cosine similarity above the threshold
        if D[i, 1] >= threshold:
            # Check if either the current item or its nearest neighbor is already in the to_keep list
            nearest_neighbor = I[i, 1]
            if i not in to_keep and nearest_neighbor not in to_keep:
                # If not, add the current item to the list
                to_keep.append(i)
        else:
            # If the similarity is below the threshold, always keep the current item
            to_keep.append(i)

    return to_keep

to_keep = deduplicate_prompts(suffix_vars, "thenlper/gte-large", 0.9)
df_unique = pd.DataFrame(np.array(suffix_vars)[to_keep], columns=['prompt'])
print(len(df_unique), len(suffix_vars))
df_unique.to_csv("deduped_prompts.csv", index=False)

In [3]:
#Generate custom dataset: 2c) Compute prompt rewrite style probability (to filter for quality; probably overkill)

#load llama-13B
import numpy as np 
import pandas as pd 
import torch
import os
import gc
import random
import time
from kaggle_secrets import UserSecretsClient
hf_access_token = UserSecretsClient().get_secret("HF_AUTH_TOKEN") 
!pip install -U bitsandbytes
!pip install -U transformers
!pip install -U accelerate
!pip install optimum
import optimum
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from torch import cuda, bfloat16
gc.collect()
torch.cuda.empty_cache()
modelName = "meta-llama/Llama-2-13b-hf"
tokenizer = AutoTokenizer.from_pretrained(modelName, token=hf_access_token)
quantization_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,bnb_4bit_use_double_quant=True)#,llm_int8_enable_fp32_cpu_offload=True)
model = AutoModelForCausalLM.from_pretrained(modelName, quantization_config=quantization_config, device_map={"": 0}, token=hf_access_token)#on T4
#https://huggingface.co/docs/transformers/generation_strategies#default-text-generation-configuration
print(model.generation_config)
model.config.use_cache = False
model.config.pretraining_tp = 1

def compute_sent_prob(sent):
    inputs = tokenizer.encode(sent, return_tensors='pt')
    outputs = model(inputs, labels=inputs)
    loss = outputs.loss
    probability = torch.exp(-loss).item() # The negative loss is the log-likelihood
    return probability
    
#compute this for each prompt in df_prompts, then write to file
df_unique=pd.read_csv("deduped_prompts.csv")
df_unique['prompt_prob'] = df_unique.iloc[:, 0].progress_apply(compute_sent_prob)
print(df_unique.sort_values(by='prompt_prob', ascending=True)[0:10])
df_unique.to_csv("deduped_prompt_probs.csv", index=False)

Collecting bitsandbytes
  Downloading bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl.metadata (2.2 kB)
Downloading bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl (119.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.8/119.8 MB[0m [31m13.7 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.43.1
Collecting transformers
  Downloading transformers-4.40.1-py3-none-any.whl.metadata (137 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m138.0/138.0 kB[0m [31m6.3 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers<0.20,>=0.19 (from transformers)
  Downloading tokenizers-0.19.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading transformers-4.40.1-py3-none-any.whl (9.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.0/9.0 MB[0m [31m87.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDo

tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/610 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/33.4k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/9.95G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/9.90G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/6.18G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

GenerationConfig {
  "bos_token_id": 1,
  "do_sample": true,
  "eos_token_id": 2,
  "max_length": 4096,
  "pad_token_id": 0,
  "temperature": 0.6,
  "top_p": 0.9
}



In [None]:
#Generate custom dataset: 3) Combine prompt parts into custom prompts

rephrase_wording_prompts=[]
for suffix_var in np.array(suffix_vars)[to_keep]:
    outstr = suffix_var
    prepidxs = random.sample(range(len(prefix_preps)), 2)
    for p in prepidxs:
        nounidxs = random.sample(range(len(prefix_nouns)), 3)
        for n in nounidxs:
            verbidxs = random.sample(range(len(prefix_verbs)), 4)
            for v in verbidxs:
                outstr = prefix_verbs[v] + ' ' + prefix_nouns[n] + ' ' + prefix_preps[p] + ' ' + suffix_var
                rephrase_wording_prompts.append(outstr)
print(len(rephrase_wording_prompts))
df_rephrase_wording_prompts=pd.DataFrame(rephrase_wording_prompts, columns=['rewrite_prompt'])
df_rephrase_wording_prompts = df_rephrase_wording_prompts.sample(frac=1).reset_index(drop=True)
print(df_rephrase_wording_prompts.head())
df_rephrase_wording_prompts.to_csv("custom_rewrite_prompts.csv", index=False)

In [None]:
#Generate custom dataset: 4) Assemble various sources for "original texts" to be rewritten

import sqlite3
conn = sqlite3.connect('/kaggle/input/wikibooks-dataset/wikibooks.sqlite')
cursor = conn.cursor()
cursor.execute("SELECT * from en")
raw_eng_text = cursor.fetchall()
cursor.execute(f"PRAGMA table_info(en);")
column_names = cursor.fetchall()
column_names = [column[1] for column in column_names]
df_eng_text = pd.DataFrame(raw_eng_text, columns=column_names)
s_abs=df_eng_text[df_eng_text["abstract"].apply(lambda x: len(x) >= 300)]["abstract"]
print(len(s_abs))
lengths = s_abs.apply(len)
print(f"Wikibooks Abstracts Max length: {lengths.max()}, Min length: {lengths.min()}, Mean length: {lengths.mean()}")

df_em = pd.read_csv("/kaggle/input/emotions/text.csv")
s_em=df_em[df_em["text"].apply(lambda x: len(x) >= 280)]["text"]
lengths = s_em.apply(len)
print(f"Emotions Length: {len(s_em)}, Max length: {lengths.max()}, Min length: {lengths.min()}, Mean length: {lengths.mean()}")

df_tds = pd.read_csv("/kaggle/input/1300-towards-datascience-medium-articles-dataset/medium.csv")
s_tds=df_tds[df_tds["Text"].apply(lambda x: len(x) <= 2000)]["Text"]
lengths = df_tds.apply(len)
print(f"TDS Length: {len(df_tds)}, Max length: {lengths.max()}, Min length: {lengths.min()}, Mean length: {lengths.mean()}")

df_rev = pd.read_csv("/kaggle/input/singapore-airlines-reviews/singapore_airlines_reviews.csv")
s_rev=df_rev[df_rev["text"].apply(lambda x: len(x) <= 1000 and len(x) >= 300)]["text"]
lengths = s_rev.apply(len)
print(f"Plane Reviews Length: {len(s_rev)}, Max length: {lengths.max()}, Min length: {lengths.min()}, Mean length: {lengths.mean()}")

df_movies=pd.read_csv("/kaggle/input/wikipedia-movie-plots/wiki_movie_plots_deduped.csv")
print(df_movies.head(1))
print(len(df_movies))
print(df_movies['Plot'].str.len().mean())
#number of rows with less than 1000 characters
print(len(df_movies[(df_movies['Plot'].str.len() <= 2000) & (df_movies['Plot'].str.len() > 500)]))
df_short = df_movies[(df_movies['Plot'].str.len() <= 2000) & (df_movies['Plot'].str.len() > 500)][['Plot']]
# Rename the 'Plot' column to 'original_text' in the new DataFrame
df_short.rename(columns={'Plot': 'original_text'}, inplace=True)
print(df_short.head(1))
df_short.to_csv("movie_plots.csv", index=False)
s_plots = df_short['original_text']

df_mixed = pd.concat([s_tds, s_plots, s_abs, s_em, s_rev], ignore_index=True)
df_mixed = df_mixed.sample(frac=1).reset_index(drop=True)#randomly reorder rows
df_mixed.columns = ['original_text']
df_mixed.to_csv('mixed.csv', index=False)

In [None]:
#Generate custom dataset: 5) Send prompts and texts through gemma-7 to get rewrites, then create training and validation sets 

###https://colab.research.google.com/drive/1UoQeGoXulgO7daG7E83H_-UWZ3UHucks
maxpromptlen=80
maxtextlen=4000
df_cust = pd.read_csv("/kaggle/input/gemma7brewrites/gemma7b_custom_prompts_rewrites.csv")
(n_train,n_val) = clean_and_split_data(df_cust,"custom_dataset",maxpromptlen,maxtextlen,maxsize=1000)
print(n_train,n_val)

In [3]:
#Create mixed dataset to test diversity impact

df_cust = pd.read_csv("/kaggle/working/custom_dataset_train.csv")
df_cd1 = pd.read_csv("/kaggle/working/crowdsourced_dataset_1_train.csv")
df_cd2 = pd.read_csv("/kaggle/working/crowdsourced_dataset_2_train.csv")
df_cd = pd.concat([df_cust[:267],df_cd1[:267],df_cd2[:266]],axis=0)
df_cd.to_csv("mixed_dataset_train.csv",index=False)
df_cust = pd.read_csv("/kaggle/working/custom_dataset_val.csv")
df_cd1 = pd.read_csv("/kaggle/working/crowdsourced_dataset_1_val.csv")
df_cd2 = pd.read_csv("/kaggle/working/crowdsourced_dataset_2_val.csv")
df_cd = pd.concat([df_cust[:67],df_cd1[:67],df_cd2[:66]],axis=0)
df_cd.to_csv("mixed_dataset_val.csv",index=False)

In [None]:
#Generate ood dataset for testing

maxpromptlen=80
maxtextlen=1000
df_ood=pd.read_csv("/kaggle/input/data-from-starter/gemma1000_w7b.csv/gemma1000_w7b.csv")
print(df_ood.head(1))
df_ood = df_ood.drop(columns=['rewritten_text', 'prompt', 'gemma_7b_rewritten_text_temp0_prefix_removed'])
df_ood = df_ood.rename(columns={'gemma_7b_rewritten_text_temp0':'rewritten_text'})
(n_train,n_val) = clean_and_split_data(df_ood,"ood_dataset",maxpromptlen,maxtextlen,maxsize=1000,test_size=0)
print(n_train,n_val)

In [None]:
# Put together big set for training of final model

df_cust = pd.read_csv("/kaggle/input/gemma7brewrites/gemma7b_custom_prompts_rewrites.csv")
print("Cust: ",df_cust.columns)
print(len(df_cust))
df_gem70k = pd.read_csv("/kaggle/input/70k-prompt-rewrite-triples/70k_gemma_template_built.csv")
df_gem70k = df_gem70k.rename(columns={'generated_text':'rewritten_text','prompt':'rewrite_prompt'})
df_gem70k = df_gem70k.drop(columns=['input'])
print("Gem70k: ",df_gem70k.columns)
print(len(df_gem70k))
df_gem1k = pd.read_csv("/kaggle/input/data-from-starter/gemma1000_w7b.csv/gemma1000_w7b.csv")
df_gem1k = df_gem1k.drop(columns=['rewritten_text', 'prompt', 'gemma_7b_rewritten_text_temp0_prefix_removed'])
df_gem1k = df_gem1k.rename(columns={'gemma_7b_rewritten_text_temp0':'rewritten_text'})
print("Gem1k: ",df_gem1k.columns)
print(len(df_gem1k))
df_nbv2 = pd.read_csv("/kaggle/input/gemma-rewrite-nbroad/nbroad-v2.csv")
print("NBV2: ",df_nbv2.columns)
print(len(df_nbv2))




In [None]:
#Screen for quality

#1) Get rid of unnatural/implausible prompts

df_nbv2['prompt_prob'] = df_nbv2['rewrite_prompt'].progress_apply(compute_sent_prob)
df_nbv2.sort_values(by=['prompt_prob'], ascending=True).head(50)[['rewrite_prompt', 'prompt_prob']]
#remove prompts with prob <0.01, or which contain multiple lines
df_nbv2 = df_nbv2[df_nbv2['prompt_prob'] > 0.01]
df_nbv2 = df_nbv2[df_nbv2['rewrite_prompt'].str.count('\n') == 0]
df_nbv2.to_csv("nbv2_prob.csv", index=False)


#2) Get rid of prompts that led to too little change, or too much

!pip install -Uq transformers sentence_transformers 
from sentence_transformers import SentenceTransformer
from tqdm.autonotebook import tqdm

def scs(s: np.ndarray, k: np.ndarray, p: int = 3, q: float = 1e-6):
    dp = np.dot(s, k)
    cosine_sim = abs(dp / ((np.linalg.norm(s) + q) * np.linalg.norm(k)))
    score = np.sign(dp) * (cosine_sim ** p)
    return score

def compute_dists(df: pd.DataFrame, model: str):
    sentence_model = SentenceTransformer(model)

    print("Converting text to embeddings...")
    embeddings1 = sentence_model.encode(df['original_text'].tolist(), show_progress_bar=True)
    embeddings2 = sentence_model.encode(df['rewritten_text'].tolist(), show_progress_bar=True)

    dists = []
    for i in tqdm(range(len(embeddings1)), desc="Computing distances"):
        dists.append(scs(np.array(embeddings1[i]), np.array(embeddings2[i])))

    return dists

df_plot_scored=pd.read_csv("/kaggle/working/gem7plotrewrites.csv")
df_plot_scored['scs'] = compute_dists(df_plot_scored, "thenlper/gte-large")
df_plot_scored = df_plot_scored.sort_values(by=['scs'], ascending=False)
df_plot_scored.to_csv("gem7plotrewrites_scored.csv", index=False)

In [None]:
#choose prompts whose (filtered) similarity between original and rewritten isn't too high
df_mixed_scored = pd.read_csv("/kaggle/working/gem7mixedrewrites_scored.csv")
df_plot_scored = pd.read_csv("/kaggle/working/gem7plotrewrites_scored.csv")
df_mixed_scored = df_mixed_scored[~df_mixed_scored['rewritten_text'].str.startswith("I am unable to ")]
df_plot_scored = df_plot_scored[~df_plot_scored['rewritten_text'].str.startswith("I am unable to ")]

df_mixed_scored=df_mixed_scored[df_mixed_scored['scs']<0.97]
df_plot_scored=df_plot_scored[df_plot_scored['scs']<0.97]
df_filtered_prompts = pd.concat([df_mixed_scored['rewrite_prompt'], df_plot_scored['rewrite_prompt']], ignore_index=True)
df_filtered_prompts.drop_duplicates(keep='first', inplace=True)

#now combine our generated datasets and filter down to where the prompts are in the filtered set
df_mixed = pd.read_csv("/kaggle/input/gemma7brewrites/gemma7b_mixed_rewrites.csv")
print(len(df_mixed))
df_plots = pd.read_csv("/kaggle/input/gemma7brewrites/gemma7b_plot_rewrites.csv")
print(len(df_plots))
df_combined = pd.concat([df_mixed, df_plots, df_mixed2], ignore_index=True)#.drop(columns=['scs'])
print(len(df_combined))
df_combined.drop_duplicates(keep='first', inplace=True)
print(len(df_combined))
df_combined = df_combined[df_combined['rewrite_prompt'].isin(df_filtered_prompts)]
print(len(df_combined))
df_combined = df_combined[~df_combined['rewritten_text'].str.startswith("I am unable to ")]
print(len(df_combined))
df_combined['rewritten_text'].nunique()