In [1]:
import numpy as np
import jieba
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader,Dataset
from collections import Counter

In [16]:
def read_data(file_path):
    src = []
    tgt = []
    with open(file_path , mode='r',encoding='utf-8') as f:
        for line in f:
            s,t = line.strip().lower().split('\t')
            s = s.split(' ')
            t = t.split(' ')
            src.append(s)
            tgt.append(t)
    return src,tgt

src_data ,tgt_data =  read_data(r'.\data\eng-fra.txt')

In [18]:
print(src_data[10:20])
print(tgt_data[10:20])

[['wait!'], ['wait!'], ['i', 'see.'], ['i', 'try.'], ['i', 'won!'], ['i', 'won!'], ['oh', 'no!'], ['attack!'], ['attack!'], ['cheers!']]
[['attends', '!'], ['attendez', '!'], ['je', 'comprends.'], ["j'essaye."], ["j'ai", 'gagné', '!'], ['je', "l'ai", 'emporté', '!'], ['oh', 'non', '!'], ['attaque', '!'], ['attaquez', '!'], ['santé', '!']]


In [22]:
# 构建词汇表
class Vocab:
    '''词汇表类，用于从文本数据中构建词汇表'''
    # 定义类属性
    # 定义填充符号、未识别符号、开始符合、结束符号
    PAD = '<pad>'
    UNK = '<unk>'
    SOS = '<SOS>'
    EOS = '<ens>'

    pad_idx = 0
    unk_idx = 1
    sos_idx = 2
    eos_idx = 3

    def __init__(self,text,max_vocab = 5000):
        '''初始化，构建词汇表
        parameter
        -----------------
        text:array of list
            包含文本的数据集
        max_vocab :int
            词汇表最大长度
        '''
        vocab = Counter()
        for text_line in text:
            vocab.update(text_line)

        self.word_index = {}
        c = self.__class__
        d = {c.PAD:c.pad_idx,
             c.UNK:c.unk_idx,
             c.SOS:c.sos_idx,
             c.EOS:c.eos_idx
             }
        self.word_index.update(d)

        for idx ,(word,count) in enumerate(vocab.most_common(max_vocab-4),start = 4):
            self.word_index[word]=idx

        self.index_word = {index:word for word,index in self.word_index.items()}
        self.vocab_size = len(self.word_index)

src_vocab = Vocab(src_data)
tgt_vocab = Vocab(tgt_data)
print(src_vocab.index_word)
print(tgt_vocab.index_word)

