In [1]:
zidian = {
    '<PAD>': 0,
    '1': 1,
    '2': 2,
    '3': 3,
    '4': 4,
    '5': 5,
    '6': 6,
    '7': 7,
    '8': 8,
    '9': 9,
    '0': 10,
    'Jan': 11,
    'Feb': 12,
    'Mar': 13,
    'Apr': 14,
    'May': 15,
    'Jun': 16,
    'Jul': 17,
    'Aug': 18,
    'Sep': 19,
    'Oct': 20,
    'Nov': 21,
    'Dec': 22,
    '-': 23,
    '/': 24,
    '<SOS>': 25,
    '<EOS>': 26,
}

In [2]:
import torch
import torch.nn as nn
import numpy as np
import datetime
from torch.utils.data import Dataset, DataLoader


class DateDataset(Dataset):
    def __init__(self):
        pass

    def __len__(self):
        return 2000

    def __getitem__(self, index):
        #随机生成一个日期
        date = np.random.randint(143835585, 2043835585)
        date = datetime.datetime.fromtimestamp(date)

        #格式化成两种格式
        #05-06-15
        #15/Jun/2005
        date_cn = date.strftime("%y-%m-%d")
        date_en = date.strftime("%d/%b/%Y")

        #中文的就是简单的拿字典编码就行了
        date_cn_code = [zidian[v] for v in date_cn]

        #英文的,首先要在首尾加上标志位,然后用字典编码
        date_en_code = []
        date_en_code += [zidian['<SOS>']]
        date_en_code += [zidian[v] for v in date_en[:3]]
        date_en_code += [zidian[date_en[3:6]]]
        date_en_code += [zidian[v] for v in date_en[6:]]
        date_en_code += [zidian['<EOS>']]

        return torch.LongTensor(date_cn_code), torch.LongTensor(date_en_code)


dataloader = DataLoader(dataset=DateDataset(),
                        batch_size=100,
                        shuffle=True,
                        drop_last=True)

#遍历数据
for i, data in enumerate(dataloader):
    sample = data
    break
sample[0][:5], sample[0].shape, sample[1][:5], sample[1].shape

(tensor([[10, 10, 23,  1,  2, 23,  2,  8],
         [ 9,  3, 23,  1,  2, 23,  1,  7],
         [ 2, 10, 23, 10,  3, 23,  1,  2],
         [ 2, 10, 23,  1,  1, 23, 10,  1],
         [10, 10, 23, 10,  5, 23,  1,  3]]),
 torch.Size([100, 8]),
 tensor([[25,  2,  8, 24, 22, 24,  2, 10, 10, 10, 26],
         [25,  1,  7, 24, 22, 24,  1,  9,  9,  3, 26],
         [25,  1,  2, 24, 13, 24,  2, 10,  2, 10, 26],
         [25, 10,  1, 24, 21, 24,  2, 10,  2, 10, 26],
         [25,  1,  3, 24, 15, 24,  2, 10, 10, 10, 26]]),
 torch.Size([100, 11]))

