In [10]:
import numpy as np
import torch
from torch import nn

import matplotlib.pyplot as plt

In [16]:
# 构造图像数据

sample_num, word_num, n_classes = 100, 30, 3
seq_len = 20

x = np.random.randint(low=0, high=word_num, size=(sample_num, seq_len), dtype=np.int32)
y = np.random.randint(low=0, high=n_classes, size=(sample_num, ), dtype=np.int32)

print(f"x shape {x.shape}, y shape {y.shape}")

x shape (100, 20), y shape (100,)


In [12]:
x

array([[18,  6, 18, ..., 16, 26,  6],
       [25, 19,  9, ..., 24,  1,  3],
       [28, 23,  7, ..., 26, 17, 25],
       ...,
       [17,  1, 17, ..., 21,  5, 11],
       [14, 13, 29, ..., 27, 16, 19],
       [ 0, 25, 15, ..., 13, 26, 10]], dtype=int32)

In [64]:
class RNN(nn.Module):
    def __init__(self, vocab_size, seq_len, num_classes, embedding_dim=8, hidden_size=10):
        super(RNN, self).__init__()
        
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)

        self.rnn = nn.RNN(
            input_size=embedding_dim,
            hidden_size=hidden_size,     # rnn hidden unit
            num_layers=3,       # number of rnn layer
            batch_first=True,   # input & output will has batch size as 1s dimension. e.g. (batch, time_step, input_size)
        )
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.embedding(x)
        out, _ = self.rnn(out)
        out = out[:, -1, :]  # 剪掉一个维度，先当于只用之一句话的最后一个状态，由三维变化为二维
        out = self.fc(out)
        return out 

In [65]:
model = RNN(vocab_size=word_num, seq_len=seq_len, num_classes=n_classes)
print(model)

use_gpu = torch.cuda.is_available()  # 判断是否有GPU加速
if use_gpu:
    model = model.cuda()

# 定义loss和optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

RNN(
  (embedding): Embedding(30, 8)
  (rnn): RNN(8, 10, num_layers=3, batch_first=True)
  (fc): Linear(in_features=10, out_features=3, bias=True)
)


In [66]:
# 开始训练
x_train, y_train = torch.LongTensor(x), torch.LongTensor(y)
# torch.zeros(sample_num, n_classes).scatter_(1, torch.LongTensor(y), 1)
# y_train = type(torch.LongTensor)

num_epochs = 500
for epoch in range(num_epochs):
    # forward
    out = model(x_train)
    loss = criterion(out, y_train)
    # backward
    optimizer.zero_grad()  # 梯度归零
    loss.backward()  # 梯度反向传播
    optimizer.step()  # 参数更新

    if (epoch+1) % 20 == 0:
        print(f'Epoch[{epoch+1}/{num_epochs}], loss: {loss.item():.6f}')

Epoch[20/500], loss: 1.110524
Epoch[40/500], loss: 1.110019
Epoch[60/500], loss: 1.109526
Epoch[80/500], loss: 1.109042
Epoch[100/500], loss: 1.108569
Epoch[120/500], loss: 1.108105
Epoch[140/500], loss: 1.107652
Epoch[160/500], loss: 1.107207
Epoch[180/500], loss: 1.106772
Epoch[200/500], loss: 1.106344
Epoch[220/500], loss: 1.105926
Epoch[240/500], loss: 1.105516
Epoch[260/500], loss: 1.105114
Epoch[280/500], loss: 1.104720
Epoch[300/500], loss: 1.104333
Epoch[320/500], loss: 1.103954
Epoch[340/500], loss: 1.103582
Epoch[360/500], loss: 1.103217
Epoch[380/500], loss: 1.102859
Epoch[400/500], loss: 1.102507
Epoch[420/500], loss: 1.102162
Epoch[440/500], loss: 1.101823
Epoch[460/500], loss: 1.101490
Epoch[480/500], loss: 1.101163
Epoch[500/500], loss: 1.100842


In [67]:
with torch.no_grad():
    predict = model(x_train)
predict = predict.data.numpy()

In [68]:
predict

array([[-0.07068527, -0.22012332, -0.01737164],
       [ 0.0863782 , -0.0467135 ,  0.13289216],
       [ 0.02278072, -0.0218434 ,  0.18174088],
       [-0.14376715, -0.27758366, -0.09171807],
       [-0.25794128, -0.31541547, -0.21167192],
       [ 0.12048702, -0.12473196,  0.00445241],
       [-0.04920324, -0.11910459,  0.03360915],
       [-0.12488468, -0.13310745, -0.00180935],
       [ 0.14338186, -0.02986543,  0.06399147],
       [-0.13030346, -0.12371926,  0.00507882],
       [-0.10083167, -0.13477316,  0.0219079 ],
       [ 0.07371041, -0.05099773,  0.22069904],
       [ 0.18935221,  0.02289301,  0.22497487],
       [-0.01867209, -0.07139479,  0.11718747],
       [ 0.14879782, -0.16106819,  0.0092022 ],
       [ 0.00520992, -0.06285569,  0.16947648],
       [ 0.16829942,  0.02081484,  0.14616325],
       [ 0.01529247, -0.07988949,  0.12312925],
       [-0.04082021, -0.08008818,  0.06940278],
       [ 0.28235388, -0.03930925,  0.1685369 ],
       [ 0.22088383, -0.06033388,  0.157