In [375]:
import torch
import torch.nn as nn

In [376]:
class LongShortTermMemoryModel(nn.Module):
    def __init__(self, encoding_size, label_size):
        super(LongShortTermMemoryModel, self).__init__()

        self.lstm = nn.LSTM(encoding_size, 128)  # 128 is the state size
        self.dense = nn.Linear(128, label_size)  # 128 is the state size

    def reset(self):  # Reset states prior to new input sequence
        zero_state = torch.zeros(1, 1, 128)  # Shape: (number of layers, batch size, state size)
        self.hidden_state = zero_state
        self.cell_state = zero_state

    def logits(self, x):  # x shape: (sequence length, batch size, encoding size)
        out, (self.hidden_state, self.cell_state) = self.lstm(x, (self.hidden_state, self.cell_state))
        return self.dense(out.reshape(-1, 128))

    def f(self, x):  # x shape: (sequence length, batch size, encoding size)
        return torch.softmax(self.logits(x), dim=1)

    def loss(self, x, y):  # x shape: (sequence length, batch size, encoding size), y shape: (sequence length, encoding size)
        return nn.functional.cross_entropy(self.logits(x), y.argmax(1))

In [377]:
c = [ #alphabetical order
    [1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],  # ' ' 00
    [0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],  # 'a' 01
    [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],  # 'c' 02
    [0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],  # 'f' 03
    [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],  # 'h' 04
    [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],  # 'l' 05
    [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],  # 'm' 06
    [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],  # 'n' 07
    [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],  # 'o' 08
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],  # 'p' 09
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],  # 'r' 10
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],  # 's' 11
    [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],  # 't' 12
]

e = [
    [1., 0., 0., 0., 0., 0., 0.],  # '🎩' 13
    [0., 1., 0., 0., 0., 0., 0.],  # '🐀' 14
    [0., 0., 1., 0., 0., 0., 0.],  # '🐈' 15
    [0., 0., 0., 1., 0., 0., 0.],  # '🏢' 16
    [0., 0., 0., 0., 1., 0., 0.],  # '🧑‍🦰' 17
    [0., 0., 0., 0., 0., 1., 0.],  # '🧢' 18
    [0., 0., 0., 0., 0., 0., 1.],  # '🧒' 19
]

encoding_size = len(c)
index_to_emoji = ['🎩', '🐀', '🐈', '🏢', '🧑‍🦰', '🧢', '🧒']
index_to_char = [' ', 'a', 'c', 'f', 'h', 'l', 'm', 'n', 'o', 'p', 'r', 's', 't']

x_train = torch.tensor([
    [[c[4]],    [c[1]], [c[12]],    [c[0]]],      # 'hat '
    [[c[10]],   [c[1]], [c[12]],    [c[0]]],     # 'rat '
    [[c[2]],    [c[1]], [c[12]],    [c[0]]],      # 'cat '
    [[c[3]],    [c[5]], [c[1]],     [c[12]]],      # 'flat'
    [[c[6]],    [c[1]], [c[12]],    [c[12]]],     # 'matt'
    [[c[2]],    [c[1]], [c[9]],     [c[0]]],       # 'cap '
    [[c[11]],   [c[8]], [c[7]],     [c[0]]],      # 'son '
])

y_train = torch.tensor([
    [e[0], e[0], e[0], e[0]],
    [e[1], e[1], e[1], e[1]],
    [e[2], e[2], e[2], e[2]],
    [e[3], e[3], e[3], e[3]],
    [e[4], e[4], e[4], e[4]],
    [e[5], e[5], e[5], e[5]],
    [e[6], e[6], e[6], e[6]],
])

In [378]:
model = LongShortTermMemoryModel(encoding_size, len(e))
optimizer = torch.optim.RMSprop(model.parameters(), 0.001)

for epoch in range(500):
    for i in range(len(x_train)):
        model.reset()
        model.loss(x_train[i], y_train[i]).backward()
        optimizer.step()
        optimizer.zero_grad()


In [379]:
def run(arg: str):
    model.reset()
    y = ''
    for char in arg:
        y = model.f(torch.tensor([[c[index_to_char.index(char)]]]))
    print(index_to_emoji[y.argmax()])

user_input = input("Emoji search:")
while user_input != "exit":
    run(user_input)
    user_input = input("Emoji search:")

🐈
🐀
🧒
🧑‍🦰
🧑‍🦰
🧑‍🦰
🎩
🐈
🎩
🎩
🎩
🎩
🐀
🐀


AttributeError: 'str' object has no attribute 'argmax'