In [8]:
### static variables

COLUMNS_DOCS = [
    "doc_id",
    "language",
    "domain",
    "content",
    "company_name",
    "court_name",
    "hospital_patient_name",
]

COLUMNS_DOCS_MANIPULATED_TEXTUAL = [
    *COLUMNS_DOCS,
    "original_doc_id",
]

COLUMNS_DOCS_MANIPULATED_TABULAR = [
    "doc_id",
    "language",
    "domain",
    "content",
    "company_names",
    "court_names",
    "hospital_patient_names",
    "original_doc_ids",
]

In [None]:
### helper functions

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import NMF, TruncatedSVD, LatentDirichletAllocation
from sklearn.preprocessing import normalize
import os
import numpy as np
from typing import List
import pandas as pd
import ast


def get_documents() -> pd.DataFrame:
    docs_original = pd.read_csv("data/DRAGONball/en/docs.csv", usecols=["doc_id", "domain", "content"])
    docs_manipulated_single_textual = pd.read_csv(
        "data/additional_data/docs/textual_manipulations_result.csv",
        usecols=["doc_id", "domain", "content", "original_doc_id"],
        dtype={"original_doc_id": "Int64"},
    )
    docs_manipulated_single_tabular = pd.read_csv(
        "data/additional_data/docs/tabular_manipulations_result.csv",
        usecols=["doc_id", "domain", "content", "original_doc_ids"],
        converters={"original_doc_ids": ast.literal_eval},
    )
    docs_manipulated_multi_textual = pd.read_csv(
        "data/additional_data/docs/multi_textual_manipulations.csv",
        usecols=["doc_id", "domain", "content", "original_doc_id"],
        dtype={"original_doc_id": "Int64"},
    )
    print(f"# original docs: {len(docs_original)}")
    print(f"# manipulated textual docs: {len(docs_manipulated_single_textual)}")
    print(f"# manipulated tabular docs: {len(docs_manipulated_single_tabular)}")
    print(f"# manipulated textual multi docs: {len(docs_manipulated_multi_textual)}")

    return pd.concat(
        [
            docs_original,
            docs_manipulated_single_textual,
            docs_manipulated_single_tabular,
            docs_manipulated_multi_textual,
        ],
        sort=False,
    )

In [28]:
docs_df = get_documents()
docs_list = docs_df["content"].to_list()


tfidf_vectorizer = TfidfVectorizer(max_df=0.95, min_df=1, stop_words="english")
tfidf_features = tfidf_vectorizer.fit_transform(docs_list)
tfidf_feature_names = tfidf_vectorizer.get_feature_names_out()
print(f"Size of vocabulary: {len(tfidf_feature_names)}")

# original docs: 108
# manipulated textual docs: 30
# manipulated tabular docs: 3
# manipulated textual multi docs: 30
Size of vocabulary: 7581


In [51]:
### NMF
nmf = NMF(n_components=108, init="nndsvda", max_iter=400)
nmf = nmf.fit(tfidf_features[:108])
nmf_data = nmf.transform(tfidf_features)
# nmf_data = nmf.fit_transform(tfidf_features)
nmf_data_normalised = normalize(nmf_data, norm="l1", axis=1)

In [52]:
### LSA
lsa = TruncatedSVD(n_components=108)
lsa_data = lsa.fit_transform(tfidf_features)
lsa_data_normalised = normalize(lsa_data, norm="l2", axis=1)

In [53]:
### LDA
lda = LatentDirichletAllocation(n_components=108)
lda_data_normalised = lda.fit_transform(tfidf_features, normalize=True)

In [54]:
### helper functions (1/2)
def merge_original_ids(row):
    if isinstance(row["original_doc_ids_tmp"], list):
        return row["original_doc_ids_tmp"]
    elif pd.notna(row["original_doc_id_tmp"]):
        return [row["original_doc_id_tmp"]]
    else:
        return pd.NA


def calc_topics(row):
    if isinstance(row["original_doc_ids"], list):
        if len(row["original_doc_ids"]) > 1:
            return np.argsort(row["doc_vector"])[-10:][::-1].tolist()
    return [np.argmax(row["doc_vector"])]


