<img src='ex12-2.png'>

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

In [3]:
one_hot_lookup = [
    [1, 0, 0, 0, 0], # 0 h
    [0, 1, 0, 0, 0], # 1 i
    [0, 0, 1, 0, 0], # 2 e
    [0, 0, 0, 1, 0], # 3 l
    [0, 0, 0, 0, 1], # 4 o
]
x_data = [0, 1, 0, 2, 3, 3] # hihell
y_data = [1, 0, 2, 3, 3, 4] # ihello
x_one_hot = [one_hot_lookup[i] for i in x_data]

# (2) Parameters

In [7]:
num_classes = 5
input_size = 5  # one_hot size
hidden_size = 5 # output from the LSTM. 5 to directly predict one-hot
batch_size = 1  # one sentence
sequence_length = 1 # Let's do one by one
num_layers = 1  # one-layer rnn

In [9]:
inputs = torch.tensor(x_one_hot, dtype=torch.float)
labels = torch.tensor(y_data, dtype=torch.long)

# 1. Model

In [22]:
class Model(nn.Module):
    def __init__(self,
                input_size=5,
                hidden_size=5,
                num_layers=1,
                batch_size=1,
                sequence_length=1,
                num_classes=5):
        super().__init__()
        self.rnn = nn.RNN(input_size=input_size,
                         hidden_size=hidden_size,
                         batch_first=True)
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.num_classes = num_classes
        
        # Fully-Connected layer
        self.fc = nn.Linear(num_classes, num_classes)

    def forward(self, x, hidden):
        # Reshape input in (batch_size, sequence_length, input_size)
        x = x.view(self.batch_size, self.sequence_length, self.input_size)

        out, hidden = self.rnn(x, hidden)
        out = self.fc(out) # Add here
        out = out.view(-1, self.num_classes)
        return hidden, out
    
    def init_hidden(self):
        return torch.zeros(self.num_layers, self.batch_size, self.hidden_size)


# 2. Criterion & Loss

In [23]:
model = Model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.1)

# 3. Training

In [24]:
model = Model(input_size=5, hidden_size=5, num_layers=1, batch_size=1, sequence_length=6, num_classes=5)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.1)

In [25]:
hidden = model.init_hidden()
loss = 0

In [26]:
idx2char = ['h', 'i', 'e', 'l', 'o']
x_data = [0, 1, 0, 2, 3, 3] # hihell
one_hot_dict = {
    'h': [1, 0, 0, 0, 0],
    'i': [0, 1, 0, 0, 0],
    'e': [0, 0, 1, 0, 0],
    'l': [0, 0, 0, 1, 0],
    'o': [0, 0, 0, 0, 1],
}
one_hot_lookup = [
    [1, 0, 0, 0, 0], # 0 h
    [0, 1, 0, 0, 0], # 1 i
    [0, 0, 1, 0, 0], # 2 e
    [0, 0, 0, 1, 0], # 3 l
    [0, 0, 0, 0, 1], # 4 o
]
y_data = [1, 0, 2, 3, 3, 4] # ihello
x_one_hot = [one_hot_lookup[x] for x in x_data]

In [27]:
inputs = torch.tensor(x_one_hot, dtype=torch.float)
labels = torch.tensor(y_data, dtype=torch.long)

In [28]:
inputs

tensor([[1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0.],
        [0., 0., 0., 1., 0.],
        [0., 0., 0., 1., 0.]])

In [29]:
labels

tensor([1, 0, 2, 3, 3, 4])

In [30]:
for epoch in range(0, 15 + 1):
    hidden.detach_()
    hidden = hidden.detach()
    hidden = hidden.clone().detach().requires_grad_(True) # New syntax from `1.0`
    
    hidden, outputs = model(inputs, hidden)
    optimizer.zero_grad()
    loss = criterion(outputs, labels) # It wraps for-loop in here
    loss.backward()
    optimizer.step()
    _, idx = outputs.max(1)
    idx = idx.data.numpy()
    result_str = [idx2char[c] for c in idx.squeeze()]
    print(f"epoch: {epoch}, loss: loss.data")
    print(f"Predicted string: {''.join(result_str)}")

epoch: 0, loss: loss.data
Predicted string: oooooo
epoch: 1, loss: loss.data
Predicted string: llllll
epoch: 2, loss: loss.data
Predicted string: lhllll
epoch: 3, loss: loss.data
Predicted string: ihilll
epoch: 4, loss: loss.data
Predicted string: ihilll
epoch: 5, loss: loss.data
Predicted string: ihilll
epoch: 6, loss: loss.data
Predicted string: ihilll
epoch: 7, loss: loss.data
Predicted string: ihelll
epoch: 8, loss: loss.data
Predicted string: ihelll
epoch: 9, loss: loss.data
Predicted string: ehelll
epoch: 10, loss: loss.data
Predicted string: ehello
epoch: 11, loss: loss.data
Predicted string: ihello
epoch: 12, loss: loss.data
Predicted string: ihello
epoch: 13, loss: loss.data
Predicted string: ihello
epoch: 14, loss: loss.data
Predicted string: ihello
epoch: 15, loss: loss.data
Predicted string: ihello
