使用代码实现下面场景的需求。

场景：
现有一批文本数据，只有少部分有标记。使用对比学习的深度聚类进行训练。对比学习使用基于同义词替换的数据增强方式完成样本构造。
在一个mini-batch中，对于样本x1来说正样本为其增强后的样本x1_,负样本为该mini-batch中其他样本x2及其增强后的样本x2_。
损失函数包括对比损失（InfoNCE），聚类损失（KL散度）和纯净度损失（即聚类后属于同一个类别的簇包含不同真实标签数据的多少。
若该簇没有真实标签的数据或有真实标签的数据但都属于同一类别，则该损失为0，否则有越多不同类别的真实标签数据损失越大）。

方法思路
1. 划分数据：
从有标签数据按各类别的20%比例划分出验证集V1,剩余的标记数据未L和未标记数据N。

2. 初始化阶段：
使用标记数据L训练微调一个bert语言模型。

3. 伪标签生成：
用微调后的语言模型提取所有数据(L+N)的词嵌入特征，然后用k-means进行聚类，给所有数据打上伪标签。

4. 打真实标签：
遍历聚类结果，计算每一个类别的置信度（该类中具有某类真实标签数据的数量比上该类中具有真实标签数据的数量中的最大值），给置信度大于50%且数量大于3的类别打上该类的真实标签，此时L扩充为L=L+m，N缩减为N=N-m。

5.循环执行：
重复步骤2-4，直到所有N都被打上真实标签或达到最大的迭代次数。

6. 最终评估：
在独立的验证集V1上评估模型的最终性能，确保模型的有效性和泛化能力。

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertModel, BertTokenizer
from torch.utils.data import Dataset, DataLoader
from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split
import numpy as np
from nltk.corpus import wordnet
import random
from collections import Counter, defaultdict
from sklearn.metrics import silhouette_score

In [None]:
# 数据增强
class TextAugmenter:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        
    def synonym_replace(self, text):
        tokens = text.split()
        n = max(1, int(len(tokens) * 0.1))  # 替换10%的词
        positions = random.sample(range(len(tokens)), n)
        
        for pos in positions:
            word = tokens[pos]
            synonyms = []
            for syn in wordnet.synsets(word):
                for lemma in syn.lemmas():
                    synonyms.append(lemma.name())
            if synonyms:
                tokens[pos] = random.choice(synonyms)
        
        return " ".join(tokens)

# 自定义数据集
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, augmenter):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.augmenter = augmenter
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # 原始样本
        encoding = self.tokenizer(text, return_tensors='pt', padding='max_length', 
                                truncation=True, max_length=128)
        
        # 增强样本
        aug_text = self.augmenter.synonym_replace(text)
        aug_encoding = self.tokenizer(aug_text, return_tensors='pt', padding='max_length',
                                    truncation=True, max_length=128)
        
        return encoding, aug_encoding, label

# 模型定义
class ContrastiveClusterModel(nn.Module):
    def __init__(self, n_clusters):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.projector = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        self.cluster_head = nn.Linear(128, n_clusters)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = outputs.last_hidden_state[:, 0]  # [CLS] token
        features = self.projector(embeddings)
        logits = self.cluster_head(features)
        return features, logits

