In [35]:
from src.fis import *
import torch
import torch.nn as nn 
import pandas
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from keras.preprocessing.text import tokenizer_from_json
import joblib
from sklearn.preprocessing import StandardScaler
import numpy as np

class NLG_Genesys():

    def __init__(self):

        self.fis = fis()
        
        with open('data/tokenizer.json', 'r', encoding='utf-8') as f:
            tokenizer_json = f.read()
            self.tokenizer = tokenizer_from_json(tokenizer_json)

    
        self.Scaler = joblib.load('data\scaler.joblib')
        self.num_decoder_tokens = len(self.tokenizer.word_index) + 1

        self.model = seq2seqLSTM(4, 64, self.num_decoder_tokens)
        self.model.load_state_dict(torch.load('Modelos\pytorch_lim_vocab\seq2seqLSTM_model.pt'))



    def predict(self,input:list):
        fuzzy_set=[]

        # get Fuzzy set names 
        for i in range(len(input)):
            if i == 0:
                name,_= self.fis.get_membership(antecedent_name='proporcional',value=input[i])
                fuzzy_set.append(name)
            elif i == 1:
                name,_= self.fis.get_membership(antecedent_name='derivativo',value=input[i])
                fuzzy_set.append(name)
            elif i == 2:
                name,_= self.fis.get_membership(antecedent_name='salida',value=input[i])
                fuzzy_set.append(name)
            elif i == 3:
                name,_= self.fis.get_membership(antecedent_name='salida',value=input[i])
                fuzzy_set.append(name)

        # Normalize and get tensor for decoder inputs        
        enc_test_in = self.Scaler.transform(np.array(input).reshape(1,-1))   
        enc_test_in = torch.tensor(enc_test_in.reshape(1,1,-1)).float()

        #Tokenize fuzzy set names and padding
        fuzzy_set=' '.join(fuzzy_set)
        dec_test_in = self.tokenizer.texts_to_sequences([fuzzy_set])
        dec_test_in = pad_sequences(dec_test_in, maxlen=self.num_decoder_tokens, padding='post', truncating='post')

        # get decoder input tensor
        dec_test_in = torch.tensor(dec_test_in).long()

        self.model.eval()
        # Hacer predicciones con el modelo cargado
        with torch.no_grad():
            preds = self.model(enc_test_in, dec_test_in)
            _, predicts = torch.max(preds.data, 1)

        predicciones_finales = self.__predictions(predicts.numpy())
        return predicciones_finales


    def __predictions(self,preds):
        preds_words = []
        for i in range(preds.shape[0]):
            preds_row = preds[i]
            preds_row_words = []
            for idx in preds_row:
                if idx in self.tokenizer.index_word:
                    preds_row_words.append(self.tokenizer.index_word[idx])
            preds_words.append(' '.join(preds_row_words))

        return ' '.join(preds_words)

In [36]:
class seq2seqLSTM(nn.Module):
    
    def __init__(self, input_size, hidden_size, num_decoder_tokens, num_layers=1):
        super(seq2seqLSTM, self).__init__()
        self.encoder_lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.decoder_embedding = nn.Embedding(num_decoder_tokens, 128)
        self.decoder_lstm = nn.LSTM(128, hidden_size, num_layers=1, batch_first=True)
        self.linear_layer = nn.Linear(hidden_size, num_decoder_tokens)

    def forward(self, inputs, dec_inputs):
        enc_output, (state_h, state_c) = self.encoder_lstm(inputs)
        embedding = self.decoder_embedding(dec_inputs)
        decoder_output, _ = self.decoder_lstm(embedding, (state_h, state_c))
        decoder_output = decoder_output.reshape(-1, decoder_output.shape[2])
        linear_output = self.linear_layer(decoder_output)
        # Remodela para que tenga el mismo número de pasos de tiempo que dec_inputs
        linear_output = linear_output.reshape(dec_inputs.shape[0], dec_inputs.shape[1], -1)
        return linear_output

In [37]:
modelo = NLG_Genesys()

