In [None]:
import pickle
import json
import os
import pandas as pd
import numpy as np
import gc

from tqdm import tqdm
import itertools

class cfg:
    data_path = "/home/xm/workspace/nbme-score-clinical-patient-notes/train_pl_all.pkl"
    out_dir = "/home/xm/workspace/nbme-score-clinical-patient-notes/"

blend_weights = {
    "/home/xm/workspace/output/1002": 0.19,
    "/home/xm/workspace/output/1012": 0.37,
    "/home/xm/workspace/output/1022": 0.44,
}

pl_df = pd.read_pickle(cfg.data_path)

In [None]:
def get_spans(char_logits, texts, th=0):
    '''
    生成所有样本的span list of list of list
    '''
    results = []
    for i, char_prob in enumerate(char_logits): # 循环所有样本
        result = np.where(char_prob > th)[0] # 大于阈值的索引值
        # 根据数值是否连续进行分组
        # result: array([  0,   1,  90,  91,  92,  93,  94,  95,  96,  97,  98, 628, 629, 630])
        # to
        # result: [[0, 1], [90, 91, 92, 93, 94, 95, 96, 97, 98], [628, 629, 630]]
        result = [list(g) for _, g in itertools.groupby(result, key=lambda n, c=itertools.count(): n - next(c))]
        temp = []
        for r in result:
            s, e = min(r), max(r)
            while texts[i][s] == ' ': # 去掉左侧空格
                s += 1
            while texts[i][e] == ' ': # 去掉右侧空格
                e -= 1
            temp.append([s, e+1])
        result = temp
        results.append(result) 
        
    # results example [[[0,5],[64,72]], [[91, 99]], [[128, 134]]]
    return results

def save_pickle(obj, path):
    '''
    保存json文件
    '''
    with open(path, 'wb') as f:
        pickle.dump(obj, f)

In [None]:
char_logits_blend = [np.zeros(len(text)) for text in pl_df.pn_history.values] # blend后，char级的预测概率值
for model_dir, w in tqdm(blend_weights.items()):
    print("model_dir:", model_dir)
    char_logits = pickle.load(open(os.path.join(model_dir, 'pl_logits.pkl'), 'rb')) # 获取 char_logits，{id: char_logits}
    # df中加入char_logits
    df = pl_df.merge(pd.DataFrame({'id': list(char_logits.keys()),
                                   'char_logits': list(char_logits.values())}),
                     on='id',
                     how='left')
    for i in range(len(pl_df)):
        char_logits_blend[i] += df.loc[i, 'char_logits'] * w # 加权求和
        
del df, char_logits
gc.collect()


all_spans = get_spans(char_logits_blend, pl_df.pn_history.values, th=1e-6) # 生成所有样本的span list of list of list



locations = []
for spans in all_spans: # 循环所有样本
    locs = []
    for s, e in spans:
        locs.append(f'{s} {e}')
    locations.append(locs) # 样本的span str of list of list，example [["0 5","64 72"], ["91 99"], ["128 134"]]
    
pl_df['location'] = locations

# 根据span 生成 annotations
annotations = []
for i, spans in enumerate(all_spans):
    annos = []
    for s, e in spans:
        annos.append(pl_df.loc[i, 'pn_history'][s:e])
    annotations.append(annos)
    
# 保存结果df
pl_df['annotation'] = annotations 
pl_df['annotation_length'] = pl_df['annotation'].apply(len)
pl_df.to_pickle(os.path.join(cfg.out_dir, f'train_pl_1002.pkl'))