def calc_topics_for_cumulative_threshold(row, threshold=0.9):
    sorted_indices = np.argsort(row)[::-1]

    # Sort the probabilities accordingly
    sorted_probs = row[sorted_indices]

    # Compute cumulative sum
    cumulative = np.cumsum(sorted_probs)

    # Find the cutoff index where cumulative sum first exceeds threshold
    cutoff = np.searchsorted(cumulative, threshold)

    # Select the indices up to and including that point
    selected_indices = sorted_indices[: cutoff + 1]

    return selected_indices

In [55]:
### helper functions (2/2)
def calc_topic_hitrate(row):
    if not isinstance(row["original_doc_ids"], list):
        return None

    original_doc_ids: List[int] = row["original_doc_ids"]

    res = []

    for id in original_doc_ids:
        topics_row = set(row["topics"])
        original_row = docs.loc[docs["doc_id"].astype(int) == int(id)].iloc[0]
        topics_original = set(original_row["topics"])
        res.append(len(topics_row.intersection(topics_original)) > 0)

    return np.mean(res)

In [56]:
### evaluate method
transformed_data = nmf_data_normalised

docs = pd.DataFrame(
    {
        "doc_id": docs_df["doc_id"].to_list(),
        "original_doc_id_tmp": docs_df["original_doc_id"].to_list(),
        "original_doc_ids_tmp": docs_df["original_doc_ids"].to_list(),
        "doc_vector": list([doc for doc in transformed_data]),
    }
)

docs["original_doc_ids"] = docs.apply(merge_original_ids, axis=1)
docs = docs.drop(["original_doc_id_tmp", "original_doc_ids_tmp"], axis=1)


docs["topics"] = docs["doc_vector"].apply(calc_topics_for_cumulative_threshold, args=(0.95,))
docs["len(topics)"] = docs["topics"].apply(len)
docs["topic_hitrate"] = docs.apply(calc_topic_hitrate, axis=1)
docs["num_non-zeros_in_vector"] = docs["doc_vector"].apply(lambda v: sum(i > 0 for i in v))

print(f"Avg. number of topics: {round(docs["len(topics)"].mean(), 2)}")
docs["doc_vector"] = docs["doc_vector"].apply(lambda v: np.sort(v)[::-1]).apply(lambda v: [round(i, 4) for i in v])

filtered_docs = docs.loc[docs["original_doc_ids"].notna()]
print(f"Recall: {round(filtered_docs["topic_hitrate"].mean(), 2)}")

filtered_docs


Avg. number of topics: 3.01
Recall: 0.97


Unnamed: 0,doc_id,doc_vector,original_doc_ids,topics,len(topics),topic_hitrate,num_non-zeros_in_vector
108,100134,"[0.9294, 0.0359, 0.0137, 0.0062, 0.0035, 0.002...",[134],"[37, 98]",2,1.0,14
109,100136,"[0.9968, 0.0016, 0.0007, 0.0003, 0.0003, 0.000...",[136],[35],1,1.0,16
110,100139,"[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",[139],[62],1,1.0,10
111,100046,"[0.9989, 0.0003, 0.0003, 0.0002, 0.0001, 0.000...",[46],[29],1,1.0,9
112,100047,"[0.8849, 0.0932, 0.0082, 0.0035, 0.0034, 0.002...",[47],"[17, 1]",2,1.0,16
...,...,...,...,...,...,...,...
166,400207,"[0.4945, 0.2989, 0.0514, 0.0268, 0.0251, 0.023...",[207],"[9, 103, 101, 85, 64, 23, 60, 40]",8,1.0,19
167,400213,"[0.5692, 0.0653, 0.0509, 0.0461, 0.0431, 0.040...",[213],"[72, 98, 73, 13, 11, 40, 101, 60, 23, 2, 17, 69]",12,1.0,19
168,400214,"[0.5735, 0.1258, 0.0654, 0.0547, 0.0431, 0.036...",[214],"[27, 89, 60, 100, 23, 87, 101, 13, 51]",9,1.0,22
169,400110,"[0.7654, 0.0411, 0.0283, 0.0271, 0.0244, 0.023...",[110],"[76, 71, 100, 95, 26, 23, 59, 25, 3]",9,1.0,22
