# 在线学习

我希望当新一批 embeddings 进入时，只进行少量的训练。既让模型适应新数据，又尽量不使原本的 embedding - label 映射发生偏移。

In [8]:
import torch
import torch.nn as nn

In [17]:
# 设备配置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 超参数配置
config = {
    "dims": [768, 512, 256, 64, 10],
    "n_clusters": 100,
    "pretrain_epochs": 100,
    "maxiter": 2000,
    "batch_size": 256,
    "update_interval": 100,
    "tol": 0.001,
    "alpha": 1.0,
    "save_dir": "./model"
}

# 加载模型
def load_dec_model(config, device):
    model = DEC().to(device)
    model.load_state_dict(
        torch.load(
            f"{config['save_dir']}/best_model.pth", 
            map_location=device,
            weights_only=True))

    model.eval()
    return model

# 自编码器
class Autoencoder(nn.Module):
    def __init__(self, dims):
        super().__init__()
        # 编码器
        encoder_layers = []
        for i in range(len(dims)-1):
            encoder_layers.append(nn.Linear(dims[i], dims[i+1]))
            if i != len(dims)-2:
                encoder_layers.append(nn.ReLU())
        self.encoder = nn.Sequential(*encoder_layers)
        
        # 解码器（对称结构）
        decoder_layers = []
        for i in reversed(range(len(dims)-1)):
            decoder_layers.append(nn.Linear(dims[i+1], dims[i]))
            if i != 0:
                decoder_layers.append(nn.ReLU())
        self.decoder = nn.Sequential(*decoder_layers)
    
    def forward(self, x):
        h = self.encoder(x)
        h = F.normalize(h, p=2, dim=1)  # 添加L2归一化
        return self.decoder(h)

# 数值稳定的聚类层
class ClusteringLayer(nn.Module):
    def __init__(self, n_clusters, alpha=1.0):
        super().__init__()
        self.n_clusters = n_clusters
        self.alpha = alpha
        self.clusters = nn.Parameter(torch.Tensor(n_clusters, config["dims"][-1]))
        nn.init.xavier_normal_(self.clusters)

    def forward(self, x):
        # 稳定计算距离
        x = x.unsqueeze(1)  # [bs, 1, feat_dim]
        clusters = self.clusters.unsqueeze(0)  # [1, n_clusters, feat_dim]
        dist = torch.sum((x - clusters)**2, dim=2) / self.alpha  # [bs, n_clusters]
        
        # 数值稳定的soft分配
        q = 1.0 / (1.0 + dist)
        q = q ** ((self.alpha + 1.0) / 2.0)
        return q / torch.sum(q, dim=1, keepdim=True)

class DEC(nn.Module):
    def __init__(self):
        super().__init__()
        self.autoencoder = Autoencoder(config["dims"]).to(device)
        self.encoder = self.autoencoder.encoder
        self.clustering = ClusteringLayer(config["n_clusters"], config["alpha"])
        self.log_softmax = nn.LogSoftmax(dim=1)

    def target_distribution(self, q):
        """修正后的目标分布计算"""
        p = q**2 / torch.sum(q, dim=0)
        return (p.t() / torch.sum(p.t(), dim=1, keepdim=True)).t().detach()  # 关键修正
    
    def pretrain(self, data_loader):
        optimizer = optim.Adam(self.parameters())
        criterion = nn.MSELoss()
        self.train()
        
        for epoch in range(config["pretrain_epochs"]):
            total_loss = 0.0
            for idx, x in data_loader:
                x = x.to(device)
                optimizer.zero_grad()
                x_recon = self.autoencoder(x)
                loss = criterion(x_recon, x)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            
            print(f"Pretrain Epoch {epoch+1}/{config['pretrain_epochs']}, Loss: {total_loss/len(data_loader):.4f}")
    
    def fit(self, X, y_true=None):
        # 数据准备（带索引）
        dataset = TensorDataset(torch.arange(len(X)), torch.from_numpy(X.astype(np.float32)))
        pretrain_loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=True)
        
        # 预训练阶段
        self.pretrain(pretrain_loader)
        
        # 初始化聚类中心
        with torch.no_grad():
            full_loader = DataLoader(dataset, batch_size=1024, shuffle=False)
            features, indices = [], []
            for idx, x in full_loader:
                features.append(self.encoder(x.to(device)).cpu())
                indices.append(idx)
            features = torch.cat(features).numpy()
            indices = torch.cat(indices).numpy()
            
            kmeans = KMeans(n_clusters=config["n_clusters"], n_init=20)
            y_pred = kmeans.fit_predict(features)
            self.clustering.clusters.data = torch.tensor(kmeans.cluster_centers_, device=device)

        # 准备聚类优化
        optimizer = optim.SGD(self.parameters(), lr=0.01, momentum=0.9)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=500, gamma=0.5)
        y_pred_last = y_pred.copy()
        best_acc = 0.0

        # 主训练循环
        cluster_loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=False)
        with tqdm(total=config["maxiter"], desc="Clustering") as pbar:
            for ite in range(config["maxiter"]):
                # 更新目标分布
                if ite % config["update_interval"] == 0:
                    with torch.no_grad():
                        # q = self.clustering(self.encoder(torch.from_numpy(X).float().to(device)))
                        # p = self.target_distribution(q)

                        q_list = []
                        for idx, x in DataLoader(dataset, batch_size=1024, shuffle=False):
                            x = x.to(device)
                            q_batch = self.clustering(self.encoder(x))
                            q_list.append(q_batch)
                        q = torch.cat(q_list, dim=0)  # 分批次计算全量q
                        p = self.target_distribution(q)  # 使用修正后的目标分布

                        # 计算聚类指标
                        y_pred = q.argmax(1).cpu().numpy()
                        if y_true is not None:
                            current_acc = acc(y_true, y_pred)
                            current_nmi = nmi(y_true, y_pred)
                            current_ari = ari(y_true, y_pred)
                            pbar.set_postfix(ACC=current_acc, NMI=current_nmi, ARI=current_ari)
                            
                            if current_acc > best_acc:
                                best_acc = current_acc
                                torch.save(self.state_dict(), f"{config['save_dir']}/best_model.pth")
                        
                        # 检查收敛
                        delta_label = np.sum(y_pred != y_pred_last) / X.shape[0]
                        if delta_label < config["tol"]:
                            print(f"\nConverged at iteration {ite}")
                            break
                        y_pred_last = y_pred.copy()
                
                # 批量训练
                for idx, x in cluster_loader:
                    x = x.to(device)
                    optimizer.zero_grad()
                    
                    # 前向计算
                    z = self.encoder(x)
                    q_batch = self.clustering(z)
                    log_q = self.log_softmax(q_batch)
                    
                    # 获取对应p值
                    p_batch = p[idx].to(device)
                    
                    # 计算损失
                    loss = F.kl_div(log_q, p_batch, reduction='batchmean')
                    loss.backward()
                    optimizer.step()

                scheduler.step()
                pbar.update(1)

        return y_pred

