<a href="https://colab.research.google.com/github/brunotech/BioBERTpt/blob/master/meena_chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Initialization

In [None]:
from google.colab import drive
drive.mount('/gdrive')

!pip install -q -U tensorflow-gpu==1.15.2
!pip install -q -U tensorflow-datasets==3.2.1
!pip install -q -U tensor2tensor
 
import tensorflow as tf
import os

project_dir = "/gdrive/MyDrive/transformer-chatbot/"
MODEL_DIR = project_dir + "saved_model/t2t_chatbot/"
DATASET_DIR = project_dir + "conversational-dataset/"
 
!mkdir -p $DATASET_DIR
!mkdir -p $MODEL_DIR
 
tf.get_logger().propagate = False

## Hyperparameters

In [None]:
MAX_SAMPLES = 40000000
DATA_DIR = MODEL_DIR + 'data'
TMP_DIR = MODEL_DIR + 'tmp'
TRAIN_DIR = MODEL_DIR + 'train'
PROBLEM = 'chat_bot'
 
USE_TPU = False
MODEL = "evolved_transformer"
HPARAMS = "evolved_transformer_base"
NUM_ENCODER_LAYERS = 1
NUM_DECODER_LAYERS = 12
BATCH_SIZE = 4096
MAX_LENGTH = 40
VOCAB_SIZE = 2**13
START_LEARNING_RATE = 0.01
 
CONVERSATION_TURNS = 3

TRAIN_STEPS = 300000 # Total number of train steps for all Epochs
EVAL_STEPS = 100 # Number of steps to perform for each evaluation
SAVE_CHECKPOINTS_STEPS = 5000
KEEP_CHECKPOINT_MAX = 1

## Problem definition

In [None]:
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry
from tensor2tensor.data_generators import text_encoder
from collections import deque
import re

PATH_TO_DATASET = DATASET_DIR + 'it'
PATH_TO_PREPROCESSED = DATASET_DIR + "preprocessed.txt"

def preprocess_sentence(sentence):
    sentence = sentence.lower().strip()
    sentence = re.sub(r"[^a-zA-Z0-9?.!,àèìòùáéíóú']+", " ", sentence)
    sentence = sentence.replace(" ' ", " ")
    sentence = sentence.strip()
    return sentence

if not os.path.isfile(PATH_TO_DATASET):
    path_to_zip = tf.keras.utils.get_file(
        DATASET_DIR + "it.gz",
        origin='http://opus.nlpl.eu/download.php?f=OpenSubtitles/v2018/mono/OpenSubtitles.it.gz')

    !gzip -dk $path_to_zip

if not os.path.isfile(PATH_TO_PREPROCESSED):
    dataset_file = open(PATH_TO_DATASET, 'r')
    preprocessed_file = open(PATH_TO_PREPROCESSED, 'w')
    for i in range(MAX_SAMPLES):
        line = dataset_file.readline()
        if not line:
            break
        line = preprocess_sentence(line)
        if line:
            preprocessed_file.write(line + '\n')
    preprocessed_file.close()
    dataset_file.close()
else:
    print("preprocessed dataset already exists")
        

@registry.register_problem
class ChatBot(text_problems.Text2TextProblem):
    @property
    def approx_vocab_size(self):
        return VOCAB_SIZE
    
    @property
    def is_generate_per_split(self):
        return False
 
    @property
    def dataset_splits(self):
        return [{
            "split": problem.DatasetSplit.TRAIN,
            "shards": 9,
        }, {
            "split": problem.DatasetSplit.EVAL,
            "shards": 1,
        }]

    SENTENCE_SEPARATOR = "<SEP>"
    SENTENCE_SEPARATOR_ID = 2

    @property
    def additional_reserved_tokens(self):
        return [self.SENTENCE_SEPARATOR]
 
    def generate_samples(self, data_dir, tmp_dir, dataset_split):
        conversation = deque()
        with open(PATH_TO_PREPROCESSED, 'r') as file:
            conversation.append(file.readline().rstrip())
            line = file.readline()
            while line:
                conversation.append(line.rstrip())
                if len(conversation) > CONVERSATION_TURNS + 1:
                    conversation.popleft()
                yield {
                    'inputs': list(conversation)[:-1], 
                    'targets': conversation[-1]
                }
                line = file.readline()

    def generate_text_for_vocab(self, data_dir, tmp_dir):
        with open(PATH_TO_PREPROCESSED, 'r') as file:
            line = file.readline()
            while line:
                yield line.strip()
                line = file.readline()

    def generate_encoded_samples(self, data_dir, tmp_dir, dataset_split):

        generator = self.generate_samples(data_dir, tmp_dir, dataset_split)
        encoder = self.get_or_create_vocab(data_dir, tmp_dir)

        def generate_encoded(generator, encoder):
            count = 0
            num_subwords_dataset = 0
            for sample in generator:
                encoded_inputs = []
                for conversation_turn in sample["inputs"]:
                    encoded_inputs.extend(encoder.encode(conversation_turn))
                    encoded_inputs.append(self.SENTENCE_SEPARATOR_ID)
                encoded_inputs.pop()
                encoded_inputs.append(text_encoder.EOS_ID)
                if len(encoded_inputs) > MAX_LENGTH:
                    encoded_inputs = encoded_inputs[-MAX_LENGTH:]
                sample["inputs"] = encoded_inputs
                sample["targets"] = encoder.encode(sample["targets"])
                sample["targets"].append(text_encoder.EOS_ID)
                # print some examples
                if count > 100 and count < 110:
                    print("_______INPUT_______")
                    print(encoder.decode(sample["inputs"]))
                    print("_______TARGET_______")
                    print(encoder.decode(sample["targets"]))
                count += 1
                num_subwords_dataset += max(len(sample["inputs"]), len(sample["targets"]))
                yield sample
            print(f"Num samples: {count}")
            print(f"Tot number of subwords in the dataset: {num_subwords_dataset}")

        return generate_encoded(generator, encoder)

