
场景：
现有一批文本数据，只有少部分有标记。使用对比学习的深度聚类进行训练。对比学习使用基于同义词替换的数据增强方式完成样本构造。
在一个mini-batch中，对于样本x1来说正样本为其增强后的样本x1_,负样本为该mini-batch中其他样本x2及其增强后的样本x2_。
损失函数包括对比损失（InfoNCE），聚类损失（KL散度）和纯度损失（即聚类后属于同一个类别的簇包含不同真实标签数据的多少。
若该簇没有真实标签的数据或有真实标签的数据但都属于同一类别，则该损失为0，否则有越多不同类别的真实标签数据损失越大）。

方法思路
1. 划分数据：
从有标签数据按各类别的20%比例划分出验证集V1,剩余的标记数据未L和未标记数据N。

2. 初始化阶段：
使用标记数据L训练微调一个bert语言模型。

3. 生成真实标签并更新数据集：
用微调后的语言模型进行预测对未标签数据N打标签，只有预测置信度大于阈值50%才打上真实标签。更新标签后的数据集，L扩充为L=L+m，N缩减为N=N-m。

4. 聚类打类别标签：
用语言模型提取所有数据(L+N)的词嵌入特征，然后用GMM进行聚类，给所有数据打上类别标签。

5. 计算损失：
计算对比损失，聚类损失和纯度损失。

6. 循环执行：
重复步骤3-5，直到所有N都被打上真实标签或达到最大的迭代次数或损失不再显著下降。

7. 最终评估：
在独立的验证集V1上评估模型的最终性能，确保模型的有效性和泛化能力。




In [None]:
texts = [
        "This is a positive review",
        "I don't like this product",
        "Great service and quality",
        "The worst experience ever",
        "Amazing product, highly recommend",
        "This is just okay",
        "The prouduct is really good",
        "I buy this product again",
        "Good quality!",
        "Not bad, Not bad",
        "It works fine at present",
        "It isn't out of my expectation",
        "It's no hurt to give it a try",
        "I don't think i will buy it again",
        "I didn't use it a lot because i have thrown it away",
        "who will buy this product?",
        # 无标签数据
        "Need to try this again",
        "Not sure about this one",
        "Will update my review later",
        "I love it",
        "It saves time"
]
    
# 部分标签（0：负面，1：中性，2：正面），None表示无标签
labels = [2, 0, 2, 0, 2, 1, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, None, None, None, None, None]

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer
from sklearn.mixture import GaussianMixture
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from nltk.corpus import wordnet
import random
from typing import List, Tuple, Optional

In [None]:
# 配置类
class Config:
    def __init__(self):
        self.max_length = 128
        self.batch_size = 32
        self.n_clusters = 10
        self.learning_rate = 2e-5
        self.max_epochs = 10
        self.temperature = 0.07
        self.confidence_threshold = 0.5
        self.num_workers = 4
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.alpha = 1
        self.beta = 1
        self.gamma = 1

# 数据增强
class TextAugmenter:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        
    def synonym_replace(self, text):
        tokens = text.split()
        n = max(1, int(len(tokens) * 0.1))  # 替换10%的词
        positions = random.sample(range(len(tokens)), n)
        
        for pos in positions:
            word = tokens[pos]
            synonyms = []
            for syn in wordnet.synsets(word):
                for lemma in syn.lemmas():
                    synonyms.append(lemma.name())
            if synonyms:
                tokens[pos] = random.choice(synonyms)
        
        return " ".join(tokens)
    
# 自定义数据集
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, augmenter, config):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.augmenter = augmenter
        self.max_length = config.max_length
        
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        
        # 原始样本
        encoding = self.tokenizer(text, return_tensors='pt', padding='max_length', 
                                truncation=True, max_length=self.max_length)
        if self.augmenter:
            # 增强样本
            aug_text = self.augmenter.synonym_replace(text)
            aug_encoding = self.tokenizer(aug_text, return_tensors='pt', padding='max_length',
                                    truncation=True, max_length=self.max_length)
            return encoding, aug_encoding, label
        else:
            return encoding, label
    
# 模型定义
class ContrastiveClusterModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.projector = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 128)
        )
        self.cluster_head = nn.Linear(128, config.n_clusters)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        embeddings = outputs.last_hidden_state[:, 0]  # [CLS] token
        features = self.projector(embeddings)
        logits = self.cluster_head(features)
        return features, logits
    
