# Char RNN

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

## 1. Preprocessing Train data

In [2]:
input_str = 'apple'
label_str = 'pple!'
char_vocab = sorted(list(set(input_str+label_str)))
print(char_vocab)

['!', 'a', 'e', 'l', 'p']


In [3]:
vocab_size = len(char_vocab)

We are going to use a `one hot vector` for every character. This means every character is going to have a `vocab_size` size of vector.

In [4]:
input_size = vocab_size
hidden_size = 5
output_size = 5
learning_rate = 0.1

In [5]:
char_to_index = dict((c,i) for i, c in enumerate(char_vocab))
print(char_to_index)

{'!': 0, 'a': 1, 'e': 2, 'l': 3, 'p': 4}


In [6]:
index_to_char={}
for key, value in char_to_index.items():
    index_to_char[value] = key
print(index_to_char)

{0: '!', 1: 'a', 2: 'e', 3: 'l', 4: 'p'}


In [7]:
x_data = [char_to_index[c] for c in input_str]
y_data = [char_to_index[c] for c in label_str]
print(x_data)
print(y_data)

[1, 4, 4, 3, 2]
[4, 4, 3, 2, 0]


In [16]:
x_one_hot = [np.eye(vocab_size)[x_data]]
# y_data = [np.eye(vocab_size)[y_data]]
X = torch.FloatTensor(x_one_hot)
Y = torch.LongTensor(y_data)

In [17]:
class Net(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Net, self).__init__()
        self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size, bias=True)

    def forward(self, x):
        x, _status = self.rnn(x)
        x = self.fc(x)
        return x

In [18]:
net = Net(input_size, hidden_size, output_size)

In [19]:
print(X.shape)
outputs = net(X)

torch.Size([1, 5, 5])


In [20]:
print(outputs.shape)

torch.Size([1, 5, 5])


In [21]:
print(outputs.view(-1, input_size).shape)

torch.Size([5, 5])


In [22]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), learning_rate)

In [24]:
for i in range(100):
    optimizer.zero_grad()
    outputs= net(X)
    loss = criterion(outputs.view(-1, 5), Y.view(-1))
    loss.backward()
    optimizer.step()

    print("{}: loss:{}".format(i, loss))
    

0: loss:1.3171494007110596
1: loss:1.1435484886169434
2: loss:0.9743162393569946
3: loss:0.7830328345298767
4: loss:0.6111847758293152
5: loss:0.473209947347641
6: loss:0.35816025733947754
7: loss:0.2656562924385071
8: loss:0.19771064817905426
9: loss:0.14820870757102966
10: loss:0.1113034039735794
11: loss:0.08416984975337982
12: loss:0.06493230164051056
13: loss:0.05044865608215332
14: loss:0.038754016160964966
15: loss:0.029881086200475693
16: loss:0.023426655679941177
17: loss:0.018636813387274742
18: loss:0.014998664148151875
19: loss:0.012214486487209797
20: loss:0.010073183104395866
21: loss:0.008397890254855156
22: loss:0.007055474910885096
23: loss:0.005970596801489592
24: loss:0.0051023224368691444
25: loss:0.0044135539792478085
26: loss:0.003865160048007965
27: loss:0.0034221361856907606
28: loss:0.0030574225820600986
29: loss:0.0027518807910382748
30: loss:0.002491906750947237
31: loss:0.002268292475491762
32: loss:0.0020741198677569628
33: loss:0.0019046490779146552
34: lo