In [6]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import numpy as np
import pandas as pd
from sentence_transformers import SentenceTransformer, models


# 下載20 newsgroups測試集(shuffle() 資料是否隨機打亂)
#  Bunch格式，類似字典
newsgroups_test = fetch_20newsgroups(subset='test', shuffle=True, random_state=30241)


df = pd.DataFrame({
    'text': newsgroups_test.data,
    'target': newsgroups_test.target
})

# 用 sample 函數，隨機抽取 50 筆資料
sample_df = df.groupby('target', group_keys=False).apply(lambda x: x.sample(50, random_state=30241))
sample_df


# List[]
# 每一個類別隨機抽取50筆資料
# sampled_data = []
# sampled_targets = []
# unique_labels = set(newsgroups_test.target)

# for label in unique_labels:
#     indices = [i for i, target in enumerate(newsgroups_test.target) if target == label]
#     sampled_indices = random.sample(indices, 50)
#     sampled_data.extend([newsgroups_test.data[i] for i in sampled_indices])
#     sampled_targets.extend([newsgroups_test.target[i] for i in sampled_indices])

# print(len(newsgroups_test.data))
# print(len(sampled_data))
# newsgroups_test




Unnamed: 0,text,target
392,From: marshall@csugrad.cs.vt.edu (Kevin Marsha...,0
6693,From: mangoe@cs.umd.edu (Charley Wingate)\nSub...,0
3229,From: mathew <mathew@mantis.co.uk>\nSubject: D...,0
880,From: frank@D012S658.uucp (Frank O'Dwyer)\nSub...,0
622,From: pww@spacsun.rice.edu (Peter Walker)\nSub...,0
...,...,...
6346,From: mls@panix.com (Michael Siemon)\nSubject:...,19
2793,From: pharvey@quack.kfu.com (Paul Harvey)\nSub...,19
57,From: syshtg@gsusgi2.gsu.edu (Tom Gillman)\nSu...,19
4825,From: keith@cco.caltech.edu (Keith Allan Schne...,19


In [7]:
# 使用 [cls] 建立embedding 

# 定義 BERT 模型和 CLS pooling
model = models.Transformer('bert-base-uncased')     # 預訓練的 bert 模型
cls_pooling = models.Pooling(model.get_word_embedding_dimension(),pooling_mode_cls_token=True)  # 啟用 CLS token pooling

# 加載 SentenceTransformer 並使用自定義配置(使用 BERT 模型結合 CLS token 的嵌入方式來提取文本特徵)
cls_model = SentenceTransformer(modules=[model, cls_pooling])

# 提取文本嵌入(轉換成 embedding)
# 預設為 numpy array，轉成 tensor 有助於後續操作
embeddings = cls_model.encode(sample_df['text'].tolist(), convert_to_tensor=True)

# 將嵌入添加到 DataFrame 中
sample_df['cls_embedding'] = embeddings.tolist()
print(sample_df.head())


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


                                                   text  target  \
392   From: marshall@csugrad.cs.vt.edu (Kevin Marsha...       0   
6693  From: mangoe@cs.umd.edu (Charley Wingate)\nSub...       0   
3229  From: mathew <mathew@mantis.co.uk>\nSubject: D...       0   
880   From: frank@D012S658.uucp (Frank O'Dwyer)\nSub...       0   
622   From: pww@spacsun.rice.edu (Peter Walker)\nSub...       0   

                                          cls_embedding  
392   [0.22063714265823364, 0.04412689805030823, -0....  
6693  [-0.1905311942100525, -0.06609037518501282, -0...  
3229  [-0.12532277405261993, -0.17267800867557526, 0...  
880   [-0.24792765080928802, -0.09719590097665787, 0...  
622   [0.16099587082862854, 0.13795432448387146, 0.0...  


In [8]:
sample_df['text'].tolist()

['From: marshall@csugrad.cs.vt.edu (Kevin Marshall)\nSubject: Re: Death Penalty (was Re: Political Atheists?)\nOrganization: Virginia Tech Computer Science Dept, Blacksburg, VA\nLines: 46\nNNTP-Posting-Host: csugrad.cs.vt.edu\n\nbil@okcforum.osrhe.edu (Bill Conner) writes:\n\n>This is fascinating. Atheists argue for abortion, defend homosexuality\n>as a means of population control, insist that the only values are\n>biological and condemn war and capital punishment. According to\n>Benedikt, if something is contardictory, it cannot exist, which in\n>this case means atheists I suppose.\n>I would like to understand how an atheist can object to war (an\n>excellent means of controlling population growth), or to capital\n>punishment, I\'m sorry but the logic escapes me.\n\nFirst, you seem to assume all atheists think alike.  An atheist does not\nbelieve in the existence of a god.  Our opinions on issues such as \ncapital punishment and abortion, however, vary greatly.  \n\nIf you were attacki

In [10]:
# 使用 mean pooling 建立embedding
pooling_model = SentenceTransformer('bert-base-uncased')

# 提取文本做嵌入
# SentenceTransformer 預設使用的是 mean pooling
pooling = pooling_model.encode(sample_df['text'].tolist(), convert_to_tensor=True)

# 將嵌入添加到 DataFrame 中
sample_df['mean_pooling'] = pooling.tolist()
print(sample_df.head())

No sentence-transformers model found with name /home/leo85741/.cache/torch/sentence_transformers/bert-base-uncased. Creating a new one with MEAN pooling.
Some weights of the model checkpoint at /home/leo85741/.cache/torch/sentence_transformers/bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification mo

                                                   text  target  \
392   From: marshall@csugrad.cs.vt.edu (Kevin Marsha...       0   
6693  From: mangoe@cs.umd.edu (Charley Wingate)\nSub...       0   
3229  From: mathew <mathew@mantis.co.uk>\nSubject: D...       0   
880   From: frank@D012S658.uucp (Frank O'Dwyer)\nSub...       0   
622   From: pww@spacsun.rice.edu (Peter Walker)\nSub...       0   

                                          cls_embedding  \
392   [0.22063714265823364, 0.04412689805030823, -0....   
6693  [-0.1905311942100525, -0.06609037518501282, -0...   
3229  [-0.12532277405261993, -0.17267800867557526, 0...   
880   [-0.24792765080928802, -0.09719590097665787, 0...   
622   [0.16099587082862854, 0.13795432448387146, 0.0...   

                                           mean_pooling  
392   [-0.07167752832174301, 0.19374574720859528, 0....  
6693  [-0.017095178365707397, 0.11830613017082214, 0...  
3229  [0.061830248683691025, 0.18986721336841583, 0....  
880   [-0.

In [11]:
sample_df

Unnamed: 0,text,target,cls_embedding,mean_pooling
392,From: marshall@csugrad.cs.vt.edu (Kevin Marsha...,0,"[0.22063714265823364, 0.04412689805030823, -0....","[-0.07167752832174301, 0.19374574720859528, 0...."
6693,From: mangoe@cs.umd.edu (Charley Wingate)\nSub...,0,"[-0.1905311942100525, -0.06609037518501282, -0...","[-0.017095178365707397, 0.11830613017082214, 0..."
3229,From: mathew <mathew@mantis.co.uk>\nSubject: D...,0,"[-0.12532277405261993, -0.17267800867557526, 0...","[0.061830248683691025, 0.18986721336841583, 0...."
880,From: frank@D012S658.uucp (Frank O'Dwyer)\nSub...,0,"[-0.24792765080928802, -0.09719590097665787, 0...","[-0.2074468582868576, 0.03253013640642166, 0.6..."
622,From: pww@spacsun.rice.edu (Peter Walker)\nSub...,0,"[0.16099587082862854, 0.13795432448387146, 0.0...","[0.23597092926502228, 0.27605223655700684, 0.2..."
...,...,...,...,...
6346,From: mls@panix.com (Michael Siemon)\nSubject:...,19,"[-0.17418424785137177, 0.2910330593585968, 0.2...","[0.003673875704407692, 0.32565951347351074, 0...."
2793,From: pharvey@quack.kfu.com (Paul Harvey)\nSub...,19,"[-0.33445265889167786, 0.4663501977920532, -0....","[-0.030506253242492676, 0.29554998874664307, 0..."
57,From: syshtg@gsusgi2.gsu.edu (Tom Gillman)\nSu...,19,"[-0.1854058802127838, -0.1649809181690216, 0.0...","[-0.08183643221855164, 0.09059812128543854, 0...."
4825,From: keith@cco.caltech.edu (Keith Allan Schne...,19,"[-0.17625266313552856, -0.025276828557252884, ...","[-0.13266976177692413, -0.08071672916412354, 0..."


In [20]:
# sample_df['cls_embedding', 'mean_pooling'] 中存儲的是嵌入向量
# 將嵌入向量轉換為 numpy array
X = np.array(sample_df['cls_embedding'].tolist())
X_mean = np.array(sample_df['mean_pooling'].tolist())

# 初始化 KMeans 模型，設定群集數量為 20
kmeans = KMeans(n_clusters=20, random_state=30241)
kmeans_mean = KMeans(n_clusters=20, random_state=30241)

# 訓練模型並預測每個文檔的 clustering 標籤
sample_df['cls_cluster'] = kmeans.fit_predict(X)
sample_df['mean_cluster'] = kmeans_mean.fit_predict(X_mean)
sample_df


Unnamed: 0,text,target,cls_embedding,mean_pooling,cls_cluster,mean_cluster
392,From: marshall@csugrad.cs.vt.edu (Kevin Marsha...,0,"[0.22063714265823364, 0.04412689805030823, -0....","[-0.07167752832174301, 0.19374574720859528, 0....",14,17
6693,From: mangoe@cs.umd.edu (Charley Wingate)\nSub...,0,"[-0.1905311942100525, -0.06609037518501282, -0...","[-0.017095178365707397, 0.11830613017082214, 0...",4,7
3229,From: mathew <mathew@mantis.co.uk>\nSubject: D...,0,"[-0.12532277405261993, -0.17267800867557526, 0...","[0.061830248683691025, 0.18986721336841583, 0....",3,5
880,From: frank@D012S658.uucp (Frank O'Dwyer)\nSub...,0,"[-0.24792765080928802, -0.09719590097665787, 0...","[-0.2074468582868576, 0.03253013640642166, 0.6...",12,12
622,From: pww@spacsun.rice.edu (Peter Walker)\nSub...,0,"[0.16099587082862854, 0.13795432448387146, 0.0...","[0.23597092926502228, 0.27605223655700684, 0.2...",4,7
...,...,...,...,...,...,...
6346,From: mls@panix.com (Michael Siemon)\nSubject:...,19,"[-0.17418424785137177, 0.2910330593585968, 0.2...","[0.003673875704407692, 0.32565951347351074, 0....",8,7
2793,From: pharvey@quack.kfu.com (Paul Harvey)\nSub...,19,"[-0.33445265889167786, 0.4663501977920532, -0....","[-0.030506253242492676, 0.29554998874664307, 0...",8,7
57,From: syshtg@gsusgi2.gsu.edu (Tom Gillman)\nSu...,19,"[-0.1854058802127838, -0.1649809181690216, 0.0...","[-0.08183643221855164, 0.09059812128543854, 0....",4,10
4825,From: keith@cco.caltech.edu (Keith Allan Schne...,19,"[-0.17625266313552856, -0.025276828557252884, ...","[-0.13266976177692413, -0.08071672916412354, 0...",3,1


In [21]:
from sklearn.metrics import adjusted_mutual_info_score

# 計算 CLS 向量聚類的調整互信息，用來衡量兩個分類結果之間的一致性
ami_cls = adjusted_mutual_info_score(sample_df['target'], sample_df['cls_cluster'])

# 計算 mean pooling 向量聚類的調整互信息
ami_mean = adjusted_mutual_info_score(sample_df['target'], sample_df['mean_cluster'])

print(f"Adjusted Mutual Information (CLS): {ami_cls}")
print(f"Adjusted Mutual Information (Mean Pooling): {ami_mean}")

# MI 值的範圍是從 0 到 1，1 表示兩個分類結果完全一致，0 表示兩者之間無關聯


Adjusted Mutual Information (CLS): 0.19065666128008857
Adjusted Mutual Information (Mean Pooling): 0.3118966637278098
