In [1]:
import torch

labels = torch.tensor([1, 2, -1, -1])
sim = torch.rand(4, 4)
global_R = torch.full((labels.size(0), labels.size(0)), -1.0)
mask_label = (labels != -1)
label_mask = mask_label.unsqueeze(0) & mask_label.unsqueeze(1)
label_R = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
global_R[label_mask] = label_R[label_mask]

print(mask_label.unsqueeze(0), '\n', mask_label.unsqueeze(1))
print(global_R)

tensor([[ True,  True, False, False]]) 
 tensor([[ True],
        [ True],
        [False],
        [False]])
tensor([[ 1.,  0., -1., -1.],
        [ 0.,  1., -1., -1.],
        [-1., -1., -1., -1.],
        [-1., -1., -1., -1.]])


从上面的输出结果可以看到，`.unsqueeze()`是在指定参数上增加维度

In [2]:
other_mask = ~label_mask
l, u = 0.5, 0.9
global_R[other_mask & (sim >= u)] = torch.tensor(1.0)
global_R[other_mask & (sim <= l)] = torch.tensor(0.0)
print(sim)
print(global_R)

tensor([[0.4763, 0.6616, 0.2477, 0.4547],
        [0.0404, 0.4367, 0.2876, 0.0377],
        [0.7667, 0.4985, 0.9509, 0.6440],
        [0.2946, 0.5130, 0.3965, 0.5472]])
tensor([[ 1.,  0.,  0.,  0.],
        [ 0.,  1.,  0.,  0.],
        [-1.,  0.,  1., -1.],
        [ 0., -1.,  0., -1.]])


In [4]:
uncert_mask = (global_R == -1)
mask = torch.tril(uncert_mask, diagonal=-1)
row, col = torch.where(mask)
uncert_ij = torch.stack([row, col], dim=1).tolist()
print(uncert_ij)

[[2, 0], [3, 1]]


In [5]:
# 计算损失
pos_mask = (global_R == 1)
neg_mask = (global_R == 0)
eps = 1e-10
pos_entropy = -torch.log(torch.clamp(sim, eps, 1.0)) * pos_mask
neg_entropy = -torch.log(torch.clamp(1 - sim, eps, 1.0)) * neg_mask
loss = pos_entropy.mean() + neg_entropy.mean() + u - l
print(loss)

tensor(0.7474)


## 第1次train梳理

- 矩阵R没有办法拼接，选不确定性样本对必须全局来弄。
- 也就是说，得先得到全局的feats，然后全局计算相似度，主要的开销是这里。

In [None]:
from tqdm import trange, tqdm
def train(self, args):

    for epoch in trange(args.train_epochs, desc="Training"):

        model.train()
        
        for batch in tqdm(train_semi_dataloader):

            with torch.set_grad_enabled(True):
                # input_ids, attention_mask, label
                feats_batch = model()

                sim_batch = self.get_sim_score(feats_batch)
                R_batch = self.get_R(sim_batch, u, l)

                # 计算相似度损失
                sim_loss = pos_entropy + neg_entropy

                indices_pairs_batch = self.get_uncert_pairs(R_batch)
                text_pairs_batch = self.get_text_pairs(indices_pairs_batch)

                global_R = torch.cat((global_R, R_batch))
                indices_pairs = torch.cat((indices_pairs, indices_pairs_batch))
                text_pairs = torch.cat((text_pairs, text_pairs_batch))

        if epoch % 5 == 0:
            # 调用LLM: dict{pair_index, llm_pred, conf}
            llm_generated_outputs = self.llm_labeling(text_pairs)
            # 更新R
            new_R = self.update_R(llm_generated_outputs, global_R)
            # 更新数据集，第一轮应该是构建数据集
            # triplet_dataset: (anchor, pos, neg)
            # 如果全部这样选负例，会不会导致一开始训练太难了？
            # 因为FaceNet提到了这个问题，LANID则直接随机选取负例
            triplet_dataset = self.update_data(llm_generated_outputs)
            triplet_dataloader = DataLoader(triplet_dataset)

            for batch in triplet_dataloader:
                seq_emb = model(batch)

                tri_loss = self.tri_loss(anchor, pos, neg)

            tri_loss += tri_loss

        loss = sim_loss + tri_loss
        # 反向传播
        loss.backward()
        optimizer.step()
        optimizer.zero_step()

## 第2次train梳理

