# test ptdec

如未安装，安装 `tensorboard`

In [1]:
# !pip install tensorboard

打开终端，启动 tensorboard 监控训练进程:

```bash
tensorboard --logdir=runs --port=6006
```

在浏览器打开：http://localhost:6006/

In [2]:
import os
import numpy as np
import torch.nn as nn
import torch.optim as optim

import utils

os.environ["LOKY_MAX_CPU_COUNT"] = "4"

CSV_PATH = '../data'

## 1. 加载 Embedding 数据

加载第二节计算的图片 Embeddings 和对应 labels.

In [3]:
# 将 csv 读入 DataFrame
train_csv_path = os.path.join(CSV_PATH, 'train_embed_label.csv')
train_df = utils.read_embedding_csv(csv_path=train_csv_path,
                                    ebd_cols=['embeddings'])
len(train_df), len(set(train_df['labels'].tolist()))

(10000, 100)

In [4]:
train_df.head()

Unnamed: 0,embeddings,labels
0,"[0.013868028298020363, -0.01785886101424694, 0...",19
1,"[0.03667556121945381, -0.08648686856031418, 0....",29
2,"[0.0741165354847908, -0.008068534545600414, 0....",0
3,"[-0.034709382802248, 0.048253390938043594, -0....",11
4,"[-0.06292618066072464, 0.06838615983724594, 0....",1


In [5]:
train_embeds = np.array(train_df['embeddings'].tolist())
train_labels = train_df['labels'].values

## 2. 训练 DEC 模型

In [6]:
import torch
from torch.utils.data import Dataset, DataLoader

class EmbeddingDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = torch.tensor(embeddings, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)
    
    def __len__(self):
        return len(self.embeddings)
    
    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

dataset = EmbeddingDataset(train_embeds, train_labels)

In [8]:
class Encoder(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=256):
        super().__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.layer(x)

encoder = Encoder(input_dim=768, hidden_dim=256)

In [9]:
from ptdec.dec import DEC

model = DEC(
    cluster_number=100,          # 与你的类别数一致
    hidden_dimension=256,        # 编码器输出维度
    encoder=encoder,
    alpha=1.0
)

In [10]:
from ptdec.model import train
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=0.001)

train(
    dataset=dataset,
    model=model,
    epochs=100,                   # 训练轮次
    batch_size=256,              # 批次大小
    optimizer=optimizer,
    stopping_delta=0.001,        # 标签变化小于0.1%时提前停止
    cuda=False,                   # 使用GPU加速（如果可用）
)

100%|█████████████████████████████| 40/40 [00:00<00:00, 276.23batch/s, acc=0.0000, dlb=-1.0000, epo=-1, lss=0.00000000]
100%|████████████████████████████████| 40/40 [00:00<00:00, 52.55batch/s, acc=0.6632, dlb=0.0000, epo=0, lss=0.05070312]
100%|████████████████████████████████| 40/40 [00:00<00:00, 55.31batch/s, acc=0.0706, dlb=0.9241, epo=1, lss=0.07348070]
100%|████████████████████████████████| 40/40 [00:00<00:00, 57.28batch/s, acc=0.0249, dlb=0.2110, epo=2, lss=0.07093086]
100%|████████████████████████████████| 40/40 [00:00<00:00, 57.08batch/s, acc=0.0303, dlb=0.1010, epo=3, lss=0.06684899]
100%|████████████████████████████████| 40/40 [00:00<00:00, 58.43batch/s, acc=0.0341, dlb=0.5470, epo=4, lss=0.05959591]
100%|████████████████████████████████| 40/40 [00:00<00:00, 55.35batch/s, acc=0.0373, dlb=0.8530, epo=5, lss=0.04909821]
100%|████████████████████████████████| 40/40 [00:00<00:00, 52.52batch/s, acc=0.0433, dlb=0.3845, epo=6, lss=0.08741675]
100%|████████████████████████████████| 4

In [11]:
from ptdec.model import predict
from ptdec.utils import cluster_accuracy

# 预测聚类结果
predicted_labels, actual_labels = predict(
    dataset=dataset,
    model=model,
    return_actual=True,
    cuda=False
)

# 计算准确率（自动处理标签重分配）
_, accuracy = cluster_accuracy(
    y_true=actual_labels.numpy(),
    y_predicted=predicted_labels.numpy(),
    cluster_number=100
)

print(f"Clustering Accuracy: {accuracy * 100:.2f}%")

100%|███████████████████████████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 40.00batch/s]

Clustering Accuracy: 8.59%



