In [1]:
import pandas as pd
import string
import os
import tqdm

tqdm.tqdm.pandas()

In [2]:
# load raw data files

sample_dict = {}

for file in sorted(os.listdir("../data/raw")):
    if file.endswith(".csv"):
        sample_dict[file.replace(".csv","")] = pd.read_csv(f"../data/raw/{file}")
        print(f"Loaded {file}: {sample_dict[file.replace('.csv','')].shape[0]} rows")

Loaded hhonline.csv: 23144 rows
Loaded lmsys.csv: 1000000 rows
Loaded prism.csv: 8011 rows
Loaded sharegpt.csv: 90665 rows
Loaded wildchat.csv: 652148 rows


  sample_dict[file.replace(".csv","")] = pd.read_csv(f"../data/raw/{file}")


In [3]:
# restrict to English language only, where information is given

for dataset in sample_dict:
    if "language" in sample_dict[dataset].columns:
        sample_dict[dataset] = sample_dict[dataset][sample_dict[dataset]["language"] == "English"]
        print(f"Restricted {dataset} to English language only: {sample_dict[dataset].shape[0]} rows remain")

Restricted lmsys to English language only: 777453 rows remain
Restricted wildchat to English language only: 360136 rows remain


In [4]:
# for lmsys only, remove redacted prompts

sample_dict["lmsys"] = sample_dict["lmsys"][sample_dict["lmsys"]["redacted"] == False]
print(f"Removed redacted prompts from lmsys: {sample_dict['lmsys'].shape[0]} rows remain")

Removed redacted prompts from lmsys: 511193 rows remain


In [5]:
# print descriptive stats on user_prompt length

for dataset in sample_dict:
    sample_dict[dataset]["user_prompt_length"] = sample_dict[dataset]["user_prompt"].str.len()
    print(f"{dataset.upper()} user_prompt_length:")
    display(sample_dict[dataset]['user_prompt_length'].describe(percentiles=[0.01,0.05,0.1,0.5,0.9,0.95,0.99]))
    print()

HHONLINE user_prompt_length:


count    23144.000000
mean        87.586286
std        151.693693
min          4.000000
1%          18.000000
5%          26.000000
10%         31.000000
50%         58.000000
90%        151.000000
95%        220.000000
99%        572.280000
max       5510.000000
Name: user_prompt_length, dtype: float64


LMSYS user_prompt_length:


count    511192.000000
mean        265.389441
std         430.021297
min           1.000000
1%            5.000000
5%           12.000000
10%          21.000000
50%         109.000000
90%         664.000000
95%        1293.000000
99%        2409.000000
max        2560.000000
Name: user_prompt_length, dtype: float64


PRISM user_prompt_length:


count    8011.000000
mean       65.732243
std        59.164021
min         2.000000
1%          5.000000
5%         17.000000
10%        23.000000
50%        50.000000
90%       121.000000
95%       167.000000
99%       295.000000
max      1195.000000
Name: user_prompt_length, dtype: float64


SHAREGPT user_prompt_length:


count     90665.000000
mean        680.437368
std        3585.515032
min           1.000000
1%            3.000000
5%           13.000000
10%          24.000000
50%         118.000000
90%        1379.600000
95%        2838.800000
99%        9652.600000
max      382782.000000
Name: user_prompt_length, dtype: float64


WILDCHAT user_prompt_length:


count    359141.000000
mean       1601.274032
std        2866.879993
min           1.000000
1%            4.000000
5%           27.000000
10%          42.000000
50%         410.000000
90%        3932.000000
95%        6013.000000
99%       13824.000000
max       99874.000000
Name: user_prompt_length, dtype: float64




In [6]:
# restrict user_prompt_length

MIN_SIZE = 10
MAX_SIZE = 1000

for dataset in sample_dict:
    sample_dict[dataset] = sample_dict[dataset][(sample_dict[dataset]["user_prompt_length"] > MIN_SIZE) & (sample_dict[dataset]["user_prompt_length"] < MAX_SIZE)]
    print(f"Restricted {dataset} to user_prompt_length > {MIN_SIZE} and < {MAX_SIZE}: {sample_dict[dataset].shape[0]} rows remain")

Restricted hhonline to user_prompt_length > 10 and < 1000: 23038 rows remain
Restricted lmsys to user_prompt_length > 10 and < 1000: 454484 rows remain
Restricted prism to user_prompt_length > 10 and < 1000: 7845 rows remain
Restricted sharegpt to user_prompt_length > 10 and < 1000: 75823 rows remain
Restricted wildchat to user_prompt_length > 10 and < 1000: 228650 rows remain


In [7]:
# drop prompts that mention specific words

