In [29]:
from nltk.translate.bleu_score import SmoothingFunction
from nltk.translate import bleu
import numpy as np
from collections import Counter,defaultdict
import nltk
from nltk.lm.preprocessing import pad_both_ends, padded_everygram_pipeline
from nltk.lm.models import Laplace
from nltk.lm import Vocabulary
import pickle

from utils.generate_couplet import beam_search_decode
from transformers import (BertTokenizer,BertConfig,BertModel)

from model.fusionDataset import FusionDataset
from model.fusion_transformer import Fusion_Anchi_Trans_Decoder, Fusion_Anchi_Transformer, Anchi_Decoder,Anchi_Transformer

import sys,os,torch,json,time
import torch.nn.functional as F
import jieba
import random
import jieba.posseg as pseg


In [2]:
"""
python evaluation
"""
s1 = time.time()

config = BertConfig.from_pretrained('AnchiBERT')
tokenizer = BertTokenizer.from_pretrained('AnchiBERT')
Anchibert = BertModel.from_pretrained('AnchiBERT',config=config)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

with open('data/char_map.json','r') as f:
    ix2glyph = defaultdict(lambda : '_')
    ix2glyph[0] = '[PAD]'
    glyph2ix = defaultdict(lambda : 1)
    glyph2ix.update({'[CLS]':0,'[SEP]':0,'[PAD]':0})
    for i, k in enumerate(json.load(f).keys(),2):
        glyph2ix[k] = i
        ix2glyph[i] = k
with open('data/pinyin_map.json','r') as f:
    pinyin2ix = defaultdict(lambda : 1)
    pinyin2ix.update({'[CLS]':0,'[SEP]':0,'[PAD]':0})
    for i,k in enumerate(json.load(f).keys(),2):
        pinyin2ix[k] = i
with open('data/pos_tags.json','r') as f:
    pos2ix = defaultdict(lambda : 0)
    pos2ix.update(json.load(f))

with open("couplet/test/in.txt",encoding='utf8') as f:
    te_in =  [row.strip().split() for row in f.readlines()]

#下联  
with open("couplet/test/out.txt",encoding='utf8') as f:
    te_out = [row.strip() for row in f.readlines()]
    
with open("couplet/train/in.txt",encoding='utf8') as f:
    tr_in =  [row.strip().split() for row in f.readlines()]

#下联  
with open("couplet/train/out.txt",encoding='utf8') as f:
    tr_out = [row.strip() for row in f.readlines()]
    
###############################################################
#            Change this part based on model                  #
###############################################################
config = { # for Fusion_Anchi_Trans_Decoder
    'max_position_embeddings':50,
    'hidden_size':768,
    'font_weight_path':'data/glyph_weight.npy',
    'pinyin_embed_dim':30, # trainable
    'pinyin_path':'data/pinyin_map.json',
    'tag_size':30,
    'tag_emb_dim':10, # trainable
    'layer_norm_eps':1e-12,
    'hidden_dropout':0.1,
    'nhead':12,
    'num_layers':6, # trainable
    'output_dim':21128,# fixed
    'device':device,
}



################################################################



Some weights of the model checkpoint at AnchiBERT were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [3]:
config = { # for Fusion_Anchi_Trans_Decoder
        'max_position_embeddings':50,
        'hidden_size':768,
        'font_weight_path':'data/glyph_weight.npy',
        'pinyin_embed_dim':30, # trainable
        'pinyin_path':'data/pinyin_map.json',
        'tag_size':30,
        'tag_emb_dim':10, # trainable
        'layer_norm_eps':1e-12,
        'hidden_dropout':0.1,
        'nhead':12,
        'num_layers':6 , #6, trainable
        'output_dim':9110,# fixed use glyph dim as output
        'device':device,
    }



# <model_name>_<optim>_<batch_num>_<lr>_<epoch>_<pinyin_embed_dim>_<tag_emb_dim>_<encoder layer>_<decoder layer>_<train_data_size>
name = 'fu_anchi_de_Adam_128_0001_60_6_30_10_193k'
model = Fusion_Anchi_Trans_Decoder(config)
model.load_state_dict(torch.load(f'result/{name}.pt'))

<All keys matched successfully>

