In [1]:
# manage data
from datasets import load_dataset
import pandas as pd

# embeddings
from sentence_transformers import SentenceTransformer

# dimensionality reduction
import umap

# clustering
import hdbscan

# extract keywords from texts
# used to assign meaningful names to clusters
from keybert import KeyBERT

# visualization
import plotly.express as px

In [2]:
# download data
dataset = load_dataset("ag_news", split="train")
print(dataset)


Dataset({
    features: ['text', 'label'],
    num_rows: 120000
})


In [3]:
# keep only first 3k articles to make computations faster
dataset_subset = dataset.train_test_split(train_size=3000)["train"]
print(dataset_subset)

Dataset({
    features: ['text', 'label'],
    num_rows: 3000
})


In [4]:
# convert dataset to pandas dataframe
# df = pd.DataFrame(dataset_subset).drop("label", axis=1)
df = pd.DataFrame(dataset_subset)
df.head()

Unnamed: 0,text,label
0,Talks Resume In Bid For N. Ireland Gov #39;t T...,0
1,"Costco profits rise, but stock falls in early ...",2
2,Syria's grip on Lebanon tested The dominance o...,0
3,The Arafat void Even as a gravely ill Yasser A...,0
4,Missouri Linebacker Suspended for Game (AP) AP...,1


In [5]:
# Label distribution
print(df.label.value_counts())

label
0    767
3    754
2    741
1    738
Name: count, dtype: int64


In [6]:
# download the sentence embeddings model
embedder = SentenceTransformer('all-mpnet-base-v2')

In [7]:
# embed article texts
corpus_embeddings = embedder.encode(df["text"].values)
print(corpus_embeddings.shape)

(3000, 768)


In [8]:
# reduce the size of the embeddings using UMAP
reduced_embeddings = umap.UMAP(n_components=2, n_neighbors=100, min_dist=0.02).fit_transform(corpus_embeddings)
print(reduced_embeddings.shape)

# put the values of the two dimensions inside the dataframe
df["x"] = reduced_embeddings[:, 0]
df["y"] = reduced_embeddings[:, 1]

# substring of the full text, for visualization purposes
df["text_short"] = df["text"].str[:100]

OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.


(3000, 2)


In [9]:
# scatter plot
hover_data = {
    "text_short": True,
    "x": False,
    "y": False
}
fig = px.scatter(df, x="x", y="y", template="plotly_dark",
                   title="Embeddings", hover_data=hover_data)
fig.update_layout(showlegend=False)
fig.show()

In [10]:
# clustering with HDBSCAN
clusterer = hdbscan.HDBSCAN(min_cluster_size=9)
labels = clusterer.fit_predict(reduced_embeddings)
df["label"] = [str(label) for label in labels]
print(f"Num of clusters: {labels.max()}")

Num of clusters: 12


In [11]:
# number of outliers
num_outliers = len(df[df["label"] == "-1"])
print(f"Num of outliers: {num_outliers} ({num_outliers / len(df) * 100:.2f} % of total)")

Num of outliers: 129 (4.30 % of total)


In [12]:
# remove outliers
df_no_outliers = df[df["label"] != "-1"]

# scatter plot
hover_data = {
    "text_short": True,
    "x": False,
    "y": False
}
fig = px.scatter(df_no_outliers, x="x", y="y", template="plotly_dark",
                   title="Embeddings", color="label", hover_data=hover_data)
fig.show()

In [13]:
# show articles in a specific cluster
cluster = "0"
df_subset = df[df["label"] == cluster].reset_index()
for i,row in df_subset.iterrows():
  print(f"- {row['text_short']}")
  if i == 10:
    break

- Explosion rocks Israel checkpoint A number of Israelis are hurt in a blast on the border between the
- After Arafat It is often the case with charismatic rebels with a just cause that, when they pass on,
- First deadline passes for Supreme Court Grokster case First round of comments from friends of the co
- JC Penney #39;s Castagna to Leave Company JC Penney Co. said Friday that Vanessa Castagna - chairman
- Finance minister appears to tip his hand on Reserve Bank rate hike New Zealand Finance Minister Mich
- Humans may need fewer genes than thought How many genes does it take to make a human? Only about the
- Genesis data 'retrieved intact' Scientists are now optimistic some valuable data can be salvaged fro
- Black Eyed Peas to Produce Music for 'Urbz' Game  LOS ANGELES (Reuters) - The band that brought the 
- Top Shiite cleric insists vote be held on time even if America &lt;b&gt;...&lt;/b&gt; US warplanes b
- OPEC Weighs Output Increase Members of the Organization of Petroleum Ex

In [14]:
# extracting keywords from texts with KeyBERT
cluster = "0"
df_subset = df[df["label"] == cluster].reset_index()
texts_concat = ". ".join(df_subset["text"].values)
keywords_and_scores = KeyBERT().extract_keywords(texts_concat,
                                    keyphrase_ngram_range=(1, 1), top_n=10)
print(keywords_and_scores)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

[('aviv', 0.4019), ('israeli', 0.3916), ('knesset', 0.3907), ('gazans', 0.3746), ('gaza', 0.3745), ('israelis', 0.3697), ('israels', 0.3692), ('explosion', 0.3672), ('israel', 0.3507), ('explodes', 0.3272)]


In [15]:
# keep only the keywords with different stem
def filter_keywords(keywords, n_keep=3):
  new_keywords = []
  for candidate_keyword in keywords:
    is_ok = True
    for compare_keyword in keywords:
      if candidate_keyword == compare_keyword:
        continue
      if compare_keyword in candidate_keyword:
        is_ok = False
        break
    if is_ok:
      new_keywords.append(candidate_keyword)
      if len(new_keywords) >= n_keep:
        break
  return new_keywords

keywords = [t[0] for t in keywords_and_scores]
keywords_filtered = filter_keywords(keywords)
print(keywords_filtered)

['aviv', 'knesset', 'gaza']


In [16]:
# assign a meaningful name to each cluster

def get_cluster_name(df, cluster):
  df_subset = df[df["label"] == cluster].reset_index()
  texts_concat = ". ".join(df_subset["text"].values)
  kw_model = KeyBERT()
  keywords_and_scores = kw_model.extract_keywords(texts_concat, keyphrase_ngram_range=(1, 1),
                                      top_n=10)
  keywords = [t[0] for t in keywords_and_scores]
  keywords_filtered = filter_keywords(keywords)
  return " - ".join(keywords_filtered)

# get all the new cluster names
all_clusters = df_no_outliers["label"].unique()
d_cluster_name_mapping = {}
for cluster in all_clusters:
  if cluster == "-1":
    d_cluster_name_mapping[cluster] = "outliers"
  else:
    d_cluster_name_mapping[cluster] = get_cluster_name(df_no_outliers, cluster)

# rename clusters
df_no_outliers["label"] = df_no_outliers["label"].apply(lambda label: d_cluster_name_mapping[label])



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



In [17]:
# scatter plot
hover_data = {
    "text_short": True,
    "x": False,
    "y": False
}
fig = px.scatter(df_no_outliers, x="x", y="y", template="plotly_dark",
                   title="Embeddings", color="label", hover_data=hover_data)
fig.show()