<a href="https://colab.research.google.com/github/ncr5012/EmailRL/blob/main/Chapter14_All_Libraries.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
#This is the code required to run the chatbot in chapter 14
#The purpose of this is to understand how NLP works in RL to facilitate the creation of a limited GI email agent

#Initialization Section - Import libraries, define utility functions

import collections
import os
import sys
import logging
import itertools
import pickle

import string
from nltk.translate import bleu_score
from nltk.tokenize import TweetTokenizer

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
import torch.nn.functional as F




def calc_bleu_many(cand_seq, ref_sequences):
    sf = bleu_score.SmoothingFunction()
    return bleu_score.sentence_bleu(ref_sequences, cand_seq,
                                    smoothing_function=sf.method1,
                                    weights=(0.5, 0.5))


def calc_bleu(cand_seq, ref_seq):
    return calc_bleu_many(cand_seq, [ref_seq])


def tokenize(s):
    return TweetTokenizer(preserve_case=False).tokenize(s)


def untokenize(words):
    to_pad = lambda t: not t.startswith("'") and \
                       t not in string.punctuation
    return "".join([
        (" " + i) if to_pad(i) else i
        for i in words
    ]).strip()





In [9]:
#load the data
url1 = 'https://raw.githubusercontent.com/ncr5012/EmailRL/main/cornell%20movie-dialogs%20corpus/movie_characters_metadata.txt'
movie_characters_metadata = pd.read_csv(url1,sep="+++$+++")
# Dataset is now stored in a Pandas Dataframe

  This is separate from the ipykernel package so we can avoid doing imports until


error: ignored

In [2]:
#cornell.py - low level data cleaner 
"""
Cornel Movies Dialogs Corpus
https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html
"""

log = logging.getLogger("cornell")
DATA_DIR = "data/cornell"
SEPARATOR = "+++$+++"


def load_dialogues(data_dir=DATA_DIR, genre_filter=''):
    """
    Load dialogues from cornell data
    :return: list of list of list of words
    """
    movie_set = None
    if genre_filter:
        movie_set = read_movie_set(data_dir, genre_filter)
        log.info("Loaded %d movies with genre %s", len(movie_set), genre_filter)
    log.info("Read and tokenise phrases...")
    lines = read_phrases(data_dir, movies=movie_set)
    log.info("Loaded %d phrases", len(lines))
    dialogues = load_conversations(data_dir, lines, movie_set)
    return dialogues


def iterate_entries(data_dir, file_name):
    with open(os.path.join(data_dir, file_name), "rb") as fd:
        for l in fd:
            l = str(l, encoding='utf-8', errors='ignore')
            yield list(map(str.strip, l.split(SEPARATOR)))


def read_movie_set(data_dir, genre_filter):
    res = set()
    for parts in iterate_entries(data_dir, "movie_titles_metadata.txt"):
        m_id, m_genres = parts[0], parts[5]
        if m_genres.find(genre_filter) != -1:
            res.add(m_id)
    return res


def read_phrases(data_dir, movies=None):
    res = {}
    for parts in iterate_entries(data_dir, "movie_lines.txt"):
        l_id, m_id, l_str = parts[0], parts[2], parts[4]
        if movies and m_id not in movies:
            continue
        tokens = utils.tokenize(l_str)
        if tokens:
            res[l_id] = tokens
    return res


def load_conversations(data_dir, lines, movies=None):
    res = []
    for parts in iterate_entries(data_dir, "movie_conversations.txt"):
        m_id, dial_s = parts[2], parts[3]
        if movies and m_id not in movies:
            continue
        l_ids = dial_s.strip("[]").split(", ")
        l_ids = list(map(lambda s: s.strip("'"), l_ids))
        dial = [lines[l_id] for l_id in l_ids if l_id in lines]
        if dial:
            res.append(dial)
    return res


def read_genres(data_dir):
    res = {}
    for parts in iterate_entries(data_dir, "movie_titles_metadata.txt"):
        m_id, m_genres = parts[0], parts[5]
        l_genres = m_genres.strip("[]").split(", ")
        l_genres = list(map(lambda s: s.strip("'"), l_genres))
        res[m_id] = l_genres
    return res


In [3]:
#data.py - high level data cleaner

UNKNOWN_TOKEN = '#UNK'
BEGIN_TOKEN = "#BEG"
END_TOKEN = "#END"
MAX_TOKENS = 20
MIN_TOKEN_FEQ = 10
SHUFFLE_SEED = 5871

EMB_DICT_NAME = "emb_dict.dat"
EMB_NAME = "emb.npy"

log = logging.getLogger("data")


def save_emb_dict(dir_name, emb_dict):
    with open(os.path.join(dir_name, EMB_DICT_NAME), "wb") as fd:
        pickle.dump(emb_dict, fd)


def load_emb_dict(dir_name):
    with open(os.path.join(dir_name, EMB_DICT_NAME), "rb") as fd:
        return pickle.load(fd)


def encode_words(words, emb_dict):
    """
    Convert list of words into list of embeddings indices, adding our tokens
    :param words: list of strings
    :param emb_dict: embeddings dictionary
    :return: list of IDs
    """
    res = [emb_dict[BEGIN_TOKEN]]
    unk_idx = emb_dict[UNKNOWN_TOKEN]
    for w in words:
        idx = emb_dict.get(w.lower(), unk_idx)
        res.append(idx)
    res.append(emb_dict[END_TOKEN])
    return res