# 损失函数
class ContrastiveClusterLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature
        
    def forward(self, features, aug_features, cluster_pred, true_labels, pseudo_labels):
        # InfoNCE loss
        features = F.normalize(features, dim=1)
        aug_features = F.normalize(aug_features, dim=1)
        
        pos_sim = torch.sum(features * aug_features, dim=1)
        neg_sim = torch.mm(features, features.t())
        
        nce_loss = -torch.log(
            torch.exp(pos_sim / self.temperature) /
            (torch.sum(torch.exp(neg_sim / self.temperature), dim=1) - 1)
        ).mean()
        
        # Clustering loss (KL divergence)
        cluster_pred = F.log_softmax(cluster_pred, dim=1)
        q = self.target_distribution(cluster_pred)
        cluster_loss = F.kl_div(cluster_pred, q, reduction='batchmean')
        
        # Purity loss
        purity_loss = self.compute_purity_loss(cluster_pred, true_labels)
        
        return nce_loss + cluster_loss + purity_loss
    
    def target_distribution(self, q):
        weight = q ** 2 / q.sum(0)
        return (weight.t() / weight.sum(1)).t()
    
    def compute_purity_loss(self, cluster_pred, true_labels):
        loss = 0
        pred_labels = torch.argmax(cluster_pred, dim=1)
        
        for c in torch.unique(pred_labels):
            cluster_mask = (pred_labels == c)
            cluster_labels = true_labels[cluster_mask]
            
            if len(cluster_labels) == 0 or len(torch.unique(cluster_labels)) <= 1:
                continue
                
            loss += len(torch.unique(cluster_labels)) - 1
            
        return loss / len(torch.unique(pred_labels))

