# Creating datasets

In [None]:
import os
import json
import wandb
import random
import pandas as pd

from dotenv import load_dotenv
from tqdm.autonotebook import tqdm

load_dotenv()

In [None]:
run = wandb.init(entity="jhu-llm-prompt-recovery", project="llm-prompt-recovery", job_type="upload-dataset")

## Stories (Writing Prompts) Dataset

In [None]:
DATASET_NAME = "stories-dataset"
RAW_FP = "data/writingPrompts/train.wp_target"
PROCESSED_FP = "outputs/datasetStories.json"

In [None]:
# read file
with open(RAW_FP, "r", encoding='utf-8') as f:
    raw = f.readlines()

In [None]:
# filter first 10000 lines which have between 500 and 1000 words
filtered = []
for line in raw:
    if len(line.split()) > 500 and len(line.split()) < 1000:
        filtered.append(line)
    if len(filtered) == 10000:
        break

In [None]:
with open(PROCESSED_FP, "w") as f:
    json.dump([{"id": idx, "text": text} for idx, text in enumerate(filtered)], f, ensure_ascii=False, indent=4)

In [None]:

artifact = wandb.Artifact(name=DATASET_NAME, type="original-text-dataset")
artifact.add_file(local_path=PROCESSED_FP)
run.log_artifact(artifact)

## News Articles (CNN-DailyMail) Dataset

In [None]:
DATASET_NAME = "news-articles-dataset"
RAW_FP = 'data/cnn_dailymail/3.0.0/validation-00000-of-00001.parquet'
PROCESSED_FP = "outputs/datasetNewsArticles.json"

In [None]:
df = pd.read_parquet(RAW_FP, engine='pyarrow')

In [None]:
raw = df['article'].to_list()

In [None]:
# filter first 10000 lines which have between 500 and 1000 words
filtered = []
for line in raw:
    if len(line.split()) > 500 and len(line.split()) < 1000:
        filtered.append(line)
    if len(filtered) == 10000:
        break

In [None]:
# remove the word (CNN) from the text
filtered = [text.replace('(CNN)', '') for text in filtered]

In [None]:
with open(PROCESSED_FP, "w") as f:
    json.dump([{"id": idx, "text": text} for idx, text in enumerate(filtered)], f, ensure_ascii=False, indent=4)

In [None]:
artifact = wandb.Artifact(name=DATASET_NAME, type="original-text-dataset")
artifact.add_file(local_path=PROCESSED_FP)
run.log_artifact(artifact)

## Enron Email Dataset

In [None]:
DATASET_NAME = "email-dataset"
RAW_DIR = "data/maildir"
PROCESSED_FP = "outputs/datasetStories.json"

In [None]:
# there are multiple folders in the maildir directory
# get list of files in the "all_documents" folder, no matter the subfolder
raw = []
for root, dirs, files in os.walk(RAW_DIR):
    if "all_documents" in root:
        for file in files:
            raw.append(os.path.join(root, file))

In [None]:
# shuffle the list of documents in raw
random.seed(42)
random.shuffle(raw)

In [None]:
# only allow common email body marks like . , ' " ( ) ? ! \' \" [ ] @ -
# allowed_punctuation_marks = [".", ",", "'", '"', "(", ")", "?", "!", "\'", "\"", "[", "]", "@", "-"]

# spam consecutive chars
spam_consecutive_chars = ["===", "---", "___", ">>>", "<<<", "[IMAGE]", "***", "> > >"]

bodies = list()
for email_fp in tqdm(raw, desc="Processing emails", total=len(raw), unit="emails"):
    with open(email_fp, "r", encoding='utf-8', errors='ignore') as f:
        lines = f.readlines()
        try:
            start = None
            end = None
            for line in lines:
                if "X-FileName" in line:
                    start = lines.index(line)
                    break
                else:
                    continue
            if start is None:
                continue
            for line in lines[start+1:]:
                if ("Forwarded by" in line) or ("Original Message" in line):
                    end = lines.index(line)
                    break
                else:
                    end = len(lines)
            body = " ".join(lines[start+1:end])

            # replace "\n" with " "
            body = body.replace("\n", " ")

            # remove extra spaces
            body = " ".join(body.split())

            # remove any punctuation marks that are not in the allowed list
            # body = "".join([char for char in body if char in allowed_punctuation_marks])

            # if the body contains any line with spam consecutive characters, skip
            if any([spam in body for spam in spam_consecutive_chars]):
                continue

            # # if body contains less than 500 words and more than 100 words, skip
            if len(body.split()) < 500 or len(body.split()) > 1000:
                continue
            bodies.append(body)
        except:
            pass


In [None]:
# SPECIAL_MENTION_ID = 944
# bodies[SPECIAL_MENTION_ID]

In [None]:
with open(PROCESSED_FP, "w") as f:
    json.dump([{"id": idx, "text": text} for idx, text in enumerate(bodies)], f, ensure_ascii=False, indent=4)

In [None]:
artifact = wandb.Artifact(name=DATASET_NAME, type="original-text-dataset")
artifact.add_file(local_path=PROCESSED_FP)
run.log_artifact(artifact)

In [None]:
run.finish()
wandb.finish()