In [1]:
# prepare_dataset.py
import pandas as pd
import glob
import os
from sklearn.model_selection import train_test_split

In [2]:
data_path = "../data/data_synthetic"  # folder where your synthetic txt files are stored
output_path = "../data/data_processed"
os.makedirs(output_path, exist_ok=True)

In [4]:
# Step 2: Collect and label all .txt files
files = glob.glob(os.path.join(data_path, "*.txt"))
data = []


In [5]:
if not files:
    raise FileNotFoundError(f"No .txt files found in '{data_path}'. Run make_synthetic_data.py first.")

for f in files:
    with open(f, "r", encoding="utf-8") as file:
        text = file.read().strip()
        # Extract label from filename prefix, e.g. "invoice_12.txt" -> "invoice"
        filename = os.path.basename(f)
        label = filename.split("_")[0].lower()
        data.append({"filename": filename, "text": text, "label": label})

In [6]:
df = pd.DataFrame(data)
print("ðŸ“Š Label distribution:")
print(df["label"].value_counts())

ðŸ“Š Label distribution:
label
contract     100
complaint    100
invoice      100
order        100
reminder     100
Name: count, dtype: int64


In [7]:
# Step 3: Save full labeled dataset
df.to_csv(os.path.join(output_path, "all_data.csv"), index=False, encoding="utf-8")


In [8]:
# Step 4: Split into train, validation, test sets (80 / 10 / 10)
train_df, temp_df = train_test_split(df, test_size=0.2, stratify=df["label"], random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, stratify=temp_df["label"], random_state=42)


In [9]:

train_df.to_csv(os.path.join(output_path, "train.csv"), index=False, encoding="utf-8")
val_df.to_csv(os.path.join(output_path, "val.csv"), index=False, encoding="utf-8")
test_df.to_csv(os.path.join(output_path, "test.csv"), index=False, encoding="utf-8")


In [10]:
print("\nâœ… Data preparation complete!")
print(f"Train set: {len(train_df)} | Validation set: {len(val_df)} | Test set: {len(test_df)}")
print(f"Labeled CSVs saved in '{output_path}/'")


âœ… Data preparation complete!
Train set: 400 | Validation set: 50 | Test set: 50
Labeled CSVs saved in '../data/data_processed/'
