In [None]:
import os
import re

import numpy as np
import pandas as pd

import nltk
from nltk.tokenize import sent_tokenize
from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import sentence_bleu
from nltk.translate.nist_score import sentence_nist 
from nltk.translate.meteor_score import meteor_score 

In [None]:
def _is_chinese_char(uchar):
    """
    判断是否中文字符
    :param uchar: input char in unicode
    :return: whether the input char is a Chinese character.
    """
    _UCODE_RANGES = [
        (u'\u3400', u'\u4db5'),  # CJK Unified Ideographs Extension A, release 3.0
        (u'\u4e00', u'\u9fa5'),  # CJK Unified Ideographs, release 1.1
        (u'\u9fa6', u'\u9fbb'),  # CJK Unified Ideographs, release 4.1
        (u'\uf900', u'\ufa2d'),  # CJK Compatibility Ideographs, release 1.1
        (u'\ufa30', u'\ufa6a'),  # CJK Compatibility Ideographs, release 3.2
        (u'\ufa70', u'\ufad9'),  # CJK Compatibility Ideographs, release 4.1
        (u'\u20000', u'\u2a6d6'),  # (UTF16) CJK Unified Ideographs Extension B, release 3.1
        (u'\u2f800', u'\u2fa1d'),  # (UTF16) CJK Compatibility Supplement, release 3.1
        (u'\uff00', u'\uffef'),  # Full width ASCII, full width of English punctuation,
                                 # half width Katakana, half wide half width kana, Korean alphabet
        (u'\u2e80', u'\u2eff'),  # CJK Radicals Supplement
        (u'\u3000', u'\u303f'),  # CJK punctuation mark
        (u'\u31c0', u'\u31ef'),  # CJK stroke
        (u'\u2f00', u'\u2fdf'),  # Kangxi Radicals
        (u'\u2ff0', u'\u2fff'),  # Chinese character structure
        (u'\u3100', u'\u312f'),  # Phonetic symbols
        (u'\u31a0', u'\u31bf'),  # Phonetic symbols (Taiwanese and Hakka expansion)
        (u'\ufe10', u'\ufe1f'),
        (u'\ufe30', u'\ufe4f'),
        (u'\u2600', u'\u26ff'),
        (u'\u2700', u'\u27bf'),
        (u'\u3200', u'\u32ff'),
        (u'\u3300', u'\u33ff'),
    ]
    for start, end in _UCODE_RANGES:
        if start <= uchar <= end:
            return True
    return False

def tokenize_with_Chinese(line):
    #  分词，中英文混合
    # 分词方法参考: SacreBleu.Metrics.tokenizer.tokenizer_zh
    _re = [
                # language-dependent part (assuming Western languages)
                (re.compile(r'([\{-\~\[-\` -\&\(-\+\:-\@\/])'), r' \1 '),
                # tokenize period and comma unless preceded by a digit
                (re.compile(r'([^0-9])([\.,])'), r'\1 \2 '),
                # tokenize period and comma unless followed by a digit
                (re.compile(r'([\.,])([^0-9])'), r' \1 \2'),
                # tokenize dash when preceded by a digit
                (re.compile(r'([0-9])(-)'), r'\1 \2 '),
            ]
    line = line.strip()
    line_in_chars = ""
    for char in line:
        if _is_chinese_char(char):
            line_in_chars += " "
            line_in_chars += char
            line_in_chars += " "
        else:
            line_in_chars += char
    line = line_in_chars
    for (_re, repl) in _re:
        line = _re.sub(repl, line)
    return line.split()

In [None]:
def read(file_name):
    """
    功能：
        读入文本数据。   
    参数:
        file_name(str) - 文件名     
    返回:
        读入文件中的文本(str)
    """
    with open(file_name,"r",encoding = "utf8") as f:
        text = f.readlines()
        return text


In [None]:
def cal_avg(l):
    """
    功能：
        计算平均数
    参数：
         l-列表list
    返回：
        平均数
    """
    return sum(l)/len(l)

In [None]:
def process_sentence(sent):
    """
    功能：
        单个句子的预处理。   
    参数:
        sent(str) - 要处理的句子文本     
    返回:
        处理后的句子(str)
    """
    sent = sent.lower()
    punctuation = r"""!"#$%&'()*+,./:;<=>?@[\]^_`{|}~“”？，！【】（）、。：；’‘……￥·"""
    dicts={i:'' for i in punctuation}
    punc_table=str.maketrans(dicts)
    new_sent = sent.translate(punc_table)
    return new_sent


In [None]:
def cal_bleu(refs, candi):
    """
    功能：
        计算BLEU值。   
    参数:
        refs(List(str)) - 多个参考译文文本 [ref1, ref2, ref3]
        candi(str) - 候选译文文本 "candidate"
    返回:
        候选译文的BLEU值(float)
    """
    refs_token = []
    for ref in refs:
        ref_token = chinese_tokenized(ref)
        refs_token.append(ref_token)
    candi_token = chinese_tokenized(candi)
    score = sentence_bleu(refs_token, candi_token, weights=[0.25,0.25,0.25,0.25])
    return score

In [None]:
def remove_num_marks(sentence):
    """
    功能：
        去除句首定位符的数字。   
    参数:
        sentence(str) - 待处理的句子
    返回:
        处理后的句子(str)
    """
    try:
        if sentence[0].isdigit() and len(sentence)<=2: 
            return "" 
        # 以上是为了处理文件中存在文中的某一行数据为空的情况
        else:
            if "." not in sentence[:3]: return sentence[2:] 
            return sentence[sentence.index(".")+1:]
    except:
        print(sentence)

