# SS 2021 SEMINAR 11 Reinforcement Learning in der Sprachtechnologie
## Reinforcement Learning Based Speech Agents I 

### Announcements

#### Grading

* Grading: procentual division: 50/50  -> 50/16/16/16

#### Homework
        
* DEADLINE: Today, JUNE 23th

#### Today

* Initial Example on RL based Speech Agents

* General First Step when RL is put on a NLP problem -> REINFORCE/Policy Gradient Methods

* Putting RL on-top of Supervised Learning Based Speech Agents: Self-Critical Sequence Training

***

### A) Paper Presentation I: Nazia

### Title: Generating Visual Explanations

##### Link: 
[Generating Visual Explanations](https://www.semanticscholar.org/paper/Generating-Visual-Explanations-Hendricks-Akata/ecf551d532d0e9cfb252a1bea04d14db620bc488)

##### Summary:
The main goal is to produce textual explanations for the classification decision of a visual system (e.g., image classification) to a non-expert user using current language and vision model (here, image captioning model). These textual explanations provide rationale for the prediction made by an AI system and plays a crucial role in developing user trust. The paper proposes a loss function based on samping and reinforcement learning to generate class-relevant descriptions (explanations).

##### Task:
To generate a textual explanation along with a classification decision for a fine-grained visual system (bird-image classification).
For example, predicted class label (“This is a warbler...”) and then an explanation conjunction (e.g., “because”) followed by the explantory text sentence fragment produced by the model (“this small bird has brown wings and yellow pointy beak”).

![alt_text](https://d3i71xaburhd42.cloudfront.net/7e6eb3091f84c10e65921e708f67f69c299ebd10/2-Figure1-1.png "")

##### Data:
+ Image data: CUB-200-2011 data consists of 200 north american bird species
+ Language data: Five sentence-level descriptions for each image

##### Visual Explanation Model:
* **Fine-grained image classifier**
    + Compact bilinear fine-grained clasification model pre-trained on CUB dataset

* **Language generation model/ Visual description model (image captioning)**
    + Encoder: CNN extracts high-level visual features
    + Decoder: Two stacked LSTMs conditioned on visual features as well as category label

* **Sentence classifier**
    + Bird category classification using image descriptions

![alt text](https://d3i71xaburhd42.cloudfront.net/7e6eb3091f84c10e65921e708f67f69c299ebd10/5-Figure3-1.png "")

* **Loss functions**
* Relevance loss
    + To train the model to predict each word in ground truth sentence
    + Produce sentences which corresponds more to the image content

* Discriminative loss based on **Reinforcement learning paradigm**
    + Based on REINFORCE algorithm
    + Sampling category relevant descriptions using Monte Carlo sampling method 
        + __State space__: {image features, category label features}
        + __Action space__: set of most relevant sentence for 200 bird categories 
        + __Policy__: to be able to sample a sentence most relevant to the predicted bird category (using model's softmax output) so as to maximize reward
        + __Reward__: high reward if predicted image classifier and sentence classifier outcome is same
        + [Helpful post](https://www.analyticsvidhya.com/blog/2020/11/reinforce-algorithm-taking-baby-steps-in-reinforcement-learning/) 
    
* The overall loss is a linear combination of the relevance and discriminative loss.

##### Results:
* **Metrics**
    + Automatic metrics (CIDEr and METEOR)
    + Human Evaluation
    + Higher METEOR and CIDEr scores should mean produced sentences are more image and class relevant.

![alt_text](https://d3i71xaburhd42.cloudfront.net/7e6eb3091f84c10e65921e708f67f69c299ebd10/10-Table1-1.png "")

* **Models (comparison of explanation, baselines and ablations)**
    + Definition: Using only category label as input
    + Description: Using only image features as input
    + Explanation Label: not trained with discriminative loss
    + Explanation Dis : not conditioned on category label
    + Explanation: using both image features and category label as input

* **Findings**
    + Explanation model has higher METEOR and CIDEr scores than baselines
    + Explanation model has higher class similarity score 
    + Explanation model has lower class rank which suggests that generated sentences closely resemble the correct class
    + Also, for human evaluation, explanation model has lower rank. All models conditioned on label has lower rank


##### Critical Discussion: 
* ++ Logical structure
* ++ Natural language explanation module trained jointly with classification model. More intuitive to non-expert users of the system
* -- Class similarity metric is not entrely a proper measure. Higher CIDEr scores can also mean that overall sentences are better; even when they are not class relevant
* -- Missing information of a few model parameters, could be important for reproducibility

***

### Discussion



***

## B) Theory

 **Putting RL on top of Supervised ML Problems in NLP**


Okay, what do we have so far from an NLP perspective:

### Sequence2Sequence Models

The idea was to model language as a sequence of tokens and to produce language models, whereby a language model is a model that getting a token is able to predict the next token. 

* SOTA: RNN based 

<img src="https://miro.medium.com/max/3972/1*1JcHGUU7rFgtXC_mydUA_Q.jpeg" width="500"/>

* SOTA: Transformer based

<img src="https://jalammar.github.io/images/t/transformer_resideual_layer_norm_3.png" width="500"/>

***

### Word Embeddings

For having better, more semantically representing representations for tokens, we use word embeddings

<img src="https://ruder.io/content/images/size/w2000/2016/04/word_embeddings_colah.png" width="500"/>

***

### Connection of RL and NLP

How can we connect these ideas of supervised learning based NLP approaches with Reinforcement Learning? 

* The connection of RL and NLP happens in the **Seq2Seq-Training Approach**

* RECAP: How do we train **supervised?** 
* => **LOG LIKELIHOOD TRAINING = given this input (and this past of length n steps), what is the most likely output token**

<img src="https://miro.medium.com/max/1838/1*ooyV-A6O7so_EhvkGZMZ3g.png" width="500"/>

* WHY is that NOT really satisfying? 
* => If we take a closer look, the supervised approach is treated as a **STEP WISE CLASSIFICATION PROBLEM**, that means for every decoding step of a sequence (e.g. a natural language sentence) we compute the log likelihood over ALL OUTPUT TOKENS (=vocabulary) and then classify the token with the highest log probability. 

<img src="https://miro.medium.com/max/5412/1*NKhwsOYNUT5xU7Pyf6Znhg.png" width="500"/>

* OPTIMIZATION: We can train this supervised multi step classification problem based on **SGD** in combination with cross-entropy loss/log-loss.  

* HOW DO WE TRAIN?: 
* => it turns out that we have different possibilities to train our models, especially our DECODERS.
* The 1. naive approach is to train the decoder based on its own output, which is called ARGMAX training, or SELF-FORCED. This leads to LONG training times.
* To improve this => What we do is something called **TEACHER FORCING** => that means as the gold output we compute the token-wise loss against a gold sentence and **FEED THE GOLD TOKEN** as input into the decoder to produce the next token. This leads to shorter training times but leads to less robust models in inference and also the model seems to give more or less average answers and does not really adapt to the input. 
*  Solution: => is in the middle: **CURRICULUM learning** = start with teacher forcing and the more you learn the more often use the own generated outputs as inputs for the next step

<img src="https://miro.medium.com/max/1078/1*lmjxIAhR3WmPKY_6bSgM1w.png" width="500"/>

***

Lets have a closer look on the decoding:

* DECODING:

* PROBLEM 1: In decoding, we then face the problem, that there is NO GOLD TOKEN anymore and we have to use the token from the last step as input to the next step. When applying this, our model (most of the time) produces very strange utterances, because HUMANS do not produce sentences based on **what could be the most probable next token**, but more seem to have a look-ahead over the next N tokens (3-8) and then distribute the probability of the main point they want to transmit over these tokens. 

* PROBLEM 2: there is another problem that can appear: we have a single point of failure in decoding -> that means if we decode one token wrong it will lead our whole output sequence into the wrong direction. During the training, we haven't asked our decoder RNN to use its own output as input, so one single mistake during the generation may confuse the decoder and lead to garbage output.

* SOLUTION: In practice, what we then apply is BEAM-SEARCH as a trick, to find the most suitable sentence that (hopefully) sounds nice to humans. That means we simulate an N step (most of the time between 5 and 10) look-ahead and out of these sequences we pick the sequence, that after the N steps is the most likely one. 

<img src="https://www.researchgate.net/profile/Basura-Fernando/publication/322582680/figure/fig2/AS:659877194113026@1534338391439/Example-of-constrained-beam-search-decoding-Each-output-sequence-must-include-the-words.png" width="500"/>

* PROBLEM 3: but wait, how do we evaluate the goodness of a sentence at all? 
* => Pure LOSS is not a very good thing to optimize against, because it leads to learning things by heart and not really generalize well. Alternative: BLEU, ROUGE, (CIDER, SPICE, for image descriptions)..... As a gist: overlap metrics, weighted overlap metrics (unigram, bigram, ..), concept-graph metrics (cider, spice). These metrics allow us to compute the overlap of our output to a SET of outputs, that we know are good/solve the problem. 

<img src="https://slidetodoc.com/presentation_image_h/09d128ca66b95f499e3c32f2446b960e/image-43.jpg" width="500"/>

If you have a novel idea for scoring: YOU'RE WELCOME :)

Seems to be a lot of dirty magic going on, ha ? :)

***

In fact, **curriculum learning** and **beam search** might not be the only ways to decode a sequence. 

This is where RL joins the game. 

* QUESTION: How to form the DECODING into an RL problem?

* FIRST STEP: The first thing to note is that our decoder outputs the **probability distribution at every step**, which is very similar to policy gradient models. From this perspective, our decoder could be seen as an agent trying to decide which token to produce at every step.

* ADVANTAGES (of framing decoding like this): 

* 1. GOAL DIRECTED: = convey a certain information. solution = sequence of tokens = sentence => multiple possible solutions/sentences that lead to the goal. (e.g. there are many possible replies to 'how are you?') By optimizing the log-likelihood objective/supervised learning, our model will try to learn some **average of all those replies**, but the average of the phrases might mostly be **'I'm fine, thanks'!** and will not necessarily be a meaningful phrase. By returning the probability distribution and sampling the next token from it, our agent potentially could learn how to produce all possible variants, instead of learning some averaged answer.

* 2. SCORE DRIVEN: The second benefit is optimizing the objective that we care about. In log-likelihood training, we're minimizing the cross-entropy between the produced tokens and tokens from the reference, but in machine translation, and many other NLP problems, we don't really care about log-likelihood: we want to maximize the BLEU score of the produced sequence. Unfortunately, the BLEU score is not differentiable, so we can't backpropagate on it. However, policy gradient methods, such as REINFORCE (Policy Gradients—an Alternative) work even when the reward is not differentiable: we just push up the probabilities of successful episodes and decrease the worse ones.

* 3. SAMPLING: The third advantage we can exploit is in the fact that our sequence generation process is defined by us and we know its internals. By introducing stochasticity into the process of decoding, we can repeat the decoding process several times, gathering different decoding scenarios from the single training sample. This can be beneficial when our training dataset is limited, which is almost always the case. This can be useful in BOOTSTRAPPING our models for generating our own training data. 

***
***

EXCURSION: POLICY GRADIENT METHODS IN A NUTSHELL

We need an algorithm , that allows us to estimate the goodness of a generated sequence (e.g a sentence) and then iteratively adapt to get better in generating new sequences. 

There is an algorithm that solves this sequence generation problem: REINFORCE.

REINFORCE as opposed to e.g. Q-Learning is a **Policy Gradient Method**. The idea behind value-based methods like q-learning was to evaluate every ACTION in every STATE due to it's **goodness to the overall return** and then in every state choose the ACTION with the highest Q-Value. 

We therefore can say, that in Q-Learning we do not build up a policy, merely we build up a Q-Table/Q-Function and **USE THE Q-TABLE/FUNCTION** in a **GREEDY WAY** (by picking always the action with the highest q-value) to **EXTRACT THE (OPTIMAL) POLICY**. 

<img src="https://cdn.analyticsvidhya.com/wp-content/uploads/2019/04/Screenshot-2019-04-16-at-5.46.01-PM.png" width="500"/>

In policy gradient methods we follow a different approach: 

* We directly build up a policy, which is a FUNCTION that takes in a state and returns (a probability distribution over) ACTION(S)

<img src="https://mohitd.github.io/images/deep-rl-policy-methods/policy-network.svg" width="500"/>

* WE update the parameters of our current policy with the gradient of the **expected return in an EPISODE** towards the policy pi we are following

* WE can take the gradient based on this by maximizing the reward of our trajectory => this is like moving the direction of the highest reward. instead moving in the direction of the negative gradient, like we want to do in gradient descent and in most supervised learning tasks, we now use gradient ascent and move into the direction of the positive gradient. 

* FUNDAMENTAL: The base for policy gradient methods is the **policy gradient theorem**, which states that this update will converge to the optimal policy. 

* REMEMBER: a policy induces a trajectory distribution (= for every state it induces a sequence of actions towards the goal state, this sequence can be deterministic or stochastic), therefore a policy also induces a reward distributions (over the different trajectories).

This looks like this: 

<img src="https://miro.medium.com/max/3464/1*zkOBQ9Izq28yXCANTmdKtA.png" width="500"/>

Explained in a few words the idea is: 
1. come up with a random trajectory as your baseline. (e.g. the ARGMAX baseline)
2. compute the reward of that trajectory.
3. come up with another trajectory.
4. compute the reward of that trajectory.
5. compute the gradients towards the baseline for the rewards in every step (=averaging over the batch, in this case of size 2) 
6. apply gradient ascent until convergence

This can be done not only over two trajectories, but also over a batch of N trajectories and then it can be averaged, analogously to what we do in supervised learning. 

<img src="https://static.horiba.com/fileadmin/Horiba/_processed_/a/2/csm_Validated_Performance_Logo_MM_Picture_fr_751ed2e239.png" width="500"/>

In contrast to supervised learning, where we compute the difference of our output to the gold standard over n examples and then average that and move into the direction of the negative gradient to reach the minimum distance, in policy gradient methods, we compute the return over n trajectories and than average that and move into the direction of the gradient to reach the maximum reward. 

Over the training time, this ensures, that for every timestep, the actions that were GOOD(=produced good returns) become more likely under the current policy that we are building up. 

Applied to DEEP REINFORCEMENT LEARNING policy gradients in action look like this: 

1. we start with a randomly initialized neural network which we call our policy network. 
2. we feed our policy network a STATE REPRESENTATION of the current TIME STEP T and it produces a DISTRIBUTION over ACTIONS as output, 
3. Due to our exploration-exploitation trade-off we sample from this distribution to get the action for that state and timestep, 
4. we collect a REWARD for that state-action sequence at timestep T,
5. we get the next state, next reward, next state, next reward -> we collect the SEQUENCE of (s,a,r,s') as a trajectory. 
6. We than take the hole SEQUENCE of steps as an EXPERIENCE and apply the policy gradient theorem to our batch of experiences by using the accumulated reward of the whole episode as our SCORING function. This makes it also suitable for sparse reward problems. 
7. We do this until convergence, and the theorem proves, that over time, the actions that lead to negative rewards are slowly going to be filtered out and for every state, the actions that lead to higher returns at the end of an episode are becoming more likeli. 

<img src="https://abhishm.github.io/assets/images/2017-05-26-policy-gradient-with-RNN/mlp_policy.png" width="500"/>

IMPORTANT: The difference to DEEP Q LEARNING is, that in a Deep Q Network, we learn THE Q-VALUE DISTRIBUTION over the ACTIONS for every state. In contrast, in a POLICY GRADIENT NETWORK, we learn the PROBABILITY DISTRIBUTION over the ACTIONS for every state. 

That means, we are optimizing towards different things! (and our updates are based on different things, in deep q learning we update towards the expected return of the action in the next state (as an approximation of the RETURN), in a policy gradient network, we update to the SEQUENCE of actions that lead to the highest reward.  

DISADVANTAGES: 

* CREDIT ASSIGNMENT PROBLEM: the problem of policy gradient methods is that it uses the accumulated reward at the end of an episode to evaluate a trajectory. This is especially suitable in case of sparse reward problems, BUT it thereby also assumes, that in an episode e.g. of the game pong, where an agent LOST the game, ALL ACTIONS must have been bad actions... which is obviously NOT ALWAYS THE CASE! This leads us to the CREDIT ASSIGNMENT PROBLEM, which is the problem of assigning rewards to specific actions. (=which action caused the reward or what are the (combination) of actions that lead to the reward). In practice this leads to the case that algorithms need a LOT of episodes as EXAMPLES for learning, which means a high training time (sample efficiency). 

* REWARD SHAPING: Sparse Reward Problems lead to REWARD SHAPING = introducing additional rewards at different time stamps to overcome the sparcity. This leads to different approaches like CURIOSITY or INTRINSIC MOTIVATION, where agents are motivated for e.g. finding new states they never saw before. BUT artificially created reward functions during reward shaping can lead to a lot of problems when done wrong (e.g. an agent that tries to trick the intermediate reward and gets addicted to intermediate rewards and does NOT reach the final reward). That means, that training in sparse reward settings is very hard but at the same time reward shaping is also very tricky crucial for success. 


***

***

OKAY, how to apply all this now to a Seq2Seq generation problem in NLP?

SOLUTION:

REINFORCE for seq2seq in a nutshell:

In practice, REINFORCE for seq2seq training could be written as the following algorithm:

1. For every dialogue input in the dataset, obtain the encoded representation using e.g. an RNN
2. Initialize the current token with the special begin token: T = 'BEG' 
3. Initialize the output sequence with the empty sequence: Out = []
4. While END_TOKEN not reached:
    * Get the probability distribution of the tokens and the new hidden state, passing the current token and the hidden state to the decoder
    * Sample the output token from the probability distribution
    * Remember the probability distribution 
    * Append the output token to the output sequence
    * Set the current token to the last output token
5. Calculate BLEU or another metric between Out and the reference sequences: Q = BLEU(Out, Outref)
6. Estimate the gradients
7. update the model using SGD
8. repeat until convergence
    
* This process is called **SELF-CRITICAL SEQUENCE TRAINING**. 

<img src="https://miro.medium.com/max/3200/1*twlT4no5KP5Bw3TiDOf1kw.png" width="500"/>

<br>

<img src="https://planspace.org/20171114-deep_rl/img/planning2.png" width="500"/>

<br>

* DISADVANTAGES: 
* VOCABULARY SIZE: vast amount of output sequences that are possible => this makes it to  a sparse reward problem. It's almost useless to train from scratch. Even for simple dialogues, the output sequence usually has at least five words, each taken from a dictionary of several thousand words. The number of different phrases of size five with a dictionary of 1,000 words equals 51000, which is slightly less than 10700. So, the probability of obtaining the correct reply in the beginning of the training (when our weights for both the encoder and decoder are random) is negligibly small.

* ADVANTAGES:
* SEQUENCE SELECTION: Better handling of multiple target sequences. For example, 'hi' could be replied with 'hi, hello, not interested', or something else. The RL point of view is to treat our decoder as a process of selecting actions when every action is a token to be generated, which fits better to the problem.

* BLEU-DIRECTED: Optimizing the BLEU score directly instead of cross-entropy loss. Using the BLEU score for the generated sequence as a gradient scale, we can push our model toward the successful sequences and decrease the probability of unsuccessful ones.

* SAMPLING: By repeating the decoding process, we can generate more episodes to train on, which will lead to better gradient estimation.

* SOLUTION: 
* => Combine the best of both worlds => Use Curriculum Learning (at the beginning teacher forcing, then more and more self-forcing) and then switch to REINFORCE to fine tune the model

* PHILOSOPHICAL ASPECT: This is also extremely interesting from a philosophical stand-point and might be the case why TEACHERS and TEACHING might be useful in a group of agents and even through generations of agents. In general, this could be seen as a uniform approach to complex RL problems, when a large action space makes it infeasible to start with a randomly behaving agent, as the chance of such an agent randomly reaching the goal is negligible. => TEACHERS are used to BOOTSTRAP new agents quickly to a certain level until they can REINFORCE themselves to find the best solutions. Sounds like the real world, ha ? 


## C) Practice

How can we implement **SELF-CRITICAL-SEQUENCE-TRAINING** in python?

We already know the supervised-learning based dialogue agent approach from the last seminar. 

Starting from this model, we put our REINFORCE algorithm on top of the supervised learning based approach. 

This helps us in selecting better sounding sequences and works as an alternative to BEAM-SEARCH decoding. 

I'm basing my examples on the book of Maxim Lapan: https://www.amazon.de/-/en/Maxim-Lapan/dp/1838826998

In [13]:
# DATA PREPROCESSING
"""
Cornel Movies Dialogs Corpus
https://www.cs.cornell.edu/~cristian/Cornell_Movie-Dialogs_Corpus.html
"""
import os
import logging

# UTILS
import string
from nltk.translate import bleu_score
from nltk.tokenize import TweetTokenizer

def calc_bleu_many(cand_seq, ref_sequences):
    sf = bleu_score.SmoothingFunction()
    return bleu_score.sentence_bleu(ref_sequences, cand_seq,
                                    smoothing_function=sf.method1,
                                    weights=(0.5, 0.5))


def calc_bleu(cand_seq, ref_seq):
    return calc_bleu_many(cand_seq, [ref_seq])


def tokenize(s):
    return TweetTokenizer(preserve_case=False).tokenize(s)


def untokenize(words):
    to_pad = lambda t: not t.startswith("'") and \
                       t not in string.punctuation
    return "".join([
        (" " + i) if to_pad(i) else i
        for i in words
    ]).strip()


# CORNELL MOVIE DATASET DIALOGUE LOADERS
log = logging.getLogger("cornell")
DATA_DIR = "data/"
SEPARATOR = "+++$+++"

def load_dialogues(data_dir=DATA_DIR, genre_filter=''):
    """
    Load dialogues from cornell data
    :return: list of list of list of words
    """
    movie_set = None
    if genre_filter:
        movie_set = read_movie_set(data_dir, genre_filter)
        log.info("Loaded %d movies with genre %s", len(movie_set), genre_filter)
    log.info("Read and tokenise phrases...")
    lines = read_phrases(data_dir, movies=movie_set)
    log.info("Loaded %d phrases", len(lines))
    dialogues = load_conversations(data_dir, lines, movie_set)
    return dialogues


def iterate_entries(data_dir, file_name):
    with open(os.path.join(data_dir, file_name), "rb") as fd:
        for l in fd:
            l = str(l, encoding='utf-8', errors='ignore')
            yield list(map(str.strip, l.split(SEPARATOR)))


def read_movie_set(data_dir, genre_filter):
    res = set()
    for parts in iterate_entries(data_dir, "movie_titles_metadata.txt"):
        m_id, m_genres = parts[0], parts[5]
        if m_genres.find(genre_filter) != -1:
            res.add(m_id)
    return res


def read_phrases(data_dir, movies=None):
    res = {}
    for parts in iterate_entries(data_dir, "movie_lines.txt"):
        l_id, m_id, l_str = parts[0], parts[2], parts[4]
        if movies and m_id not in movies:
            continue
        tokens = tokenize(l_str)
        if tokens:
            res[l_id] = tokens
    return res


def load_conversations(data_dir, lines, movies=None):
    res = []
    for parts in iterate_entries(data_dir, "movie_conversations.txt"):
        m_id, dial_s = parts[2], parts[3]
        if movies and m_id not in movies:
            continue
        l_ids = dial_s.strip("[]").split(", ")
        l_ids = list(map(lambda s: s.strip("'"), l_ids))
        dial = [lines[l_id] for l_id in l_ids if l_id in lines]
        if dial:
            res.append(dial)
    return res


def read_genres(data_dir):
    res = {}
    for parts in iterate_entries(data_dir, "movie_titles_metadata.txt"):
        m_id, m_genres = parts[0], parts[5]
        l_genres = m_genres.strip("[]").split(", ")
        l_genres = list(map(lambda s: s.strip("'"), l_genres))
        res[m_id] = l_genres
    return res



In [14]:
# PREPROCESSING
import collections
import os
import sys
import logging
import itertools
import pickle

UNKNOWN_TOKEN = '#UNK'
BEGIN_TOKEN = "#BEG"
END_TOKEN = "#END"
MAX_TOKENS = 20
MIN_TOKEN_FEQ = 10
SHUFFLE_SEED = 5871

EMB_DICT_NAME = "emb_dict.dat"
EMB_NAME = "emb.npy"

log = logging.getLogger("data")


def save_emb_dict(dir_name, emb_dict):
    with open(os.path.join(dir_name, EMB_DICT_NAME), "wb") as fd:
        pickle.dump(emb_dict, fd)


def load_emb_dict(dir_name):
    with open(os.path.join(dir_name, EMB_DICT_NAME), "rb") as fd:
        return pickle.load(fd)


def encode_words(words, emb_dict):
    """
    Convert list of words into list of embeddings indices, adding our tokens
    :param words: list of strings
    :param emb_dict: embeddings dictionary
    :return: list of IDs
    """
    res = [emb_dict[BEGIN_TOKEN]]
    unk_idx = emb_dict[UNKNOWN_TOKEN]
    for w in words:
        idx = emb_dict.get(w.lower(), unk_idx)
        res.append(idx)
    res.append(emb_dict[END_TOKEN])
    return res


def encode_phrase_pairs(phrase_pairs, emb_dict, filter_unknows=True):
    """
    Convert list of phrase pairs to training data
    :param phrase_pairs: list of (phrase, phrase)
    :param emb_dict: embeddings dictionary (word -> id)
    :return: list of tuples ([input_id_seq], [output_id_seq])
    """
    unk_token = emb_dict[UNKNOWN_TOKEN]
    result = []
    for p1, p2 in phrase_pairs:
        p = encode_words(p1, emb_dict), encode_words(p2, emb_dict)
        if unk_token in p[0] or unk_token in p[1]:
            continue
        result.append(p)
    return result


def group_train_data(training_data):
    """
    Group training pairs by first phrase
    :param training_data: list of (seq1, seq2) pairs
    :return: list of (seq1, [seq*]) pairs
    """
    groups = collections.defaultdict(list)
    for p1, p2 in training_data:
        l = groups[tuple(p1)]
        l.append(p2)
    return list(groups.items())


def iterate_batches(data, batch_size):
    assert isinstance(data, list)
    assert isinstance(batch_size, int)

    ofs = 0
    while True:
        batch = data[ofs*batch_size:(ofs+1)*batch_size]
        if len(batch) <= 1:
            break
        yield batch
        ofs += 1


def load_data(genre_filter, max_tokens=MAX_TOKENS, min_token_freq=MIN_TOKEN_FEQ):
    dialogues = load_dialogues(genre_filter=genre_filter)
    if not dialogues:
        log.error("No dialogues found, exit!")
        sys.exit()
    log.info("Loaded %d dialogues with %d phrases, generating training pairs",
             len(dialogues), sum(map(len, dialogues)))
    phrase_pairs = dialogues_to_pairs(dialogues, max_tokens=max_tokens)
    log.info("Counting freq of words...")
    word_counts = collections.Counter()
    for dial in dialogues:
        for p in dial:
            word_counts.update(p)
    freq_set = set(map(lambda p: p[0], filter(lambda p: p[1] >= min_token_freq, word_counts.items())))
    log.info("Data has %d uniq words, %d of them occur more than %d",
             len(word_counts), len(freq_set), min_token_freq)
    phrase_dict = phrase_pairs_dict(phrase_pairs, freq_set)
    return phrase_pairs, phrase_dict


def phrase_pairs_dict(phrase_pairs, freq_set):
    """
    Return the dict of words in the dialogues mapped to their IDs
    :param phrase_pairs: list of (phrase, phrase) pairs
    :return: dict
    """
    res = {UNKNOWN_TOKEN: 0, BEGIN_TOKEN: 1, END_TOKEN: 2}
    next_id = 3
    for p1, p2 in phrase_pairs:
        for w in map(str.lower, itertools.chain(p1, p2)):
            if w not in res and w in freq_set:
                res[w] = next_id
                next_id += 1
    return res


def dialogues_to_pairs(dialogues, max_tokens=None):
    """
    Convert dialogues to training pairs of phrases
    :param dialogues:
    :param max_tokens: limit of tokens in both question and reply
    :return: list of (phrase, phrase) pairs
    """
    result = []
    for dial in dialogues:
        prev_phrase = None
        for phrase in dial:
            if prev_phrase is not None:
                if max_tokens is None or (len(prev_phrase) <= max_tokens and len(phrase) <= max_tokens):
                    result.append((prev_phrase, phrase))
            prev_phrase = phrase
    return result


def decode_words(indices, rev_emb_dict):
    return [rev_emb_dict.get(idx, UNKNOWN_TOKEN) for idx in indices]


def trim_tokens_seq(tokens, end_token):
    res = []
    for t in tokens:
        res.append(t)
        if t == end_token:
            break
    return res


def split_train_test(data, train_ratio=0.95):
    count = int(len(data) * train_ratio)
    return data[:count], data[count:]

In [15]:
# MODEL
import numpy as np

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn_utils
import torch.nn.functional as F

HIDDEN_STATE_SIZE = 512
EMBEDDING_DIM = 50


class PhraseModel(nn.Module):
    def __init__(self, emb_size, dict_size, hid_size):
        super(PhraseModel, self).__init__()

        self.emb = nn.Embedding(
            num_embeddings=dict_size, embedding_dim=emb_size)
        self.encoder = nn.LSTM(
            input_size=emb_size, hidden_size=hid_size,
            num_layers=1, batch_first=True)
        self.decoder = nn.LSTM(
            input_size=emb_size, hidden_size=hid_size,
            num_layers=1, batch_first=True)
        self.output = nn.Linear(hid_size, dict_size)

    def encode(self, x):
        _, hid = self.encoder(x)
        return hid

    def get_encoded_item(self, encoded, index):
        # For RNN
        # return encoded[:, index:index+1]
        # For LSTM
        return encoded[0][:, index:index+1].contiguous(), \
               encoded[1][:, index:index+1].contiguous()

    def decode_teacher(self, hid, input_seq):
        # Method assumes batch of size=1
        out, _ = self.decoder(input_seq, hid)
        out = self.output(out.data)
        return out

    def decode_one(self, hid, input_x):
        out, new_hid = self.decoder(input_x.unsqueeze(0), hid)
        out = self.output(out)
        return out.squeeze(dim=0), new_hid

    def decode_chain_argmax(self, hid, begin_emb, seq_len,
                            stop_at_token=None):
        """
        Decode sequence by feeding predicted token to the net again. Act greedily
        """
        res_logits = []
        res_tokens = []
        cur_emb = begin_emb

        for _ in range(seq_len):
            out_logits, hid = self.decode_one(hid, cur_emb)
            out_token_v = torch.max(out_logits, dim=1)[1]
            out_token = out_token_v.data.cpu().numpy()[0]

            cur_emb = self.emb(out_token_v)

            res_logits.append(out_logits)
            res_tokens.append(out_token)
            if stop_at_token is not None:
                if out_token == stop_at_token:
                    break
        return torch.cat(res_logits), res_tokens

    def decode_chain_sampling(self, hid, begin_emb, seq_len,
                              stop_at_token=None):
        """
        Decode sequence by feeding predicted token to the net again.
        Act according to probabilities
        """
        res_logits = []
        res_actions = []
        cur_emb = begin_emb

        for _ in range(seq_len):
            out_logits, hid = self.decode_one(hid, cur_emb)
            out_probs_v = F.softmax(out_logits, dim=1)
            out_probs = out_probs_v.data.cpu().numpy()[0]
            action = int(np.random.choice(
                out_probs.shape[0], p=out_probs))
            action_v = torch.LongTensor([action])
            action_v = action_v.to(begin_emb.device)
            cur_emb = self.emb(action_v)

            res_logits.append(out_logits)
            res_actions.append(action)
            if stop_at_token is not None:
                if action == stop_at_token:
                    break
        return torch.cat(res_logits), res_actions


def pack_batch_no_out(batch, embeddings, device="cpu"):
    assert isinstance(batch, list)
    # Sort descending (CuDNN requirements)
    batch.sort(key=lambda s: len(s[0]), reverse=True)
    input_idx, output_idx = zip(*batch)
    # create padded matrix of inputs
    lens = list(map(len, input_idx))
    input_mat = np.zeros((len(batch), lens[0]), dtype=np.int64)
    for idx, x in enumerate(input_idx):
        input_mat[idx, :len(x)] = x
    input_v = torch.tensor(input_mat).to(device)
    input_seq = rnn_utils.pack_padded_sequence(
        input_v, lens, batch_first=True)
    # lookup embeddings
    r = embeddings(input_seq.data)
    emb_input_seq = rnn_utils.PackedSequence(
        r, input_seq.batch_sizes)
    return emb_input_seq, input_idx, output_idx


def pack_input(input_data, embeddings, device="cpu"):
    input_v = torch.LongTensor([input_data]).to(device)
    r = embeddings(input_v)
    return rnn_utils.pack_padded_sequence(
        r, [len(input_data)], batch_first=True)


def pack_batch(batch, embeddings, device="cpu"):
    emb_input_seq, input_idx, output_idx = pack_batch_no_out(
        batch, embeddings, device)

    # prepare output sequences, with end token stripped
    output_seq_list = []
    for out in output_idx:
        s = pack_input(out[:-1], embeddings, device)
        output_seq_list.append(s)
    return emb_input_seq, output_seq_list, input_idx, output_idx


def seq_bleu(model_out, ref_seq):
    model_seq = torch.max(model_out.data, dim=1)[1]
    model_seq = model_seq.cpu().numpy()
    return calc_bleu(model_seq, ref_seq)

In [19]:
# TRAINING I: Supervised Curriculum Learning (Teacher Forcing + Argmax/Self-Forcing)
import os
import random
import argparse
import logging
import numpy as np

import torch
import torch.optim as optim
import torch.nn.functional as F

SAVES_DIR = "data"

BATCH_SIZE = 32
LEARNING_RATE = 1e-3
MAX_EPOCHES = 100

log = logging.getLogger("train")

TEACHER_PROB = 0.5

def run_test(test_data, net, end_token, device="cpu"):
    bleu_sum = 0.0
    bleu_count = 0
    for p1, p2 in test_data:
        input_seq = pack_input(p1, net.emb, device)
        enc = net.encode(input_seq)
        _, tokens = net.decode_chain_argmax(
            enc, input_seq.data[0:1], seq_len=MAX_TOKENS,
            stop_at_token=end_token)
        bleu_sum += calc_bleu(tokens, p2[1:])
        bleu_count += 1
    return bleu_sum / bleu_count


# run training
fmt = "%(asctime)-15s %(levelname)s %(message)s"
logging.basicConfig(format=fmt, level=logging.INFO)
data_genre = "comedy"
cuda_available = True
name_of_the_run = 'comedy_conversations'
device = torch.device("cuda" if cuda_available else "cpu")
saves_path = os.path.join(SAVES_DIR, name_of_the_run)
os.makedirs(saves_path, exist_ok=True)

phrase_pairs, emb_dict = load_data(genre_filter=data_genre)
log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))

save_emb_dict(saves_path, emb_dict)
end_token = emb_dict[END_TOKEN]

train_data = encode_phrase_pairs(phrase_pairs, emb_dict)
rand = np.random.RandomState(SHUFFLE_SEED)
rand.shuffle(train_data)
log.info("Training data converted, got %d samples", len(train_data))
train_data, test_data = split_train_test(train_data)
log.info("Train set has %d phrases, test %d", len(train_data), len(test_data))

net = PhraseModel(emb_size=EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=HIDDEN_STATE_SIZE).to(device)
log.info("Model: %s", net)

optimiser = optim.Adam(net.parameters(), lr=LEARNING_RATE)
best_bleu = None

for epoch in range(MAX_EPOCHES):
    losses = []
    bleu_sum = 0.0
    bleu_count = 0
    for batch in iterate_batches(train_data, BATCH_SIZE):
        optimiser.zero_grad()
        input_seq, out_seq_list, _, out_idx = pack_batch(batch, net.emb, device)
        enc = net.encode(input_seq)

        net_results = []
        net_targets = []
        for idx, out_seq in enumerate(out_seq_list):
            ref_indices = out_idx[idx][1:]
            enc_item = net.get_encoded_item(enc, idx)
            if random.random() < TEACHER_PROB:
                r = net.decode_teacher(enc_item, out_seq)
                bleu_sum += seq_bleu(r, ref_indices)
            else:
                r, seq = net.decode_chain_argmax(enc_item, out_seq.data[0:1], len(ref_indices))
                bleu_sum += calc_bleu(seq, ref_indices)
            net_results.append(r)
            net_targets.extend(ref_indices)
            bleu_count += 1
        results_v = torch.cat(net_results)
        targets_v = torch.LongTensor(net_targets).to(device)
        loss_v = F.cross_entropy(results_v, targets_v)
        loss_v.backward()
        optimiser.step()

        losses.append(loss_v.item())
    bleu = bleu_sum / bleu_count
    bleu_test = run_test(test_data, net, end_token, device)
    log.info("Epoch %d: mean loss %.3f, mean BLEU %.3f, " "test BLEU %.3f", epoch, np.mean(losses), bleu, bleu_test)
    if best_bleu is None or best_bleu < bleu_test:
        if best_bleu is not None:
            out_name = os.path.join(
                saves_path, "pre_bleu_%.3f_%02d.dat" % (
                    bleu_test, epoch))
            torch.save(net.state_dict(), out_name)
            log.info("Best BLEU updated %.3f", bleu_test)
        best_bleu = bleu_test

    if epoch % 10 == 0:
        out_name = os.path.join(
            saves_path, "epoch_%03d_%.3f_%.3f.dat" % (
                epoch, bleu, bleu_test))
        torch.save(net.state_dict(), out_name)


2021-06-23 09:15:36,536 INFO Loaded 159 movies with genre comedy
2021-06-23 09:15:36,539 INFO Read and tokenise phrases...
2021-06-23 09:15:44,560 INFO Loaded 93039 phrases
2021-06-23 09:15:44,628 INFO Loaded 4445 dialogues with 17107 phrases, generating training pairs
2021-06-23 09:15:44,635 INFO Counting freq of words...
2021-06-23 09:15:44,683 INFO Data has 11807 uniq words, 1442 of them occur more than 10
2021-06-23 09:15:44,723 INFO Obtained 8885 phrase pairs with 1441 uniq words
2021-06-23 09:15:44,783 INFO Training data converted, got 3116 samples
2021-06-23 09:15:44,783 INFO Train set has 2960 phrases, test 156
2021-06-23 09:15:44,831 INFO Model: PhraseModel(
  (emb): Embedding(1441, 50)
  (encoder): LSTM(50, 512, batch_first=True)
  (decoder): LSTM(50, 512, batch_first=True)
  (output): Linear(in_features=512, out_features=1441, bias=True)
)
2021-06-23 09:15:57,827 INFO Epoch 0: mean loss 5.041, mean BLEU 0.165, test BLEU 0.116
2021-06-23 09:16:11,227 INFO Epoch 1: mean loss 4

In [20]:
# Training II: Self-Critical-Sequence-Training
import os
import random
import argparse
import logging
import numpy as np

import torch
import torch.optim as optim
import torch.nn.functional as F

SAVES_DIR = "data"

BATCH_SIZE = 16
LEARNING_RATE = 5e-4
MAX_EPOCHES = 10000

log = logging.getLogger("train")


def run_test(test_data, net, end_token, device="cpu"):
    bleu_sum = 0.0
    bleu_count = 0
    for p1, p2 in test_data:
        input_seq = pack_input(p1, net.emb, device)
        enc = net.encode(input_seq)
        _, tokens = net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=MAX_TOKENS,stop_at_token=end_token)
         ref_indices = [indices[1:] for indices in p2]
        bleu_sum += calc_bleu_many(tokens, ref_indices)
        bleu_count += 1
    return bleu_sum / bleu_count


fmt = "%(asctime)-15s %(levelname)s %(message)s"
logging.basicConfig(format=fmt, level=logging.INFO)
data_genre = "comedy"
cuda_available = True 
name_of_the_run = 'comedy_conversations'
model_path = "data/comedy_conversations/epoch_090_0.908_0.113.dat"
number_of_samples_in_prob_mode = 4
disable_skip = False # Disable skipping of samples with high argmax BLEU action=store_true
device = torch.device("cuda" if cuda_available else "cpu")
saves_path = os.path.join(SAVES_DIR, name_of_the_run)
os.makedirs(saves_path, exist_ok=True)

phrase_pairs, emb_dict = load_data(genre_filter=data_genre)
log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))
save_emb_dict(saves_path, emb_dict)
end_token = emb_dict[END_TOKEN]
train_data = encode_phrase_pairs(phrase_pairs, emb_dict)
rand = np.random.RandomState(SHUFFLE_SEED)
rand.shuffle(train_data)
train_data, test_data = split_train_test(train_data)
log.info("Training data converted, got %d samples", len(train_data))
train_data = group_train_data(train_data)
test_data = group_train_data(test_data)
log.info("Train set has %d phrases, test %d", len(train_data), len(test_data))

rev_emb_dict = {idx: word for word, idx in emb_dict.items()}

net = PhraseModel(emb_size=EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=HIDDEN_STATE_SIZE).to(device)
log.info("Model: %s", net)

net.load_state_dict(torch.load(model_path))
log.info("Model loaded from %s, continue ""training in RL mode...", model_path)

# BEGIN token
beg_token = torch.LongTensor([emb_dict[BEGIN_TOKEN]])
beg_token = beg_token.to(device)

optimiser = optim.Adam(net.parameters(), lr=LEARNING_RATE, eps=1e-3)
batch_idx = 0
best_bleu = None
for epoch in range(MAX_EPOCHES):
    random.shuffle(train_data)
    dial_shown = False
    total_samples = 0
    skipped_samples = 0
    bleus_argmax = []
    bleus_sample = []

    for batch in iterate_batches(train_data, BATCH_SIZE):
        batch_idx += 1
        optimiser.zero_grad()
        input_seq, input_batch, output_batch = pack_batch_no_out(batch, net.emb, device)
        enc = net.encode(input_seq)

        net_policies = []
        net_actions = []
        net_advantages = []
        beg_embedding = net.emb(beg_token)

        for idx, inp_idx in enumerate(input_batch):
            total_samples += 1
            ref_indices = [indices[1:]for indices in output_batch[idx]]
            item_enc = net.get_encoded_item(enc, idx)
            r_argmax, actions = net.decode_chain_argmax(item_enc, beg_embedding, MAX_TOKENS,stop_at_token=end_token)
            argmax_bleu = calc_bleu_many(actions, ref_indices)
            bleus_argmax.append(argmax_bleu)

            if not disable_skip:
                if argmax_bleu > 0.99:
                    skipped_samples += 1
                    continue

            if not dial_shown:
                w = decode_words(inp_idx, rev_emb_dict)
                log.info("Input: %s", untokenize(w))
                ref_words = [untokenize(decode_words(ref, rev_emb_dict)) for ref in ref_indices]
                ref = " ~~|~~ ".join(ref_words)
                log.info("Refer: %s", ref)
                w = decode_words(actions, rev_emb_dict)
                log.info("Argmax: %s, bleu=%.4f", untokenize(w), argmax_bleu)

            for _ in range(number_of_samples_in_prob_mode):
                r_sample, actions = net.decode_chain_sampling(item_enc, beg_embedding,MAX_TOKENS,stop_at_token=end_token)
                sample_bleu = calc_bleu_many(actions, ref_indices)

                if not dial_shown:
                    w = decode_words(actions, rev_emb_dict)
                    log.info("Sample: %s, bleu=%.4f", untokenize(w), sample_bleu)

                net_policies.append(r_sample)
                net_actions.extend(actions)
                adv = sample_bleu - argmax_bleu
                net_advantages.extend([adv]*len(actions))
                bleus_sample.append(sample_bleu)
            dial_shown = True

        if not net_policies:
            continue

        policies_v = torch.cat(net_policies)
        actions_t = torch.LongTensor(net_actions).to(device)
        adv_v = torch.FloatTensor(net_advantages).to(device)
        log_prob_v = F.log_softmax(policies_v, dim=1)
        lp_a = log_prob_v[range(len(net_actions)), actions_t]
        log_prob_actions_v = adv_v * lp_a
        loss_policy_v = -log_prob_actions_v.mean()

        loss_v = loss_policy_v
        loss_v.backward()
        optimiser.step()

    bleu_test = run_test(test_data, net, end_token, device)
    bleu = np.mean(bleus_argmax)
    log.info("Epoch %d, test BLEU: %.3f", epoch, bleu_test)
    if best_bleu is None or best_bleu < bleu_test:
        best_bleu = bleu_test
        log.info("Best bleu updated: %.4f", bleu_test)
        torch.save(net.state_dict(), os.path.join(saves_path, "bleu_%.3f_%02d.dat" % (bleu_test, epoch)))
    if epoch % 10 == 0:
        torch.save(net.state_dict(), os.path.join(saves_path, "epoch_%03d_%.3f_%.3f.dat" % (epoch, bleu, bleu_test)))

2021-06-23 09:43:28,761 INFO Loaded 159 movies with genre comedy
2021-06-23 09:43:28,761 INFO Read and tokenise phrases...
2021-06-23 09:43:36,947 INFO Loaded 93039 phrases
2021-06-23 09:43:37,026 INFO Loaded 4445 dialogues with 17107 phrases, generating training pairs
2021-06-23 09:43:37,032 INFO Counting freq of words...
2021-06-23 09:43:37,083 INFO Data has 11807 uniq words, 1442 of them occur more than 10
2021-06-23 09:43:37,124 INFO Obtained 8885 phrase pairs with 1441 uniq words
2021-06-23 09:43:37,183 INFO Training data converted, got 2960 samples
2021-06-23 09:43:37,187 INFO Train set has 2597 phrases, test 153
2021-06-23 09:43:37,232 INFO Model: PhraseModel(
  (emb): Embedding(1441, 50)
  (encoder): LSTM(50, 512, batch_first=True)
  (decoder): LSTM(50, 512, batch_first=True)
  (output): Linear(in_features=512, out_features=1441, bias=True)
)
2021-06-23 09:43:37,236 INFO Model loaded from data/comedy_conversations/epoch_090_0.908_0.113.dat, continue training in RL mode...
2021-

KeyboardInterrupt: 

In [21]:
# EVALUATION
import argparse
import logging

import torch

log = logging.getLogger("data_test")

logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
data_genre = "comedy"
cuda_available = True 
name_of_the_run = 'comedy_conversations'
model_path = "data/comedy_conversations/epoch_160_0.991_0.108.dat"
device = torch.device("cuda" if cuda_available else "cpu")


phrase_pairs, emb_dict = load_data(data_genre)
log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))
train_data = encode_phrase_pairs(phrase_pairs, emb_dict)
train_data = group_train_data(train_data)
rev_emb_dict = {idx: word for word, idx in emb_dict.items()}

