In [1]:
# 使用Snorkel来ensemble结果

In [2]:
import json
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score
from snorkel.labeling import LabelingFunction
from snorkel.labeling import PandasLFApplier
from snorkel.labeling import LFAnalysis
from snorkel.labeling.model import MajorityLabelVoter, LabelModel

In [3]:
ensemble_model = ['chatglm2-6b', 'qwen-7b-chat', 'siamese_uninlu', 'paddle_nlp']
DATASET = 'tnews'
split = 'train'

In [4]:
def get_anno(dataset, model_name, split='train'):
    res = {}
    if model_name == 'gt':
        fname = f'/mnt/workspace/exp_dataset/{dataset}/{split}.json'
    else:
        fname = f'/mnt/workspace/exp_dataset/{dataset}/result_{model_name}_{dataset}_{split}.json'
    with open(fname, 'r', encoding='utf-8') as f:
        for line in f:
            line = json.loads(line.strip())
            res[line['sentence']] = line['label']
    return res

In [None]:
ABSTAIN = -1

def get_acc(gt, pred):
    # 考虑覆盖率来计算acc
    new_gt, new_pred = [], []
    for i in range(len(gt)):
        if pred[i] >= 0:
            new_gt.append(gt[i])
            new_pred.append(pred[i])
            
    return np.array(new_gt), np.array(new_pred)

if DATASET == 'tnews':
    texts, labels = [], []
    with open(f'/mnt/workspace/exp_dataset/tnews/{split}.json', 'r', encoding='utf-8') as f:
        for line in f:
            line = json.loads(line.strip())
            text, label = line['sentence'], line['label']
            texts.append(text)
            labels.append(label)
            
    ensemble_model_annota = [get_anno(DATASET, model_name, split=split) for model_name in ensemble_model]
    
    df = pd.DataFrame()
    df['sentence'] = texts
    labels = np.array(labels)
    
    def get_llm_annota(x, model_index):
        annota = ensemble_model_annota[model_index]
        if x['sentence'] in annota:
            return annota[x['sentence']]
        else:
            return ABSTAIN
    
    def get_df(model_index):
        return LabelingFunction(
        name=f"model_{ensemble_model[model_index]}",
        f=get_llm_annota,
        resources=dict(model_index=model_index))
    
    # 不同的标注结果，构造为不同的LF
    LF1 = get_df(0)
    LF2 = get_df(1)
    LF3 = get_df(2)
    LF4 = get_df(3)
    
    lfs = [LF1, LF2, LF3, LF4]
    
    applier = PandasLFApplier(lfs=lfs)
    L_train = applier.apply(df=df)
    
    # 应用Label Model，得到estimated label
    majority_model = MajorityLabelVoter(cardinality=15)
    preds_train_mv = majority_model.predict(L=L_train)
    
    label_model = LabelModel(cardinality=15, verbose=True)
    label_model.fit(L_train=L_train, n_epochs=500, log_freq=100, seed=123)
    preds_train_lm = label_model.predict(L=L_train)
    
    
    gt_mv, pred_mv = get_acc(labels, preds_train_mv)
    gt_lm, pred_lm = get_acc(labels, preds_train_lm)

    print(f'Majority Vote acc（仅衡量覆盖结果): {accuracy_score(gt_mv, pred_mv)}')
    print(f'Label Model acc（仅衡量覆盖结果): {accuracy_score(gt_lm, pred_lm)}')
    
    majority_acc = majority_model.score(L=L_train, Y=labels, tie_break_policy="random")["accuracy"]
    print(f"{'Majority Vote Accuracy:':<25} {majority_acc * 100:.1f}%, coverage:{len([pred for pred in preds_train_mv if pred >= 0]) / len(preds_train_mv)}")

    label_model_acc = label_model.score(L=L_train, Y=labels, tie_break_policy="random")["accuracy"]
    print(f"{'Label Model Accuracy:':<25} {label_model_acc * 100:.1f}%, coverage:{len([pred for pred in preds_train_lm if pred >= 0]) / len(preds_train_lm)}")
    

In [None]:
LFAnalysis(L=L_train, lfs=lfs).lf_summary()

In [19]:
# 导出ensemble结果

with open(f'/mnt/workspace/exp_dataset/{DATASET}/result_ensemble_MajorityVote_{DATASET}_{split}.json', 'w', encoding='utf-8') as f:
    for i in range(len(texts)):
        text, label =  texts[i], int(preds_train_lm[i])
        if label >= 0:
            f.write(json.dumps({
                'sentence':text,
                'label': label
            }, ensure_ascii=False)+'\n')
            
with open(f'/mnt/workspace/exp_dataset/{DATASET}/result_ensemble_LabelModel_{DATASET}_{split}.json', 'w', encoding='utf-8') as f:
    for i in range(len(texts)):
        text, label =  texts[i], int(preds_train_mv[i])
        if label >= 0:
            f.write(json.dumps({
                'sentence':text,
                'label': label
            }, ensure_ascii=False)+'\n')
        