In [10]:
import torch
import torch.nn as nn
import numpy as np

class LongShortTermMemoryModel(nn.Module):

    def __init__(self, emoji_encoding, encoding_size):
        super(LongShortTermMemoryModel, self).__init__()

        self.lstm = nn.LSTM(encoding_size, 128, batch_first=True)  # 128 is the state size
        self.dense = nn.Linear(128, emoji_encoding)  # 128 is the state size

    def reset(self, batch_size):  # Reset states prior to new input sequence
        # takes in a batch_size so that the hidden_state, and cell_state is the correct size
        # when operating on different batch_sizes
        zero_state = torch.zeros(1, batch_size, 128)  # Shape: (batch_length, sequence_length, state size) Due to batch_first=True
        self.hidden_state = zero_state
        self.cell_state = zero_state

    def logits(self, x):  # x shape: (sequence length, batch size, encoding size)
        out, (self.hidden_state, self.cell_state) = self.lstm(x, (self.hidden_state, self.cell_state))
        # lstm computes an output for each element in the seqeunce
        # this means that out is of the shape (batch_size, sequence_length, state_size)
        # ':' means take everything from this dimension
        # '-1' means take the last element form this dimension
        out = out[:, -1, :]
        return self.dense(out)

    def f(self, x):  # x shape: (sequence length, batch size, encoding size)
        return torch.softmax(self.logits(x), dim=1)

    def loss(self, x, y):  # x shape: (sequence length, batch size, encoding size), y shape: (sequence length, encoding size)
        return nn.functional.cross_entropy(self.logits(x), y.argmax(1))

    
index_to_char = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 
                 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 
                 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ']

index_to_emoji = ['\U0001F408', '\U0001F400', '\U0001F3A9']

encoding_size = len(index_to_char)
emoji_size    = len(index_to_emoji) 

char_encodings  = torch.eye(encoding_size).numpy()
emoji_encodings = torch.eye(emoji_size).numpy()

In [11]:
x_train = torch.tensor([
    [char_encodings[index_to_char.index('c')], char_encodings[index_to_char.index('a')], char_encodings[index_to_char.index('t')]],
    [char_encodings[index_to_char.index('r')], char_encodings[index_to_char.index('a')], char_encodings[index_to_char.index('t')]],
    [char_encodings[index_to_char.index('h')], char_encodings[index_to_char.index('a')], char_encodings[index_to_char.index('t')]]
])
    
y_train = torch.tensor([emoji_encodings[0], emoji_encodings[1], emoji_encodings[2]])

print(x_train.shape)
print(x_train.shape)
print(y_train.shape)
print(y_train.shape)

print("Batch_Size, Sequence_Length, Encoding_size:", x_train.shape)
print("Batch_Size, Emoji_Encoding_Size:", y_train.shape)

torch.Size([3, 3, 27])
torch.Size([3, 3, 27])
torch.Size([3, 3])
torch.Size([3, 3])
Batch_Size, Sequence_Length, Encoding_size: torch.Size([3, 3, 27])
Batch_Size, Emoji_Encoding_Size: torch.Size([3, 3])


In [12]:
model = LongShortTermMemoryModel(emoji_size, encoding_size)

In [13]:
optimizer = torch.optim.RMSprop(model.parameters(), 0.001)
for epoch in range(100):
    model.reset(3)
    model.loss(x_train, y_train).backward()
    optimizer.step()
    optimizer.zero_grad()

    if epoch % 10 == 9:
        # Generates emojis from the words "cat", "rat", and "hat"
        model.reset(1)
        y = model.f(torch.tensor(x_train[0].reshape(1, 3, encoding_size)))
        print("Cat:", index_to_emoji[y.argmax(1)])

        model.reset(1)
        y = model.f(torch.tensor(x_train[1].reshape(1, 3, encoding_size)))
        print("Rat:", index_to_emoji[y.argmax(1)])

        model.reset(1)
        y = model.f(torch.tensor(x_train[2].reshape(1, 3, encoding_size)))
        print("Hat:", index_to_emoji[y.argmax(1)])

Cat: 🐈
Rat: 🐀
Hat: 🎩
Cat: 🐈
Rat: 🐀
Hat: 🎩
Cat: 🐈
Rat: 🐀
Hat: 🎩
Cat: 🐈
Rat: 🐀
Hat: 🎩
Cat: 🐈
Rat: 🐀
Hat: 🎩
Cat: 🐈
Rat: 🐀
Hat: 🎩
Cat: 🐈
Rat: 🐀
Hat: 🎩
Cat: 🐈
Rat: 🐀
Hat: 🎩
Cat: 🐈
Rat: 🐀
Hat: 🎩
Cat: 🐈
Rat: 🐀
Hat: 🎩


  y = model.f(torch.tensor(x_train[0].reshape(1, 3, encoding_size)))
  y = model.f(torch.tensor(x_train[1].reshape(1, 3, encoding_size)))
  y = model.f(torch.tensor(x_train[2].reshape(1, 3, encoding_size)))


In [14]:
rt = torch.tensor([[char_encodings[index_to_char.index('h')], char_encodings[index_to_char.index('t')], char_encodings[index_to_char.index(' ')]]])
#rt = rt.reshape(1, 3, encoding_size)

model.reset(1)
y = model.f(torch.tensor(rt))
print(index_to_emoji[y.argmax(1)])

🎩


  y = model.f(torch.tensor(rt))


In [15]:
torch.tensor([torch.eye(9).numpy()[0]])

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [16]:
torch.tensor([char_encodings[0]])

tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.]])