net = PhraseModel(emb_size=EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=HIDDEN_STATE_SIZE)
net.load_state_dict(torch.load(model_path))

end_token = emb_dict[END_TOKEN]

seq_count = 0
sum_bleu = 0.0

for seq_1, targets in train_data:
    input_seq = pack_input(seq_1, net.emb)
    enc = net.encode(input_seq)
    _, tokens = net.decode_chain_argmax(enc, input_seq.data[0:1],
                                        seq_len=MAX_TOKENS, stop_at_token=end_token)
    references = [seq[1:] for seq in targets]
    bleu = calc_bleu_many(tokens, references)
    sum_bleu += bleu
    seq_count += 1

log.info("Processed %d phrases, mean BLEU = %.4f", seq_count, sum_bleu / seq_count)

2021-06-23 10:24:58,048 INFO Loaded 159 movies with genre comedy
2021-06-23 10:24:58,049 INFO Read and tokenise phrases...
2021-06-23 10:25:05,956 INFO Loaded 93039 phrases
2021-06-23 10:25:06,023 INFO Loaded 4445 dialogues with 17107 phrases, generating training pairs
2021-06-23 10:25:06,083 INFO Counting freq of words...
2021-06-23 10:25:06,131 INFO Data has 11807 uniq words, 1442 of them occur more than 10
2021-06-23 10:25:06,170 INFO Obtained 8885 phrase pairs with 1441 uniq words
2021-06-23 10:25:16,668 INFO Processed 2730 phrases, mean BLEU = 0.9478