def drop_prompts(df_dict, drop_word):
    
    for dataset in df_dict:

        print(f"  Dropping { df_dict[dataset][df_dict[dataset]['user_prompt'].str.lower().str.contains(drop_word)].shape[0]} rows from {dataset}")
        df_dict[dataset] = df_dict[dataset][~df_dict[dataset]["user_prompt"].str.lower().str.contains(drop_word)]
    
    return df_dict

# drop prompts that mention programming languages and terms
for drop_word in ["python", "javascript", "sql", "ruby", "matplotlib", "dataframe", "http", "="]:
    print(f"Dropping prompts that mention '{drop_word}'")
    sample_dict = drop_prompts(sample_dict, drop_word)
print()

# drop prompts that use jailbreak phrases
for drop_word in ["say something toxic", "do anything now"]:
    print(f"Dropping prompts that mention '{drop_word}'")
    sample_dict = drop_prompts(sample_dict, drop_word)
print()

# weird quirks in the LMSYS data
for drop_word in ["give me an introduction over 200 words", "chemical industry", "hydrometry", "\[your answer\]"]:
    print(f"Dropping prompts that mention '{drop_word}'")
    sample_dict = drop_prompts(sample_dict, drop_word)
print()

# weird quirks in the ShareGPT data
for drop_word in ["you are chatgpt", "with bing"]:
    print(f"Dropping prompts that mention '{drop_word}'")
    sample_dict = drop_prompts(sample_dict, drop_word)
print()

# weird quirks in the WildChat
for drop_word in ["give me a response to"]:
    print(f"Dropping prompts that mention '{drop_word}'")
    sample_dict = drop_prompts(sample_dict, drop_word)

Dropping prompts that mention 'python'
  Dropping 17 rows from hhonline
  Dropping 34343 rows from lmsys
  Dropping 11 rows from prism
  Dropping 2374 rows from sharegpt
  Dropping 2647 rows from wildchat
Dropping prompts that mention 'javascript'
  Dropping 4 rows from hhonline
  Dropping 1937 rows from lmsys
  Dropping 0 rows from prism
  Dropping 715 rows from sharegpt
  Dropping 512 rows from wildchat
Dropping prompts that mention 'sql'
  Dropping 0 rows from hhonline
  Dropping 2846 rows from lmsys
  Dropping 0 rows from prism
  Dropping 712 rows from sharegpt
  Dropping 1059 rows from wildchat
Dropping prompts that mention 'ruby'
  Dropping 0 rows from hhonline
  Dropping 175 rows from lmsys
  Dropping 1 rows from prism
  Dropping 60 rows from sharegpt
  Dropping 407 rows from wildchat
Dropping prompts that mention 'matplotlib'
  Dropping 0 rows from hhonline
  Dropping 74 rows from lmsys
  Dropping 0 rows from prism
  Dropping 24 rows from sharegpt
  Dropping 37 rows from wildch

In [8]:
# clean prompts for purposes of deduplication

def clean_prompt(prompt):
    
    # lowercase
    prompt = prompt.lower()

    # strip all punctuation using translate
    prompt = prompt.translate(str.maketrans("", "", string.punctuation))


    # remove non-ASCII characters
    prompt = prompt.replace("\n", " ")
    prompt = prompt.replace("\r", " ")
    prompt = prompt.replace("\t", " ")
    prompt = prompt.replace("  ", " ")
    prompt = prompt.strip()

    return prompt

for dataset in sample_dict:
    sample_dict[dataset]["user_prompt_clean"] = sample_dict[dataset]["user_prompt"].apply(clean_prompt)

In [9]:
# write function to deduplicate prompt dataframe, which writes number of duplicates of each prompt to a new column
def deduplicate_prompts(df):
    
    # count duplicates
    df["n_duplicates"] = df.groupby("user_prompt_clean")["user_prompt_clean"].transform("count")
    
    # drop duplicates
    df = df.drop_duplicates(subset="user_prompt_clean")

    return df

# deduplicate prompts
for dataset in sample_dict:
    sample_dict[dataset] = deduplicate_prompts(sample_dict[dataset])
    print(f"Deduplicated {dataset} user_prompt: {sample_dict[dataset].shape[0]} rows remain")

Deduplicated hhonline user_prompt: 9016 rows remain
Deduplicated lmsys user_prompt: 192933 rows remain
Deduplicated prism user_prompt: 7624 rows remain
Deduplicated sharegpt user_prompt: 52304 rows remain
Deduplicated wildchat user_prompt: 177743 rows remain


In [10]:
# additonal language filtering with GlotLID

import fasttext
from huggingface_hub import hf_hub_download

glotlid_model = fasttext.load_model(hf_hub_download(repo_id="cis-lmu/glotlid", filename="model.bin"))

