In [1]:
import os
import torch
import evaluate
import numpy as np
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from eval_utils import read_pmcids, sent_json, load_ExtModel, load_AbstrModel, TrigramBlock, convert_sentence_obj

import warnings
warnings.simplefilter("ignore", FutureWarning)
warnings.simplefilter("ignore", UserWarning)

In [2]:
IDS_PATH = '../../dataset/pmcids/test.txt'
JSON_DIR = '../../dataset/sentence_json/'
PARQUET_DIR = '../../dataset/sentence_features/'
MODEL = load_ExtModel('../extractive_summarizer/model/LGB_model_F10_S.pkl')
BLOCK = ['F8','F9','label']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
def convert_sentence_df(sentJson, pred, true_proba):
    
    # 摘要、正文 轉換為 DataFrame
    abstract = pd.DataFrame([(section, sent['text'].strip()) for section in 'IMRD' for sent in sentJson['abstract'][section]],
                       columns=['section', 'text']).astype({'section': 'category', 'text': 'string'})
    
    body = pd.DataFrame([(section, sent['text'].strip(), sent['label']) for section in 'IMRD' for sent in sentJson['body'][section]],
                       columns=['section', 'text', 'label']).astype({'section': 'category', 'text': 'string', 'label': 'bool'})

    # 加上預測結果和機率
    body['predict'] = pred.astype('bool')
    body['proba'] = true_proba.astype('float16')

    
    # 對每章節的提取句子進行 trigram blocking
    if set_trigram_blocking:
        for section in 'IMRD':
            block = TrigramBlock()
            temp = body.loc[(body['section'] == section) & (body['predict'] == True)].sort_values(by='proba', ascending=False)
            for i, row in temp.iterrows():
                if block.check_overlap(row['text']):
                    body.at[i, 'predict'] = False 
                    
    return body, abstract

In [4]:
def process_article(pmcid, threshold,
                    model=MODEL, block_cols=BLOCK, json_dir=JSON_DIR, parquet_dir=PARQUET_DIR):
    
    # 預測
    def predict(x):
        true_proba = model.predict_proba(x)[:, 1]
        # 如果沒有任何句子的預測機率大於閾值，則選取最大機率的句子為摘要句
        if not np.any(true_proba > threshold):
            true_proba[true_proba == np.max(true_proba)] = 1
        pred = (true_proba > threshold).astype('int')
        return pred, true_proba
    
    # 讀取句子特徵，進行預測
    df = pd.read_parquet(f'{parquet_dir}/{pmcid}.parquet')
    sentFeat  = df.drop(columns=block_cols)
    pred, true_proba = predict(sentFeat)
    
    # 讀取句子資料，組合對應文本
    sentJson = sent_json(f'{json_dir}/{pmcid}.json')
    body, abstract = convert_sentence_df(sentJson, pred, true_proba)
    ext = body[body['predict'] == True]
    
    return ext, abstract 

In [5]:
def generate(ext, tokenizer, model):
    abstr = {key: '' for key in 'IMRD'}
    for section in 'IMRD':
        ext_text = ' '.join(list(ext[ext['section']==section]['text']))
        model_inputs = tokenizer(ext_text,  truncation=True, return_tensors='pt').input_ids
        outputs = model.generate(model_inputs.to(device))
        abstr_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        abstr[section] = abstr_text
    return abstr

In [6]:
def main(tokenizer, model, pmcid_file=IDS_PATH, threshold=0.5):
    
    rouge = evaluate.load('rouge')
    pmcids = read_pmcids(pmcid_file)
    lst = ['ALL', 'I', 'M', 'R', 'D']
    hyp = {key: [] for key in lst}
    ref = {key: [] for key in lst}
    
    for pmcid in pmcids:
        ext, abstract = process_article(pmcid, threshold)
        abstr = generate(ext, tokenizer, model)      
        for section in lst:
            section_filter = 'IMRD' if section == 'ALL' else section
            hyp_txt = ' '.join(list(abstr[x] for x in section_filter))
            ref_txt = ' '.join(list(abstract['text'])) if section == 'ALL' else ' '.join(list(abstract[abstract['section']==section]['text']))
            hyp[section].append(hyp_txt)
            ref[section].append(ref_txt)

    res = {key: rouge.compute(predictions=hyp[key], references=ref[key], use_stemmer=True, use_aggregator=True) for key in lst}
    return pd.DataFrame(res).round(4)

## Evaluation Result

#### 1. LGBM+BART

In [7]:
set_trigram_blocking = False

In [8]:
%%time
model_checkpoint = '../abstractive_summarizer/model/checkpoint_bart/checkpoint-39375' # 7-th checkpoint
TOKENIZER = AutoTokenizer.from_pretrained(model_checkpoint)
ABSTRMODEL = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
ABSTRMODEL = ABSTRMODEL.to(device)
main(tokenizer=TOKENIZER, model=ABSTRMODEL)

CPU times: user 3h 20min 25s, sys: 11.2 s, total: 3h 20min 36s
Wall time: 3h 6min 32s


Unnamed: 0,ALL,I,M,R,D
rouge1,0.5583,0.459,0.4079,0.4326,0.4553
rouge2,0.2538,0.2253,0.167,0.186,0.2428
rougeL,0.3054,0.3176,0.2692,0.2684,0.3391
rougeLsum,0.3054,0.3176,0.2692,0.2685,0.3391


#### 2. LGBM+BioBART

In [None]:
set_trigram_blocking = False