In [39]:
def beam_search_decode_test(model,k,bert,tokenizer,
                      sent,glyph2ix,
                      pinyin2ix,pos2ix,
                      ix2glyph,device):
    def get_last_log_prob(sent,memory,model=model,\
                     tokenizer=tokenizer,\
                     glyph2ix=glyph2ix,\
                     pinyin2ix=pinyin2ix,\
                     pos2ix=pos2ix,bert=bert,\
                     encode=False,\
                     skip_error=False,\
                     device=device):
        """
        helper function to get the last word log prob
        return torch([n]), where n is dimension of output
        """
        model.eval()
        
        Ysents_input_ids,Ysents_token_type_ids,\
        Ysents_attention_mask,Ysents_pinyin_ids,\
        Ysents_glyph_ids,Ysents_pos_ids,\
        trueY,y_mask_ids\
        = FusionDataset.prepare_sequence(sents=[sent],\
                                         tokenizer=tokenizer,\
                                         glyph2ix=glyph2ix,\
                                         pinyin2ix=pinyin2ix,\
                                         pos2ix=pos2ix,\
                                         encode=False,\
                                         skip_error=False,\
                                         device=device)

        Yword_embeddings = bert(Ysents_input_ids,\
                                Ysents_token_type_ids,\
                                Ysents_attention_mask \
                               )['last_hidden_state'].detach()

        decode_input = {
                        'memory':memory,\
                        'Xpad_hidden_mask':None,\
                        'Yword_embeddings':Yword_embeddings,\
                        'Ysents_pinyin_ids':Ysents_pinyin_ids, \
                        'Ysents_glyph_ids':Ysents_glyph_ids,\
                        'Ysents_pos_ids':Ysents_pos_ids,\
                        'Ypad_hidden_mask':None,\
                        'tgt_mask':y_mask_ids}

        out = model.decode(**decode_input)
        out = model.Linear(out[-1,0,:])
        # get the latest generate word
        prob = F.log_softmax(out,dim=-1)
        return prob
    
    bert.to(device)
    model.to(device)
    
    model.eval()
    # Generate encoder input
    Xsents_input_ids,Xsents_token_type_ids, \
    Xsents_attention_mask,Xsents_pinyin_ids,\
    Xsents_glyph_ids,\
    Xsents_pos_ids = FusionDataset.prepare_sequence(sents=[sent],
                                                    tokenizer=tokenizer,
                                                    glyph2ix=glyph2ix,
                                                    pinyin2ix=pinyin2ix,
                                                    pos2ix=pos2ix,
                                                    encode=True,
                                                    skip_error=False,
                                                    device=device)

    Xword_embeddings = bert(Xsents_input_ids, \
                         Xsents_token_type_ids, \
                         Xsents_attention_mask  \
                         )['last_hidden_state'].detach()
    encode_input = {'Xword_embeddings':Xword_embeddings,
                    'Xsents_pinyin_ids':Xsents_pinyin_ids, \
                    'Xsents_glyph_ids':Xsents_glyph_ids,\
                    'Xsents_pos_ids':Xsents_pos_ids,
                    'Xpad_hidden_mask':None}
    # ENCODER 
    memory = model.encode(**encode_input)
    
    # sentence processing
    sentence_data = ''.join(sent)
    sent_pattern = [len(i) for i in list(jieba.cut(sentence_data))]
    cache_word = []
    iter_cnt = 0
    
    # pos tagging
    sent_pos_ids = []
    for word in pseg.cut(sentence_data, use_paddle=True):
        sent_pos_ids.extend([pos2ix[word.flag.lower()]]*len(word.word))
    
    
    while iter_cnt < 15:
        sequences = [[list(),0.0]]
        for _ in range(len(sent)):
            all_candidates =list()

            # expand each current candidate
            for i in range(len(sequences)):
                seq, score = sequences[i]

                last_word_prob = get_last_log_prob(seq,memory)

                values, indices = last_word_prob.topk(k+len(sent))
                for j in range(k+len(sent)):
                    w = ix2glyph[indices[j].item()]
                    if w in cache_word:
                        continue
                    if len(seq) != 0:
                        if w in seq or w in cache_word:
                            continue
                    candidate = [seq+[w], score - values[j]]
                    all_candidates.append(candidate)

            # order all candidates by score
            ordered = sorted(all_candidates, key=lambda t:t[1])

            # select k best
            sequences = ordered[:k]

        pred_sent = ''.join(sequences[0][0])
        pred_pattern = [len(i) for i in list(jieba.cut(pred_sent))]
        
        pred_pos = []
        for word in pseg.cut(pred_sent, use_paddle=True):
            pred_pos.extend([pos2ix[word.flag.lower()]]*len(word.word))
    
        

        if pred_pattern == sent_pattern:
            break
        else:
            target_idx = -1
            
            ############# pos tagging part #################
            for i in range(len(pred_sent)):
                if pred_pos[i] != sent_pos_ids[i]:
                    target_idx = i
                    break
            
            if target_idx == -1:
                break
            
            ############### sentence breaking part ############