In [None]:
def train(self, args, ):

    for epoch in trange(args.num_train_epochs, desc="Training"):
        # 1. 相似度损失
        for batch in tqdm(train_semi_dataloader):

            with torch.set_grad_enabled(True):
                # input_ids, attention_mask, label
                feats_batch = model()

                sim_batch = self.get_sim_score(feats_batch)
                R_batch = self.get_R(sim_batch, u, l)

                # 计算相似度损失
                sim_loss = pos_entropy + neg_entropy
            sim_loss += sim_loss
        
        # 2. 三元组损失（每隔几轮更新）
        if (epoch + 1) % 5 == 0:
            # 调用LLM标注，得到标注结果
            # 这里llm_labeling方法中已经包含前期步骤：先拿到**全局**特征/sim/R
            # feats, y_true = self.eval(train_semi_dataloader)
            # sim_mat = self.get_sim_score(feats)
            # global_R = self.get_global_R(sim_mat, u, l)
            # indices_pairs = self.get_uncert_pairs(global_R)
            # text_pairs = self.get_text_pairs(indices_pairs)
            llm_outputs = self.llm_labeling(text_pairs)

            # 更新R矩阵和三元组数据集
            global_R = self.update_R(global_R)
            tri_dataset = self.update_data(llm_outputs)
            
        for batch in tri_dataloader:
            anchor, neg, pos = model()

            tri_loss = self.tri_loss(anchor, pos, neg)

        loss = sim_loss + tri_loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    
        


            

## 关于模型

### CDAC
- backbone + [activate + dropout + classify]
- mean pooling

### ALUP
- feature_ext: 
    - backbone + dropout
    - 没有维度变化，默认768

- simple_forward(对比学习投影层): 
    - backbone + dropout + head
    - head: linear + relu + dropout + linear
    - head_feat_dim
- CLS token

### LANID
- 纯特征提取器
- CLS token
- backbone + normalize(head)
- head: linear + relu + dropout + linear
- head_feat_dim



## 第3次train梳理(1026)

- 得到伪标签
    - 模型前向传播得到特征向量，
    - 进行聚类
    - 顺带得到聚类评估结果

- 更新数据集，得到 *train_dataloader*

- 正常训练流程


In [None]:
def update_data(self, indices_pairs, ):

    feats, _ = self.get_features(args, train_semi_dataloader)
    km = KMeans(n_clusters=self.num_labels).fit(feats)
    cluster_centroids, y_pred = km.cluster_centers_, km.labels_
    # 匈牙利算法进行对齐
    cluster_centroids, y_pred = self.alignment(self.centroids, cluster_centroids, y_pred)
    self.centroids = cluster_centroids
    # 匈牙利算法将预测结果映射到真实标签：
    # y_pred_map: 每个具体样本预测标签对应映射后的标签，
    # cluster_map: 每个聚类中心对应映射后的标签
    y_pred_map, cluster_map, cluster_map_opp = self.get_hungray_aligment(y_pred, y_true)

    relations = []
    for step, (i, j) in enumerate(indices_pairs):
        i_label = y_pred_map[indices_pairs[step][0]]
        j_label = y_pred_map[indices_pairs[step][1]]
        if i_label == j_label:
            relations[""]

        
def alignment(self, old_centroids, new_centroids, cluster_labels):
    self.logger.info("***** Conducting Alignment *****")
    if old_centroids is not None:

        old_centroids = old_centroids
        new_centroids = new_centroids
        
        DistanceMatrix = np.linalg.norm(old_centroids[:,np.newaxis,:]-new_centroids[np.newaxis,:,:],axis=2) 
        row_ind, col_ind = linear_sum_assignment(DistanceMatrix)
        
        aligned_centroids = np.zeros_like(old_centroids)
        alignment_labels = list(col_ind)

        for i in range(self.num_labels):
            label = alignment_labels[i]
            aligned_centroids[i] = new_centroids[label]
        # 新label对应老label
        pseudo2label = {label:i for i,label in enumerate(alignment_labels)}
        pseudo_labels = np.array([pseudo2label[label] for label in cluster_labels])

    else:
        aligned_centroids = new_centroids    
        pseudo_labels = cluster_labels 

    self.logger.info("***** Update Pseudo Labels With Real Labels *****")
    
    return aligned_centroids, pseudo_labels


def get_hungray_aligment(self, y_pred, y_true):
    num_test_samples = len(y_pred)
    D = max(y_pred.max(), y_true.max()) + 1
    w = np.zeros((D, D))
    for i in range(y_pred.size):
        w[y_pred[i], y_true[i]] += 1
    ind = np.transpose(np.asarray(linear_sum_assignment(w.max() - w)))
    y_pred_map = []
    cluster_map = [0]*len(ind)
    cluster_map_opp = [0]*len(ind)
    for i in range(num_test_samples):
        yp = y_pred[i]
        y_pred_map.append(ind[yp][1])
    y_pred_map = np.asarray(y_pred_map)

    for item in ind:
        cluster_map[item[0]] = item[1]
        cluster_map_opp[item[1]] = item[0]
    cluster_map = np.asarray(cluster_map)
    cluster_map_opp = np.asarray(cluster_map_opp)
    assert np.all(cluster_map[cluster_map_opp] == np.arange(len(ind)))
    return y_pred_map, cluster_map, cluster_map_opp