In [24]:
# USER TEST
import os
import argparse
import logging

import torch

log = logging.getLogger("use")

def words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=False):
    tokens = encode_words(words, emb_dict)
    input_seq = pack_input(tokens, net.emb)
    enc = net.encode(input_seq)
    end_token = emb_dict[END_TOKEN]
    if use_sampling:
        _, out_tokens = net.decode_chain_sampling(enc, input_seq.data[0:1], seq_len=MAX_TOKENS,
                                                  stop_at_token=end_token)
    else:
        _, out_tokens = net.decode_chain_argmax(enc, input_seq.data[0:1], seq_len=MAX_TOKENS,
                                                stop_at_token=end_token)
    if out_tokens[-1] == end_token:
        out_tokens = out_tokens[:-1]
    out_words = decode_words(out_tokens, rev_emb_dict)
    return out_words


def process_string(s, emb_dict, rev_emb_dict, net, use_sampling=False):
    out_words = words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=use_sampling)
    print(" ".join(out_words))


# run live test
logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
data_genre = "comedy"
dictionary_path = "data/comedy_conversations/"
cuda_available = True 
name_of_the_run = 'comedy_conversations'
model_path = "data/comedy_conversations/epoch_160_0.991_0.108.dat"
device = torch.device("cuda" if cuda_available else "cpu")
sampling_enabled = False # Enable sampling generation instead of argmax
current_string_to_process = "" # String to process, otherwise will loop
self_loop_length = 1 # Enable self-loop mode with given amount of phrases."


emb_dict = load_emb_dict(os.path.dirname(dictionary_path))
net = PhraseModel(emb_size=EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=HIDDEN_STATE_SIZE)
net.load_state_dict(torch.load(model_path))

rev_emb_dict = {idx: word for word, idx in emb_dict.items()}

while True:
    if current_string_to_process:
        input_string = current_string_to_process
    else:
        input_string = input(">>> ")
    if not input_string:
        break

    words = tokenize(input_string)
    for _ in range(self_loop_length):
        words = words_to_words(words, emb_dict, rev_emb_dict, net, use_sampling=sampling_enabled)
        print(untokenize(words))

    if current_string_to_process:
        break
pass

>>>  hey, how are you doing ?


not exactly, it.


>>>  oh, that sounds funny actually :)


i don't want you to say this on.


>>>  ha ?


you never take him.


>>>  true.


it's none of your business.


>>>  maybe not. i guess.


oh, what did you say?


>>>  maybe not, i said!


i? oh to no!


>>>  yes.


here.


>>>  


# TODO's

1. Send your finished presentations (+ possibly annotated paper) by **Monday 12.00 AM/midnight** via email to henrik.voigt@uni-jena.de

***
