In [1]:
%%capture
!pip install transformers
!pip install pypinyin
!pip install jieba
!pip install paddlepaddle

In [None]:
import re,time,json,pickle
from collections import defaultdict
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
import torch.nn as nn

from transformers import (BertTokenizer,BertConfig,BertModel)

from model.fusionDataset import FusionDataset

config = BertConfig.from_pretrained('AnchiBERT')
tokenizer = BertTokenizer.from_pretrained('AnchiBERT')
Anchibert = BertModel.from_pretrained('AnchiBERT',config=config)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
with open('result/anchi_tra_Adam_128_00001_60_6_6_110k_losses.pt','rb') as f:
    losses = pickle.load(f)
zoom = 1
x = [_ for _ in range(zoom,len(losses[1])+1)]

plt.plot(x,losses[0][zoom-1:],label='Train loss')
plt.plot(x,losses[1][zoom-1:],label='valid loss')
plt.legend()

### Load Necessary preproceeded Data

In [None]:
with open('data/char_map.json','r') as f:
    ix2glyph = defaultdict(lambda : '_')
    ix2glyph[0] = '[PAD]'
    glyph2ix = defaultdict(lambda : 1)
    glyph2ix.update({'[CLS]':0,'[SEP]':0,'[PAD]':0})
    for i, k in enumerate(json.load(f).keys(),2):
        glyph2ix[k] = i
        ix2glyph[i] = k
with open('data/pinyin_map.json','r') as f:
    pinyin2ix = defaultdict(lambda : 1)
    pinyin2ix.update({'[CLS]':0,'[SEP]':0,'[PAD]':0})
    for i,k in enumerate(json.load(f).keys(),2):
        pinyin2ix[k] = i
with open('data/pos_tags.json','r') as f:
    pos2ix = defaultdict(lambda : 0)
    pos2ix.update(json.load(f))

# Decoder Section

In [None]:
from model.fusion_transformer import Fusion_Anchi_Trans_Decoder, Fusion_Anchi_Transformer, Anchi_Decoder,Anchi_Transformer

In [None]:
with open("couplet/test/in.txt",encoding='utf8') as f:
    te_in =  [row.strip().split() for row in f.readlines()]
# train 下联  
with open("couplet/test/out.txt",encoding='utf8') as f:
    te_out = [row.strip().split() for row in f.readlines()]

In [None]:
config = { # Anchi_Transformer
    'max_position_embeddings':50,
    'hidden_size':768,
    'layer_norm_eps':1e-12, 
    'hidden_dropout':0.1, 
    'nhead':12,
    'num_encoder_layers':6, # trainable
    'num_decoder_layers':6, # trainable
    'output_dim':9110,# fixed use glyph dim as output
    'dim_feedforward': 3072,
    'activation':'relu',
    'trans_dropout':0.1,
    'device':device
}
# <model_name>_<optim>_<batch_num>_<lr>_<epoch>_<encoder layer>_<decoder layer>_<train_data_size>
# name = 'anchi_tra_Adam_128_0001_60_6_6_110k'
model= Anchi_Transformer(config)
model.load_state_dict(torch.load('result/anchi_tra_Adam_128_00001_60_6_6_110k.pt'))

In [None]:
from utils.generate_couplet import greedy_decode,beam_search_decode

In [None]:
predict = greedy_decode(model=model,
                      bert=Anchibert,
                      tokenizer=tokenizer,
                      sent=te_in[0],
                      glyph2ix=glyph2ix,
                      pinyin2ix=pinyin2ix,
                      pos2ix=pos2ix,
                      ix2glyph=ix2glyph,
                        device=device)

In [None]:
' '.join(te_in[0]),' '.join(predict)

In [None]:
predict = beam_search_decode(model=model,
                            k=5,
                          bert=Anchibert,
                          tokenizer=tokenizer,
                          sent=te_in[0],
                          glyph2ix=glyph2ix,
                          pinyin2ix=pinyin2ix,
                          pos2ix=pos2ix,
                          ix2glyph=ix2glyph,
                            device=device)

In [17]:
for i, j in predict:
    print( ' '.join(i),j)

彴 镪 云 年 秬 吣 秬 斵 侓 趿 閱 声 焐 艋 何 褵 蜨 瘽 tensor(206.5308, device='cuda:0', grad_fn=<SubBackward0>)
彴 镪 云 年 秬 吣 秬 斵 侓 趿 閱 声 焐 艋 何 褵 蜨 梦 tensor(206.5351, device='cuda:0', grad_fn=<SubBackward0>)
彴 镪 云 年 秬 吣 秬 斵 侓 趿 閱 声 焐 艋 何 褵 崺 瘽 tensor(206.5449, device='cuda:0', grad_fn=<SubBackward0>)
彴 镪 云 年 秬 吣 秬 斵 侓 趿 閱 声 焐 艋 何 褵 崺 梦 tensor(206.5492, device='cuda:0', grad_fn=<SubBackward0>)
彴 镪 云 年 秬 吣 秬 斵 侓 趿 閱 声 焐 艋 何 褵 蜨 黒 tensor(206.5609, device='cuda:0', grad_fn=<SubBackward0>)