#             for i in range(min(len(pred_pattern), len(sent_pattern))):
#                 if pred_pattern[i] == sent_pattern[i]:
#                     continue
#                 else:
#                     if pred_pattern[i] > sent_pattern[i]:
#                         target_idx = sum(sent_pattern[:i-1]) + random.randint(0, sent_pattern[i])
#                     else:
#                         target_idx = sum(sent_pattern[:i-1]) + random.randint(0, pred_pattern[i])
            cache_word.append(pred_sent[target_idx])
            
        iter_cnt += 1
        print(pred_sent)
#         print(cache_word)
    return sequences

In [51]:
predict[0][0]

['即', '亦', '既', '也', '众', '乃', '敌']

In [41]:
idx = 3
dataset = [te_in, te_out]
dataset = [[d[idx]] for d in dataset] 

predicts = []
for sent in dataset[0]:
    predict = beam_search_decode_test(model=model,
                            k=2,
                          bert=Anchibert,
                          tokenizer=tokenizer,
                          sent=sent,
                          glyph2ix=glyph2ix,
                          pinyin2ix=pinyin2ix,
                          pos2ix=pos2ix,
                          ix2glyph=ix2glyph,
                            device=device)[0][0]
    predicts.append(''.join(predict))
    
for i, j , k in zip(dataset[0],predicts,dataset[1]):
    print('top:',''.join(i))
    print('predict:',j)
    print('gold:',k)


Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 337.68it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, ?it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.62it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.37it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.62it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1012.87it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1012.87it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.50it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.74i

即其随适就顺从


1it [00:00, 506.56it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1012.87it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.68it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1013.61it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1013.36it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.74it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.74it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.68it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.62it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successf

斯哪木个是只张


1it [00:00, ?it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1013.36it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.80it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1013.61it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1013.12it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.80it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 507.05it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.93it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.68it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully.

其如乃则若条仍


1it [00:00, 1013.36it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.99it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.37it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.62it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 507.05it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.80it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1013.61it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.62it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 337.65it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfu

既已且自亦也又


1it [00:00, 1013.36it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.68it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.68it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.80it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1013.85it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1014.10it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.74it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.19it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.74it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successf

小醇浓酸辣黑椒


1it [00:00, ?it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1014.83it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1013.61it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1012.63it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.80it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.68it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.31it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1013.85it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.74it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully

少幼童稚绛髫纯


1it [00:00, 1013.61it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.62it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.50it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.74it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.68it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.74it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.68it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.80it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.68it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successful

己已自本是只个


1it [00:00, 1013.61it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1013.61it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1013.61it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.68it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 1012.87it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.86it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.62it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.86it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled successfully......
1it [00:00, 506.62it/s]
Paddle enabled successfully......
DEBUG:jieba._compat:Paddle enabled success

top: 晋世文章昌二陆
predict: 自由生活保权存
gold: 魏 家 词 赋 重 三 曹


In [None]:
制谓词滥，准型反类似作样？或仗例典图

In [21]:
list(jieba.cut('一句相思吟岁月'))

['一句', '相思', '吟', '岁月']

In [20]:
list(jieba.cut('二三或少也稀多'))

['二三', '或少', '也', '稀多']

In [None]:
with open(f'../result/{name}_predict.txt','w') as f:
    for i in predicts:
        f.write(i+'\n')

res = evaluate_pred(te_out, predicts)
print('result',res)
print('time:',time.time()-s1)
with open(f'../result/{name}.txt','w') as f:
    f.write(f'{res[0]}\t{res[1][0]}\t{res[1][1]}')