# 训练函数
def train(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    
    for batch in train_loader:
        orig_encoding, aug_encoding, labels = batch
        
        orig_encoding = {k: v.squeeze(1).to(device) for k, v in orig_encoding.items()}
        aug_encoding = {k: v.squeeze(1).to(device) for k, v in aug_encoding.items()}
        labels = labels.to(device)
        
        features, cluster_pred = model(**orig_encoding)
        aug_features, _ = model(**aug_encoding)
        
        loss = criterion(features, aug_features, cluster_pred, labels, None)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(train_loader)

In [None]:
class PseudoLabelGenerator:
    def __init__(self, n_clusters, confidence_threshold=0.5, min_samples=3):
        self.n_clusters = n_clusters
        self.confidence_threshold = confidence_threshold
        self.min_samples = min_samples
        self.kmeans = KMeans(n_clusters=n_clusters)
        
    def extract_features(self, model, dataloader, device):
        """提取所有数据的特征"""
        model.eval()
        features_list = []
        labels_list = []
        indices_list = []
        
        with torch.no_grad():
            for batch_idx, (orig_encoding, _, labels) in enumerate(dataloader):
                orig_encoding = {k: v.squeeze(1).to(device) for k, v in orig_encoding.items()}
                features, _ = model(**orig_encoding)
                features_list.append(features.cpu().numpy())
                labels_list.append(labels.numpy())
                indices_list.extend(range(batch_idx * dataloader.batch_size,
                                       min((batch_idx + 1) * dataloader.batch_size,
                                           len(dataloader.dataset))))
        
        features = np.concatenate(features_list, axis=0)
        labels = np.concatenate(labels_list, axis=0)
        return features, labels, indices_list

    def generate_pseudo_labels(self, features, true_labels, labeled_mask):
        """生成伪标签"""
        # 使用K-means聚类
        cluster_labels = self.kmeans.fit_predict(features)
        
        # 计算聚类质量
        silhouette_avg = silhouette_score(features, cluster_labels)
        print(f"Clustering Silhouette Score: {silhouette_avg:.4f}")
        
        # 为每个簇计算置信度
        cluster_info = self._analyze_clusters(cluster_labels, true_labels, labeled_mask)
        
        # 生成新的伪标签
        pseudo_labels = np.full(len(features), -1)  # -1表示未分配标签
        confident_indices = []
        
        for cluster_id, info in cluster_info.items():
            if info['confidence'] >= self.confidence_threshold and info['count'] >= self.min_samples:
                cluster_mask = (cluster_labels == cluster_id)
                pseudo_labels[cluster_mask] = info['majority_label']
                # 只为未标记数据生成伪标签
                confident_indices.extend([i for i, (m, c) in enumerate(zip(~labeled_mask, cluster_mask))
                                       if m and c])
        
        return pseudo_labels, confident_indices, cluster_labels

    def _analyze_clusters(self, cluster_labels, true_labels, labeled_mask):
        """分析每个簇的标签分布和置信度"""
        cluster_info = {}
        
        for cluster_id in range(self.n_clusters):
            cluster_mask = (cluster_labels == cluster_id)
            labeled_in_cluster = labeled_mask & cluster_mask
            
            if not any(labeled_in_cluster):
                cluster_info[cluster_id] = {
                    'confidence': 0.0,
                    'majority_label': -1,
                    'count': np.sum(cluster_mask)
                }
                continue
            
            # 统计该簇中有标签样本的类别分布
            cluster_true_labels = true_labels[labeled_in_cluster]
            label_counts = Counter(cluster_true_labels)
            
            majority_label = max(label_counts.items(), key=lambda x: x[1])[0]
            majority_count = label_counts[majority_label]
            total_labeled = sum(label_counts.values())
            
            cluster_info[cluster_id] = {
                'confidence': majority_count / total_labeled,
                'majority_label': majority_label,
                'count': np.sum(cluster_mask)
            }
        
        return cluster_info

def update_dataset(dataset, pseudo_labels, confident_indices):
    """更新数据集的标签"""
    for idx in confident_indices:
        dataset.labels[idx] = pseudo_labels[idx]
    return dataset

def train_with_pseudo_labels(model, train_dataset, device, max_epochs=10, pseudo_label_interval=2):
    """带伪标签更新的训练循环"""
    pseudo_label_generator = PseudoLabelGenerator(n_clusters=10)
    criterion = ContrastiveClusterLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
    
    # 初始化数据加载器
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    
    # 记录已标记的样本
    labeled_mask = np.array([label != -1 for label in train_dataset.labels])
    
    for epoch in range(max_epochs):
        # 常规训练
        train_loss = train(model, train_loader, optimizer, criterion, device)
        print(f"Epoch {epoch+1}/{max_epochs}, Loss: {train_loss:.4f}")
        
        # 定期更新伪标签
        if (epoch + 1) % pseudo_label_interval == 0:
            # 提取特征
            features, true_labels, indices = pseudo_label_generator.extract_features(
                model, train_loader, device)
            
            # 生成新的伪标签
            pseudo_labels, confident_indices, cluster_labels = pseudo_label_generator.generate_pseudo_labels(
                features, true_labels, labeled_mask)
            
            # 更新数据集
            train_dataset = update_dataset(train_dataset, pseudo_labels, confident_indices)
            
            # 更新标记掩码
            labeled_mask[confident_indices] = True
            
            print(f"Generated {len(confident_indices)} new pseudo labels")
            print(f"Total labeled samples: {np.sum(labeled_mask)}/{len(labeled_mask)}")
            
            # 如果所有样本都被标记，提前结束
            if np.all(labeled_mask):
                print("All samples have been labeled. Stopping training.")
                break
            
    return model

def evaluate(model, val_dataset, device):
    """评估模型性能"""
    model.eval()
    val_loader = DataLoader(val_dataset, batch_size=32)
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for orig_encoding, _, labels in val_loader:
            orig_encoding = {k: v.squeeze(1).to(device) for k, v in orig_encoding.items()}
            _, logits = model(**orig_encoding)
            predictions = torch.argmax(logits, dim=1)
            
            correct += (predictions == labels.to(device)).sum().item()
            total += labels.size(0)
    
    accuracy = correct / total
    print(f"Validation Accuracy: {accuracy:.4f}")
    
    return accuracy

In [None]:
# 主训练循环
def main():
    # 初始化
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    augmenter = TextAugmenter(tokenizer)

     # 准备数据集
    train_texts, val_texts, train_labels, val_labels = train_test_split(
        texts, labels, test_size=0.2, stratify=labels)
    
    # 创建数据集
    train_dataset = TextDataset(train_texts, train_labels, tokenizer, augmenter)
    val_dataset = TextDataset(val_texts, val_labels, tokenizer, augmenter)
    
    # 模型初始化
    model = ContrastiveClusterModel(n_clusters=10).to(device)

    # 训练模型
    model = train_with_pseudo_labels(
        model=model,
        train_dataset=train_dataset,
        device=device,
        max_epochs=10,
        pseudo_label_interval=2
    )
    
    # 评估模型
    evaluate(model, val_dataset, device)

In [None]:
main()