In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
sentence = (
    "if you want to build a ship, don't drum up people together to "
    "collect wood and don't assign them tasks and work, but rather "
    "teach them to long for the endless immensity of the sea."
)

In [3]:
char_set = list(set(sentence))
char_dic = {c: i for i, c in enumerate(char_set)}
char_dic

{'u': 0,
 "'": 1,
 'i': 2,
 'f': 3,
 'a': 4,
 ' ': 5,
 'c': 6,
 'e': 7,
 'g': 8,
 '.': 9,
 'w': 10,
 'p': 11,
 ',': 12,
 'k': 13,
 'y': 14,
 'b': 15,
 'd': 16,
 't': 17,
 's': 18,
 'm': 19,
 'h': 20,
 'l': 21,
 'n': 22,
 'r': 23,
 'o': 24}

In [4]:
dic_size = len(char_dic)
dic_size

25

In [5]:
hidden_size = dic_size
sequence_length = 10
lr = 0.1

In [6]:
x_data = []
y_data = []

for i in range(0, len(sentence) - sequence_length):
    x_str = sentence[i : i + sequence_length]
    y_str = sentence[i + 1 : i + 1 + sequence_length]
    print(i, x_str, "->", y_str)

    x_data.append([char_dic[c] for c in x_str])
    y_data.append([char_dic[c] for c in y_str])

x_data[0], y_data[0]

0 if you wan -> f you want
1 f you want ->  you want 
2  you want  -> you want t
3 you want t -> ou want to
4 ou want to -> u want to 
5 u want to  ->  want to b
6  want to b -> want to bu
7 want to bu -> ant to bui
8 ant to bui -> nt to buil
9 nt to buil -> t to build
10 t to build ->  to build 
11  to build  -> to build a
12 to build a -> o build a 
13 o build a  ->  build a s
14  build a s -> build a sh
15 build a sh -> uild a shi
16 uild a shi -> ild a ship
17 ild a ship -> ld a ship,
18 ld a ship, -> d a ship, 
19 d a ship,  ->  a ship, d
20  a ship, d -> a ship, do
21 a ship, do ->  ship, don
22  ship, don -> ship, don'
23 ship, don' -> hip, don't
24 hip, don't -> ip, don't 
25 ip, don't  -> p, don't d
26 p, don't d -> , don't dr
27 , don't dr ->  don't dru
28  don't dru -> don't drum
29 don't drum -> on't drum 
30 on't drum  -> n't drum u
31 n't drum u -> 't drum up
32 't drum up -> t drum up 
33 t drum up  ->  drum up p
34  drum up p -> drum up pe
35 drum up pe -> rum up peo
36

([2, 3, 5, 14, 24, 0, 5, 10, 4, 22], [3, 5, 14, 24, 0, 5, 10, 4, 22, 17])

In [7]:
x_one_hot = [np.eye(dic_size)[x] for x in x_data]
X = torch.FloatTensor(x_one_hot)
y = torch.LongTensor(y_data)
X.shape, y.shape

  X = torch.FloatTensor(x_one_hot)


(torch.Size([170, 10, 25]), torch.Size([170, 10]))

In [8]:
class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, layers):
        super().__init__()
        self.rnn = nn.RNN(input_dim, hidden_dim, num_layers=layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, hidden_dim, bias=True)

    def forward(self, x):
        x, _status = self.rnn(x)
        x = self.fc(x)
        return x

In [9]:
net = Net(dic_size, hidden_size, 2)

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=lr)

In [11]:
for i in range(100):
    optimizer.zero_grad()
    outputs = net(X)
    loss = criterion(outputs.view(-1, dic_size), y.view(-1))
    loss.backward()
    optimizer.step()

    results = outputs.argmax(dim=2)
    predict_str = ""
    for j, result in enumerate(results):
        if j == 0:
            predict_str += "".join([char_set[t] for t in result])
        else:
            predict_str += char_set[result[-1]]

    print(predict_str, end="\n\n")

bu''''''''t'''''o''o'''''bu''u''a'''''c'uo'u''u'''''''''''''''''''''''''''''''''''a''''cb''''''''''''c''''''''a''uo'''''''''''''''t'''''''''''''''o'''''''''ccbt''''c''''''''''''''

ttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttttt

o,t o,t o,,l,l,lg,,ll,,,g,,,ll,lgl,,og,,l,,ol,,lll,,,o,,,gl,,lo,,lll,,lg,,,l,,l,,,lg,,lg,,ll,,,gl,,,,,,g,lll,ocllll,ll,,,lg,,lgl,,ol,,l,,lo,,lo,,ll,,.gl,,ll,,l,og,,lgl,lg,,ol,b,,,

oooooooooo ooo o oo o o o o o o o oooooooo o ooo ooo o oo ooooooooooooooooo o o o o o ooooo o o ooo o ooo o o o o o o oooo ooooo oooooooooo o ooo o ooooo o o o o o  o o ooooooooo 

  d s d d d d  d d      d d d d h d d d     dd     d            d d           d   d d d d     d d d   d     d d d d d d d      d d d   d    d     d   n d d     d d   d    d d  hd 

   u ute ttttt t t  tt tt t ttttttt tut    ttt  t  t      tt    t ttt t       t t tnt tut    tt