# 标签生成



In [1]:
import os
import collections
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import utils

from typing import Dict, Set, Tuple, List
from sklearn.cluster import DBSCAN

CSV_PATH = './data'
SAMPLE_NUM = 7000

os.environ["LOKY_MAX_CPU_COUNT"] = "4"

if utils.in_jupyter():
    # 在 Jupyter 时 tqdm 的导入方式
    from tqdm.notebook import tqdm
else:
    # 在终端时 tqdm 的导入方式
    from tqdm import tqdm

In [2]:
# 将 csv 读入 DataFrame
train_csv_path = os.path.join(CSV_PATH, 'embed_label.csv')
df = utils.read_embedding_csv(csv_path=train_csv_path,
                              ebd_cols=['embeddings'])
len(df), len(set(df['labels'].tolist()))

(10000, 100)

In [3]:
df = df.drop(columns=['labels'])
df.head()

Unnamed: 0,embeddings
0,"[0.024523582309484482, -0.03633105754852295, 0..."
1,"[-0.002521098591387272, 0.022899063304066658, ..."
2,"[0.008400454185903072, -0.012612388469278812, ..."
3,"[-0.004734962247312069, -0.0035224033053964376..."
4,"[-0.021240245550870895, -0.03918471559882164, ..."


In [4]:
df['embeddings'][0].shape

(1408,)

In [5]:
train_df = df.head(SAMPLE_NUM).copy()

eps=0.1
clustering = DBSCAN(eps=eps, min_samples=3, metric='cosine').fit(train_df['embeddings'].tolist())
labels = clustering.labels_

In [6]:
len(labels), max(labels), len([1 for e in labels if e != -1])

(7000, 147, 1993)

In [7]:
labels_counter = collections.Counter(labels)
sorted_labels = sorted(labels_counter.items(), key=lambda e: e[1], reverse=True)
sorted_labels[:5]

[(-1, 5007), (10, 117), (17, 72), (8, 67), (6, 64)]

In [8]:
# 把 labels 作为 dbscan_id 写入 DataFrame 中
train_df['dbscan_id'] = labels
train_df

Unnamed: 0,embeddings,dbscan_id
0,"[0.024523582309484482, -0.03633105754852295, 0...",-1
1,"[-0.002521098591387272, 0.022899063304066658, ...",0
2,"[0.008400454185903072, -0.012612388469278812, ...",-1
3,"[-0.004734962247312069, -0.0035224033053964376...",-1
4,"[-0.021240245550870895, -0.03918471559882164, ...",1
...,...,...
6995,"[0.013036987744271755, 0.004907825030386448, 0...",5
6996,"[0.006075031124055386, 0.06972860544919968, 0....",-1
6997,"[-0.0033710638526827097, 0.03442999720573425, ...",-1
6998,"[-0.02391742914915085, 0.04698537290096283, 0....",100


In [9]:
# 新增一个列 cluster_id，为值为 -1 的类赋予 label

def id_generator(used_id_set: Set[int]):
    """生成未被使用的最小ID"""
    i = 0
    while True:
        while i in used_id_set:
            i += 1
        yield i
        i += 1

dbscan_ids = train_df['dbscan_id'].tolist()
gen = id_generator(set(dbscan_ids))

cluster_id = list()
for e in train_df['dbscan_id'].tolist():
    if e == -1:
        cluster_id.append(next(gen))
    else:
        cluster_id.append(e)

train_df['cluster_id'] = cluster_id
train_df.head(5)

Unnamed: 0,embeddings,dbscan_id,cluster_id
0,"[0.024523582309484482, -0.03633105754852295, 0...",-1,148
1,"[-0.002521098591387272, 0.022899063304066658, ...",0,0
2,"[0.008400454185903072, -0.012612388469278812, ...",-1,149
3,"[-0.004734962247312069, -0.0035224033053964376...",-1,150
4,"[-0.021240245550870895, -0.03918471559882164, ...",1,1


In [10]:
# 为每个 cluster_id 计算聚类中心
cluster_center = list()
for e in train_df['cluster_id'].tolist():
    embeds = train_df[train_df["cluster_id"] == e]["embeddings"]
    cluster_center.append(np.mean(embeds, axis=0).tolist())