def encode_phrase_pairs(phrase_pairs, emb_dict, filter_unknows=True):
    """
    Convert list of phrase pairs to training data
    :param phrase_pairs: list of (phrase, phrase)
    :param emb_dict: embeddings dictionary (word -> id)
    :return: list of tuples ([input_id_seq], [output_id_seq])
    """
    unk_token = emb_dict[UNKNOWN_TOKEN]
    result = []
    for p1, p2 in phrase_pairs:
        p = encode_words(p1, emb_dict), encode_words(p2, emb_dict)
        if unk_token in p[0] or unk_token in p[1]:
            continue
        result.append(p)
    return result


def group_train_data(training_data):
    """
    Group training pairs by first phrase
    :param training_data: list of (seq1, seq2) pairs
    :return: list of (seq1, [seq*]) pairs
    """
    groups = collections.defaultdict(list)
    for p1, p2 in training_data:
        l = groups[tuple(p1)]
        l.append(p2)
    return list(groups.items())


def iterate_batches(data, batch_size):
    assert isinstance(data, list)
    assert isinstance(batch_size, int)

    ofs = 0
    while True:
        batch = data[ofs*batch_size:(ofs+1)*batch_size]
        if len(batch) <= 1:
            break
        yield batch
        ofs += 1


def load_data(genre_filter, max_tokens=MAX_TOKENS, min_token_freq=MIN_TOKEN_FEQ):
    dialogues = cornell.load_dialogues(genre_filter=genre_filter)
    if not dialogues:
        log.error("No dialogues found, exit!")
        sys.exit()
    log.info("Loaded %d dialogues with %d phrases, generating training pairs",
             len(dialogues), sum(map(len, dialogues)))
    phrase_pairs = dialogues_to_pairs(dialogues, max_tokens=max_tokens)
    log.info("Counting freq of words...")
    word_counts = collections.Counter()
    for dial in dialogues:
        for p in dial:
            word_counts.update(p)
    freq_set = set(map(lambda p: p[0], filter(lambda p: p[1] >= min_token_freq, word_counts.items())))
    log.info("Data has %d uniq words, %d of them occur more than %d",
             len(word_counts), len(freq_set), min_token_freq)
    phrase_dict = phrase_pairs_dict(phrase_pairs, freq_set)
    return phrase_pairs, phrase_dict


def phrase_pairs_dict(phrase_pairs, freq_set):
    """
    Return the dict of words in the dialogues mapped to their IDs
    :param phrase_pairs: list of (phrase, phrase) pairs
    :return: dict
    """
    res = {UNKNOWN_TOKEN: 0, BEGIN_TOKEN: 1, END_TOKEN: 2}
    next_id = 3
    for p1, p2 in phrase_pairs:
        for w in map(str.lower, itertools.chain(p1, p2)):
            if w not in res and w in freq_set:
                res[w] = next_id
                next_id += 1
    return res


def dialogues_to_pairs(dialogues, max_tokens=None):
    """
    Convert dialogues to training pairs of phrases
    :param dialogues:
    :param max_tokens: limit of tokens in both question and reply
    :return: list of (phrase, phrase) pairs
    """
    result = []
    for dial in dialogues:
        prev_phrase = None
        for phrase in dial:
            if prev_phrase is not None:
                if max_tokens is None or (len(prev_phrase) <= max_tokens and len(phrase) <= max_tokens):
                    result.append((prev_phrase, phrase))
            prev_phrase = phrase
    return result


def decode_words(indices, rev_emb_dict):
    return [rev_emb_dict.get(idx, UNKNOWN_TOKEN) for idx in indices]


def trim_tokens_seq(tokens, end_token):
    res = []
    for t in tokens:
        res.append(t)
        if t == end_token:
            break
    return res


def split_train_test(data, train_ratio=0.95):
    count = int(len(data) * train_ratio)
    return data[:count], data[count:]


In [None]:
#model.py - used to...

import numpy as np

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
import torch.nn.functional as F

from . import utils

HIDDEN_STATE_SIZE = 512
EMBEDDING_DIM = 50