## Generate data

In [None]:
from tensor2tensor import problems

t2t_problem = problems.problem(PROBLEM)
t2t_problem.generate_data(DATA_DIR, TMP_DIR)

## Training

In [None]:
from tensor2tensor.utils.trainer_lib import create_run_config, create_experiment
from tensor2tensor.utils.trainer_lib import create_hparams
from tensor2tensor.utils import registry, hparam, learning_rate
from tensor2tensor import models, problems
import json

# Init Hparams object from T2T Problem
hparams = create_hparams(HPARAMS)

# Make Changes to Hparams
hparams.num_encoder_layers = NUM_ENCODER_LAYERS
hparams.num_decoder_layers = NUM_DECODER_LAYERS
hparams.batch_size = BATCH_SIZE
hparams.max_length = MAX_LENGTH
hparams.optimizer = 'Adafactor'
hparams.learning_rate_constant = START_LEARNING_RATE
hparams.learning_rate_warmup_steps = 10000
hparams.learning_rate_schedule = "constant*rsqrt_normalized_decay"

hparams_json = hparams.to_json()
print(str(hparams_json))

# Save hparams 
with open(MODEL_DIR + 'hparams.json', 'w') as json_file:
    json_file.write(hparams_json)

In [None]:
RUN_CONFIG = create_run_config(
      model_dir=TRAIN_DIR,
      model_name=MODEL,
      save_checkpoints_steps = SAVE_CHECKPOINTS_STEPS,
      keep_checkpoint_max = KEEP_CHECKPOINT_MAX
)

tensorflow_exp_fn = create_experiment(
        run_config=RUN_CONFIG,
        hparams=hparams,
        model_name=MODEL,
        problem_name=PROBLEM,
        data_dir=DATA_DIR, 
        train_steps=TRAIN_STEPS, 
        eval_steps=EVAL_STEPS, 
        use_tpu=USE_TPU,
        schedule="continuous_train_and_eval",
        eval_throttle_seconds=300,
        use_xla=True # For acceleration
    ) 

tensorflow_exp_fn.continuous_train_and_eval()

## Test

In [None]:
# from tensor2tensor.utils.trainer_lib import create_run_config, create_experiment
# from tensor2tensor.utils.trainer_lib import create_hparams
# from tensor2tensor.utils import registry, hparam, learning_rate
# from tensor2tensor import models, problems

# hparams = create_hparams(HPARAMS)
# hparams.num_encoder_layers = NUM_ENCODER_LAYERS
# hparams.num_decoder_layers = NUM_DECODER_LAYERS
# hparams.batch_size = BATCH_SIZE
# hparams.max_length = MAX_LENGTH

# RUN_CONFIG = create_run_config(
#       model_dir=TRAIN_DIR,
#       model_name=MODEL,
#       save_checkpoints_steps = SAVE_CHECKPOINTS_STEPS,
#       keep_checkpoint_max = KEEP_CHECKPOINT_MAX
# )

# tensorflow_exp_fn = create_experiment(
#         run_config=RUN_CONFIG,
#         hparams=hparams,
#         model_name=MODEL,
#         problem_name=PROBLEM,
#         data_dir=DATA_DIR, 
#         train_steps=TRAIN_STEPS, 
#         eval_steps=1000, 
#         use_tpu=USE_TPU,
#         schedule="evaluate",
#     ) 