train_df['cluster_center'] = cluster_center
train_df.head(5)

Unnamed: 0,embeddings,dbscan_id,cluster_id,cluster_center
0,"[0.024523582309484482, -0.03633105754852295, 0...",-1,148,"[0.024523582309484482, -0.03633105754852295, 0..."
1,"[-0.002521098591387272, 0.022899063304066658, ...",0,0,"[-0.00018301361706107855, 0.022485706851714186..."
2,"[0.008400454185903072, -0.012612388469278812, ...",-1,149,"[0.008400454185903072, -0.012612388469278812, ..."
3,"[-0.004734962247312069, -0.0035224033053964376...",-1,150,"[-0.004734962247312069, -0.0035224033053964376..."
4,"[-0.021240245550870895, -0.03918471559882164, ...",1,1,"[-0.006647970941932206, -0.02852569787230875, ..."


In [11]:
# 新 embedding 加入现有聚类
test_embeddings = df['embeddings'][SAMPLE_NUM:].tolist()
len(test_embeddings)

3000

由于 DBSCAN 基于密度可达性成簇，其边界可能是任意非凸形状。因此 cluster_center 不一定在簇内部。

为了简化新 embeddings 加入现有簇的流程，我们只将在 cluster_center 的 eps 邻域内的 embeddings 加入现有簇，否则建立新簇。这种做法一点也不严谨hhh，但是看你的任务是什么，在我的任务下这种做法还算堪用，好处是计算量小，服务器不会太累。

In [12]:
from sklearn.metrics.pairwise import cosine_distances

# 将列表转换为元组进行去重
def remove_duplicate(lst):
    unique_tuples = set(tuple(sublist) for sublist in lst)
    return [list(t) for t in unique_tuples]

# 声明一个新的 id 生成器
cluster_ids = train_df['cluster_id'].tolist()
gen = id_generator(set(cluster_ids))

# 对于新加入的 embedding，比照现有类心，看这个 embedding 是否在某个类心的 eps 范围内，如果不在就建立新簇
# for embed in test_embeddings:
for embed in tqdm(test_embeddings, total=len(test_embeddings), desc="Processing embeddings"):
    # 现有类心
    centroids = remove_duplicate(train_df['cluster_center'].tolist())
    found_cluster = False
    for centroid in centroids:
        distance = cosine_distances([embed], [centroid])[0][0]
        if distance <= eps:
            found_cluster = True
            my_centroid = centroid
            break
    if not found_cluster:
        my_centroid = embed

    cluster_id = next(gen)
    new_row = {
        'embeddings': embed,
        'dbscan_id': None,
        'cluster_id': cluster_id,
        'cluster_center': my_centroid}
    train_df = pd.concat([train_df, pd.DataFrame([new_row])], ignore_index=True)

Processing embeddings:   0%|          | 0/3000 [00:00<?, ?it/s]

In [15]:
train_df

Unnamed: 0,embeddings,dbscan_id,cluster_id,cluster_center
0,"[0.024523582309484482, -0.03633105754852295, 0...",-1,148,"[0.024523582309484482, -0.03633105754852295, 0..."
1,"[-0.002521098591387272, 0.022899063304066658, ...",0,0,"[-0.00018301361706107855, 0.022485706851714186..."
2,"[0.008400454185903072, -0.012612388469278812, ...",-1,149,"[0.008400454185903072, -0.012612388469278812, ..."
3,"[-0.004734962247312069, -0.0035224033053964376...",-1,150,"[-0.004734962247312069, -0.0035224033053964376..."
4,"[-0.021240245550870895, -0.03918471559882164, ...",1,1,"[-0.006647970941932206, -0.02852569787230875, ..."
...,...,...,...,...
9995,"[0.02526906132698059, 0.006334671750664711, 0....",,8150,"[0.02526906132698059, 0.006334671750664711, 0...."
9996,"[-0.0032427890691906214, 0.0032633657101541758...",,8151,"[-0.0032427890691906214, 0.0032633657101541758..."
9997,"[0.001930834841914475, -0.025012478232383728, ...",,8152,"[0.0057026533932800876, -0.018566691178293376,..."
9998,"[0.025050941854715347, -0.017404677346348763, ...",,8153,"[0.025493682051698368, -0.011526024201884866, ..."