class PhraseModel(nn.Module):
    def __init__(self, emb_size, dict_size, hid_size):
        super(PhraseModel, self).__init__()

        self.emb = nn.Embedding(
            num_embeddings=dict_size, embedding_dim=emb_size)
        self.encoder = nn.LSTM(
            input_size=emb_size, hidden_size=hid_size,
            num_layers=1, batch_first=True)
        self.decoder = nn.LSTM(
            input_size=emb_size, hidden_size=hid_size,
            num_layers=1, batch_first=True)
        self.output = nn.Linear(hid_size, dict_size)

    def encode(self, x):
        _, hid = self.encoder(x)
        return hid

    def get_encoded_item(self, encoded, index):
        # For RNN
        # return encoded[:, index:index+1]
        # For LSTM
        return encoded[0][:, index:index+1].contiguous(), \
               encoded[1][:, index:index+1].contiguous()

    def decode_teacher(self, hid, input_seq):
        # Method assumes batch of size=1
        out, _ = self.decoder(input_seq, hid)
        out = self.output(out.data)
        return out

    def decode_one(self, hid, input_x):
        out, new_hid = self.decoder(input_x.unsqueeze(0), hid)
        out = self.output(out)
        return out.squeeze(dim=0), new_hid

    def decode_chain_argmax(self, hid, begin_emb, seq_len,
                            stop_at_token=None):
        """
        Decode sequence by feeding predicted token to the net again. Act greedily
        """
        res_logits = []
        res_tokens = []
        cur_emb = begin_emb

        for _ in range(seq_len):
            out_logits, hid = self.decode_one(hid, cur_emb)
            out_token_v = torch.max(out_logits, dim=1)[1]
            out_token = out_token_v.data.cpu().numpy()[0]

            cur_emb = self.emb(out_token_v)

            res_logits.append(out_logits)
            res_tokens.append(out_token)
            if stop_at_token is not None:
                if out_token == stop_at_token:
                    break
        return torch.cat(res_logits), res_tokens

    def decode_chain_sampling(self, hid, begin_emb, seq_len,
                              stop_at_token=None):
        """
        Decode sequence by feeding predicted token to the net again.
        Act according to probabilities
        """
        res_logits = []
        res_actions = []
        cur_emb = begin_emb

        for _ in range(seq_len):
            out_logits, hid = self.decode_one(hid, cur_emb)
            out_probs_v = F.softmax(out_logits, dim=1)
            out_probs = out_probs_v.data.cpu().numpy()[0]
            action = int(np.random.choice(
                out_probs.shape[0], p=out_probs))
            action_v = torch.LongTensor([action])
            action_v = action_v.to(begin_emb.device)
            cur_emb = self.emb(action_v)

            res_logits.append(out_logits)
            res_actions.append(action)
            if stop_at_token is not None:
                if action == stop_at_token:
                    break
        return torch.cat(res_logits), res_actions


def pack_batch_no_out(batch, embeddings, device="cpu"):
    assert isinstance(batch, list)
    # Sort descending (CuDNN requirements)
    batch.sort(key=lambda s: len(s[0]), reverse=True)
    input_idx, output_idx = zip(*batch)
    # create padded matrix of inputs
    lens = list(map(len, input_idx))
    input_mat = np.zeros((len(batch), lens[0]), dtype=np.int64)
    for idx, x in enumerate(input_idx):
        input_mat[idx, :len(x)] = x
    input_v = torch.tensor(input_mat).to(device)
    input_seq = rnn_utils.pack_padded_sequence(
        input_v, lens, batch_first=True)
    # lookup embeddings
    r = embeddings(input_seq.data)
    emb_input_seq = rnn_utils.PackedSequence(
        r, input_seq.batch_sizes)
    return emb_input_seq, input_idx, output_idx


def pack_input(input_data, embeddings, device="cpu"):
    input_v = torch.LongTensor([input_data]).to(device)
    r = embeddings(input_v)
    return rnn_utils.pack_padded_sequence(
        r, [len(input_data)], batch_first=True)


def pack_batch(batch, embeddings, device="cpu"):
    emb_input_seq, input_idx, output_idx = pack_batch_no_out(
        batch, embeddings, device)

    # prepare output sequences, with end token stripped
    output_seq_list = []
    for out in output_idx:
        s = pack_input(out[:-1], embeddings, device)
        output_seq_list.append(s)
    return emb_input_seq, output_seq_list, input_idx, output_idx


def seq_bleu(model_out, ref_seq):
    model_seq = torch.max(model_out.data, dim=1)[1]
    model_seq = model_seq.cpu().numpy()
    return utils.calc_bleu(model_seq, ref_seq)

In [None]:
#cor_reader
import argparse
import collections

from libbots import cornell, data


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-g", "--genre", default='', help="Genre to show dialogs from")
    parser.add_argument("--show-genres", action='store_true', default=False, help="Display genres stats")
    parser.add_argument("--show-dials", action='store_true', default=False, help="Display dialogs")
    parser.add_argument("--show-train", action='store_true', default=False, help="Display training pairs")
    parser.add_argument("--show-dict-freq", action='store_true', default=False, help="Display dictionary frequency")
    args = parser.parse_args()

    if args.show_genres:
        genre_counts = collections.Counter()
        genres = cornell.read_genres(cornell.DATA_DIR)
        for movie, g_list in genres.items():
            for g in g_list:
                genre_counts[g] += 1
        print("Genres:")
        for g, count in genre_counts.most_common():
            print("%s: %d" % (g, count))

    if args.show_dials:
        dials = cornell.load_dialogues(genre_filter=args.genre)
        for d_idx, dial in enumerate(dials):
            print("Dialog %d with %d phrases:" % (d_idx, len(dial)))
            for p in dial:
                print(" ".join(p))
            print()

    if args.show_train or args.show_dict_freq:
        phrase_pairs, emb_dict = data.load_data(genre_filter=args.genre)

    if args.show_train:
        rev_emb_dict = {idx: word for word, idx in emb_dict.items()}
        train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
        train_data = data.group_train_data(train_data)
        unk_token = emb_dict[data.UNKNOWN_TOKEN]

        print("Training pairs (%d total)" % len(train_data))
        train_data.sort(key=lambda p: len(p[1]), reverse=True)
        for idx, (p1, p2_group) in enumerate(train_data):
            w1 = data.decode_words(p1, rev_emb_dict)
            w2_group = [data.decode_words(p2, rev_emb_dict) for p2 in p2_group]
            print("%d:" % idx, " ".join(w1))
            for w2 in w2_group:
                print("%s:" % (" " * len(str(idx))), " ".join(w2))

    if args.show_dict_freq:
        words_stat = collections.Counter()
        for p1, p2 in phrase_pairs:
            words_stat.update(p1)
        print("Frequency stats for %d tokens in the dict" % len(emb_dict))
        for token, count in words_stat.most_common():
            print("%s: %d" % (token, count))
    pass