In [38]:
inputs = [
[-0.973660277,0.685095798,59.56100462,-9.122009236],
[-0.377828726,0.568407083,46.67430895,12.77140921],
[0.01567219,0.379727079,-13.13443793,61.17541422],
[0.372804459,0.662840809,-84.56089181,87.96033443],
[0.990658137,0.313651328,-30.9341863,-28.1316274],
[-0.894200915,0.008267488,74.83362165,-24.45181324],
[-0.468190925,0.019575402,61.08238042,-10.26384242],
[0.019810769,0.000863419,23.91385244,32.84036038],
[0.444760626,-0.00742206,-71.69119038,67.5243902],
[0.975931682,-0.0353132,-30.37134636,-22.94441907],
[-0.925427849,-0.574096066,88.75713081,-38.75713081],
[-0.469932986,-0.261050124,81.16554977,-31.16554977],
[-0.007347862,-0.425361015,60.36739308,-10.36739308],
[0.373523294,-0.590534398,13.30958829,46.02849406],
[0.902971201,-0.492407206,-23.53141338,-17.06282676],
[-0.947488931,0.793020541,59.12481552,-8.249631049],
[-0.331723123,0.63642343,39.75846852,18.53460957],
[0.058486198,0.688639871,-21.69723966,64.38646487],
[0.354756966,0.90586868,-80.95139323,86.60677246],
[0.908604376,0.596051107,-39.13956237,-11.72087527],
[-0.940906397,-0.031286348,79.66125599,-29.66125599],
[-0.456236773,0.016182308,60.91704468,-10.22062591],
[-0.073795792,0.02487639,32.44004927,25.08653337],
[0.33504713,-0.02833159,-40.69294547,60.42721755],
[0.923057508,0.02508882,-41.1428986,-6.689031356],
[-0.976919282,-0.691433447,89.61532137,-39.61532137],
[-0.42459674,-0.36844626,80.40994567,-30.40994567],
[0.057579164,-0.773128795,52.80260447,-1.363125369],
[0.346327935,-0.707799717,16.70900818,41.94919018],
[0.896734243,-0.535691005,-23.11561621,-16.23123241],
[-0.962375937,0.36487035,59.37293229,-8.745864575],
[-0.347296585,0.490141157,42.0944878,16.58792684],
[0.06574558,0.987705804,-23.14911591,64.93091847],
[0.422075911,0.48106835,-87.79240891,85.58481781],
[0.929792685,0.943977194,-37.02073154,-15.95853692],
[-0.977643503,0.026914147,75.76141055,-25.41464388],
[-0.350922908,-0.039946403,60.25201708,-9.266838318],
[0.007901508,-0.008105208,28.69918332,29.77077815],
[0.465417625,0.037736427,-72.59796276,61.85667046],
[0.892911268,0.028127537,-43.5508124,-1.505378636],
[-0.964863637,-0.669478224,89.41439394,-39.41439394],
[-0.477627914,-0.280073077,81.29379857,-31.29379857],
[-0.009788409,-0.541308062,60.48942047,-10.48942047],
[0.460663168,-0.670877511,5.955788773,41.91157755],
[0.926652127,-0.977768927,-25.11014183,-20.22028366],
[-0.998165934,0.815672566,59.96943223,-9.938864454],
[-0.461571175,0.262724904,51.02618625,7.947627498],
[-0.0755047,0.240513311,1.32570504,50.56191247],
[0.453065454,0.223814272,-84.69345456,79.38690911],
[0.905463611,0.77398552,-39.4536389,-11.0927222],
[-0.994422301,0.014832287,78.08752049,-27.99625566],
[-0.462955959,0.005857598,61.70573688,-11.42906341],
[-0.040867728,0.006244355,32.29719893,26.44728877],
[0.4689018,-0.030901154,-59.35291792,55.68526387],
[0.927234615,-0.022514308,-31.27407199,-11.760462],
[-0.966609752,-0.637654896,89.44349586,-39.44349586],
[-0.39021046,-0.625833275,79.51052301,-29.51052301],
[0.035815084,-0.228802559,55.52311445,-4.627737338],
[0.402667348,-0.270504183,9.822176806,49.64435361],
[0.994986309,-0.897700303,-29.66575396,-29.33150792],
[-0.955195763,0.294938855,59.25326272,-8.506525431],
[-0.347003504,0.539915182,42.05052566,16.62456195],
[-0.011246419,0.43601646,-8.313037114,58.5941976],
[0.404427877,0.831800531,-89.55721226,89.11442453],
[0.963801889,0.899157357,-33.61981112,-22.76037776],
[-0.901902112,-0.013865397,77.7371003,-27.7371003],
[-0.47739777,0.022524856,61.18643913,-10.26724376],
[-0.062114991,0.0071577,33.67948762,24.53738359],
[0.37782183,-0.035464172,-53.15646836,68.22387882],
[0.907940958,0.035640476,-42.91450207,-2.99695021],
[-0.943215246,-0.946856167,89.05358744,-39.05358744],
[-0.488473636,-0.248855376,81.4745606,-31.4745606],
[-0.062393613,-0.668639532,63.11968067,-13.11968067],
[0.417029417,-0.276859738,8.864705558,47.72941112],
[0.944565199,-0.65933207,-26.30434662,-22.60869324]
]


In [None]:
with open('Modelos/pytorch_lim_vocab/results/resultados.txt', 'w') as f:
    for i in range(len(inputs)):
        resultado = modelo.predict(inputs[i])
        f.write(resultado + '\n')