-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(similarity.py): add multiple data-processing and plotting functi…
…ons, switch clustering method from DBSCAN to Agglomerative Clustering
- Loading branch information
1 parent
5e5b235
commit 4f38cba
Showing
1 changed file
with
260 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,282 @@ | ||
import multiprocessing as mp | ||
from functools import partial | ||
from typing import Optional, Tuple | ||
from pathlib import Path | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
import pandas as pd | ||
from sklearn.cluster import DBSCAN | ||
from scipy.cluster.hierarchy import dendrogram | ||
from sklearn.cluster import AgglomerativeClustering | ||
from sklearn.feature_extraction.text import TfidfVectorizer | ||
from sklearn.metrics.pairwise import cosine_similarity | ||
|
||
from corprep import HyFI | ||
|
||
def process_week(weekly_data): | ||
vectorizer = TfidfVectorizer() | ||
tfidf_matrix = vectorizer.fit_transform(weekly_data["tokenizedText"]) | ||
similarity_matrix = cosine_similarity(tfidf_matrix) | ||
logger = HyFI.getLogger(__name__) | ||
|
||
# Perform DBSCAN clustering on the similarity matrix | ||
db = DBSCAN(eps=0.8, min_samples=2, metric="precomputed", n_jobs=-1) | ||
db.fit(similarity_matrix) | ||
|
||
# For each cluster, get the index of the document with the earliest createdDt | ||
df = pd.DataFrame( | ||
{"cluster": db.labels_, "createdDt": weekly_data.index}, | ||
index=weekly_data["newsId"], | ||
def plot_dendrogram(model, **kwargs): | ||
""" | ||
Plots the hierarchical clustering dendrogram. | ||
""" | ||
# Create linkage matrix and then plot the dendrogram | ||
|
||
# create the counts of samples under each node | ||
counts = np.zeros(model.children_.shape[0]) | ||
n_samples = len(model.labels_) | ||
for i, merge in enumerate(model.children_): | ||
current_count = sum( | ||
1 if child_idx < n_samples else counts[child_idx - n_samples] | ||
for child_idx in merge | ||
) | ||
counts[i] = current_count | ||
|
||
linkage_matrix = np.column_stack( | ||
[model.children_, model.distances_, counts] | ||
).astype(float) | ||
|
||
# Plot the corresponding dendrogram | ||
plt.title("Hierarchical Clustering Dendrogram") | ||
dendrogram(linkage_matrix, truncate_mode="level", p=3) | ||
plt.xlabel("Number of points in node (or index of point if no parenthesis).") | ||
plt.show() | ||
|
||
|
||
def plot_similarity_distribution( | ||
similarity_matrix: np.ndarray, | ||
percentile: int = 80, | ||
distance_threshold: Optional[float] = None, | ||
title_name: str = "", | ||
title_fontsize: int = 10, | ||
output_dir: str = ".", | ||
show_fig: bool = False, | ||
save_fig: bool = True, | ||
) -> str: | ||
""" | ||
Plots the distribution of cosine similarities between pairs of documents. | ||
""" | ||
# Flatten the matrix to a 1D array | ||
similarity_array = similarity_matrix[ | ||
np.triu_indices(similarity_matrix.shape[0], k=1) | ||
] | ||
|
||
# Compute the number of samples and the average similarity | ||
num_samples = len(similarity_array) | ||
med_similarity = np.median(similarity_array) | ||
pct_similarity = np.percentile(similarity_array, percentile) | ||
|
||
# Reset the matplotlib figure | ||
plt.clf() | ||
|
||
# Generate histogram | ||
plt.hist(similarity_array, bins="auto", color="#0504aa", alpha=0.7, rwidth=0.85) | ||
# If distance threshold is provided, plot a vertical line at that threshold | ||
if distance_threshold: | ||
sim_threshold = 1 - distance_threshold | ||
plt.axvline(x=sim_threshold, color="r", linestyle="dashed", linewidth=1) | ||
|
||
# Add plot labels | ||
plt.grid(axis="y", alpha=0.75) | ||
plt.xlabel("Cosine Similarity") | ||
plt.ylabel("Frequency") | ||
title = "Distribution of Cosine Similarities" | ||
title = f"{title} - {title_name}" if title_name else title | ||
plt.title(title, fontsize=title_fontsize) | ||
|
||
# Add labels for the number of samples and the average similarity | ||
x_pos = 0.4 | ||
y_pos = 0.95 | ||
plt.text( | ||
x_pos, | ||
y_pos, | ||
f"Number of Sample pairs: {num_samples}", | ||
transform=plt.gca().transAxes, | ||
fontsize=title_fontsize - 1, | ||
) | ||
plt.text( | ||
x_pos, | ||
y_pos - 0.05, | ||
f"Median Similarity: {med_similarity:.2f}", | ||
transform=plt.gca().transAxes, | ||
fontsize=title_fontsize - 1, | ||
) | ||
plt.text( | ||
x_pos, | ||
y_pos - 0.1, | ||
f"{percentile}th Percentile: {pct_similarity:.2f}", | ||
transform=plt.gca().transAxes, | ||
fontsize=title_fontsize - 1, | ||
) | ||
num_samples_over_threshold = np.sum(similarity_array > sim_threshold) | ||
plt.text( | ||
x_pos, | ||
y_pos - 0.15, | ||
f"Number of samples over threshold: {num_samples_over_threshold}", | ||
transform=plt.gca().transAxes, | ||
fontsize=title_fontsize - 1, | ||
) | ||
earliest_doc_indices = df.groupby("cluster")["createdDt"].idxmin().values | ||
|
||
return weekly_data.loc[earliest_doc_indices] | ||
# Save and/or show the figure | ||
filename = title_name.replace(" ", "_").lower() | ||
filename = f"similarity_distribution_{filename}.png" | ||
output_file = f"{output_dir}/{filename}" | ||
if save_fig: | ||
Path(output_dir).mkdir(parents=True, exist_ok=True) | ||
plt.savefig(output_file, dpi=300) | ||
if show_fig: | ||
plt.show() | ||
return filename | ||
|
||
|
||
def filter_similar_docs(dataset): | ||
num_cores = mp.cpu_count() | ||
pool = mp.Pool(num_cores) | ||
def process_batch( | ||
data: Tuple[str, pd.DataFrame], | ||
min_num_docs: int = 5, | ||
percentile: int = 80, | ||
distance_threshold: Optional[float] = None, | ||
linkage: str = "average", | ||
token_col: str = "nouns", | ||
id_col: str = "newsId", | ||
ordering_col: str = "createdDt_int", | ||
duplicate_col: str = "duplicate", | ||
fig_col: str = "fig_filename", | ||
output_dir: str = ".", | ||
show_fig: bool = False, | ||
save_fig: bool = False, | ||
) -> pd.DataFrame: | ||
""" | ||
Processes a batch of data by calculating the cosine similarity between all pairs of documents in the batch. | ||
""" | ||
batch_name, batch_data = data | ||
batch_data.reset_index(inplace=True) | ||
if len(batch_data) < min_num_docs: | ||
logger.info( | ||
"Batch %s has %d documents, which is less than the minimum number of documents (%d). Skipping.", | ||
batch_name, | ||
len(batch_data), | ||
min_num_docs, | ||
) | ||
return batch_data | ||
# Custom tokenizer that uses the already tokenized words | ||
vectorizer = TfidfVectorizer(tokenizer=lambda x: x, lowercase=False) | ||
tfidf_matrix = vectorizer.fit_transform(batch_data[token_col]) | ||
similarity_matrix = cosine_similarity(tfidf_matrix) | ||
filename = plot_similarity_distribution( | ||
similarity_matrix, | ||
distance_threshold=distance_threshold, | ||
title_name=batch_name, | ||
percentile=percentile, | ||
output_dir=output_dir, | ||
show_fig=show_fig, | ||
save_fig=save_fig, | ||
) | ||
batch_data[fig_col] = filename | ||
|
||
# Convert the HuggingFace dataset to a pandas DataFrame | ||
dataset = dataset.to_pandas() | ||
# Calculate the 90th percentile of the similarity scores | ||
percentile_thres = np.percentile(similarity_matrix, percentile) | ||
if not distance_threshold: | ||
distance_threshold = max(0.5, 1 - percentile_thres) | ||
distance_threshold = max(min(0.99, distance_threshold), 0.01) | ||
# Perform Agglomerative Clustering on the similarity matrix | ||
ac = AgglomerativeClustering( | ||
n_clusters=None, | ||
affinity="precomputed", | ||
linkage=linkage, | ||
distance_threshold=distance_threshold, | ||
) | ||
ac.fit( | ||
1 - similarity_matrix | ||
) # AgglomerativeClustering expects distances, not similarities | ||
# plot_dendrogram(ac) | ||
|
||
# Convert the createdDt column to datetime and set it as the index | ||
dataset["createdDt"] = pd.to_datetime(dataset["createdDt"]) | ||
dataset.set_index("createdDt", inplace=True) | ||
# Create DataFrame df with cluster labels and Unix timestamp createdDt | ||
df = pd.DataFrame( | ||
{"cluster": ac.labels_, ordering_col: batch_data[ordering_col]}, | ||
index=batch_data[id_col], | ||
) | ||
|
||
# Convert createdDt_int to string and concatenate with newsId | ||
df[ordering_col] = df[ordering_col].astype(str) | ||
concat_col = f"{ordering_col}_{id_col}" | ||
df[concat_col] = df[ordering_col] + "|" + df.index | ||
|
||
# Split the dataset into weekly chunks for multiprocessing | ||
weeks = [g for n, g in dataset.groupby(pd.Grouper(freq="W"))] | ||
# Find the minimum createdDt_newsId for each cluster | ||
min_createdDt_newsId = df.groupby("cluster")[concat_col].min() | ||
|
||
# Process each week in parallel | ||
filtered_weeks = pool.map(process_week, weeks) | ||
# Extract newsId from the minimum createdDt_newsId | ||
earliest_doc_indices = min_createdDt_newsId.str.split("|").str[-1] | ||
similar_data = batch_data[~batch_data[id_col].isin(earliest_doc_indices)] | ||
batch_data.loc[similar_data.index, duplicate_col] = True | ||
return batch_data | ||
|
||
# Concatenate the results into a single filtered dataset | ||
filtered_dataset = pd.concat(filtered_weeks) | ||
|
||
pool.close() | ||
pool.join() | ||
def find_similar_docs( | ||
data: pd.DataFrame, | ||
num_workers: int = 2, | ||
min_num_docs: int = 5, | ||
percentile: int = 80, | ||
distance_threshold: Optional[float] = None, | ||
linkage: str = "average", | ||
grouping_freq: str = "W", | ||
grouping_name: str = "Week", | ||
date_col: str = "createdDt", | ||
token_col: str = "nouns", | ||
id_col: str = "newsId", | ||
ordering_col: str = "createdDt_int", | ||
duplicate_col: str = "duplicate", | ||
fig_col: str = "fig_filename", | ||
output_dir: str = ".", | ||
show_fig: bool = False, | ||
save_fig: bool = False, | ||
verbose: bool = False, | ||
) -> pd.DataFrame: | ||
""" | ||
Finds similar documents in the given data using cosine similarity and Agglomerative Clustering. | ||
""" | ||
# Convert createdDt to datetime first | ||
data[date_col] = pd.to_datetime(data[date_col]) | ||
# Convert createdDt to Unix timestamp and store in createdDt_int | ||
data[ordering_col] = data[date_col].astype(np.int64) // 10**9 | ||
data.set_index(date_col, inplace=True) | ||
data[duplicate_col] = False | ||
data[fig_col] = None | ||
|
||
return filtered_dataset | ||
# Group data by week | ||
batchs = [grp for _, grp in data.groupby(pd.Grouper(freq=grouping_freq))] | ||
batchs = [ | ||
(f"{grouping_name} {min(week.index).date()}", week) | ||
for week in batchs | ||
if len(week) > 0 | ||
] | ||
|
||
# Prepare partial function for process_batch | ||
process_batch_partial = partial( | ||
process_batch, | ||
min_num_docs=min_num_docs, | ||
percentile=percentile, | ||
distance_threshold=distance_threshold, | ||
linkage=linkage, | ||
token_col=token_col, | ||
id_col=id_col, | ||
ordering_col=ordering_col, | ||
duplicate_col=duplicate_col, | ||
fig_col=fig_col, | ||
output_dir=output_dir, | ||
show_fig=show_fig, | ||
save_fig=save_fig, | ||
) | ||
|
||
# Apply process_week to each group of data using multiprocessing | ||
with mp.Pool(processes=num_workers) as pool: | ||
filtered_weeks = pool.map( | ||
process_batch_partial, | ||
batchs, | ||
) | ||
|
||
# Use the function as follows | ||
# dataset is your HuggingFace dataset | ||
# filtered_dataset = filter_similar_docs(dataset) | ||
data = pd.concat(filtered_weeks) | ||
logger.info("Number of documents: %d", len(data)) | ||
if duplicate_col in data.columns: | ||
logger.info("Number of duplicate documents: %d", data[duplicate_col].sum()) | ||
if verbose: | ||
print(data.tail(10)) | ||
return data |