In [None]:
#Telegram_Bot
import os
import sys
import logging
import configparser
import argparse

try:
    import telegram.ext
except ImportError:
    print("You need python-telegram-bot package installed "
          "to start the bot")
    sys.exit()

from libbots import data, model, utils

import torch

# Configuration file with the following contents
# [telegram]
# api=API_KEY
CONFIG_DEFAULT = "~/.config/rl_ch12_bot.ini"

log = logging.getLogger("telegram")


if __name__ == "__main__":
    fmt = "%(asctime)-15s %(levelname)s %(message)s"
    logging.basicConfig(format=fmt, level=logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config", default=CONFIG_DEFAULT,
        help="Configuration file for the bot, default=" +
             CONFIG_DEFAULT)
    parser.add_argument(
        "-m", "--model", required=True, help="Model to load")
    parser.add_argument(
        "--sample", default=False, action='store_true',
        help="Enable sampling mode")
    prog_args = parser.parse_args()

    conf = configparser.ConfigParser()
    if not conf.read(os.path.expanduser(prog_args.config)):
        log.error("Configuration file %s not found",
                  prog_args.config)
        sys.exit()

    emb_dict = data.load_emb_dict(
        os.path.dirname(prog_args.model))
    log.info("Loaded embedded dict with %d entries",
             len(emb_dict))
    rev_emb_dict = {
        idx: word for word, idx in emb_dict.items()
    }
    end_token = emb_dict[data.END_TOKEN]

    net = model.PhraseModel(
        emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict),
        hid_size=model.HIDDEN_STATE_SIZE)
    net.load_state_dict(torch.load(prog_args.model))

    def bot_func(bot, update, args):
        text = " ".join(args)
        words = utils.tokenize(text)
        seq_1 = data.encode_words(words, emb_dict)
        input_seq = model.pack_input(seq_1, net.emb)
        enc = net.encode(input_seq)
        if prog_args.sample:
            _, tokens = net.decode_chain_sampling(
                enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
                stop_at_token=end_token)
        else:
            _, tokens = net.decode_chain_argmax(
                enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
                stop_at_token=end_token)
        if tokens[-1] == end_token:
            tokens = tokens[:-1]
        reply = data.decode_words(tokens, rev_emb_dict)
        if reply:
            reply_text = utils.untokenize(reply)
            bot.send_message(chat_id=update.message.chat_id,
                             text=reply_text)

    updater = telegram.ext.Updater(conf['telegram']['api'])
    updater.dispatcher.add_handler(
        telegram.ext.CommandHandler('bot', bot_func,
                                    pass_args=True))

    log.info("Bot initialized, started serving")
    updater.start_polling()
    updater.idle()


In [None]:
#use_model
import os
import argparse
import logging

from libbots import data, model, utils

import torch

log = logging.getLogger("use")


def words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=False):
    tokens = data.encode_words(words, emb_dict)
    input_seq = model.pack_input(tokens, net.emb)
    enc = net.encode(input_seq)
    end_token = emb_dict[data.END_TOKEN]
    if use_sampling:
        _, out_tokens = net.decode_chain_sampling(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
                                                  stop_at_token=end_token)
    else:
        _, out_tokens = net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
                                                stop_at_token=end_token)
    if out_tokens[-1] == end_token:
        out_tokens = out_tokens[:-1]
    out_words = data.decode_words(out_tokens, rev_emb_dict)
    return out_words


def process_string(s, emb_dict, rev_emb_dict, net, use_sampling=False):
    out_words = words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=use_sampling)
    print(" ".join(out_words))


if __name__ == "__main__":
    logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument("-m", "--model", required=True, help="Model name to load")
    parser.add_argument("-s", "--string", help="String to process, otherwise will loop")
    parser.add_argument("--sample", default=False, action="store_true", help="Enable sampling generation instead of argmax")
    parser.add_argument("--self", type=int, default=1, help="Enable self-loop mode with given amount of phrases.")
    args = parser.parse_args()

    emb_dict = data.load_emb_dict(os.path.dirname(args.model))
    net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE)
    net.load_state_dict(torch.load(args.model))

    rev_emb_dict = {idx: word for word, idx in emb_dict.items()}

    while True:
        if args.string:
            input_string = args.string
        else:
            input_string = input(">>> ")
        if not input_string:
            break

        words = utils.tokenize(input_string)
        for _ in range(args.self):
            words = words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=args.sample)
            print(utils.untokenize(words))

        if args.string:
            break
    pass



