In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split

df = pd.read_csv("sentio-data-processed.csv")

print(f"Total rows: {len(df)}")
print(f"Label distribution:\n{df['label'].value_counts()}")
print(f"Unique labels: {df['label'].unique()}")

labels = df['label'].unique().tolist()
samples_per_label = 3000 // len(labels)

increment_samples = []
for label in labels:
    label_df = df[df['label'] == label]
    sampled = label_df.sample(n=samples_per_label, random_state=42)
    increment_samples.append(sampled)

increment_df = pd.concat(increment_samples)
remaining_df = df.drop(increment_df.index)

train_df, test_df = train_test_split(
    remaining_df,
    test_size=0.2,
    stratify=remaining_df['label'],
    random_state=42
)

for dataset in [increment_df, train_df, test_df]:
    dataset = dataset.sample(frac=1, random_state=42)

increment_df = increment_df[['text_preprocessed', 'label']].reset_index(drop=True)
train_df = train_df[['text_preprocessed', 'label']].reset_index(drop=True)
test_df = test_df[['text_preprocessed', 'label']].reset_index(drop=True)

increment_df.to_csv("sentio-data-increment.csv", index=False)
train_df.to_csv("sentio-data-train.csv", index=False)
test_df.to_csv("sentio-data-test.csv", index=False)

print(f"\nIncrement set: {len(increment_df)} rows")
print(increment_df['label'].value_counts())
print(f"\nTrain set: {len(train_df)} rows")
print(train_df['label'].value_counts())
print(f"\nTest set: {len(test_df)} rows")
print(test_df['label'].value_counts())

Total rows: 47960
Label distribution:
label
Depression    15016
Normal        12851
Suicidal      10590
Stress         9503
Name: count, dtype: int64
Unique labels: ['Stress' 'Normal' 'Depression' 'Suicidal']

Increment set: 3000 rows
label
Stress        750
Normal        750
Depression    750
Suicidal      750
Name: count, dtype: int64

Train set: 35968 rows
label
Depression    11413
Normal         9681
Suicidal       7872
Stress         7002
Name: count, dtype: int64

Test set: 8992 rows
label
Depression    2853
Normal        2420
Suicidal      1968
Stress        1751
Name: count, dtype: int64
