In [None]:
from enum import Enum 

class ModelSettings(Enum):
    MODEL_PATH = 'token_rnn/saved_models/2apr2021'
    MAX_SEQUENCE_LENGTH = 30
    DATAPATH = ""

In [None]:
with open(ModelSettings.DATAPATH.value+'conditions.txt','w') as condition_file:
    condition_file.write("""question positive cat
question neutral cat
question negative cat
statement positive cat
statement neutral cat 
statement negative cat""")
with open(ModelSettings.DATAPATH.value+'sentences.txt','w') as sentence_file:
    sentence_file.write("""do you like cute cats?
do you have a cat?
cats are so annoying, aren't they?
i love cats
i have a cat
i really can't stand cats""")

In [None]:
from typing import List

from numpy import ndarray
from ffast import load

class Encoder:
    def __init__(self) -> None:
        self.BOS_TOKEN = "<bos>"
        self.EOS_TOKEN = "<eos>"
        self.tokeniser = load('poincare')
        self.tokeniser.add_special_token(self.BOS_TOKEN)
        self.tokeniser.add_special_token(self.EOS_TOKEN)
        BOS,EOS = self.tokeniser.encode(f"{self.BOS_TOKEN} {self.EOS_TOKEN}").ids
        self.BOS_TOKEN_ID = BOS 
        self.EOS_TOKEN_ID = EOS
        self.UNKNOWN_TOKEN_ID = len(self.tokeniser)-1

    def token_ids_vectoriser(self, token_ids:List[int]) -> List[ndarray]:
        return list(map(
            lambda token_vector:token_vector.reshape((-1,1)),
            self.tokeniser.decode(token_ids).semantics()
        ))

    def condition_vectoriser(self, text:str) -> ndarray:
        return self.tokeniser.encode(text).vector.reshape(-1,1)

    def format_sentence(self, sentence:str) -> str:
        return f"{self.BOS_TOKEN} {sentence} {self.EOS_TOKEN}"

    @staticmethod
    def format_condition(style:str,sentiment:str,keywords:List[str]) -> str:
        return f"{style} {sentiment} {' '.join(keywords)}"


In [None]:
from typing import Generator, Optional

from windML import RNN

class ConditionalNLG:
    def __init__(self) -> None:
        self.STYLES = ("question","statement")
        self.SENTIMENTS = ("positive","neutral","negative")
        self.encoder = Encoder()
        self.decoder = RNN(
            load_path=ModelSettings.MODEL_PATH.value,
            token_vector_size=self.encoder.tokeniser.token_size,
            token_vocabulary_size=len(self.encoder.tokeniser),
            hidden_dimension=self.encoder.tokeniser.size
        )

    def train(self, epochs:int=100, data_path:str=ModelSettings.DATAPATH.value) -> None:
        self.decoder.fit(
            token_ids_vectoriser=self.encoder.token_ids_vectoriser,
            token_ids=list(self.read_sentences(data_path)), 
            encoded_contexts=list(self.read_conditions(data_path)),
            epochs=epochs
        )
        self.decoder.save(ModelSettings.MODEL_PATH.value)

    def generate(self, style:str, sentiment:str, keywords:List[str], prompt:Optional[str]=None) -> str:
        assert style in self.STYLES and sentiment in self.SENTIMENTS
        generated_token_ids = self.decoder.generate(
            bos_id=self.encoder.BOS_TOKEN_ID, 
            eos_id=self.encoder.EOS_TOKEN_ID, 
            token_ids_vectoriser=self.encoder.token_ids_vectoriser, 
            prompt_ids=list() if prompt is None else self.encoder.tokeniser.encode(prompt).ids,
            condition_vector=self.encoder.condition_vectoriser(
                self.encoder.format_condition(
                    style=style,
                    sentiment=sentiment, 
                    keywords=keywords
                )
            )
        )
        return str(self.encoder.tokeniser.decode(generated_token_ids))

    def read_conditions(self,data_path:str) -> Generator[ndarray,None,None]:
        with open(data_path+'conditions.txt') as condition_file:
            for condition in condition_file.readlines():
                yield self.encoder.condition_vectoriser(condition.strip()) 

    def read_sentences(self,data_path:str) -> Generator[List[int],None,None]:
        with open(data_path+'sentences.txt') as sentence_file:
            for sentence in sentence_file.readlines():
                yield self.encoder.tokeniser.encode(sentence).ids 

In [None]:
model = ConditionalNLG()
model.train(epochs=30)

In [None]:
keywords = ["cat"]
for style in ("question","statement"):
    for sentiment in ("positive","neutral","negative"):
        sentence = model.generate(style,sentiment,keywords)
        print(sentence)