In [None]:
#!/usr/bin/env python3
import argparse
import logging

from libbots import data, model, utils

import torch

log = logging.getLogger("data_test")

if __name__ == "__main__":
    logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument("--data", required=True,
                        help="Category to filter, empty string will use the full dataset")
    parser.add_argument("-m", "--model", required=True, help="Model name to load")
    args = parser.parse_args()

    phrase_pairs, emb_dict = data.load_data(args.data)
    log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))
    train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
    train_data = data.group_train_data(train_data)
    rev_emb_dict = {idx: word for word, idx in emb_dict.items()}

    net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE)
    net.load_state_dict(torch.load(args.model))

    end_token = emb_dict[data.END_TOKEN]

    seq_count = 0
    sum_bleu = 0.0

    for seq_1, targets in train_data:
        input_seq = model.pack_input(seq_1, net.emb)
        enc = net.encode(input_seq)
        _, tokens = net.decode_chain_argmax(enc, input_seq.data[0:1],
                                            seq_len=data.MAX_TOKENS, stop_at_token=end_token)
        references = [seq[1:] for seq in targets]
        bleu = utils.calc_bleu_many(tokens, references)
        sum_bleu += bleu
        seq_count += 1

    log.info("Processed %d phrases, mean BLEU = %.4f", seq_count, sum_bleu / seq_count)


In [None]:
#data_test
import argparse
import logging

from libbots import data, model, utils

import torch

log = logging.getLogger("data_test")

if __name__ == "__main__":
    logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument("--data", required=True,
                        help="Category to filter, empty string will use the full dataset")
    parser.add_argument("-m", "--model", required=True, help="Model name to load")
    args = parser.parse_args()

    phrase_pairs, emb_dict = data.load_data(args.data)
    log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))
    train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
    train_data = data.group_train_data(train_data)
    rev_emb_dict = {idx: word for word, idx in emb_dict.items()}

    net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE)
    net.load_state_dict(torch.load(args.model))

    end_token = emb_dict[data.END_TOKEN]

    seq_count = 0
    sum_bleu = 0.0

    for seq_1, targets in train_data:
        input_seq = model.pack_input(seq_1, net.emb)
        enc = net.encode(input_seq)
        _, tokens = net.decode_chain_argmax(enc, input_seq.data[0:1],
                                            seq_len=data.MAX_TOKENS, stop_at_token=end_token)
        references = [seq[1:] for seq in targets]
        bleu = utils.calc_bleu_many(tokens, references)
        sum_bleu += bleu
        seq_count += 1

    log.info("Processed %d phrases, mean BLEU = %.4f", seq_count, sum_bleu / seq_count)


In [None]:
#train_crossent
import os
import random
import argparse
import logging
import numpy as np
from tensorboardX import SummaryWriter

from libbots import data, model, utils

import torch
import torch.optim as optim
import torch.nn.functional as F

SAVES_DIR = "saves"

BATCH_SIZE = 32
LEARNING_RATE = 1e-3
MAX_EPOCHES = 100

log = logging.getLogger("train")

TEACHER_PROB = 0.5


def run_test(test_data, net, end_token, device="cpu"):
    bleu_sum = 0.0
    bleu_count = 0
    for p1, p2 in test_data:
        input_seq = model.pack_input(p1, net.emb, device)
        enc = net.encode(input_seq)
        _, tokens = net.decode_chain_argmax(
            enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
            stop_at_token=end_token)
        bleu_sum += utils.calc_bleu(tokens, p2[1:])
        bleu_count += 1
    return bleu_sum / bleu_count