In [None]:
def cn_meteor_score(ref_sents, candi_sent):
    ref_sents = [" ".join(chinese_tokenized(ref_sent)) for ref_sent in ref_sents]
    candi_sent = " ".join(chinese_tokenized(candi_sent))
    return meteor_score(ref_sents, candi_sent)

In [None]:
def get_all_dir(path):
    dbtype_list = os.listdir(path)
    for dbtype in dbtype_list:
        if os.path.isfile(os.path.join(path,dbtype)):
            dbtype_list.remove(dbtype)
    return dbtype_list

In [None]:
def read_txt(txt_path):
    for _,_, file in os.walk(txt_path):
        txt_files = file
    print(txt_path)
    txt_files = [txt_path+t for t in txt_files if t.endswith(".txt") and "没转写完" not in t]
    print(txt_files)
    return txt_files

In [None]:
def run_calc(candi_files, ref_files, task, dataset):
    # Step1. 读入文本数据

    # 存放当前任务的参考译文
    references = []
#     import pdb;pdb.set_trace()
    for f in ref_files:
        sentences = read(f)
        #     去除句首定位符的数字
        sentences = [remove_num_marks(sent) for sent in sentences if sent[0].isdigit()]
        sentences = [process_sentence(sent.strip()) for sent in sentences] #strip()去掉收尾的特殊字符，sent[0].isdigit()获取以数字开头的句子, process_sentence预处理每个句子
        assert len(sentences)== sentence_num # 断言每个文本的个数为sentence_num
    
        references.append(sentences)
    sentence_num = len(references[0]) # 每个任务中句子的个数

    # 存放当前任务的参考译文
    candidates = []
    for f in candi_files:
        sentences = read(f)
        # 去除句首定位符的数字
        sentences = [remove_num_marks(sent) for sent in sentences if sent[0].isdigit()]
        sentences = [process_sentence(sent.strip()) for sent in sentences] #strip()去掉收尾的特殊字符，sent[0].isdigit()获取以数字开头的句子, process_sentence预处理每个句子
        print(f"参考译文句子个数:{sentence_num},候选译文句子个数：{len(sentences)},文件名：{f}")
        assert len(sentences)== sentence_num # 断言每个文本的个数为sentence_num
        candidates.append(sentences)
    # 计算每个句子的bleu、meteor、nist值
    # 遍历所有候选译文
    doc_bleu = {}
    doc_meteor = {}
    doc_nist = {}
    for i in range(len(candidates)):
        candi_doc = candidates[i]
        # 遍历每个候选译文的所有句子
        sent_bleu_list = []
        sent_meteor_list = []
        sent_nist_list = []
        for idx in range(len(candi_doc)):
            candi_sent = candi_doc[idx]
            # 获取当前句子的参考译文
            ref_sents = [ref[idx] for ref in references] 
    #         获取当前句子的bleu、meteor值
            score_bleu = cal_bleu(ref_sents, candi_sent)
            score_meteor = cn_meteor_score(ref_sents, candi_sent)
    #         由于nist内部涉及除法，句子为空会报错。因此，当发生这种情况下，将nist的分数设为0
            try:
                score_nist = sentence_nist([chinese_tokenized(sent) for sent in ref_sents], chinese_tokenized(candi_sent))
            except ZeroDivisionError:
                score_nist = 0
            sent_bleu_list.append(score_bleu)
            sent_meteor_list.append(score_meteor)
            sent_nist_list.append(score_nist)
    #     将所有句子的分数值求平均作为该候选译文的最终值
        doc_bleu[candi_files[i]] = cal_avg(sent_bleu_list)
        doc_meteor[candi_files[i]] = cal_avg(sent_meteor_list)
        doc_nist[candi_files[i]] = cal_avg(sent_nist_list)
        taskname = f"{task.strip()}_{dataset}_"
    single_result = {"ID":[taskname+c.split("/")[-1].replace(".txt","") for c in candi_files],taskname+"BLEU": doc_bleu.values(), taskname+"NIST": doc_nist.values(), taskname+"METEOR": doc_meteor.values()}
    return pd.DataFrame.from_dict(single_result)

In [None]:
tasks = ["1_SI", "2_ST", "3_ST", "4_ST", "5_CI", "6_CI", "7_CI", "8_CI", "9_CI"]
datasets = []
for task in tasks:
    task_result_list = []
    print("#"*10, task, "#"*10)
    datasets = get_all_dir(f"../{task}")
    datasets = [dataset for dataset in datasets if "EC" in dataset]
    print(datasets)
    if len(datasets)==0: continue
    for dataset in datasets:
        print("#"*5, dataset, "#"*5)
        candi_path = f"../{task}/{dataset}/candidates/"
        ref_path = f"../{task}/{dataset}/references/"
        source_path = f"../{task}/{dataset}/source/"
        #print(candidates_path, references_path, source_path)
        candi_files = read_txt(candi_path)
        ref_files = read_txt(ref_path)
        task_result = run_calc(candi_files, ref_files, task, dataset)
        task_result_list.append(task_result)
    task_df = pd.concat(task_result_list,axis=1)
    task_df.to_excel(f"./result/sep/{task}.EC.BLEU.NIST.METEOR.xls",index=None)