# tensorflow_exp_fn.evaluate()

## Predict

In [None]:
import tensorflow as tf
from tensor2tensor import models
from tensor2tensor import problems
from tensor2tensor.utils import hparams_lib
from tensor2tensor.utils import registry
from tensor2tensor.data_generators import text_problems
import numpy as np
import re
 
# sampling parameters
SAMPLING_TEMPERATURE = 0.88
NUM_SAMPLES = 5
MAX_LCS_RATIO = 0.8
 
tfe = tf.contrib.eager
tfe.enable_eager_execution()
Modes = tf.estimator.ModeKeys
 
chat_bot_problem = problems.problem("chat_bot")
ckpt_path = tf.train.latest_checkpoint(TRAIN_DIR)
encoders = chat_bot_problem.feature_encoders(DATA_DIR)
hparams = hparams_lib.create_hparams_from_json(MODEL_DIR + 'hparams.json')
hparams.data_dir = DATA_DIR
hparams_lib.add_problem_hparams(hparams, "chat_bot")
hparams.sampling_method = "random"
hparams.sampling_temp = SAMPLING_TEMPERATURE
 
chatbot_model = registry.model(MODEL)(hparams, Modes.PREDICT)
 
def preprocess_sentence(sentence):
    sentence = sentence.lower().strip()
    # creating a space between a word and the punctuation following it
    # eg: "he is a boy." => "he is a boy ."
    sentence = re.sub(r"([?.!,])", r" \1 ", sentence)
    sentence = sentence.replace("'", "' ")
    sentence = re.sub(r'[" "]+', " ", sentence)
    sentence = re.sub(r"[^a-zA-Z0-9?.!,àèìòùáéíóú']+", " ", sentence)
    sentence = sentence.strip()
    return sentence
 
def postprocess_sentence(sentence):
    # remove space before punctuation
    sentence = sentence.rstrip(" .")
    return re.sub(r"\s+(\W)", r"\1", sentence)
 
def encode(conversation, output_str=None):
    """Input str to features dict, ready for inference"""
    encoded_inputs = []
    for conversation_turn in conversation:
        encoded_inputs += encoders["inputs"].encode(conversation_turn) + [2]
    encoded_inputs.pop()
    encoded_inputs += [1]
    if len(encoded_inputs) > hparams.max_length:
        encoded_inputs = encoded_inputs[-hparams.max_length:]
    batch_inputs = tf.reshape(encoded_inputs, [1, -1, 1])  # Make it 3D.
    return {"inputs": batch_inputs}
 
def decode(integers):
    """List of ints to str"""
    integers = list(np.squeeze(integers))
    if 1 in integers:
        integers = integers[:integers.index(1)]
    decoded = encoders["inputs"].decode(integers)
    return postprocess_sentence(decoded)
 
def lcs_ratio(context, predicted): 
    m = len(context) 
    n = len(predicted) 
    L = [[None]*(n + 1) for i in range(m + 1)] 
    for i in range(m + 1): 
        for j in range(n + 1): 
            if i == 0 or j == 0 : 
                L[i][j] = 0
            elif context[i-1] == predicted[j-1]: 
                L[i][j] = L[i-1][j-1]+1
            else: 
                L[i][j] = max(L[i-1][j], L[i][j-1]) 
    return L[m][n] / n
 
def predict(conversation):
    preprocessed = [preprocess_sentence(x) for x in conversation]
    encoded_inputs = encode(preprocessed)
    print("decoded input: " + decode(encoded_inputs["inputs"]))
    with tfe.restore_variables_on_create(ckpt_path):
        while True:
            output_candidates = [chatbot_model.infer(encoded_inputs) for _ in range(NUM_SAMPLES)]
            output_candidates.sort(key = lambda x: -float(x["scores"]))
 
            for x in output_candidates:
                print(str(float(x["scores"])) + "\t" + decode(x["outputs"]))
 
            for candidate in output_candidates:
                decoded = decode(candidate["outputs"])
                if lcs_ratio(" ".join(preprocessed), decoded) < MAX_LCS_RATIO:
                    return decoded
 
 
conversation = []
while True:
    sentence = input("Input: ")
    conversation.append(sentence)
    while len(conversation) > CONVERSATION_TURNS: 
        conversation.pop(0)
    response = predict(conversation)
    conversation.append(response)
    print(response)

## Tensorboard

In [None]:
%load_ext tensorboard
%tensorboard --logdir $TRAIN_DIR