In [23]:
class DynamicClusteringLayer(ClusteringLayer):
    def __init__(self, n_clusters, alpha=1.0, expansion_threshold=0.3):
        super().__init__(n_clusters, alpha)
        self.expansion_threshold = expansion_threshold
        self.history_centers_mask = None  # 标记历史中心

    def dynamic_update(self, new_embeddings):
        """动态扩展聚类中心"""
        with torch.no_grad():
            # 计算新数据到最近中心的距离
            dists = torch.cdist(new_embeddings, self.clusters)
            min_dists, _ = torch.min(dists, dim=1)
            
            # 识别需要新建中心的数据
            new_centers = new_embeddings[min_dists > self.expansion_threshold]
            if len(new_centers) > 0:
                # 扩展聚类中心
                updated_centers = torch.cat([self.clusters, new_centers], dim=0)
                self.clusters = nn.Parameter(updated_centers)
                self.n_clusters += len(new_centers)
                
                # 更新历史中心标记
                new_mask = torch.ones(self.n_clusters, dtype=bool)
                if self.history_centers_mask is not None:
                    new_mask[:len(self.history_centers_mask)] = self.history_centers_mask
                self.history_centers_mask = new_mask

def elastic_loss(current_centers, original_centers, fisher_matrix, lambda_=0.5):
    """基于EWC的弹性约束"""
    delta = current_centers - original_centers
    return lambda_ * torch.sum(fisher_matrix * delta**2)

def compute_fisher(model, data_loader, samples=1000):
    fisher = torch.zeros_like(model.clustering.clusters)
    
    model.eval()
    for idx, (x, _) in enumerate(data_loader):
        if idx * data_loader.batch_size > samples:
            break
        x = x.to(device)
        q = model.clustering(model.encoder(x))
        prob = q.max(dim=1)[0]
        prob.backward(torch.ones_like(prob))
        fisher += model.clustering.clusters.grad.pow(2)
        model.zero_grad()
    return fisher / len(data_loader)

def online_update(model, new_embeddings, original_centers, fisher_matrix,
                  lr=0.01, max_iters=50, batch_size=256):
    # 冻结编码器参数
    for param in model.encoder.parameters():
        param.requires_grad = False
    
    # 动态扩展聚类中心
    model.clustering.dynamic_update(new_embeddings)
    
    # 配置优化器（仅优化聚类中心）
    optimizer = optim.SGD(
        [{'params': model.clustering.clusters, 'lr': lr}], 
        momentum=0.9
    )
    
    # 准备数据
    dataset = TensorDataset(torch.from_numpy(new_embeddings))
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # 训练循环
    for epoch in range(max_iters):
        for batch in loader:
            x = batch[0].to(device)
            
            # 前向计算
            z = model.encoder(x)
            q = model.clustering(z)
            
            # 目标分布计算（仅用新数据）
            p = model.target_distribution(q)
            
            # 损失计算
            kl_loss = F.kl_div(torch.log(q), p, reduction='batchmean')
            ewc_loss = elastic_loss(
                model.clustering.clusters[:len(original_centers)],
                original_centers,
                fisher_matrix
            )
            
            total_loss = kl_loss + ewc_loss
            
            # 反向传播
            optimizer.zero_grad()
            total_loss.backward()
            
            # 限制历史中心梯度
            if model.clustering.history_centers_mask is not None:
                mask = model.clustering.history_centers_mask.to(device)
                model.clustering.clusters.grad[mask] *= 0.3  # 衰减历史中心更新
            
            optimizer.step()
    
    return model

In [24]:
# 加载基础模型
base_model = load_dec_model(config, device)

In [25]:
# 记录初始聚类中心和历史Fisher信息
original_centers = base_model.clustering.clusters.clone().detach()
fisher_matrix = compute_fisher(base_model, legacy_data_loader)  # 预计算Fisher信息

NameError: name 'legacy_data_loader' is not defined

In [None]:
updated_model = online_update(
    base_model, 
    new_embeddings=new_emb,
    original_centers=original_centers,
    fisher_matrix=fisher_matrix
)