In [1]:
from datasets import load_dataset
import pandas as pd

# embeddings
from sentence_transformers import SentenceTransformer

# dimensionality reduction
import umap

# clustering
import hdbscan

# extract keywords from texts
# used to assign meaningful names to clusters
from keybert import KeyBERT

# visualization
import plotly.express as px

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = load_dataset('ag_news', split='train')
dataset

Found cached dataset ag_news (/home/mushahid/.cache/huggingface/datasets/ag_news/default/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


Dataset({
    features: ['text', 'label'],
    num_rows: 120000
})

In [3]:
dataset_subset = dataset.train_test_split(train_size=3000)["train"]
print(dataset_subset)

Dataset({
    features: ['text', 'label'],
    num_rows: 3000
})


In [4]:
df = pd.DataFrame(dataset_subset).drop("label", axis=1)
df.head()

Unnamed: 0,text
0,Piazza to come off disabled list NEW YORK -- T...
1,Slate #39;s Jurisprudence: Court Mulls Wine by...
2,Last Rites Sounded for Life-Changing Video For...
3,"Battleground voters wrestle with US economy, s..."
4,"Microsoft, Ask Jeeves unveil desktop search to..."


In [5]:
embedder = SentenceTransformer('all-mpnet-base-v2')

In [6]:
corpus_embeddings = embedder.encode(df['text'].values)
print(corpus_embeddings.shape)

(3000, 768)


In [7]:
reduced_embeddings = umap.UMAP(n_components=2, n_neighbors=100, min_dist=0.02).fit_transform(corpus_embeddings)

df['x'] = reduced_embeddings[:,0]
df['y'] = reduced_embeddings[:,1]

df["text_short"] = df["text"].str[:100]

In [8]:
hover_data = {
    "text_short": True,
    "x": False,
    "y": False
}
fig = px.scatter(df, x="x", y="y", template="plotly_dark",
                   title="Embeddings", hover_data=hover_data)
fig.update_layout(showlegend=False)
fig.show()

In [9]:
clusterer = hdbscan.HDBSCAN(min_cluster_size=9)
labels = clusterer.fit_predict(reduced_embeddings)
df["label"] = [str(label) for label in labels]
print(f"Num of clusters: {labels.max()}")

Num of clusters: 10


In [10]:
# number of outliers
num_outliers = len(df[df["label"] == "-1"])
print(f"Num of outliers: {num_outliers} ({num_outliers / len(df) * 100:.2f} % of total)")

Num of outliers: 171 (5.70 % of total)


In [11]:
# remove outliers
df_no_outliers = df[df["label"] != "-1"]

# scatter plot
hover_data = {
    "text_short": True,
    "x": False,
    "y": False
}
fig = px.scatter(df_no_outliers, x="x", y="y", template="plotly_dark",
                   title="Embeddings", color="label", hover_data=hover_data)
fig.show()

In [28]:
cluster = "0"
df_subset = df[df["label"] == cluster].reset_index()
texts_concat = ". ".join(df_subset["text"].values)

keywords_and_scores = KeyBERT().extract_keywords(texts_concat,
                                    keyphrase_ngram_range=(1, 1), top_n=10)
print(keywords_and_scores)

[('judiciary', 0.4365), ('justices', 0.3954), ('judicial', 0.3896), ('courts', 0.3846), ('jeeves', 0.3719), ('defendants', 0.3572), ('court', 0.3569), ('courtroom', 0.3537), ('lawmakers', 0.3518), ('proceedings', 0.3471)]


In [29]:
def filter_keywords(keywords, n_keep=3):
  new_keywords = []
  for candidate_keyword in keywords:
    is_ok = True
    for compare_keyword in keywords:
      if candidate_keyword == compare_keyword:
        continue
      if compare_keyword in candidate_keyword:
        is_ok = False
        break
    if is_ok:
      new_keywords.append(candidate_keyword)
      if len(new_keywords) >= n_keep:
        break
  return new_keywords

keywords = [t[0] for t in keywords_and_scores]
keywords_filtered = filter_keywords(keywords)
print(keywords_filtered)

['judiciary', 'justices', 'judicial']


In [36]:
def get_cluster_name(df, cluster):
  df_subset = df[df["label"] == cluster].reset_index()
  texts_concat = ". ".join(df_subset["text"].values)
  kw_model = KeyBERT()
  keywords_and_scores = kw_model.extract_keywords(texts_concat, keyphrase_ngram_range=(1, 1), top_n=10)
  keywords = [t[0] for t in keywords_and_scores]
  keywords_filtered = filter_keywords(keywords)

  return " - ".join(keywords_filtered)

# get all the new cluster names
all_clusters = df_no_outliers["label"].unique()
d_cluster_name_mapping = {}

for cluster in all_clusters:
  if cluster == "-1":
    d_cluster_name_mapping[cluster] = "outliers"
  else:
    cluster_name = get_cluster_name(df_no_outliers, cluster)
    print(cluster_name)
    d_cluster_name_mapping[cluster] = cluster_name

mets - inning - braves
judiciary - justices - judicial
kobe - yao - nba
qb - quarterback - cornerback
robson - tottenham - striker
ryder - mcilroy - mickelson
sven - finland - sweden
olympian - olympic - sprinter
auburn - baylor - usc
mascot - pacquiao - busch
cricket - batsman - icc


In [38]:
df_no_outliers["label"] = df_no_outliers["label"].apply(lambda label: d_cluster_name_mapping[label])



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



In [39]:
df_no_outliers

Unnamed: 0,text,x,y,text_short,label
0,Piazza to come off disabled list NEW YORK -- T...,3.601893,5.672773,Piazza to come off disabled list NEW YORK -- T...,mets - inning - braves
1,Slate #39;s Jurisprudence: Court Mulls Wine by...,5.768422,1.359118,Slate #39;s Jurisprudence: Court Mulls Wine by...,judiciary - justices - judicial
2,Last Rites Sounded for Life-Changing Video For...,6.516179,-0.577455,Last Rites Sounded for Life-Changing Video For...,judiciary - justices - judicial
3,"Battleground voters wrestle with US economy, s...",7.480041,3.987694,"Battleground voters wrestle with US economy, s...",judiciary - justices - judicial
4,"Microsoft, Ask Jeeves unveil desktop search to...",5.649577,-0.441817,"Microsoft, Ask Jeeves unveil desktop search to...",judiciary - justices - judicial
...,...,...,...,...,...
2995,Turkey on Chirac-Schroeder agenda President Ja...,8.200398,3.891457,Turkey on Chirac-Schroeder agenda President Ja...,judiciary - justices - judicial
2996,"Third-period comeback sends US out Finland, a ...",4.704815,4.790061,"Third-period comeback sends US out Finland, a ...",sven - finland - sweden
2997,Gerrard stunner seals comeback STEVEN GERRARD ...,4.326807,4.539196,Gerrard stunner seals comeback STEVEN GERRARD ...,robson - tottenham - striker
2998,Turner Doesn't Believe in Moral Victories (AP)...,3.036557,4.551551,Turner Doesn't Believe in Moral Victories (AP)...,qb - quarterback - cornerback


In [40]:
hover_data = {
    "text_short": True,
    "x": False,
    "y": False
}
fig = px.scatter(df_no_outliers, x="x", y="y", template="plotly_dark",
                   title="Embeddings", color="label", hover_data=hover_data)
fig.show()