if __name__ == "__main__":
    fmt = "%(asctime)-15s %(levelname)s %(message)s"
    logging.basicConfig(format=fmt, level=logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--data", required=True,
        help="Category to use for training. Empty "
             "string to train on full dataset")
    parser.add_argument(
        "--cuda", action='store_true', default=False,
        help="Enable cuda")
    parser.add_argument(
        "-n", "--name", required=True, help="Name of the run")
    args = parser.parse_args()
    device = torch.device("cuda" if args.cuda else "cpu")

    saves_path = os.path.join(SAVES_DIR, args.name)
    os.makedirs(saves_path, exist_ok=True)

    phrase_pairs, emb_dict = data.load_data(
        genre_filter=args.data)
    log.info("Obtained %d phrase pairs with %d uniq words",
             len(phrase_pairs), len(emb_dict))
    data.save_emb_dict(saves_path, emb_dict)
    end_token = emb_dict[data.END_TOKEN]
    train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
    rand = np.random.RandomState(data.SHUFFLE_SEED)
    rand.shuffle(train_data)
    log.info("Training data converted, got %d samples",
             len(train_data))
    train_data, test_data = data.split_train_test(train_data)
    log.info("Train set has %d phrases, test %d",
             len(train_data), len(test_data))

    net = model.PhraseModel(
        emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict),
        hid_size=model.HIDDEN_STATE_SIZE).to(device)
    log.info("Model: %s", net)

    writer = SummaryWriter(comment="-" + args.name)

    optimiser = optim.Adam(net.parameters(), lr=LEARNING_RATE)
    best_bleu = None
    for epoch in range(MAX_EPOCHES):
        losses = []
        bleu_sum = 0.0
        bleu_count = 0
        for batch in data.iterate_batches(train_data, BATCH_SIZE):
            optimiser.zero_grad()
            input_seq, out_seq_list, _, out_idx = \
                model.pack_batch(batch, net.emb, device)
            enc = net.encode(input_seq)

            net_results = []
            net_targets = []
            for idx, out_seq in enumerate(out_seq_list):
                ref_indices = out_idx[idx][1:]
                enc_item = net.get_encoded_item(enc, idx)
                if random.random() < TEACHER_PROB:
                    r = net.decode_teacher(enc_item, out_seq)
                    bleu_sum += model.seq_bleu(r, ref_indices)
                else:
                    r, seq = net.decode_chain_argmax(
                        enc_item, out_seq.data[0:1],
                        len(ref_indices))
                    bleu_sum += utils.calc_bleu(seq, ref_indices)
                net_results.append(r)
                net_targets.extend(ref_indices)
                bleu_count += 1
            results_v = torch.cat(net_results)
            targets_v = torch.LongTensor(net_targets).to(device)
            loss_v = F.cross_entropy(results_v, targets_v)
            loss_v.backward()
            optimiser.step()

            losses.append(loss_v.item())
        bleu = bleu_sum / bleu_count
        bleu_test = run_test(test_data, net, end_token, device)
        log.info("Epoch %d: mean loss %.3f, mean BLEU %.3f, "
                 "test BLEU %.3f", epoch, np.mean(losses),
                 bleu, bleu_test)
        writer.add_scalar("loss", np.mean(losses), epoch)
        writer.add_scalar("bleu", bleu, epoch)
        writer.add_scalar("bleu_test", bleu_test, epoch)
        if best_bleu is None or best_bleu < bleu_test:
            if best_bleu is not None:
                out_name = os.path.join(
                    saves_path, "pre_bleu_%.3f_%02d.dat" % (
                        bleu_test, epoch))
                torch.save(net.state_dict(), out_name)
                log.info("Best BLEU updated %.3f", bleu_test)
            best_bleu = bleu_test

        if epoch % 10 == 0:
            out_name = os.path.join(
                saves_path, "epoch_%03d_%.3f_%.3f.dat" % (
                    epoch, bleu, bleu_test))
            torch.save(net.state_dict(), out_name)

    writer.close()


In [None]:
#train_scst
import os
import random
import argparse
import logging
import numpy as np
from tensorboardX import SummaryWriter

from libbots import data, model, utils

import torch
import torch.optim as optim
import torch.nn.functional as F

import ptan

SAVES_DIR = "saves"

BATCH_SIZE = 16
LEARNING_RATE = 5e-4
MAX_EPOCHES = 10000

log = logging.getLogger("train")