# 损失函数
class ContrastiveClusterLoss(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.temperature = config.temperature
        self.alpha = config.alpha
        self.beta = config.beta
        self.gamma = config.gamma
        
    def forward(self, features, aug_features, cluster_pred, true_labels, pseudo_labels):
        # InfoNCE loss
        features = F.normalize(features, dim=1)
        aug_features = F.normalize(aug_features, dim=1)
        
        pos_sim = torch.sum(features * aug_features, dim=1)
        neg_sim = torch.mm(features, features.t())
        
        nce_loss = -torch.log(
            torch.exp(pos_sim / self.temperature) /
            (torch.sum(torch.exp(neg_sim / self.temperature), dim=1) - 1)
        ).mean()
        
        # Clustering loss (KL divergence)
        cluster_pred = F.log_softmax(cluster_pred, dim=1)
        q = self.target_distribution(cluster_pred)
        cluster_loss = F.kl_div(cluster_pred, q, reduction='batchmean')
        
        # Purity loss
        purity_loss = self.compute_purity_loss(cluster_pred, true_labels)
        
        return self.alpha * nce_loss + self.beta * cluster_loss + self.gamma * purity_loss
    
    def target_distribution(self, q):
        weight = q ** 2 / q.sum(0)
        return (weight.t() / weight.sum(1)).t()
    
    def compute_purity_loss(self, cluster_pred, true_labels):
        loss = 0
        pred_labels = torch.argmax(cluster_pred, dim=1)
        
        for c in torch.unique(pred_labels):
            cluster_mask = (pred_labels == c)
            cluster_labels = true_labels[cluster_mask]
            
            if len(cluster_labels) == 0 or len(torch.unique(cluster_labels)) <= 1:
                continue
                
            loss += len(torch.unique(cluster_labels)) - 1
            
        return loss / len(torch.unique(pred_labels))
    

def evaluate(model, val_dataset, config):
    """评估模型性能"""
    model.eval()
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for orig_encoding, labels in val_loader:
            orig_encoding = {k: v.squeeze(1).to(config.device) for k, v in orig_encoding.items()}
            _, logits = model(**orig_encoding)
            predictions = torch.argmax(logits, dim=1)
            
            correct += (predictions == labels.to(config.device)).sum().item()
            total += labels.size(0)
    
    accuracy = correct / total
    print(f"Validation Accuracy: {accuracy:.4f}")
    
    return accuracy

# 数据加载和划分函数
def prepare_data(texts: List[str], labels: Optional[List[int]], 
                val_ratio: float = 0.2) -> Tuple:
    """
    准备训练数据，划分验证集
    """
    if labels is not None:
        # 有标签数据划分
        labeled_texts = []
        labeled_labels = []
        unlabeled_texts = []
        
        for text, label in zip(texts, labels):
            if label is not None and not pd.isna(label):
                labeled_texts.append(text)
                labeled_labels.append(label)
            else:
                unlabeled_texts.append(text)
                
        # 划分验证集
        train_texts, val_texts, train_labels, val_labels = train_test_split(
            labeled_texts, labeled_labels, 
            test_size=val_ratio, 
            stratify=labeled_labels
        )

        print(f"数据集统计:")
        print(f"训练集大小: {len(train_texts)}")
        print(f"验证集大小: {len(val_texts)}")
        print(f"未标记数据大小: {len(unlabeled_texts)}")
        
        return train_texts, train_labels, unlabeled_texts, val_texts, val_labels
    else:
        # 全部为无标签数据
        return [], [], texts, [], []
    
# 定义一个聚类器基类
class BaseClusterer:
    def fit_predict(self, features):
        raise NotImplementedError
        
# GMM聚类实现
class GMMClusterer(BaseClusterer):
    def __init__(self, n_clusters):
        self.n_clusters = n_clusters
        self.gmm = GaussianMixture(n_components=n_clusters)
        
    def fit_predict(self, features):
        # 将tensor转换为numpy
        if torch.is_tensor(features):
            features = features.cpu().detach().numpy()
        return self.gmm.fit_predict(features)
    
def train_model(train_label_dataset, train_unlabel_dataset, config):
    # 模型初始化
    model = ContrastiveClusterModel(config).to(config.device)
    criterion = ContrastiveClusterLoss(config)
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3)
    
    # 初始化聚类器
    clusterer = GMMClusterer(n_clusters=config.n_clusters)
    
    best_loss = float('inf')
    patience = 0
    max_patience = 5
    
    for epoch in range(config.max_epochs):
        print(f"\nEpoch {epoch+1}/{config.max_epochs}")

        # 创建数据加载器
        label_loader = DataLoader(train_label_dataset, 
                            batch_size=config.batch_size,
                            shuffle=True,
                            num_workers=config.num_workers)
        unlabel_loader = DataLoader(train_unlabel_dataset,
                              batch_size=config.batch_size,
                              shuffle=False,
                              num_workers=config.num_workers)
        
        # 1. 提取所有数据特征
        model.eval()
        all_features = []
        all_labels = []  # 有标签数据的真实标签
        unlabel_features = []
        
        with torch.no_grad():
            # 提取有标签数据特征
            for batch in label_loader:
                orig_encoding, _, labels = batch
                orig_encoding = {k: v.squeeze(1).to(config.device) for k, v in orig_encoding.items()}
                features, _ = model(**orig_encoding)
                all_features.append(features.cpu())
                all_labels.extend(labels.numpy())
            
            # 提取无标签数据特征
            for batch in unlabel_loader:
                orig_encoding = batch[0]
                orig_encoding = {k: v.squeeze(1).to(config.device) for k, v in orig_encoding.items()}
                features, _ = model(**orig_encoding)
                unlabel_features.append(features.cpu())

        # 合并所有特征用于聚类
        all_features = torch.cat(all_features, dim=0)
        unlabel_features = torch.cat(unlabel_features, dim=0) if unlabel_features else torch.tensor([])
        combined_features = torch.cat([all_features, unlabel_features], dim=0)
        
        # 2. GMM聚类
        cluster_labels = clusterer.fit_predict(combined_features)
        # 分离有标签和无标签数据的聚类结果
        label_cluster_labels = cluster_labels[:len(all_features)]
        unlabel_cluster_labels = cluster_labels[len(all_features):]

        # 3. 训练阶段
        model.train()
        total_loss = 0
        batch_count = 0
        
        for batch_idx, batch in enumerate(label_loader):
            orig_encoding, aug_encoding, labels = batch
            
            # 移动数据到设备
            orig_encoding = {k: v.squeeze(1).to(config.device) for k, v in orig_encoding.items()}
            aug_encoding = {k: v.squeeze(1).to(config.device) for k, v in aug_encoding.items()}
            labels = labels.to(config.device)
            
            # 获取当前批次对应的聚类标签
            start_idx = batch_idx * config.batch_size
            end_idx = start_idx + len(labels)
            batch_cluster_labels = torch.tensor(label_cluster_labels[start_idx:end_idx]).to(config.device)

            # 前向传播
            features_orig, logits_orig = model(**orig_encoding)
            features_aug, logits_aug = model(**aug_encoding)

            # 计算损失：对比损失 + 聚类损失 + 纯度损失
            loss = criterion(
                features=features_orig,
                aug_features=features_aug,
                cluster_pred=logits_orig,
                true_labels=labels,
                pseudo_labels=batch_cluster_labels  # 使用聚类结果作为伪标签
            )

            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            batch_count += 1

        avg_loss = total_loss / batch_count if batch_count > 0 else float('inf')
        print(f"Epoch {epoch+1} Average Loss: {avg_loss:.4f}")
        
        # 更新学习率和早停
        scheduler.step(avg_loss)
        if avg_loss < best_loss:
            best_loss = avg_loss
            patience = 0
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            patience += 1
            if patience >= max_patience:
                print("Early stopping triggered")
                break

        # 4. 更新数据集
        model.eval()
        confidences = []
        predictions = []
        
        with torch.no_grad():
            for batch in unlabel_loader:
                orig_encoding = batch[0]
                orig_encoding = {k: v.squeeze(1).to(config.device) for k, v in orig_encoding.items()}
                _, logits = model(**orig_encoding)
                probs = F.softmax(logits, dim=1)
                confidence, prediction = torch.max(probs, dim=1)
                confidences.extend(confidence.cpu().numpy())
                predictions.extend(prediction.cpu().numpy())

        confident_mask = np.array(confidences) > config.confidence_threshold
        if np.sum(confident_mask) > 0:
            new_labeled_indices = np.where(confident_mask)[0]
            new_labeled_texts = [train_unlabel_dataset.texts[i] for i in new_labeled_indices]
            new_labeled_labels = [predictions[i] for i in new_labeled_indices]
            
            train_label_dataset.texts.extend(new_labeled_texts)
            train_label_dataset.labels.extend(new_labeled_labels)
            
            remaining_indices = np.where(~confident_mask)[0]
            train_unlabel_dataset.texts = [train_unlabel_dataset.texts[i] for i in remaining_indices]
            
            print(f"Added {len(new_labeled_texts)} new labeled samples")
            print(f"Remaining unlabeled samples: {len(train_unlabel_dataset)}")

        if len(train_unlabel_dataset) == 0:
            print("All unlabeled data has been labeled")
            break

    model.load_state_dict(torch.load('best_model.pth'))
    return model

def count_unique_labels(lst):
    series = pd.Series(lst)
    filtered_series = series[pd.notna(series)]
    return filtered_series.nunique()

def main():
    # 设置随机数种子
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)

     # 初始化tokenizer
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    augmenter = TextAugmenter(tokenizer)

    # 准备数据
    train_texts, train_labels, unlabeled_texts, val_texts, val_labels = prepare_data(
        texts, labels
    )

    train_label_dataset = TextDataset(train_texts, train_labels, augmenter, config)
    train_unlabel_dataset = TextDataset(unlabeled_texts, None, None, config)
    val_dataset = TextDataset(val_texts, val_labels, None, config)

    # 初始化配置
    config = Config()
    config.max_length = 512
    config.n_clusters = count_unique_labels(train_labels)

    # 训练模型
    model = train_model(train_label_dataset, train_unlabel_dataset, config)

    # 评估模型
    evaluate(model, val_dataset, config)

In [None]:
main()