In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# Random seed to make results deterministic and reproducible
torch.manual_seed(0)

sentence = ("if you want to build a ship, don't drum up people together to "
            "collect wood and don't assign them tasks and work, but rather "
            "teach them to long for the endless immensity of the sea.")
# make dictionary
char_set = list(set(sentence))
char_dic = {c: i for i, c in enumerate(char_set)}
char_dic

{"'": 0,
 'g': 1,
 'b': 2,
 't': 3,
 'm': 4,
 'i': 5,
 ' ': 6,
 ',': 7,
 'f': 8,
 's': 9,
 'a': 10,
 'e': 11,
 'l': 12,
 '.': 13,
 'd': 14,
 'r': 15,
 'k': 16,
 'u': 17,
 'w': 18,
 'y': 19,
 'h': 20,
 'c': 21,
 'p': 22,
 'n': 23,
 'o': 24}

In [2]:
# hyperparameter
dic_size = len(char_dic)
hidden_size = len(char_dic)
sequence_length = 10 # 임의 결정
learning_rate = 0.1

In [3]:
x_data = []
y_data = []

for i in range(0, len(sentence) - sequence_length):
    x_str = sentence[i : i + sequence_length]
    y_str = sentence[i + 1 : i + sequence_length + 1]
    print(i, x_str, '->', y_str)
    
    # get index to make one hot vector
    x_data.append([char_dic[c] for c in x_str])
    y_data.append([char_dic[c] for c in y_str])

x_one_hot = [np.eye(dic_size)[c] for c in x_data] # one hot encoding


0 if you wan -> f you want
1 f you want ->  you want 
2  you want  -> you want t
3 you want t -> ou want to
4 ou want to -> u want to 
5 u want to  ->  want to b
6  want to b -> want to bu
7 want to bu -> ant to bui
8 ant to bui -> nt to buil
9 nt to buil -> t to build
10 t to build ->  to build 
11  to build  -> to build a
12 to build a -> o build a 
13 o build a  ->  build a s
14  build a s -> build a sh
15 build a sh -> uild a shi
16 uild a shi -> ild a ship
17 ild a ship -> ld a ship,
18 ld a ship, -> d a ship, 
19 d a ship,  ->  a ship, d
20  a ship, d -> a ship, do
21 a ship, do ->  ship, don
22  ship, don -> ship, don'
23 ship, don' -> hip, don't
24 hip, don't -> ip, don't 
25 ip, don't  -> p, don't d
26 p, don't d -> , don't dr
27 , don't dr ->  don't dru
28  don't dru -> don't drum
29 don't drum -> on't drum 
30 on't drum  -> n't drum u
31 n't drum u -> 't drum up
32 't drum up -> t drum up 
33 t drum up  ->  drum up p
34  drum up p -> drum up pe
35 drum up pe -> rum up peo
36

In [4]:
len(x_one_hot)

170

In [5]:
X = torch.FloatTensor(x_one_hot)
Y = torch.LongTensor(y_data)
X.shape

  X = torch.FloatTensor(x_one_hot)


torch.Size([170, 10, 25])

In [6]:
# rnn = nn.RNN(dic_size, hidden_size, batch_first=True) # batch_first = True로 두면 batch dimension이 가장 앞으로 온다.
class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, layers):
        super().__init__()
        self.rnn = nn.RNN(input_dim, hidden_dim, num_layers=layers, batch_first=True) # num_layers를 주면 여러 층을 만들 수 있다.
        self.fc = nn.Linear(hidden_dim, hidden_dim, bias=True)
    def forward(self, x):
        x, _status = self.rnn(x)
        x = self.fc(x)
        return x

RNNnet = Net(dic_size, hidden_size, 2)
RNNnet

Net(
  (rnn): RNN(25, 25, num_layers=2, batch_first=True)
  (fc): Linear(in_features=25, out_features=25, bias=True)
)

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(RNNnet.parameters(), learning_rate)

In [16]:
epochs = 100
for i in range(epochs):
    optimizer.zero_grad()
    output = RNNnet(X)
    
    # loss = criterion(output, Y)
    loss = criterion(output.view(-1, dic_size), Y.view(-1))
    loss.backward()
    optimizer.step()
    
    print(output, output.shape)
    print(output.view(X.shape[0], -1, dic_size))
    result = torch.argmax(output.view(X.shape[0], -1, dic_size), dim=2)
    print(result, result.shape)
    
    predicted_str = ''
    for j, predicted_char in enumerate(result):
        if j == 0:
            predicted_str += ''.join([char_set[t] for t in predicted_char])
        else:
            predicted_str += char_set[predicted_char[-1]]
            
    print(predicted_str)

tensor([[[ -2.1967,   7.1988,  -2.9320,  ...,   7.1592,  -0.7761,   4.6289],
         [ -7.9078,  -5.0815, -11.7137,  ...,  -6.3805,   3.3785,   2.7797],
         [ -4.2328,  -5.1190,  -4.2326,  ...,   1.6589,   7.4987,   1.5819],
         ...,
         [  0.7704,  -7.7905,   5.7151,  ...,   1.1204,  -0.5623,   9.0222],
         [ -5.0527,   0.9883,  -1.9418,  ...,   0.9489,  14.2695,   8.2906],
         [ -0.2171,   6.9149,  -6.0076,  ...,   5.9867,  -8.4882,   2.8841]],

        [[-11.4244, -13.1455,  -6.2113,  ...,  -5.3407, -10.7743,  11.7500],
         [ -4.3102,   0.6500,  -4.6794,  ...,   0.2443,   7.4210,   3.3460],
         [  0.1170,   0.5491,   0.9560,  ...,   7.3470,  -3.6709,  11.7076],
         ...,
         [ -5.0703,   0.9722,  -1.9447,  ...,   0.9508,  14.2852,   8.3253],
         [ -0.3508,   6.8049,  -6.2003,  ...,   5.8477,  -8.3613,   2.8838],
         [ -3.6938,  -9.4163, -10.8702,  ..., -10.1188,  -2.7014,  -6.6086]],

        [[ -3.4656,  -0.2938,   3.8264,  ...