In [8]:
%%time
model_checkpoint = '../abstractive_summarizer/model/checkpoint_biobart/checkpoint-16875' # 3-th checkpoint
TOKENIZER = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=1024)
ABSTRMODEL = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
ABSTRMODEL = ABSTRMODEL.to(device)
main(tokenizer=TOKENIZER, model=ABSTRMODEL)

CPU times: user 3h 39min 37s, sys: 9.82 s, total: 3h 39min 47s
Wall time: 3h 25min 47s


Unnamed: 0,ALL,I,M,R,D
rouge1,0.5642,0.4657,0.4159,0.4392,0.4596
rouge2,0.2587,0.233,0.1725,0.1909,0.2457
rougeL,0.3093,0.3222,0.2747,0.2728,0.3434
rougeLsum,0.3092,0.3222,0.2746,0.2727,0.3433


#### 3. LGBM+TB+BioBART

In [7]:
set_trigram_blocking = True

In [8]:
%%time
model_checkpoint = '../abstractive_summarizer/model/checkpoint_tb_biobart/checkpoint-16875' # 3-th checkpoint
TOKENIZER = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=1024)
ABSTRMODEL = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
ABSTRMODEL = ABSTRMODEL.to(device)
main(tokenizer=TOKENIZER, model=ABSTRMODEL)

CPU times: user 3h 15min 31s, sys: 8.19 s, total: 3h 15min 39s
Wall time: 3h 1min 35s


Unnamed: 0,ALL,I,M,R,D
rouge1,0.5419,0.4517,0.3959,0.3949,0.4434
rouge2,0.2345,0.2157,0.1551,0.1545,0.2269
rougeL,0.2913,0.3082,0.26,0.2453,0.3271
rougeLsum,0.2913,0.3081,0.26,0.2452,0.327


#### 4. LGBM+BioBART(base)

In [7]:
set_trigram_blocking = False

In [8]:
model_checkpoint = "GanjinZero/biobart-v2-base"
TOKENIZER = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=1024)
ABSTRMODEL = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
generation_config = {
    'num_beams': 5,
    'max_length': 512,
    'min_length': 64,
    'length_penalty': 2.0,
    'early_stopping': True,
    'no_repeat_ngram_size': None
}

ABSTRMODEL.config.update(generation_config)

In [9]:
%%time
ABSTRMODEL = ABSTRMODEL.to(device)
main(tokenizer=TOKENIZER, model=ABSTRMODEL)

CPU times: user 10h 47min 3s, sys: 19.8 s, total: 10h 47min 22s
Wall time: 10h 32min 16s


Unnamed: 0,ALL,I,M,R,D
rouge1,0.3773,0.3103,0.2829,0.3453,0.364
rouge2,0.2041,0.161,0.1235,0.1699,0.2107
rougeL,0.2245,0.2156,0.1821,0.2184,0.2778
rougeLsum,0.2245,0.2156,0.182,0.2184,0.2777


#### 5. BioBART

In [7]:
set_trigram_blocking = False

In [8]:
def generate_abs(body, tokenizer, model):
    abstr = {key: '' for key in 'IMRD'}
    for section in 'IMRD':
        text = ' '.join([i['text'] for i in body if i['section'] == section])
        model_inputs = tokenizer(text,  truncation=True, return_tensors='pt').input_ids
        outputs = model.generate(model_inputs.to(device))
        abstr_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        abstr[section] = abstr_text
    return abstr

In [9]:
def main_abs(tokenizer, model, pmcid_file=IDS_PATH, json_dir=JSON_DIR):
    
    rouge = evaluate.load('rouge')
    pmcids = read_pmcids(pmcid_file)
    lst = ['ALL', 'I', 'M', 'R', 'D']
    hyp = {key: [] for key in lst}
    ref = {key: [] for key in lst}
    
    for pmcid in pmcids:
        # 讀 Json 檔案轉成 obj
        sentJson = sent_json(f'{json_dir}/{pmcid}.json')
        body = convert_sentence_obj(sentJson['body'])
        hypothesis = generate_abs(body, tokenizer, model)
        reference = convert_sentence_obj(sentJson['abstract'])
        
        for section in lst:
            section_filter = 'IMRD' if section == 'ALL' else section
            hyp_txt = ' '.join(list(hypothesis[x] for x in section_filter))
            ref_txt = ' '.join([i['text'] for i in reference if i['section'] in section_filter])
            hyp[section].append(hyp_txt)
            ref[section].append(ref_txt)
        
    res = {key: rouge.compute(predictions=hyp[key], references=ref[key], use_stemmer=True, use_aggregator=True) for key in lst}
    return pd.DataFrame(res).round(4)

In [11]:
%%time
model_checkpoint = '../abstractive_summarizer/model/checkpoint_biobart/checkpoint-16875' # 3-th checkpoint
TOKENIZER = AutoTokenizer.from_pretrained(model_checkpoint, model_max_length=1024)
ABSTRMODEL = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)
ABSTRMODEL = ABSTRMODEL.to(device)
main_abs(tokenizer=TOKENIZER, model=ABSTRMODEL)

CPU times: user 3h 55min 35s, sys: 4.6 s, total: 3h 55min 40s
Wall time: 3h 55min 35s


Unnamed: 0,ALL,I,M,R,D
rouge1,0.5446,0.4529,0.406,0.4313,0.3389
rouge2,0.2342,0.2202,0.1663,0.1851,0.114
rougeL,0.2809,0.31,0.2679,0.2677,0.2155
rougeLsum,0.2809,0.31,0.2679,0.2678,0.2155
