In [1]:
import os
import sys
sys.path.append(os.path.join(os.path.dirname(""),".."))

import custom
import torch
import torch.nn as nn
import re


In [2]:
class Encoder(nn.Module):
    def __init__(self, embedding_tensor):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(embedding_tensor, freeze=True, padding_idx=0)
        self.rnn = nn.LSTM(embedding_tensor.shape[1], embedding_tensor.shape[1], batch_first=True, bidirectional=True)

    def forward(self, x):
        # x = torch.flip(x, [-1])
        x = self.embedding(x)
        output, hc = self.rnn(x)
        return output, hc

class Decoder(nn.Module) :
    def __init__(self, embedding_tensor):
        super().__init__()
        self.embedding = nn.Embedding.from_pretrained(embedding_tensor, freeze=True, padding_idx=0)
        self.rnn = nn.LSTM(embedding_tensor.shape[1] * 3, embedding_tensor.shape[1], batch_first=True, bidirectional=True)
        self.f = nn.Linear(embedding_tensor.shape[1] * 4, embedding_tensor.shape[0])
        self.encoder_h_context = None

    def forward(self, encoder_output, encoder_hc, t = None) :
        encoder_h_forward = encoder_hc[0][0:1,:,:]
        encoder_h_backward = encoder_hc[0][1:2,:,:]
        self.encoder_h_context = torch.concat([encoder_h_forward,encoder_h_backward], dim = -1).transpose(0,1)
        batch_size = encoder_output.shape[0]
        decoder_input = torch.zeros(batch_size, 1).type(torch.long).to(encoder_output.device)
        decoder_hc = encoder_hc
        decoder_output_list = []

        for i in range(4) :
            decoder_output, decoder_hc = self.forward_sub(decoder_input, decoder_hc)
            decoder_output_list.append(decoder_output)

            if t is None :
                decoder_input = decoder_output.argmax(dim = -1).detach()
            else :
                decoder_input = t[:, i].unsqueeze(-1)

        decoder_output_list = torch.cat(decoder_output_list, dim=1)
        return decoder_output_list, decoder_hc, None

    def forward_sub(self, x, h) :
        x = self.embedding(x)
        x = torch.concat([self.encoder_h_context, x], dim = -1)
        output, hc = self.rnn(x, h)
        output = torch.concat([self.encoder_h_context, output], dim = -1)
        output = self.f(output)
        return output, hc

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

encoder = torch.load("num_encoder.pt", weights_only=False)
decoder = torch.load("num_decoder.pt", weights_only=False)

encoder = encoder.to(device)
decoder = decoder.to(device)

In [4]:
dic = {' ': 0, '1': 1, '2': 2, '3': 3, '4': 4, '5': 5, '6': 6, '7': 7, '8': 8, '9': 9, '0': 10, '+': 11}
query = None

while True :
    query = input("식을 입력하세요 (최대 길이 7) : ")
    if query == "종료" :
        break
    query = re.sub(r"[^0-9+]", "", string = query)
    query = list(query)
    query = custom.word_vectorize(query, dic, 7, False, " ", " ")
    query = [query]
    tensor = torch.tensor(query, dtype = torch.long, device = device)
    y, h = encoder(tensor)
    ys, _, _ = decoder(y, h)
    for y in ys[0] :
        print(list(dic.keys())[y.argmax(dim = -1).item()], end = "")

식을 입력하세요 (최대 길이 7) :  종료