def run_test(test_data, net, end_token, device="cpu"):
    bleu_sum = 0.0
    bleu_count = 0
    for p1, p2 in test_data:
        input_seq = model.pack_input(p1, net.emb, device)
        enc = net.encode(input_seq)
        _, tokens = net.decode_chain_argmax(
            enc, input_seq.data[0:1], seq_len=data.MAX_TOKENS,
            stop_at_token=end_token)
        ref_indices = [
            indices[1:]
            for indices in p2
        ]
        bleu_sum += utils.calc_bleu_many(tokens, ref_indices)
        bleu_count += 1
    return bleu_sum / bleu_count


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    fmt = "%(asctime)-15s %(levelname)s %(message)s"
    logging.basicConfig(format=fmt, level=logging.INFO)
    parser.add_argument(
        "--data", required=True,
        help="Category to use for training. Empty "
             "string to train on full dataset")
    parser.add_argument(
        "--cuda", action='store_true', default=False,
        help="Enable cuda")
    parser.add_argument(
        "-n", "--name", required=True,
        help="Name of the run")
    parser.add_argument(
        "-l", "--load", required=True,
        help="Load model and continue in RL mode")
    parser.add_argument(
        "--samples", type=int, default=4,
        help="Count of samples in prob mode")
    parser.add_argument(
        "--disable-skip", default=False, action='store_true',
        help="Disable skipping of samples with high argmax BLEU")
    args = parser.parse_args()
    device = torch.device("cuda" if args.cuda else "cpu")

    saves_path = os.path.join(SAVES_DIR, args.name)
    os.makedirs(saves_path, exist_ok=True)

    phrase_pairs, emb_dict = \
        data.load_data(genre_filter=args.data)
    log.info("Obtained %d phrase pairs with %d uniq words",
             len(phrase_pairs), len(emb_dict))
    data.save_emb_dict(saves_path, emb_dict)
    end_token = emb_dict[data.END_TOKEN]
    train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
    rand = np.random.RandomState(data.SHUFFLE_SEED)
    rand.shuffle(train_data)
    train_data, test_data = data.split_train_test(train_data)
    log.info("Training data converted, got %d samples",
             len(train_data))
    train_data = data.group_train_data(train_data)
    test_data = data.group_train_data(test_data)
    log.info("Train set has %d phrases, test %d",
             len(train_data), len(test_data))

    rev_emb_dict = {idx: word for word, idx in emb_dict.items()}

    net = model.PhraseModel(
        emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict),
        hid_size=model.HIDDEN_STATE_SIZE).to(device)
    log.info("Model: %s", net)

    writer = SummaryWriter(comment="-" + args.name)
    net.load_state_dict(torch.load(args.load))
    log.info("Model loaded from %s, continue "
             "training in RL mode...", args.load)

    # BEGIN token
    beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]])
    beg_token = beg_token.to(device)

    with ptan.common.utils.TBMeanTracker(
            writer, 100) as tb_tracker:
        optimiser = optim.Adam(
            net.parameters(), lr=LEARNING_RATE, eps=1e-3)
        batch_idx = 0
        best_bleu = None
        for epoch in range(MAX_EPOCHES):
            random.shuffle(train_data)
            dial_shown = False

            total_samples = 0
            skipped_samples = 0
            bleus_argmax = []
            bleus_sample = []

            for batch in data.iterate_batches(
                    train_data, BATCH_SIZE):
                batch_idx += 1
                optimiser.zero_grad()
                input_seq, input_batch, output_batch = \
                    model.pack_batch_no_out(batch, net.emb, device)
                enc = net.encode(input_seq)

                net_policies = []
                net_actions = []
                net_advantages = []
                beg_embedding = net.emb(beg_token)

                for idx, inp_idx in enumerate(input_batch):
                    total_samples += 1
                    ref_indices = [
                        indices[1:]
                        for indices in output_batch[idx]
                    ]
                    item_enc = net.get_encoded_item(enc, idx)
                    r_argmax, actions = net.decode_chain_argmax(
                        item_enc, beg_embedding, data.MAX_TOKENS,
                        stop_at_token=end_token)
                    argmax_bleu = utils.calc_bleu_many(
                        actions, ref_indices)
                    bleus_argmax.append(argmax_bleu)

                    if not args.disable_skip:
                        if argmax_bleu > 0.99:
                            skipped_samples += 1
                            continue

                    if not dial_shown:
                        w = data.decode_words(
                            inp_idx, rev_emb_dict)
                        log.info("Input: %s", utils.untokenize(w))
                        ref_words = [
                            utils.untokenize(
                                data.decode_words(
                                    ref, rev_emb_dict))
                            for ref in ref_indices
                        ]
                        ref = " ~~|~~ ".join(ref_words)
                        log.info("Refer: %s", ref)
                        w = data.decode_words(
                            actions, rev_emb_dict)
                        log.info("Argmax: %s, bleu=%.4f",
                                 utils.untokenize(w), argmax_bleu)

                    for _ in range(args.samples):
                        r_sample, actions = \
                            net.decode_chain_sampling(
                                item_enc, beg_embedding,
                                data.MAX_TOKENS,
                                stop_at_token=end_token)
                        sample_bleu = utils.calc_bleu_many(
                            actions, ref_indices)

                        if not dial_shown:
                            w = data.decode_words(
                                actions, rev_emb_dict)
                            log.info("Sample: %s, bleu=%.4f",
                                     utils.untokenize(w),
                                     sample_bleu)

                        net_policies.append(r_sample)
                        net_actions.extend(actions)
                        adv = sample_bleu - argmax_bleu
                        net_advantages.extend(
                            [adv]*len(actions))
                        bleus_sample.append(sample_bleu)
                    dial_shown = True

                if not net_policies:
                    continue

                policies_v = torch.cat(net_policies)
                actions_t = torch.LongTensor(
                    net_actions).to(device)
                adv_v = torch.FloatTensor(
                    net_advantages).to(device)
                log_prob_v = F.log_softmax(policies_v, dim=1)
                lp_a = log_prob_v[range(len(net_actions)),
                                  actions_t]
                log_prob_actions_v = adv_v * lp_a
                loss_policy_v = -log_prob_actions_v.mean()

                loss_v = loss_policy_v
                loss_v.backward()
                optimiser.step()

                tb_tracker.track("advantage", adv_v, batch_idx)
                tb_tracker.track("loss_policy", loss_policy_v,
                                 batch_idx)
                tb_tracker.track("loss_total", loss_v, batch_idx)

            bleu_test = run_test(test_data, net,
                                 end_token, device)
            bleu = np.mean(bleus_argmax)
            writer.add_scalar("bleu_test", bleu_test, batch_idx)
            writer.add_scalar("bleu_argmax", bleu, batch_idx)
            writer.add_scalar("bleu_sample",
                              np.mean(bleus_sample), batch_idx)
            writer.add_scalar("skipped_samples",
                              skipped_samples / total_samples,
                              batch_idx)
            writer.add_scalar("epoch", batch_idx, epoch)
            log.info("Epoch %d, test BLEU: %.3f",
                     epoch, bleu_test)
            if best_bleu is None or best_bleu < bleu_test:
                best_bleu = bleu_test
                log.info("Best bleu updated: %.4f", bleu_test)
                torch.save(net.state_dict(), os.path.join(
                    saves_path, "bleu_%.3f_%02d.dat" % (
                        bleu_test, epoch)))
            if epoch % 10 == 0:
                torch.save(net.state_dict(), os.path.join(
                    saves_path, "epoch_%03d_%.3f_%.3f.dat" % (
                        epoch, bleu, bleu_test)))

    writer.close()



In [None]:
from unittest import TestCase

import libbots.data
from libbots import data, subtitles


