In [1]:
import os

os.chdir("..")

### Sampling Criteria

- 2. Each author should have 10 samples per evaluation dataset
- 3. Each author in the evaluation data should at least form 2 non-outlier clusters
- 4. These 2 clusters must also appear in the train sets, with each cluster have at least 5 samples for each non-outlier cluster

In [2]:
import os
import pandas as pd
from tqdm import tqdm
from bertopic import BERTopic


def load_dataset(dataset):
    train = pd.read_csv(f"dataset_prepare/{dataset}_train.csv")
    test = pd.read_csv(f"dataset_prepare/{dataset}_test.csv")
    test = test[train.columns.to_list() + ["summary"]]
    train["split"] = "train"
    test["split"] = "test"
    return pd.concat([train, test], ignore_index=True)


def berttopic_clustering(docs):
    topic_model = BERTopic()
    topics, _ = topic_model.fit_transform(docs)
    return topics


def topic_model_a_dataset(dataset, max_num_authors=50):
    """
    This function loads a dataset, applies BERTopic clustering to the text data,
    and saves the resulting train and test sets with topic labels. 

    Parameters:
    - dataset (str): The name of the dataset to be processed.
    - max_num_authors (int): The maximum number of most frequently occuring authors to keep in the test set. 
    """
    
    df = load_dataset(dataset)

    for author in tqdm(df.author.unique()):
        sub = df[df.author == author]
        docs = sub.text.tolist()
        topics = berttopic_clustering(docs)

        for j, ix in enumerate(sub.index):
            df.loc[ix, "cluster"] = topics[j]
    
    test = df[df.split == "test"]
    test.drop(columns=["split"], inplace=True)
    train = df[df.split == "train"]
    train.drop(columns=["summary", "split"], inplace=True)

    original_train = pd.read_csv(f"dataset_prepare/{dataset}_train.csv")
    assert train.shape[0] == original_train.shape[0], "Train set size mismatch after topic modeling."
    assert train[["author", "text"]].equals(original_train[["author", "text"]]), "Train set content mismatch after topic modeling."

    test = test[test.cluster != -1]
    authors_to_keep = []
    for author in test.author.value_counts().index:
        sub = test[test.author == author]

        # Check if the author has at least 10 samples in the test set with more than 1 cluster
        if len(sub) >= 10 and len(sub.cluster.unique()) > 1:
            test_clusters = sub.cluster.unique()
            train_sub = train[train.author == author]

            # Check if the author has at least 10 samples in the train set for each cluster
            to_add = True

            for cluster in test_clusters:
                train_sub_sub = train_sub[train_sub.cluster == cluster]
                if len(train_sub_sub) < 5:
                    to_add = False

            if to_add:
                authors_to_keep.append(author)

    # test = test[test.author.isin(authors_to_keep)]
    test_new = []
    for author in authors_to_keep[:max_num_authors]:
        sub = test[test.author == author].sample(10)
        test_new.append(sub)

    test = pd.concat(test_new, ignore_index=True)
    
    save_dir = "dataset_followup"
    os.makedirs(save_dir, exist_ok=True)
    train.to_csv(f"{save_dir}/{dataset}_train.csv", index=False)
    test.to_csv(f"{save_dir}/{dataset}_test.csv", index=False)
    print(f"Saved {dataset} dataset with topics to {save_dir} folder.")
    return train, test

  from .autonotebook import tqdm as notebook_tqdm
2025-05-08 20:24:39.540327: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1746750279.551866 2864383 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746750279.555305 2864383 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746750279.565864 2864383 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746750279.565874 2864383 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746750279.565876 2864383

In [3]:
train, test = topic_model_a_dataset("enron", max_num_authors=30)
print("Number of unique authors in test set:", len(test.author.unique()))
test.columns

100%|██████████| 150/150 [01:28<00:00,  1.70it/s]

Saved enron dataset with topics to dataset_followup folder.
Number of unique authors in test set: 30





Index(['author', 'text', 'subject', 'AA-label', 'summary', 'cluster'], dtype='object')

In [4]:
train, test = topic_model_a_dataset("blog", max_num_authors=50)
print("Number of unique authors in test set:", len(test.author.unique()))
test.columns

100%|██████████| 100/100 [01:46<00:00,  1.07s/it]


Saved blog dataset with topics to dataset_followup folder.
Number of unique authors in test set: 50


Index(['author', 'text', 'topic', 'gender', 'age', 'sign', 'date', 'AA-label',
       'summary', 'cluster'],
      dtype='object')

In [5]:
train, test = topic_model_a_dataset("CCAT50", max_num_authors=30)
print("Number of unique authors in test set:", len(test.author.unique()))
test.columns

100%|██████████| 50/50 [00:27<00:00,  1.84it/s]

Saved CCAT50 dataset with topics to dataset_followup folder.
Number of unique authors in test set: 30





Index(['author', 'text', 'file_name', 'AA-label', 'summary', 'cluster'], dtype='object')

In [6]:
train, test = topic_model_a_dataset("reddit", max_num_authors=50)
print("Number of unique authors in test set:", len(test.author.unique()))
test.columns

100%|██████████| 100/100 [00:56<00:00,  1.76it/s]


Saved reddit dataset with topics to dataset_followup folder.
Number of unique authors in test set: 50


Index(['index', 'author', 'text', 'subreddit', 'AA-label', 'summary',
       'cluster'],
      dtype='object')

### Double Check

In [7]:
blog_test = pd.read_csv("dataset_followup/CCAT50_test.csv")
blog_test.columns

Index(['author', 'text', 'file_name', 'AA-label', 'summary', 'cluster'], dtype='object')

In [8]:
blog_test = pd.read_csv("dataset_followup/reddit_test.csv")
blog_test.columns

Index(['index', 'author', 'text', 'subreddit', 'AA-label', 'summary',
       'cluster'],
      dtype='object')

In [9]:
blog_test = pd.read_csv("dataset_followup/enron_test.csv")
blog_test.columns

Index(['author', 'text', 'subject', 'AA-label', 'summary', 'cluster'], dtype='object')

In [10]:
blog_test = pd.read_csv("dataset_followup/blog_test.csv")
blog_test.columns

Index(['author', 'text', 'topic', 'gender', 'age', 'sign', 'date', 'AA-label',
       'summary', 'cluster'],
      dtype='object')