for dataset in sample_dict:
    print(f"Detecting language for {dataset}...")
    sample_dict[dataset]["detected_language"] = sample_dict[dataset]["user_prompt_clean"].progress_apply(glotlid_model.predict)


def in_english(prediction):
    if prediction[0][0] == "__label__eng_Latn":
        return True
    else:
        return False
    
for dataset in sample_dict:
    sample_dict[dataset] = sample_dict[dataset][sample_dict[dataset]["detected_language"].apply(in_english)]
    print(f"Restricted {dataset} to English language only (using GlotLID): {sample_dict[dataset].shape[0]} rows remain")



Detecting language for hhonline...


100%|██████████| 9016/9016 [00:04<00:00, 2134.66it/s]


Detecting language for lmsys...


100%|██████████| 192933/192933 [01:32<00:00, 2083.50it/s]


Detecting language for prism...


100%|██████████| 7624/7624 [00:03<00:00, 2199.30it/s]


Detecting language for sharegpt...


100%|██████████| 52304/52304 [00:26<00:00, 2011.58it/s]


Detecting language for wildchat...


100%|██████████| 177743/177743 [01:28<00:00, 2005.18it/s]


Restricted hhonline to English language only (using GlotLID): 8839 rows remain
Restricted lmsys to English language only (using GlotLID): 184600 rows remain
Restricted prism to English language only (using GlotLID): 7393 rows remain
Restricted sharegpt to English language only (using GlotLID): 36667 rows remain
Restricted wildchat to English language only (using GlotLID): 170911 rows remain


In [11]:
# show most common prompts in each dataset

for dataset in sample_dict:
    print("#"*50)
    print(f"Most common prompts in {dataset}:")
    print("#"*50, "\n")

    for _, row in sample_dict[dataset].sort_values("n_duplicates", ascending=False).head(5).iterrows():
        print(f"### (n={row['n_duplicates']})")
        print(row['user_prompt'].replace("\n", " "))
        print()

##################################################
Most common prompts in hhonline:
################################################## 

### (n=33)
I have a song stuck in my head. but I can't think of the title of the song. I know it a song from the 90s but just can't figure it out. Can you help me identify the name of the song that includes the following lyrics: 

### (n=24)
I have a song stuck in my head. but I can't think of the title of the song. I know it a song from the 2000s but just can't figure it out. Can you help me identify the name of the song that includes the following lyrics: 

### (n=19)
Please summarize the article below into three concise sentences:

### (n=13)
Can you tell me the difference between a solar system, a galaxy, and the universe?

### (n=13)
Who are you?

##################################################
Most common prompts in lmsys:
################################################## 

### (n=8934)
You are the text completion model and you must complete

In [12]:
# export whole dataframes, plus samples for annotation

N_SAMPLES = 100

for key in sample_dict.keys():
    sample_dict[key][["id", "user_prompt", "n_duplicates"]].to_csv(f"../data/clean/{key}.csv", index=False)
    sample_dict[key][["id", "user_prompt"]].sample(N_SAMPLES, random_state=42).to_csv(f"../data/samples/{key}.csv", index=False)

In [13]:
# create annotation file

out_dict = dict()
for file in os.listdir("../data/samples"):
    if file.endswith(".csv") and file != "all_samples.csv":
        out_dict[file.replace(".csv", "")] = pd.read_csv(f"../data/samples/{file}")

# concat all samples into one dataframe
all_samples = pd.concat(out_dict.values(), ignore_index=True)

# shuffle
all_samples = all_samples.sample(frac=1, random_state=42)

# export
all_samples.to_csv("../data/samples/all_samples.csv", index=False)


In [14]:
# create combined clean file

out_dict = dict()
for file in os.listdir("../data/clean"):
    if file.endswith(".csv") and "all_clean" not in file:
        out_dict[file.replace(".csv", "")] = pd.read_csv(f"../data/clean/{file}")

# concat all samples into one dataframe
all_clean = pd.concat(out_dict.values(), ignore_index=True)

# shuffle
all_clean = all_clean.sample(frac=1, random_state=42)

# deduplicate
print("all_clean before deduplication:", all_clean.shape[0], "rows")
all_clean["user_prompt_clean"] = all_clean["user_prompt"].apply(clean_prompt)
all_clean = all_clean.drop_duplicates(subset="user_prompt_clean")
print(f"Deduplicated all_clean user_prompt: {all_clean.shape[0]} rows remain")

# export
all_clean[["id", "user_prompt"]].to_csv("../data/clean/all_clean.csv", index=False)
all_clean.to_csv("../data/clean/all_clean_full.csv", index=False)


all_clean before deduplication: 408410 rows
Deduplicated all_clean user_prompt: 406884 rows remain