class TestData(TestCase):
    emb_dict = {
        data.BEGIN_TOKEN: 0,
        data.END_TOKEN: 1,
        data.UNKNOWN_TOKEN: 2,
        'a': 3,
        'b': 4
    }

    def test_encode_words(self):
        res = data.encode_words(['a', 'b', 'c'], self.emb_dict)
        self.assertEqual(res, [0, 3, 4, 2, 1])

    # def test_dialogues_to_train(self):
    #     dialogues = [
    #         [
    #             libbots.data.Phrase(words=['a', 'b'], time_start=0, time_stop=1),
    #             libbots.data.Phrase(words=['b', 'a'], time_start=2, time_stop=3),
    #             libbots.data.Phrase(words=['b', 'a'], time_start=2, time_stop=3),
    #         ],
    #         [
    #             libbots.data.Phrase(words=['a', 'b'], time_start=0, time_stop=1),
    #         ]
    #     ]
    #
    #     res = data.dialogues_to_train(dialogues, self.emb_dict)
    #     self.assertEqual(res, [
    #         ([0, 3, 4, 1], [0, 4, 3, 1]),
    #         ([0, 4, 3, 1], [0, 4, 3, 1]),
    #     ])


In [None]:
#test_subtitles

import datetime
from unittest import TestCase

import libbots.data
from libbots import subtitles


class TestPhrases(TestCase):
    def test_split_phrase(self):
        phrase = libbots.data.Phrase(words=["a", "b", "c"], time_start=datetime.timedelta(seconds=0),
                                     time_stop=datetime.timedelta(seconds=10))
        res = subtitles.split_phrase(phrase)
        self.assertIsInstance(res, list)
        self.assertEqual(len(res), 1)
        self.assertEqual(res[0], phrase)

        phrase = libbots.data.Phrase(words=["a", "b", "-", "c"], time_start=datetime.timedelta(seconds=0),
                                     time_stop=datetime.timedelta(seconds=10))
        res = subtitles.split_phrase(phrase)
        self.assertEqual(len(res), 2)
        self.assertEqual(res[0].words, ["a", "b"])
        self.assertEqual(res[1].words, ["c"])
        self.assertAlmostEqual(res[0].time_start.total_seconds(), 0)
        self.assertAlmostEqual(res[0].time_stop.total_seconds(), 5)
        self.assertAlmostEqual(res[1].time_start.total_seconds(), 5)
        self.assertAlmostEqual(res[1].time_stop.total_seconds(), 10)

        phrase = libbots.data.Phrase(words=['-', 'Wait', 'a', 'sec', '.', '-'], time_start=datetime.timedelta(0, 588, 204000),
                                     time_stop=datetime.timedelta(0, 590, 729000))
        res = subtitles.split_phrase(phrase)
        self.assertEqual(res[0].words, ["Wait", "a", "sec", "."])


class TestUtils(TestCase):
    def test_parse_time(self):
        self.assertEqual(subtitles.parse_time("00:00:33,074"),
                         datetime.timedelta(seconds=33, milliseconds=74))

    def test_remove_braced_words(self):
        self.assertEqual(subtitles.remove_braced_words(['a', 'b', 'c']),
                         ['a', 'b', 'c'])
        self.assertEqual(subtitles.remove_braced_words(['a', '[', 'b', ']', 'c']),
                         ['a', 'c'])
        self.assertEqual(subtitles.remove_braced_words(['a', '[', 'b', 'c']),
                         ['a'])
        self.assertEqual(subtitles.remove_braced_words(['a', ']', 'b', 'c']),
                         ['a', 'b', 'c'])
        self.assertEqual(subtitles.remove_braced_words(['a', '(', 'b', ']', 'c']),
                         ['a', 'c'])
        self.assertEqual(subtitles.remove_braced_words(['a', '(', 'b', 'c']),
                         ['a'])
        self.assertEqual(subtitles.remove_braced_words(['a', ')', 'b', 'c']),
                         ['a', 'b', 'c'])


In [None]:
#data_test

import libbots.data
from libbots import data, subtitles


class TestData(TestCase):
    emb_dict = {
        data.BEGIN_TOKEN: 0,
        data.END_TOKEN: 1,
        data.UNKNOWN_TOKEN: 2,
        'a': 3,
        'b': 4
    }

    def test_encode_words(self):
        res = data.encode_words(['a', 'b', 'c'], self.emb_dict)
        self.assertEqual(res, [0, 3, 4, 2, 1])

    # def test_dialogues_to_train(self):
    #     dialogues = [
    #         [
    #             libbots.data.Phrase(words=['a', 'b'], time_start=0, time_stop=1),
    #             libbots.data.Phrase(words=['b', 'a'], time_start=2, time_stop=3),
    #             libbots.data.Phrase(words=['b', 'a'], time_start=2, time_stop=3),
    #         ],
    #         [
    #             libbots.data.Phrase(words=['a', 'b'], time_start=0, time_stop=1),
    #         ]
    #     ]
    #
    #     res = data.dialogues_to_train(dialogues, self.emb_dict)
    #     self.assertEqual(res, [
    #         ([0, 3, 4, 1], [0, 4, 3, 1]),
    #         ([0, 4, 3, 1], [0, 4, 3, 1]),
    #     ])