In [3]:
class Seq2Seq(nn.Module):
    def __init__(self):
        super().__init__()

        #encoder
        #一共27个词,编码成16维向量
        self.encoder_embed = nn.Embedding(num_embeddings=27, embedding_dim=16)

        #输入是16维向量,隐藏层是32维向量
        self.encoder = nn.LSTM(input_size=16,
                               hidden_size=32,
                               num_layers=1,
                               batch_first=True)

        #decoder
        #一共27个词,编码成16维向量
        self.decoder_embed = nn.Embedding(num_embeddings=27, embedding_dim=16)

        #输入是16维向量,隐藏层是32维向量
        self.decoder_cell = nn.LSTMCell(input_size=16, hidden_size=32)

        #输入是32维向量,输出是27分类
        self.out_fc = nn.Linear(in_features=32, out_features=27)

    def forward(self, x, y):

        #x编码
        #[b,8] -> [b,8,16]
        x = self.encoder_embed(x)

        #进入循环网络,得到记忆
        #[b,8,16] -> [b,8,32],[1,b,32],[1,b,32]
        _, (h, c) = self.encoder(x, None)

        #[1,b,32],[1,b,32] -> [b,32],[b,32]
        h = h.squeeze()
        c = c.squeeze()

        #丢弃y的最后一个词
        #因为训练的时候是以y的每一个词输入,预测下一个词
        #所以不需要最后一个词
        #[b,11] -> [b,10]
        y = y[:, :-1]

        #y编码
        #[b,10] -> [b,10,16]
        y = self.decoder_embed(y)

        #用cell遍历y的每一个词
        outs = []
        for i in range(10):

            #把y的每个词依次输入循环网络
            #第一个词的记忆是x的最后一个词的记忆
            #往后每个词的记忆是上一个词的记忆
            #[b,16] -> [b,32],[b,32]
            h, c = self.decoder_cell(y[:, i], (h, c))

            #把每一步的记忆输出成词
            #[b,32] -> [b,27]
            out = self.out_fc(h)
            outs.append(out)

        #把所有的输出词组合成一句话
        outs = torch.stack(outs, dim=0)
        #[10,b,27] -> #[b,10,27]
        outs = outs.permute(1, 0, 2)

        return outs


model = Seq2Seq()

out = model(sample[0], sample[1])
out[0, :2], out.shape

(tensor([[ 0.1030, -0.1141, -0.1131,  0.0795, -0.0905, -0.0374,  0.1219, -0.1612,
           0.0579,  0.1664, -0.2122, -0.1027,  0.0780, -0.0654, -0.1115,  0.0588,
           0.0916, -0.1834,  0.0258, -0.0365, -0.0360, -0.0095, -0.0952,  0.0344,
           0.1635,  0.0589,  0.0351],
         [ 0.0419, -0.1616, -0.0572,  0.0688, -0.1074,  0.0127,  0.0630, -0.1885,
           0.0040,  0.1648, -0.1619, -0.0888,  0.1776, -0.0717, -0.1585,  0.1230,
           0.0006, -0.1844,  0.0451, -0.0121,  0.0281, -0.0370, -0.0675,  0.0560,
           0.0824,  0.0116,  0.0791]], grad_fn=<SliceBackward>),
 torch.Size([100, 10, 27]))

In [4]:
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

model.train()
for epoch in range(200):
    for i, data in enumerate(dataloader):
        x, y = data

        optimizer.zero_grad()

        #计算输出
        y_pred = model(x, y)

        #丢弃y的第一个词
        #因为训练的时候是以y的每一个词输入,预测下一个词
        #所以在计算loss的时候不需要第一个词
        #[b,11] -> [b,10]
        y = y[:, 1:]

        #打平,不然计算不了loss
        #[b,10,27] -> [b*10,27]
        y_pred = y_pred.reshape(-1, 27)

        #[b,10] -> [b*10]
        y = y.reshape(-1)

        loss = loss_func(y_pred, y)
        loss.backward()
        optimizer.step()

    if epoch % 10 == 0:
        print(epoch, loss.item())

0 1.8169605731964111
10 0.3342939019203186
20 0.09271496534347534
30 0.013728859834372997
40 0.005192959681153297
50 0.0030120087321847677
60 0.0020933125633746386
70 0.0014203935861587524
80 0.001078095636330545
90 0.0008668229565955698
100 0.0006723636179231107
110 0.0005770153948105872
120 0.0004773383552674204
130 0.00037985510425642133
140 0.00029425040702335536
150 0.0002678190066944808
160 0.0002279395266668871
170 0.00021694882889278233
180 0.00017270594253204763
190 0.00015532056568190455


In [5]:
#构造反转的字典
reverse_zidian = {}
for k, v in zidian.items():
    reverse_zidian[v] = k
reverse_zidian


#数字化的句子转字符串
def seq_to_str(seq):
    seq = seq.detach().numpy()
    return ''.join([reverse_zidian[idx] for idx in seq])


