In [None]:
import numpy as np
import pandas as pd

from utils.text_tensor import load_text_dataset_from_json

import zstandard as zstd
import json

import io
import re

## Load All Datasets

In [None]:
dataset_paths = input('Enter dataset paths, space separated: ').split()

In [None]:
dataset_paths

In [None]:
results = []
for dataset_path in dataset_paths:
    results.append(load_text_dataset_from_json(dataset_path))

In [None]:
data = []
labels = []

for result in results:
    data.extend(result[0])
    labels.extend(result[1])

In [None]:
data = np.array(data)
labels = np.array(labels)

In [None]:
print(data.shape, labels.shape)

In [None]:
dataset = pd.DataFrame(data, columns=['Data'])

In [None]:
dataset['Labels'] = labels

In [None]:
np.unique(dataset['Labels'], return_counts=True)

## Remove Unnecessary characters and symbols

Removing '>' characters, which symbolize replies in reddit comments. GPT would not generate this symbol.

In [None]:
dataset['Data'] = dataset['Data'].str.replace('^>', '', regex=True)

Removing bot comments

In [None]:
dataset = dataset[~dataset['Data'].str.contains("I am a bot")]

Removing profanity

In [None]:
from better_profanity import profanity

def contains_explicit(text):
    return profanity.contains_profanity(text)

# Filter out rows containing explicit content
filtered_data = dataset[dataset['Data'].apply(contains_explicit)]

# Save the rows with explicit content to a separate CSV file
filtered_data.to_csv('explicit_content.csv', index=False)

# Remove the rows with explicit content from the original DataFrame
clean_df = dataset.drop(filtered_data.index)

# Save the cleaned DataFrame to another CSV file
clean_df.to_csv('cleaned_content.csv', index=False)

In [None]:
print(clean_df.shape, dataset.shape)

In [None]:
print(filtered_data.shape, dataset.shape)

## Store Dataset

In [None]:
# Store dataframe as csv
dataset.to_csv('datasets/reddit_datasets/gpt_reddit_dataset.csv', index=False)

In [None]:
# Convert DataFrame to a byte stream
df_bytes = io.BytesIO()
dataset.to_csv(df_bytes, index=False)

# Compress the byte stream using Zstandard
cmpr = zstd.ZstdCompressor()
compressed_bytes = cmpr.compress(df_bytes.getvalue())

In [None]:
with open(f'datasets/reddit_datasets/gpt_reddit_dataset.zst', 'wb') as f:
    f.write(compressed_bytes)

## Read the data

In [None]:
# Decompress the byte stream using Zstandard
dctx = zstd.ZstdDecompressor()
decompressed_bytes = dctx.decompress(compressed_bytes)

# Convert the decompressed byte stream back to a DataFrame
dataset = pd.read_csv(io.BytesIO(decompressed_bytes))

# Print the decompressed DataFrame (optional, just to check the result)
print(dataset)


## Store dataset without special characters

In [None]:
# Function to remove non-alphanumeric characters and links
def clean_text(text):
    # Replace any URL-like patterns with an empty string
    text = re.sub(r'\b(?:https?://|www\.)\S+\b', '', text)
    # Replace non-alphanumeric characters
    text = re.sub(r'[^\w\s$.,!?"\']', '', text)
    # Remove extra spaces (optional)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

# Apply the clean_text function to the 'text_column'
# Create a new DataFrame with the cleaned 'text_column' and include 'other_column'
cleaned_dataset = pd.DataFrame({
    'Data': dataset['Data'].apply(clean_text),
    'Labels': dataset['Labels']
})

In [None]:
# Convert DataFrame to a byte stream
df_bytes = io.BytesIO()
cleaned_dataset.to_csv(df_bytes, index=False)

# Compress the byte stream using Zstandard
cmpr = zstd.ZstdCompressor()
compressed_bytes = cmpr.compress(df_bytes.getvalue())

In [None]:
cleaned_dataset.to_csv('datasets/reddit_datasets/gpt_reddit_dataset_cleaned.csv', index=False)

In [None]:
with open(f'dataset_output.txt', 'w') as f:
    f.write(cleaned_dataset.to_markdown())

## Filter existing dataset

Filtering more unique outliers from the text data. The goal is to standardize the text format as much as possible. 

In [None]:
data = pd.read_csv("datasets/reddit_datasets/reddit_filtered_data.csv")

pd.options.mode.chained_assignment = None  # default='warn'

# # Remove start and end quotes.
# data['TextOnly'] = data['Data'].apply(lambda x: x[1:-1] if x.startswith('"') and x.endswith('"') else x)


In [None]:
# Remove href links in format [text](link) and regular links
pattern = r'\[([^]]*)\]\([^)]*\)'

data['Data'] = data['Data'].apply(lambda x: re.sub(pattern, r'\1', x))
data['Data'] = data['Data'].apply(lambda x: re.sub(r'http\S+', '', x))

In [None]:
# Remove newline characters
data['Data'] = data['Data'].str.replace('\n', '')

In [None]:
# Remove u/ reddit specific character

data['Data'] = data['Data'].str.replace('u/', '')

In [None]:
# Remove > reddit specific character

data['Data'] = data['Data'].str.replace('> ', '')

In [None]:
# Remove ** (markdown bolding)

data['Data'] = data['Data'].apply(lambda x: re.sub(r'\*{2,}', '*', re.sub(r'\*(.*?)\*', r'\1', x)))

In [None]:
# Remove data under 50 characters. 
data = data[data['Data'].apply(len) >= 100]

In [None]:
data.to_csv('reddit_more_filtered_dataset.csv', index=False)