{0: '<pad>', 1: '<unk>', 2: '<SOS>', 3: '<ens>', 4: 'je', 5: 'de', 6: 'pas', 7: 'que', 8: 'à', 9: 'ne', 10: 'le', 11: 'la', 12: '?', 13: 'vous', 14: 'il', 15: 'est', 16: 'ce', 17: 'un', 18: 'tu', 19: 'a', 20: 'nous', 21: 'en', 22: 'les', 23: 'tom', 24: 'une', 25: 'me', 26: "j'ai", 27: 'pour', 28: 'suis', 29: 'elle', 30: "c'est", 31: '!', 32: 'dans', 33: 'des', 34: 'plus', 35: 'te', 36: 'qui', 37: 'faire', 38: 'se', 39: 'du', 40: 'tout', 41: 'mon', 42: 'avec', 43: 'au', 44: 'veux', 45: 'si', 46: 'ça', 47: "qu'il", 48: 'son', 49: 'fait', 50: 'et', 51: 'y', 52: 'sont', 53: 'cette', 54: 'votre', 55: 'être', 56: 'très', 57: 'ma', 58: 'sur', 59: 'ils', 60: 'été', 61: 'pense', 62: 'lui', 63: 'pourquoi', 64: "n'est", 65: 'peux', 66: 'était', 67: 'comment', 68: "n'ai", 69: 'jamais', 70: 'ton', 71: 'sa', 72: 'où', 73: 'par', 74: 'quelque', 75: 'dit', 76: 'êtes', 77: 'temps', 78: 'vraiment', 79: 'chose', 80: 'tous', 81: 'sais', 82: 'sommes', 83: 'on', 84: 'ont', 85: 'beaucoup', 86: 'bien', 87: 'b

In [23]:
# 自定义数据集
class ParallelDataset(Dataset):
    '''自定义数据集类，获取平行数据'''
    def __init__(self,src_data,tgt_data,src_vocab,tgt_vocab,
                 max_src_length=None,max_tgt_length=None):
        if max_src_length is None:
             max_src_length = self.__get_max_seq_len__(src_data)
        if max_tgt_length is None:
             max_tgt_length = self.__get_max_seq_len__(tgt_data)  

        self.data = []    
        for src,tgt in zip(src_data,tgt_data):
            src_idx = [src_vocab.word_index.get(token,Vocab.unk_idx) for token in src]
            tgt_idx = [tgt_vocab.word_index.get(token,Vocab.unk_idx) for token in tgt]

            tgt_idx = [Vocab.sos_idx] + tgt_idx + [Vocab.eos_idx]
            # 填充或截断
            src_idx = self.__pad_or_truncatr__(src_idx,max_src_length)
            tgt_idx = self.__pad_or_truncatr__(tgt_idx,max_tgt_length)
            # 将序列转化为张量
            src_idx = torch.LongTensor(src_idx)
            tgt_idx = torch.LongTensor(tgt_idx)

            self.data.append((src_idx,tgt_idx))

    def __len__(self):
         return len(self.data)
    
    def __getitem__(self,index):
         return self.data[index]

    def __pad_or_truncatr__(self,seq,max_len):
        seq_len = len(seq)
        if seq_len>max_len:
            seq = seq[:max_len]
        else:
            seq = seq + [Vocab.pad_idx]*(max_len - seq_len)
        return seq

    def __get_max_seq_len__(self,text_data):
            max_len = max(len(d) for d in text_data)
            return max_len


In [25]:
print(src_data[0])
print(tgt_data[0])

['go.']
['va', '!']


In [24]:
dataset = ParallelDataset(src_data,tgt_data,src_vocab,tgt_vocab)
dataset[0]

(tensor([187,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0]),
 tensor([  2, 123,  31,   3,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0]))

In [None]:
# 位置编码

class PositionalEncodeing(nn.Module):
    def __init__(self,d_model,max_length = 1000):
        '''初始化方法
        
        parameter
        --------------
        d_model ：int
            嵌入向量维度
        max_length ：int
            最大序列长度    
        '''
        super().__init__()
        # 创建位置编码矩阵
        pe = torch.zeros(max_length,d_model)
        # 创建一个一维张量，其元素为从0到max_length-1，便是序列中的各个位置
        # 将形状转（max_length,)换为（max_length,1)，便于后续计算
        position = torch.arange(0,max_length,dtype=torch.float).unsqueeze(1)
        # exp(log(a)*b) = a^b

        div_trem = torch.exp(torch.arange(0,d_model,2) * (-np.log(10000.0)/d_model))
        # d_model必须为偶数，保证奇数长度与偶数长度相同
        # Position*div_trem.shape = (max_length,d_model/2)
        pe[:,0::2] = torch.sin(position * div_trem)
        pe[:,1::2] = torch.cos(position * div_trem)
        # 将pe注册为模型的缓冲区
        # 缓冲区时pytorch中的一种特殊属性，其不会被计算图追踪，不会更新梯度
        # 但是，成为缓冲区后，会成为state_dict的一部分，会随着模型一起保存和加载
        # 当注册缓冲区后，变量就会绑定当前对象，成为当前对象属性
        # 注册属性与绑定属性的区别:
            # 1、缓冲区会随着模型一起保存和加载，但是绑定属性无此功能
            # 2、缓冲区与模型参数一样，会随着模型一起迁移，但绑定属性无此功能
        self.register_buffer('pe',pe)
    
    def forward(self,x):
        # x.shape = (batch_size,seq_length,d_model)
        # 将词嵌入向量与位置张量相加
        x + self.pe[:x.size(1)]
        return x

In [None]:
class TransformerModel(nn.Module):
    '''transformer模型类
    pytorch中提供的tansformer类,不包含词嵌入和位置编码以及输出层
    '''
    def __init__(self,src_vocab_size,tgt_vocab_size)