seq_to_str(sample[0][0]), seq_to_str(sample[1][0])

('00-12-28', '<SOS>28/Dec/2000<EOS>')

In [6]:
#预测
def predict(x):
    model.eval()

    #x编码
    #[b,8] -> [b,8,16]
    x = model.encoder_embed(x)
    #进入循环网络,得到记忆
    #[b,8,16] -> [b,8,32],[1,b,32],[1,b,32]
    _, (h, c) = model.encoder(x, None)

    #[1,b,32],[1,b,32] -> [b,32],[b,32]
    h = h.squeeze()
    c = c.squeeze()

    #初始化输入,每一个词的输入应该是上一个词的输出
    #因为我们的y第一个词固定是<SOS>,所以直接以这个词开始
    #[b]
    out = torch.full((x.size(0), ), zidian['<SOS>'], dtype=torch.int64)
    #[b] -> [b,16]
    out = model.decoder_embed(out)

    #循环生成9个词,收尾的两个标签没有预测的价值,直接忽略了
    outs = []
    for i in range(9):

        #把每个词输入循环网络
        #第一个词的记忆是x的最后一个词的记忆
        #往后每个词的记忆是上一个词的记忆
        #[b,16] -> [b,32],[b,32]
        h, c = model.decoder_cell(out, (h, c))

        #[b,32] -> [b,27]
        out = model.out_fc(h)

        #把每一步的记忆输出成词
        #[b,27] -> [b]
        out = out.argmax(dim=1)
        outs.append(out)

        #把这一步的输出作为下一步的输入
        #[b] -> [b,16]
        out = model.decoder_embed(out)

    #把所有的输出词组合成一句话
    #[9,b]
    outs = torch.stack(outs, dim=0)
    #[9,b] -> [b,9]
    outs = outs.permute(1, 0)

    return outs


#测试
for i, data in enumerate(dataloader):
    x, y = data
    y_pred = predict(x)
    for xi, yi, pi in zip(x, y, y_pred):
        print(seq_to_str(xi), seq_to_str(yi), seq_to_str(pi))
    break

80-12-27 <SOS>27/Dec/1980<EOS> 27/Dec/1980
30-04-05 <SOS>05/Apr/2030<EOS> 05/Apr/2030
23-05-10 <SOS>10/May/2023<EOS> 10/May/2023
87-01-18 <SOS>18/Jan/1987<EOS> 18/Jan/1987
22-07-10 <SOS>10/Jul/2022<EOS> 10/Jul/2022
92-05-24 <SOS>24/May/1992<EOS> 24/May/1992
25-11-10 <SOS>10/Nov/2025<EOS> 10/Nov/2025
08-01-28 <SOS>28/Jan/2008<EOS> 28/Jan/2008
97-09-12 <SOS>12/Sep/1997<EOS> 12/Sep/1997
93-09-25 <SOS>25/Sep/1993<EOS> 25/Sep/1993
95-11-06 <SOS>06/Nov/1995<EOS> 06/Nov/1995
23-09-29 <SOS>29/Sep/2023<EOS> 29/Sep/2023
89-11-04 <SOS>04/Nov/1989<EOS> 04/Nov/1989
01-08-05 <SOS>05/Aug/2001<EOS> 05/Aug/2001
80-02-22 <SOS>22/Feb/1980<EOS> 22/Feb/1980
98-01-28 <SOS>28/Jan/1998<EOS> 28/Jan/1998
24-12-03 <SOS>03/Dec/2024<EOS> 03/Dec/2024
95-05-19 <SOS>19/May/1995<EOS> 19/May/1995
22-03-27 <SOS>27/Mar/2022<EOS> 27/Mar/2022
05-10-07 <SOS>07/Oct/2005<EOS> 07/Oct/2005
88-08-01 <SOS>01/Aug/1988<EOS> 01/Aug/1988
27-08-28 <SOS>28/Aug/2027<EOS> 28/Aug/2027
91-05-31 <SOS>31/May/1991<EOS> 31/May/1991
75-11-07 <S