In [None]:
import numpy as np
import pickle
import operator
from keras_transformer import get_model, decode
# main_path = '/content/drive/My Drive/Colab Notebooks/'    #Google Colab FilePath
main_path = './'
path = main_path + 'middle_data/'
path = 'middle_data/'
with open(path + 'encode_input.pkl', 'rb') as f:
    encode_input = pickle.load(f)
with open(path + 'decode_input.pkl', 'rb') as f:
    decode_input = pickle.load(f)
with open(path + 'decode_output.pkl', 'rb') as f:
    decode_output = pickle.load(f)
with open(path + 'source_token_dict.pkl', 'rb') as f:
    source_token_dict = pickle.load(f)
with open(path + 'target_token_dict.pkl', 'rb') as f:
    target_token_dict = pickle.load(f)
with open(path + 'source_tokens.pkl', 'rb') as f:
    source_tokens = pickle.load(f)
print('Done')

In [None]:
print(len(source_token_dict))
print(len(target_token_dict))
print(len(encode_input))
# 构建模型
model = get_model(
    token_num=max(len(source_token_dict), len(target_token_dict)),
    embed_dim=64,
    encoder_num=2,
    decoder_num=2,
    head_num=4,
    hidden_dim=256,
    dropout_rate=0.05,
    use_same_embed=False,  # 不同语言需要使用不同的词嵌入
)
model.compile('adam', 'sparse_categorical_crossentropy')
# model.summary()
print('Done')

In [None]:
import numpy as np
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau
from keras.preprocessing.sequence import pad_sequences
import warnings
warnings.filterwarnings('ignore')

def clean_sequences(sequences):
    cleaned = []
    for seq in sequences:
        if not hasattr(seq, '__len__') or isinstance(seq, (int, float)):
            cleaned.append([int(seq)])
        else:
            if hasattr(seq, 'tolist'):
                seq = seq.tolist()
            flat = []
            for x in seq:
                if hasattr(x, '__len__') and not isinstance(x, str):
                    flat.extend([int(i) for i in x])
                else:.
                    flat.append(int(x))
            cleaned.append(flat)
    return cleaned

# —— 1. 构造双向训练数据 —— #
# 中→英: encode_input -> decode_input -> decode_output
# 英→中: decode_output -> encode_input -> encode_output (需构造)

# 构造 encode_output：即原中文输入加 <END>
end_token_id = source_token_dict['<END>']
encode_output = [seq + [end_token_id] for seq in encode_input]

# 构造训练集：合并两组对称翻译对
encode_input_all  = encode_input + decode_output
decode_input_all  = decode_input + encode_input
decode_output_all = decode_output + encode_output  # 修复此处

# —— 2. 动态计算并打印 MAX_LEN —— #
lens_enc     = [len(seq) for seq in encode_input_all]
lens_dec_in  = [len(seq) for seq in decode_input_all]
lens_dec_out = [len(seq) for seq in decode_output_all]
MAX_LEN = max(max(lens_enc), max(lens_dec_in), max(lens_dec_out))
print("使用的 MAX_LEN =", MAX_LEN)

# —— 3. 清洗输出列表中可能的标量或嵌套样本 —— #
cleaned = []
for seq in decode_output_all:
    if not hasattr(seq, '__len__') or isinstance(seq, (int, float)):
        cleaned.append([int(seq)])
    else:
        if hasattr(seq, 'tolist'):
            seq = seq.tolist()
        flat = []
        for x in seq:
            if hasattr(x, '__len__') and not isinstance(x, str):
                flat.extend([int(i) for i in x])
            else:
                flat.append(int(x))
        cleaned.append(flat)
decode_output_all = cleaned

# 清洗三个序列
encode_input_all = clean_sequences(encode_input_all)
decode_input_all = clean_sequences(decode_input_all)
decode_output_all = clean_sequences(decode_output_all)

# —— 4. pad_sequences —— #
pad_enc_id = source_token_dict['<PAD>']
pad_dec_id = target_token_dict['<PAD>']

encode_input_all = pad_sequences(
    encode_input_all, maxlen=MAX_LEN,
    padding='post', truncating='post', value=pad_enc_id
)
decode_input_all = pad_sequences(
    decode_input_all, maxlen=MAX_LEN,
    padding='post', truncating='post', value=pad_dec_id
)
decode_output_all = pad_sequences(
    decode_output_all, maxlen=MAX_LEN,
    padding='post', truncating='post', value=pad_dec_id
)

# —— 5. 定义回调 —— #
filepath = "./models/W--{epoch:03d}-{loss:.4f}-.weights.h5"
checkpoint = ModelCheckpoint(
    filepath, monitor='loss', verbose=1,
    save_best_only=True, mode='min',
    save_weights_only=True, save_freq='epoch'
)
reduce_lr = ReduceLROnPlateau(
    monitor='loss', factor=0.2, patience=2,
    verbose=1, mode='min', min_delta=1e-4, min_lr=0
)
callbacks_list = [checkpoint, reduce_lr]

# —— 6. 训练 —— #
history = model.fit(
    x=[encode_input_all, decode_input_all],
    y=decode_output_all,
    epochs=10,
    batch_size=64,
    shuffle=True,
    verbose=1,
    callbacks=callbacks_list
)
for i in range(3):
    print("Enc:", encode_input_all[i])
    print("Dec In:", decode_input_all[i])
    print("Dec Out:", decode_output_all[i])







In [None]:
#加载模型
model.load_weights('models/W--010-0.5277-.weights.h5')
target_token_dict_inv = {v: k for k, v in target_token_dict.items()}
print('Done')

In [None]:
from keras.preprocessing import sequence
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import jieba
import requests

import re

# 判断是否包含中文字符
def contains_chinese(text):
    return bool(re.search(r'[\u4e00-\u9fff]', text))

# 获取输入序列并转换为编码（支持中英互译）
def get_input(seq, is_chinese):
    if is_chinese:
        seq = ' '.join(jieba.lcut(seq, cut_all=False))
        seq = seq.split(' ')
        token_dict = source_token_dict
    else:
        seq = seq.strip().split(' ')
        token_dict = target_token_dict

    seq = ['<START>'] + seq + ['<END>']
    seq = seq + ['<PAD>'] * (34 - len(seq))
    
    for x in seq:
        if x not in token_dict:
            return False, []
    seq_ids = [token_dict[x] for x in seq]
    return True, seq_ids

# 翻译并输出结果
def get_ans(seq_ids, is_chinese):
    decoded = decode(
        model,
        [seq_ids],
        start_token=(target_token_dict if is_chinese else source_token_dict)['<START>'],
        end_token=(target_token_dict if is_chinese else source_token_dict)['<END>'],
        pad_token=(target_token_dict if is_chinese else source_token_dict)['<PAD>'],
    )
    token_dict_inv = target_token_dict_inv if is_chinese else {v: k for k, v in source_token_dict.items()}
    print(' '.join(map(lambda x: token_dict_inv[x], decoded[0][1:-1])))

# 循环交互
while True:
    seq = input("请输入中英文句子 (输入 'x' 退出): ")
    if seq.strip().lower() == 'x':
        break
    is_chinese = contains_chinese(seq)
    flag, seq_ids = get_input(seq, is_chinese)
    if flag:
        get_ans(seq_ids, is_chinese)
    else:
        print('听不懂呢。')