diff --git a/README.md b/README.md index 349220f..c9c7fde 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,321 @@ ## Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned -This is a repo for the ACL 2019 paper ["Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned"](https://arxiv.org/abs/1905.09418). -Code of the model will appear by the time of publication. + + + +This is the official repo for the ACL 2019 paper ["Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned"](https://arxiv.org/abs/1905.09418). + +Maybe it's worth to first looking into the [blog post](https://github.com/lena-voita/lena-voita.github.io/posts/acl19_heads.html). + +#### Bibtex +``` +@inproceedings{voita-etal-2019-analyzing, + title = "Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned", + author = "Voita, Elena and + Talbot, David and + Moiseev, Fedor and + Sennrich, Rico and + Titov, Ivan", + booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers)", + month = jul, + year = "2019", + address = "Florence, Italy", + publisher = "Association for Computational Linguistics", +} +``` + +## Introduction + +Multi-head self-attention is a key component of the Transformer, a state-of-the-art architecture for neural machine translation. In this work we evaluate the contribution made by individual attention heads in the encoder to the overall performance of the model and analyze the roles played by them. We find that the most important and confident heads play consistent and often linguistically-interpretable roles. When pruning heads using a method based on stochastic gates and a differentiable relaxation of the L0 penalty, we observe that specialized heads are last to be pruned. Our novel pruning method removes the vast majority of heads without seriously affecting performance. For example, on the English-Russian WMT dataset, pruning 38 out of 48 encoder heads results in a drop of only 0.15 BLEU. + +In this repo, we provide code and describe steps needed to reproduce our experiments with the L0 head pruning. + +## Pruning Attention Heads + +In the standard Transformer, results of different attention heads in a layer are concatenated: + +```MultiHead(Q, K, V ) = Concat(head_i)W^O.``` + +We would like to disable less important heads completely, i.e. ideally apply `L0` regularization to the number of heads. We modify the original Transformer architecture by multiplying the representation computed by each `head_i` by a scalar gate `g_i`: + +```MultiHead(Q, K, V ) = Concat(g_i * head_i)W^O.``` + +Unlike usual gates, `g_i` are parameters specific to heads and are independent of the input (i.e. the sentence). Each gate `g_i` is a random variable drawn independently from a head-specific [Hard Concrete distribution](https://openreview.net/pdf?id=H1Y8hhg0b). The distributions have non-zero probability mass at 0 and 1; look at the illustration. + +![concrete_gif](./resources/concrete_crop.gif) + +We use the sum of the probabilities of heads being non-zero (`L_C`) as a stochastic relaxation of the non-differentiable `L0` norm. The resulting training objective is: + +```L = L_xent + λ * L_C.``` + +By varying the coefficient `λ` in the optimized objective, we obtain models with different numbers of retained heads. Below is shown how the probabilities of encoder heads being completely closed (P(g_i)=0) change in training for different values of `λ` (pruning starts from a converged model). White color denotes P(g_i=0) = 1, which means that a head is completely removed from the model. + +![enc_head_gif](./resources/enc_head_gif_delay7-min.gif) + +(Gif is for model trained on EN-RU WMT. For other datasets, values of `λ` can be different.) + +We observe that the model converges to solutions where gates are either almost completely closed or completely open. This means that at test time we can treat the model as a standard Transformer and use only a subset of heads. + +--- +# Experiments + +## Requirements + +__Operating System:__ This implementation works on the most popular Linux distributions (tested on Ubuntu 14, 16). It will also likely to work on Mac OS. For other operating systems we recommend using Docker. + +__Hardware:__ The model can be trained on one or several GPUs. Training on CPU is also supported. + +__OpenMPI(optional):__ To train on several GPUs, you have to install OpenMPI. The code was tested on [OpenMPI 3.1.2(download)](https://download.open-mpi.org/release/open-mpi/v3.1/openmpi-3.1.2.tar.gz). See build instructions [here]( https://www.open-mpi.org/faq/?category=building#easy-build). + +__Python:__ The code works with Python 3.5 and 3.6; we recommend using [anaconda](https://www.anaconda.com/). Install the rest of python packages with `pip install -r requirements.txt`. If you haven't build OpenMPI, remove `horovod` from the list of requirements. + +## Data preprocessing +The model training config requires the data to be preprocessed, i.e. tokenized and bpeized. +### Tokenization +Here is an example of how to tokenize (and lowercase) you data: +``` +cat text_lines.en | moses-tokenizer en | python3 -c "import sys; print(sys.stdin.read().lower())" > text_lines.en.tok +``` + +For the OpenSubtitles18 dataset, you do not need this step since the data is already tokenized. + +### BPE-ization +Learn BPE rules: +``` +subword-nmt learn-bpe -s 32000 < text_lines.en.tok > bpe_rules.en +``` +Apply BPE rules to your data: +``` +/path_to_this_repo/lib/tools/apply_bpe.py --bpe_rules ./bpe_rules.en < text_lines.en.tok > text_lines.en.bpeized +``` +--- +## Model training + +In the [scripts](./scripts) folder you can find files `train_baseline.sh`, `train_concrete_heads.sh` and `train_fixed_alive_heads.sh` with configs for training baseline, model with heads pruning using relaxation of the L0 penalty, and model with a fixed configuration of open and closed heads. + +To launch an experiment, do the following (example is for the heads pruning experiment): +``` +mkdir exp_dir_name && cd exp_dir_name +cp the-story-of-heads_dir/scripts/train_concrete_heads.sh . +bash train_concrete_heads.sh +``` + +After that, checkpoints will be in the `exp_dir_name/build/checkpoint` directory, summary for tensorboard - in `exp_dir_name/build/summary`, translations of dev set for checkpoints (if specified; see below) in `exp_dir_name/build/translations`. + +--- +## Notebooks: how to use a model +In the [noteboooks](./notebooks) folder you can find notebooks showing how to deal with your trained model. From a notebook name it's content has to be clear, but I'll write this just in case. + +[1_Load_model_and_translate](./notebooks/1_Load_model_and_translate.ipynb) - how to load model and translate sentences; + +[2_Look_at_attention_maps](./notebooks/2_Look_at_attention_maps.ipynb) - how to draw attention maps for encoder heads; + +[3_Look_which_heads_are_dead](./notebooks/3_Look_which_heads_are_dead.ipynb) - if you are pruning heads, you might want to know which ended up dead; this notebook shows you how to do so. + + +--- +## Training config tour + +Each training script has a thorough description of the parameters and explanation of the things you need to change for your experiment. Here we'll provide a tour of the config files and explain the parameters once again. + + +### Data +First, you need to specify your directory with the [the-story-of-heads](./) repo, data directory and train/dev file names. +``` +REPO_DIR="../" # insert the dir to the the-story-of-heads repo +DATA_DIR="../" # insert your datadir + +NMT="${REPO_DIR}/scripts/nmt.py" + +# path to preprocessed data (tokenized, bpe-ized) +train_src="${DATA_DIR}/train.src" +train_dst="${DATA_DIR}/train.dst" +dev_src="${DATA_DIR}/dev.src" +dev_dst="${DATA_DIR}/dev.dst" +``` +After that, in the config you'll see the code for creating vocabularies from your data and shuffling the data. + +--- +### Model +``` +params=( +... +--model lib.task.seq2seq.models.transformer_head_gates.Model +...) +``` +This is the Transformer model with extra options for attention head gates: stochastic, fixed or no extra parameters for the baseline. Model hyperparameters are split into groups: +* main model hp, +* minor model hp (probably you do not want to change them) +* regularization and label smoothing +* inference params (beam search with a beam of 4) +* head gates parameters (for the baseline, nothing is here) + +For the baseline, the parameters are as follows: +``` +hp = { + "num_layers": 6, + "num_heads": 8, + "ff_size": 2048, + "ffn_type": "conv_relu", + "hid_size": 512, + "emb_size": 512, + "res_steps": "nlda", + + "rescale_emb": True, + "inp_emb_bias": True, + "normalize_out": True, + "share_emb": False, + "replace": 0, + + "relu_dropout": 0.1, + "res_dropout": 0.1, + "attn_dropout": 0.1, + "label_smoothing": 0.1, + + "translator": "ingraph", + "beam_size": 4, + "beam_spread": 3, + "len_alpha": 0.6, + "attn_beta": 0, + } +``` +This set of parameters corresponds to Transformer-base [(Vaswani et al., 2017)](https://papers.nips.cc/paper/7181-attention-is-all-you-need). + +To train the model with heads pruning, you need to specify the types of attention heads you want to prune. For encoder self-attention heads only, +``` + "concrete_heads": {"enc-self"}, +``` +and for all attention types, it's +``` + "concrete_heads": {"enc-self", "dec-self", "dec-enc"}, +``` + +For fixed head configuration, specify gate values for each head: +``` + "alive_heads": {"enc-self": [[1,0,1,0,1,0,1,0], + [1,1,1,1,1,1,1,1], + [0,0,0,0,0,0,0,0], + [1,1,1,0,0,1,0,0], + [0,0,0,0,1,1,1,1], + [0,0,1,1,0,0,1,1]], + }, +``` +In this case, only encoder self-attention heads will be masked. For all attention types, specify all gates: +``` + "alive_heads": {"enc-self": [[1,0,1,0,1,0,1,0], + [1,1,1,1,1,1,1,1], + ... + [0,0,1,1,0,0,1,1]], + "dec-self": [[...], + ..., + [...]], + "dec-enc": [[...], + ..., + [...]], + }, +``` +--- +### Problem (loss function) +You need to set the training objective for you model. For the baseline and fixed head configuration, it's the standard cross-entropy loss with no extra options: +``` +params=( + ... + --problem lib.task.seq2seq.problems.default.DefaultProblem + --problem-opts '{}' + ...) +``` +For pruning heads, loss function is `L = L_xent + λ * L_C.`. You need to set another problem and specify the value of `λ`: +``` +params=( + ... + --problem lib.task.seq2seq.problems.concrete.ConcreteProblem + --problem-opts '{'"'"'concrete_coef'"'"': 0.1,}' + ...) +``` +--- +### Starting checkpoint +If you start model training from already trained model (for example, we start pruning heads from the trained baseline model), specify the initial checkpoint: +``` +params=( + ... + --pre-init-model-checkpoint 'dir_to_your_trained_baseline_checkpoint.npz' + ...) +``` +You do not need this if you start from scratch. + +--- +### Variables to optimize +If you want to freeze some sets of parameters in the model (for example, when pruning encoder heads we freeze the decoder parameters to ensure that heads functions do not move to the decoder), you have to specify which parameters you **want** to optimize. To optimize only encoder, add `variables` to `--optimizer-opts`: +``` +params=( + ... + --optimizer-opts '{'"'"'beta1'"'"': 0.9, '"'"'beta2'"'"': 0.998, + '"'"'variables'"'"': ['"'"'mod/emb_inp*'"'"', + '"'"'mod/enc*'"'"',],}' + ...) +``` +(Here `beta1` and `beta2` are parameters of the adam optimizer). + +--- +### Batch size +It has been shown that Transformer’s performance depends heavily on a batch size (see for example [Popel and Bojar, 2018](https://content.sciendo.com/view/journals/pralin/110/1/article-p43.xml)), and we chose a large value of batch size to ensure that models show their best performance. In our experiments, each training batch contained a set of translation pairs containing approximately 16000 source tokens. This can be reached by using several of GPUs or by accumulating the gradients for several batches and then making an update. Our implementation enables both these options. + +Batch size per one gpu is set like this: +``` +params=( + ... + --batch-len 4000 + ...) +``` +The effective batch size will be then `batch-len * num_gpus`. For example, with `--batch-len 4000` and `4 gpus` you would get the desirable batch size of 16000. + +If you do not have several gpus (often, we don't have either :) ), you still have to have models of a proper quality. For this, accumulate the gradients for several batches and then make an update. Add `average_grads: True` and `sync_every_steps: N` to the optimizer options like this: +``` +params=( + ... + --optimizer-opts '{'"'"'beta1'"'"': 0.9, '"'"'beta2'"'"': 0.998, + '"'"'sync_every_steps'"'"': 4, + '"'"'average_grads'"'"': True, }' + ...) +``` +The effective batch size will be then `batch-len * sync_every_steps`. For example, with `--batch-len 4000` and `sync_every_steps: 4` you would get the desirable batch size of 16000. + + +--- +### Other options +If you want to see dev BLEU score on your tensorboard: +``` +params=( + ... + --translate-dev + --translate-dev-every 2048 + ...) +``` +Specify how often you want to save a checkpoint: +``` +params=( + ... + --checkpoint-every-steps 2048 + ...) +``` +Specify how often you want to score the dev set (eval loss values): +``` +params=( + ... + --score-dev-every 256 + ...) +``` +How many last checkpoints to keep: +``` +params=( + ... + --keep-checkpoints-max 10 + ...) +``` + +--- + +# Comments +* `lib.task.seq2seq.models.transformer_head_gates` model enables you to train baseline as well as other versions, but if you want Transformer model without any modifications, you can find it here: `lib.task.seq2seq.models.transformer`. diff --git a/lib/__init__.py b/lib/__init__.py new file mode 100644 index 0000000..a700fb1 --- /dev/null +++ b/lib/__init__.py @@ -0,0 +1,6 @@ +from . import train +from .meta import * +from .util import * +from .ops import * +from .task import * +from .layers import * diff --git a/lib/data.py b/lib/data.py new file mode 100644 index 0000000..8599ceb --- /dev/null +++ b/lib/data.py @@ -0,0 +1,270 @@ +import numpy as np +import tensorflow as tf + +import bintrees +import os +import sys +import random +import threading +import itertools +from .util import nested_pack, nested_flatten +from .ops import mpi + + +class TfUploader: + + def __init__(self, iterator, capacity, dtypes=None, shapes=None, session=None): + self.session = session if session is not None else tf.get_default_session() + self.empty = False + + # Detect dtypes from first iterator element + if dtypes is None or shapes is None: + # We need to wrap iterator access in session because it may call TF operations + with self.session.as_default(): + try: + first = next(iterator) + except StopIteration: + self.empty = True + return + + self.structure = first + self.dtypes = tuple(e.dtype for e in nested_flatten(first)) + self.shapes = tuple(tuple(map(lambda x: None, e.shape)) for e in nested_flatten(first)) + self.iterator = itertools.chain([first], iterator) + else: + self.structure = dtypes + self.dtypes = tuple(nested_flatten(dtypes)) + self.shapes = tuple(nested_flatten(shapes)) + self.iterator = iterator + + self.session_close_lock = threading.Lock() + self.session_closed = False + + with tf.name_scope("uploader"): + self.queue = tf.FIFOQueue(dtypes=self.dtypes, capacity=capacity) + + self.enqueue_inputs = [tf.placeholder(dtype=dt) for dt in self.dtypes] + self.enqueue_op = self.queue.enqueue(self.enqueue_inputs) + self.close_op = self.queue.close() + + def __enter__(self): + if not self.empty: + self.thread = threading.Thread(target=self._thread_main) + self.thread.daemon = True + self.thread.start() + return self + + def __exit__(self, *args): + if not self.empty: + with self.session_close_lock: + if not self.session_closed: + self.session.run(self.queue.close(True)) + self.session_closed = True + self.thread.join(1) + return False + + def get_next(self): + if self.empty: + raise tf.errors.OutOfRangeError(None, None, "Queue is empty") + res = self.queue.dequeue() + if isinstance(res, list): + for t, sh in zip(res, self.shapes): + t.set_shape(sh) + res = nested_pack(res, self.structure) + return res + + def _thread_main(self): + try: + # We need to wrap iterator access in session because it may call TF operations + with self.session.graph.as_default(), self.session.as_default(): + for t in self.iterator: + self.session.run(self.enqueue_op, feed_dict=dict(zip(self.enqueue_inputs, tuple(nested_flatten(t))))) + + with self.session_close_lock: + self.session.run(self.close_op) + self.session_closed = True + except tf.errors.CancelledError: + pass + + +class LastElement(object): + """ + Class wrapping last element in RoundRobinIterator + """ + def __init__(self, element=None): + self.element = element + +class RoundRobinIterator(object): + """ + Class implementing Round-Robin iterator between coordinator and workers + """ + def __init__(self, iterator=None, is_train=True, with_cost=False): + self.iterator = iterator + self.is_train = is_train + self.with_cost = with_cost + self.mpi_rank = os.getenv('OMPI_COMM_WORLD_RANK') or '0' + self.mpi_size = os.getenv('OMPI_COMM_WORLD_SIZE') or '1' + self.finish = False + + def __iter__(self): + self.finish = False + return self + + def __next__(self): + if self.finish: + # we should quit iterator + raise StopIteration + + buf = None + if self.mpi_rank == '0' or self.mpi_rank is None: + # fill buffer with elements to scatter + try: + buf = [] + for _ in range(int(self.mpi_size)): + batch = next(self.iterator) + if self.with_cost: + if len(buf) == 0: # On first element save coordinator cost + coord_cost = batch[-1] + batch = batch + (coord_cost,) + buf.append(batch) + except StopIteration: + # if iterator is out, scatter None values (during training) and + # add None to missing workers + if self.is_train: + buf = [LastElement()] * int(self.mpi_size) + else: + for i in range(len(buf)): + buf[i] = LastElement(buf[i]) + buf += [LastElement()] * (int(self.mpi_size) - len(buf)) + + # scatter objects between workers + value = mpi.scatter_obj(buf) + if isinstance(value, LastElement): + if value.element is None: + raise StopIteration + # remember to quit iterator at the next step + self.finish = True + value = value.element + return value + + +class CostBufferIterator(object): + """ + Class implementing CostBuffer iterator for fast finding of the batch with + desired cost (useful for balancing batches) + + We assume inputs from the iterator passed in the constructor in the form: + + """ + def __init__(self, iterator=None, buf_size=1000): + self.iterator = iterator + self.buf_size = buf_size + self.tree = bintrees.FastRBTree() + self.coord_costs = [] + self.rng = random.Random(42) + + def __iter__(self): + self.tree = bintrees.FastRBTree() + self.coord_costs = [] + self.rng = random.Random(42) + return self + + def __next__(self): + # Warming up + while len(self.coord_costs) < self.buf_size: + try: + batch, cost, coord_cost = next(self.iterator) + except StopIteration: + break + if cost in self.tree: + self.tree[cost].append(batch) + else: + self.tree[cost] = [batch] + self.coord_costs.append(coord_cost) + + # No elements left - finish iteration + if len(self.coord_costs) == 0: + raise StopIteration + + # generate cost to choose and choose relevant batch + index = self.rng.randrange(len(self.coord_costs)) + best_cost = self._find_best_match(self.coord_costs[index]) + batch = self.tree[best_cost][0] + + # remove selected items from structures + del self.coord_costs[index] + del self.tree[best_cost][0] + if len(self.tree[best_cost]) == 0: + del self.tree[best_cost] + + return batch + + def _find_best_match(self, cost): + min_cost = self.tree.min_key() + max_cost = self.tree.max_key() + if cost <= min_cost: + return min_cost + if cost >= max_cost: + return max_cost + floor_cost = self.tree.floor_key(cost) + ceil_cost = self.tree.ceiling_key(cost) + return floor_cost if abs(ceil_cost - cost) < abs(floor_cost - cost) else ceil_cost + + +class ShuffleIterator(object): + """ + Class implementing shuffling iterator via auxiliary buffer + """ + def __init__(self, iterator, buf_size=1000): + self.iterator = iterator + self.buf_size = buf_size + self.buf = [] + self.rng = random.Random(42) + + def __iter__(self): + self.buf = [] + self.rng = random.Random(42) + return self + + def __next__(self): + # Return element from the previously shuffled buffer + if len(self.buf) > 0: + value = self.buf.pop() + return value + + # Keep elements in the buffer + while len(self.buf) < self.buf_size: + try: + value = next(self.iterator) + except StopIteration: + break + self.buf.append(value) + + # No elements left - finish iteration + if len(self.buf) == 0: + raise StopIteration + + # Shuffle and return element from the buffer + self.rng.shuffle(self.buf) + value = self.buf.pop() + return value + + +def pad_seq_list(array, sentinel): + """ + Add padding, compose lengths + """ + # Compute max length. + maxlen = 0 + for seq in array: + maxlen = max(maxlen, len(seq)) + + # Pad. + padded = [] + lens = [] + for seq in array: + padding = maxlen - len(seq) + padded.append(seq + [sentinel] * padding) + lens.append(len(seq)) + + return padded, lens diff --git a/lib/layers/__init__.py b/lib/layers/__init__.py new file mode 100644 index 0000000..877360d --- /dev/null +++ b/lib/layers/__init__.py @@ -0,0 +1,3 @@ +from .basic import * +from .attn import * +from .lrp import * diff --git a/lib/layers/attn.py b/lib/layers/attn.py new file mode 100644 index 0000000..aae72ab --- /dev/null +++ b/lib/layers/attn.py @@ -0,0 +1,374 @@ + +import tensorflow as tf +import math + +import lib +from lib.ops.basic import is_dropout_enabled, dropout +from lib.ops import record_activations as rec +from .basic import Dense, LRP + +from lib.layers.concrete_gate import ConcreteGate + + +class MultiHeadAttn: + """ + Multihead scaled-dot-product attention with input/output transformations + """ + ATTN_BIAS_VALUE = -1e9 + + def __init__( + self, name, inp_size, + key_depth, value_depth, output_depth, + num_heads, attn_dropout, attn_value_dropout, + kv_inp_size=None, _format='combined' + ): + self.name = name + self.key_depth = key_depth + self.value_depth = value_depth + self.num_heads = num_heads + self.attn_dropout = attn_dropout + self.attn_value_dropout = attn_value_dropout + self.format = _format + kv_inp_size = kv_inp_size or inp_size + + with tf.variable_scope(name) as scope: + self.scope = scope + + if self.format == 'use_kv': + self.query_conv = Dense( + 'query_conv', + inp_size, key_depth, + activ=lambda x: x, + bias_initializer=tf.zeros_initializer(), + ) + + self.kv_conv = Dense( + 'mem_conv', + kv_inp_size, key_depth + value_depth, + activ=lambda x: x, + bias_initializer=tf.zeros_initializer(), + ) + + if kv_inp_size == inp_size: + self.combined_conv = Dense( + 'combined_conv', + inp_size, key_depth * 2 + value_depth, + activ=lambda x: x, + matrix=tf.concat([self.query_conv.W, self.kv_conv.W], axis=1), + bias=tf.concat([self.query_conv.b, self.kv_conv.b], axis=0), + ) + + elif self.format == 'combined': + assert inp_size == kv_inp_size, 'combined format is only supported when inp_size == kv_inp_size' + self.combined_conv = Dense( + 'mem_conv', # old name for compatibility + inp_size, key_depth * 2 + value_depth, + activ=lambda x: x, + bias_initializer=tf.zeros_initializer()) + + self.query_conv = Dense( + 'query_conv', + inp_size, key_depth, + activ=lambda x: x, + matrix=self.combined_conv.W[:, :key_depth], + bias=self.combined_conv.b[:key_depth], + ) + + self.kv_conv = Dense( + 'kv_conv', + kv_inp_size, key_depth + value_depth, + activ=lambda x: x, + matrix=self.combined_conv.W[:, key_depth:], + bias=self.combined_conv.b[key_depth:], + ) + else: + raise Exception("Unexpected format: " + self.format) + + self.out_conv = Dense( + 'out_conv', + value_depth, output_depth, + activ=lambda x: x, + bias_initializer=tf.zeros_initializer()) + + def attention_core(self, q, k, v, attn_mask): + """ Core math operations of multihead attention layer """ + q = self._split_heads(q) # [batch_size * n_heads * n_q * (k_dim/n_heads)] + k = self._split_heads(k) # [batch_size * n_heads * n_kv * (k_dim/n_heads)] + v = self._split_heads(v) # [batch_size * n_heads * n_kv * (v_dim/n_heads)] + + key_depth_per_head = self.key_depth / self.num_heads + q = q / math.sqrt(key_depth_per_head) + + # Dot-product attention + # logits: (batch_size * n_heads * n_q * n_kv) + attn_bias = MultiHeadAttn.ATTN_BIAS_VALUE * (1 - attn_mask) + logits = tf.matmul( + q, + tf.transpose(k, perm=[0, 1, 3, 2])) + attn_bias + weights = tf.nn.softmax(logits) + + tf.add_to_collection("AttnWeights", weights) + tf.add_to_collection(lib.meta.ATTENTIONS, lib.meta.Attention(self.scope, weights, logits, attn_mask)) + + if is_dropout_enabled(): + weights = dropout(weights, 1.0 - self.attn_dropout) + x = tf.matmul( + weights, # [batch_size * n_heads * n_q * n_kv] + v # [batch_size * n_heads * n_kv * (v_deph/n_heads)] + ) + combined_x = self._combine_heads(x) + + if is_dropout_enabled(): + combined_x = dropout(combined_x, 1.0 - self.attn_value_dropout) + return combined_x + + def __call__(self, query_inp, attn_mask, kv_inp=None, kv=None): + """ + query_inp: [batch_size * n_q * inp_dim] + attn_mask: [batch_size * 1 * n_q * n_kv] + kv_inp: [batch_size * n_kv * inp_dim] + ----------------------------------------------- + results: [batch_size * n_q * output_depth] + """ + assert kv is None or kv_inp is None, "please only feed one of kv or kv_inp" + + with tf.variable_scope(self.scope), tf.name_scope(self.name) as scope: + rec.save_activation('kv', kv) + if kv_inp is not None or kv is not None: + q = self.query_conv(query_inp) + if kv is None: + kv = self.kv_conv(kv_inp) + k, v = tf.split(kv, [self.key_depth, self.value_depth], axis=2) + rec.save_activation('is_combined', False) + else: + combined = self.combined_conv(query_inp) + q, k, v = tf.split(combined, [self.key_depth, self.key_depth, self.value_depth], axis=2) + rec.save_activation('is_combined', True) + + rec.save_activations(q=q, k=k, v=v, attn_mask=attn_mask) + combined_x = self.attention_core(q, k, v, attn_mask) + outputs = self.out_conv(combined_x) + + return outputs + + def relprop(self, R): + with tf.variable_scope(self.scope): + assert rec.get_activation('kv') is None, "relprop through translatemodelfast is not implemented" + R = self.out_conv.relprop(R) + q, k, v, attn_mask = rec.get_activations('q', 'k', 'v', 'attn_mask') + # TODO relprop with taylor expansion? + Rq, Rk, Rv = LRP.relprop(lambda q, k, v: self.attention_core(q, k, v, attn_mask), None, R, q, k, v) + + if rec.get_activation('is_combined'): + Rqkv = tf.concat([Rq, Rk, Rv], axis=2) # [batch, time, 3 * hid_size] + Rinp = self.combined_conv.relprop(Rqkv) + return Rinp + else: + Rkv = tf.concat([Rk, Rv], axis=2) # [batch, time, 2 * hid_size] + Rkvinp = self.kv_conv.relprop(Rkv) + Rqinp = self.query_conv.relprop(Rq) + return {'query_inp': Rqinp, 'kv_inp': Rkvinp} + + def _split_heads(self, x): + """ + Split channels (dimension 3) into multiple heads (dimension 1) + input: (batch_size * ninp * inp_dim) + output: (batch_size * n_heads * ninp * (inp_dim/n_heads)) + """ + old_shape = x.get_shape().dims + dim_size = old_shape[-1] + new_shape = old_shape[:-1] + [self.num_heads] + [dim_size // self.num_heads if dim_size else None] + ret = tf.reshape(x, tf.concat([tf.shape(x)[:-1], [self.num_heads, tf.shape(x)[-1] // self.num_heads]], 0)) + ret.set_shape(new_shape) + return tf.transpose(ret, [0, 2, 1, 3]) # [batch_size * n_heads * ninp * (hid_dim//n_heads)] + + def _combine_heads(self, x): + """ + Inverse of split heads + input: (batch_size * n_heads * ninp * (inp_dim/n_heads)) + out: (batch_size * ninp * inp_dim) + """ + x = tf.transpose(x, [0, 2, 1, 3]) + old_shape = x.get_shape().dims + a, b = old_shape[-2:] + new_shape = old_shape[:-2] + [a * b if a and b else None] + ret = tf.reshape(x, tf.concat([tf.shape(x)[:-2], [tf.shape(x)[-2] * tf.shape(x)[-1]]], 0)) + ret.set_shape(new_shape) + return ret + + + +class MultiHeadAttnConcrete(MultiHeadAttn): + """ + Multihead scaled-dot-product attention with input/output transformations. + This is the modification with scalar gates to each head, which enables head pruning introduced in https://arxiv.org/abs/1905.09418 + """ + + def __init__( + self, name, inp_size, + key_depth, value_depth, output_depth, + num_heads, attn_dropout, attn_value_dropout, + kv_inp_size=None, _format='combined', + gate_hp={'l0_penalty': 1.0}, + ): + super().__init__(name, inp_size, + key_depth, value_depth, output_depth, + num_heads, attn_dropout, attn_value_dropout, + kv_inp_size=kv_inp_size, _format=_format) + + self.gate_hp = gate_hp + + with tf.variable_scope(name): + self.scope = tf.get_variable_scope() + self.gate = ConcreteGate('gate', shape=[1, self.num_heads, 1, 1], **self.gate_hp) + + + def __call__(self, query_inp, attn_mask, kv_inp=None, kv=None): + """ + query_inp: [batch_size * n_q * inp_dim] + attn_mask: [batch_size * 1 * n_q * n_kv] + kv_inp: [batch_size * n_kv * inp_dim] + ----------------------------------------------- + results: [batch_size * n_q * output_depth] + """ + assert kv is None or kv_inp is None, "please only feed one of kv or kv_inp" + with tf.name_scope(self.name) as scope: + if kv_inp is not None or kv is not None: + q = self.query_conv(query_inp) + if kv is None: + kv = self.kv_conv(kv_inp) + k, v = tf.split(kv, [self.key_depth, self.value_depth], axis=2) + else: + combined = self.combined_conv(query_inp) + q, k, v = tf.split(combined, [self.key_depth, self.key_depth, self.value_depth], axis=2) + q = self._split_heads(q) # [batch_size * n_heads * n_q * (k_dim/n_heads)] + k = self._split_heads(k) # [batch_size * n_heads * n_kv * (k_dim/n_heads)] + v = self._split_heads(v) # [batch_size * n_heads * n_kv * (v_dim/n_heads)] + + key_depth_per_head = self.key_depth / self.num_heads + q = q / math.sqrt(key_depth_per_head) + + # Dot-product attention + # logits: (batch_size * n_heads * n_q * n_kv) + attn_bias = MultiHeadAttn.ATTN_BIAS_VALUE * (1 - attn_mask) + logits = tf.matmul( + q, + tf.transpose(k, perm=[0, 1, 3, 2])) + attn_bias + weights = tf.nn.softmax(logits) + + tf.add_to_collection("AttnWeights", weights) + + tf.add_to_collection(lib.meta.ATTENTIONS, lib.meta.Attention(scope, weights, logits, attn_mask)) + + if is_dropout_enabled(): + weights = dropout(weights, 1.0 - self.attn_dropout) + x = tf.matmul( + weights, # [batch_size * n_heads * n_q * n_kv] + v # [batch_size * n_heads * n_kv * (v_deph/n_heads)] + ) + # x: [batch, n_heads, n_q, (v_deph/n_heads)] + + # ======================== apply the gate ======================== + gated_x = self.gate(x) + + tf.add_to_collection("CONCRETE", self.gate.get_sparsity_rate()) + tf.add_to_collection("GATEVALUES", self.gate.get_gates(False)) + # ================================================================== + + combined_x = self._combine_heads(gated_x) + + if is_dropout_enabled(): + combined_x = dropout(combined_x, 1.0 - self.attn_value_dropout) + + outputs = self.out_conv(combined_x) + + return outputs + + + +class MultiHeadAttnFixedAliveHeads(MultiHeadAttn): + """ + Multihead scaled-dot-product attention with input/output transformations. + This is the modification with constant binary gates for each head, + which specify which heads are present. + Need to pass 'head_gate' parameter, which the list of num_heads values, one for each head. + """ + + def __init__( + self, name, inp_size, + key_depth, value_depth, output_depth, + num_heads, attn_dropout, attn_value_dropout, + kv_inp_size=None, _format='combined', + head_gate=None, + ): + super().__init__(name, inp_size, + key_depth, value_depth, output_depth, + num_heads, attn_dropout, attn_value_dropout, + kv_inp_size=kv_inp_size, _format=_format) + + assert head_gate is not None, "You must feed values for head gates" + self.head_gate = head_gate + + with tf.variable_scope(name): + self.scope = tf.get_variable_scope() + self.gate = tf.constant(self.head_gate, dtype=tf.float32)[None, :, None, None] + + + def __call__(self, query_inp, attn_mask, kv_inp=None, kv=None): + """ + query_inp: [batch_size * n_q * inp_dim] + attn_mask: [batch_size * 1 * n_q * n_kv] + kv_inp: [batch_size * n_kv * inp_dim] + ----------------------------------------------- + results: [batch_size * n_q * output_depth] + """ + assert kv is None or kv_inp is None, "please only feed one of kv or kv_inp" + with tf.name_scope(self.name) as scope: + if kv_inp is not None or kv is not None: + q = self.query_conv(query_inp) + if kv is None: + kv = self.kv_conv(kv_inp) + k, v = tf.split(kv, [self.key_depth, self.value_depth], axis=2) + else: + combined = self.combined_conv(query_inp) + q, k, v = tf.split(combined, [self.key_depth, self.key_depth, self.value_depth], axis=2) + q = self._split_heads(q) # [batch_size * n_heads * n_q * (k_dim/n_heads)] + k = self._split_heads(k) # [batch_size * n_heads * n_kv * (k_dim/n_heads)] + v = self._split_heads(v) # [batch_size * n_heads * n_kv * (v_dim/n_heads)] + + key_depth_per_head = self.key_depth / self.num_heads + q = q / math.sqrt(key_depth_per_head) + + # Dot-product attention + # logits: (batch_size * n_heads * n_q * n_kv) + attn_bias = MultiHeadAttn.ATTN_BIAS_VALUE * (1 - attn_mask) + logits = tf.matmul( + q, + tf.transpose(k, perm=[0, 1, 3, 2])) + attn_bias + weights = tf.nn.softmax(logits) + + tf.add_to_collection("AttnWeights", weights) + + tf.add_to_collection(lib.meta.ATTENTIONS, lib.meta.Attention(scope, weights, logits, attn_mask)) + + if is_dropout_enabled(): + weights = dropout(weights, 1.0 - self.attn_dropout) + x = tf.matmul( + weights, # [batch_size * n_heads * n_q * n_kv] + v # [batch_size * n_heads * n_kv * (v_deph/n_heads)] + ) + # x: [batch, n_heads, n_q, (v_deph/n_heads)] + + # ======================== apply the gate ======================== + gated_x = self.gate * x + # ================================================================== + + combined_x = self._combine_heads(gated_x) + + if is_dropout_enabled(): + combined_x = dropout(combined_x, 1.0 - self.attn_value_dropout) + + outputs = self.out_conv(combined_x) + + return outputs + diff --git a/lib/layers/basic.py b/lib/layers/basic.py new file mode 100644 index 0000000..51fe6a7 --- /dev/null +++ b/lib/layers/basic.py @@ -0,0 +1,392 @@ +# Basic NN layers + +import lib + +import tensorflow as tf +from ..util import nop_ctx +from ..ops import record_activations as rec +from .lrp import LRP +from ..ops.basic import * + +############################################################################### +# # +# LAYERS # +# # +############################################################################### + + + +## ---------------------------------------------------------------------------- +# Dense +class Dense: + def __init__( + self, name, + inp_size, out_size, activ=tf.tanh, + matrix=None, bias=None, + matrix_initializer=None, bias_initializer=None): + + """ + /W + /b + + User can explicitly specify matrix to use instead of W (/W is + not created then), but this is not recommended to external users. + """ + self.name = name + self.activ = activ + self.inp_size = inp_size + self.out_size = out_size + + with tf.variable_scope(name) as self.scope: + if matrix is None: + self.W = get_model_variable('W', shape=[inp_size, out_size], initializer=matrix_initializer) + else: + self.W = matrix + + if bias is None: + self.b = get_model_variable('b', shape=[out_size], initializer=bias_initializer) + else: + self.b = bias + + def __call__(self, inp): + """ + inp: [..., inp_size] + -------------------- + Ret: [..., out_size] + """ + with tf.variable_scope(self.scope): + out = self.activ(dot(inp, self.W) + self.b) + out.set_shape([None] * (out.shape.ndims - 1) + [self.out_size]) + if rec.is_recorded(): + rec.save_activations(inp=inp, out=out) + return out + + def relprop(self, output_relevance): + """ + computes input relevance given output_relevance + :param output_relevance: relevance w.r.t. layer output, [*dims, out_size] + notation from DOI:10.1371/journal.pone.0130140, Eq 60 + """ + # make two copies of the layer: one with positive params and one with negative + clone_self = lambda W, b, activ: Dense(self.name, self.inp_size, self.out_size, + activ, matrix=W, bias=b) + f_positive = clone_self(tf.maximum(self.W, LRP.eps), b=0, activ=nop) + f_negative = clone_self(tf.minimum(self.W, -LRP.eps), b=0, activ=nop) # the dark side of me + + with tf.variable_scope(self.scope): + inp, out = rec.get_activations('inp', 'out') + # inp: [*dims, inp_size], out: [*dims, out_size] + input_relevance = LRP.relprop(f_positive, f_negative, output_relevance, inp) + return input_relevance + + @property + def input_size(self): + return self.inp_size + + @property + def output_size(self): + return self.out_size + +## ---------------------------------------------------------------------------- +# Embedding + +class Embedding: + def __init__(self, name, voc_size, emb_size, matrix=None, initializer=None, device=''): + """ + Parameters: + + /mat + """ + self.name = name + self.voc_size = voc_size + self.emb_size = emb_size + self.device = device + + if matrix is not None: + self.mat = matrix + else: + with tf.variable_scope(name), (tf.device(device) if device is not None else nop_ctx()): + self.mat = get_model_variable('mat', shape=[voc_size, emb_size], initializer=initializer) + + def __call__(self, inp, gumbel=False): + """ + inp: [...] + -------------------- + Ret: [..., emb_size] + """ + with tf.name_scope(self.name), (tf.device(self.device) if self.device is not None else nop_ctx()): + return tf.gather(self.mat, inp) if not gumbel else dot(inp, self.mat) + +## ---------------------------------------------------------------------------- +# LayerNorm + +class LayerNorm: + """ + Performs Layer Normalization + """ + def __init__(self, name, inp_size, epsilon=1e-6): + self.name = name + self.epsilon = epsilon + + with tf.variable_scope(name): + self.scale = get_model_variable('scale', shape=[inp_size], initializer=tf.ones_initializer()) + self.bias = get_model_variable('bias', shape=[inp_size], initializer=tf.zeros_initializer()) + + def __call__(self, inp): + with tf.variable_scope(self.name): + mean = tf.reduce_mean(inp, axis=[-1], keep_dims=True) + variance = tf.reduce_mean(tf.square(inp - mean), axis=[-1], keep_dims=True) + norm_x = (inp - mean) * tf.rsqrt(variance + self.epsilon) + return norm_x * self.scale + self.bias + + def relprop(self, R): + #TODO find out the "canonic" way to relrop through layernorm + return R + +## ---------------------------------------------------------------------------- +# ResidualWrapper + + +class Wrapper: + """ Reflection-style wrapper, code from http://code.activestate.com/recipes/577555-object-wrapper-class/ """ + def __init__(self, wrapped_layer): + self.wrapped_layer = wrapped_layer + + def __getattr__(self, attr): + if attr in self.__dict__: + return getattr(self, attr) + return getattr(self.wrapped_layer, attr) + + +class ResidualLayerWrapper(Wrapper): + def __init__(self, name, wrapped_layer, inp_size, out_size, steps='ldan', dropout=0, dropout_seed=None): + """ + Applies any number of residual connection, dropout and/or layer normalization before or after wrapped layer + :param steps: a sequence of operations to perform, containing any combination of: + - 'l' - call wrapped [l]ayer, this operation should be used exactly once + - 'd' - apply [d]ropout with p = dropout and seed = dropout_seed + - 'a' - [a]dd inputs to output (residual connection) + - 'n' - apply layer [n]ormalization here, can only be done once + """ + assert steps.count('l') == 1, "residual wrapper must call wrapped layer exactly once" + assert steps.count('n') <= 1, "in the current implementaion, there can be at most one layer normalization step" + assert inp_size == out_size or 'a' not in steps, "residual step only works if inp_size == out_size" + self.name = name + self.wrapped_layer = wrapped_layer + + if 'n' in steps: + ln_size = inp_size if steps.index('n') < steps.index('l') else out_size + with tf.variable_scope(name) as self.scope: + self.norm_layer = LayerNorm("layer_norm", ln_size) + + self.steps = steps + self.preprocess_steps = steps[:steps.index('l')] + self.postprocess_steps = steps[steps.index('l') + 1:] + self.dropout = dropout + self.dropout_seed = dropout_seed + + def __call__(self, inp, *args, **kwargs): + out = self.preprocess(inp) + out = self.wrapped_layer(out, *args, **kwargs) + out = self.postprocess(out, inp) + return out + + def preprocess(self, inp): + return self._perform(self.preprocess_steps, inp) + + def postprocess(self, out, inp=None): + return self._perform(self.postprocess_steps, out, inp=inp) + + def _perform(self, steps, out, inp=None): + with tf.variable_scope(self.scope): + if inp is None: + inp = out + + for s in steps: + if s == 'd': + if is_dropout_enabled(): + out = lib.ops.dropout(out, 1.0 - self.dropout, seed=self.dropout_seed) + elif s == 'a': + rec.save_activations(inp=inp, out_pre_residual=out) + out += inp + elif s == 'n': + out = self.norm_layer(out) + else: + raise RuntimeError("Unknown process step: %s" % s) + return out + + def relprop(self, R, main_key=None): + original_scale = tf.reduce_sum(abs(R)) + with tf.variable_scope(self.scope): + Rinp_residual = 0.0 + for s in self.steps[::-1]: + if s == 'l': + R = self.wrapped_layer.relprop(R) + if isinstance(R, dict): + assert main_key is not None + R_dict = R + R = R_dict[main_key] + elif s == 'a': + inp, out = rec.get_activations('inp', 'out_pre_residual') + Rinp_residual, R = LRP.relprop(lambda a, b: a + b, None, R, inp, out) + elif s == 'n': + R = self.norm_layer.relprop(R) + + pre_residual_scale = tf.reduce_sum(abs(R) + abs(Rinp_residual)) + + R = R + Rinp_residual + R = R * pre_residual_scale / tf.reduce_sum(tf.abs(R)) + if main_key is not None: + R_dict = dict(R_dict) + R_dict[main_key] = R + total_scale = sum(tf.reduce_sum(abs(relevance)) for relevance in R_dict.values()) + R_dict = {key: value * original_scale / total_scale + for key, value in R_dict.items()} + return R_dict + else: + return R + + +############################################################################### +# # +# SEQUENCE LOSSES # +# # +############################################################################### + + +class SequenceLossBase: + def rdo_to_logits(self, *args, **kwargs): + raise NotImplementedError() + + def rdo_to_logits__predict(self, *args, **kwargs): + return self.rdo_to_logits(*args, **kwargs) + + +class LossXent(SequenceLossBase): + def __init__( + self, name, rdo_size, voc, hp, + matrix=None, bias=None, + matrix_initializer=None, bias_initializer=tf.zeros_initializer(), + ): + """ + Parameters: + + Dense: /logits + """ + if 'lm_path' in hp: + raise NotImplementedError("LM fusion not implemented") + + self.name = name + self.rdo_size = rdo_size + self.voc_size = voc.size() + + self.bos = voc.bos + self.label_smoothing = hp.get('label_smoothing', 0) + + with tf.variable_scope(name): + self._rdo_to_logits = Dense( + 'logits', rdo_size, self.voc_size, activ=nop, + matrix=matrix, bias=bias, + matrix_initializer=matrix_initializer, bias_initializer=bias_initializer) + + def __call__(self, rdo, out, out_len): + """ + rdo: [batch_size, ninp, rdo_size] + out: [batch_size, ninp], dtype=int + out_len: [batch_size] + inp_words: [batch_size, ninp], dtype=string + attn_P_argmax: [batch_size, ninp], dtype=int + -------------------------- + Ret: [batch_size] + """ + logits = self.rdo_to_logits(rdo, out, out_len) # [batch_size, ninp, voc_size] + return self.logits2loss(logits, out, out_len) + + def rdo_to_logits(self, rdo, out, out_len): + """ + compute logits in training mode + :param rdo: pre-final activations float32[batch, num_outputs, hid_size] + :param out: output sequence, padded with EOS int64[batch, num_outputs] + :param out_len: lengths of outputs in :out: excluding padding, int64[batch] + """ + return self._rdo_to_logits(rdo) + + def logits2loss(self, logits, out, out_len, reduce_rows=True): + if self.label_smoothing: + voc_size = tf.shape(logits)[-1] + smooth_positives = 1.0 - self.label_smoothing + smooth_negatives = self.label_smoothing / tf.to_float(voc_size - 1) + onehot_labels = tf.one_hot(out, depth=voc_size, on_value=smooth_positives, off_value=smooth_negatives) + + losses = tf.nn.softmax_cross_entropy_with_logits( + labels=onehot_labels, + logits=logits, + name="xentropy") + + # Normalizing constant is the best cross-entropy value with soft targets. + # We subtract it just for readability, makes no difference on learning. + normalizing = -(smooth_positives * tf.log(smooth_positives) + + tf.to_float(voc_size - 1) * smooth_negatives * tf.log(smooth_negatives + 1e-20)) + losses -= normalizing + else: + losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=out) + + losses *= tf.sequence_mask(out_len, maxlen=tf.shape(out)[1], dtype=logits.dtype) + + if reduce_rows: + return tf.reduce_sum(losses, axis=1) + else: + return losses + + def rdo_to_logits__predict(self, rdo, prefix): + """ like rdo_to_logits, but used in beam search """ + return self._rdo_to_logits(rdo) + + +LossXentLm = LossXent # alias + + +class FFN: + """ + Feed-forward layer + """ + + def __init__(self, name, + inp_size, hid_size, out_size, + relu_dropout): + assert isinstance(hid_size, int), "List of hidden sizes not is not supported" + self.name = name + self.relu_dropout = relu_dropout + + with tf.variable_scope(name): + self.first_conv = Dense( + 'conv1', + inp_size, hid_size, + activ=tf.nn.relu, + bias_initializer=tf.zeros_initializer()) + + self.second_conv = Dense( + 'conv2', + hid_size, out_size, + activ=lambda x: x, + bias_initializer=tf.zeros_initializer()) + + def __call__(self, inputs, summarize_preactivations=False): + """ + inp: [batch_size * ninp * inp_dim] + --------------------------------- + out: [batch_size * ninp * out_dim] + """ + with tf.variable_scope(self.name): + hidden = self.first_conv(inputs) + if is_dropout_enabled(): + hidden = dropout(hidden, 1.0 - self.relu_dropout) + + outputs = self.second_conv(hidden) + + return outputs + + def relprop(self, R): + R = self.second_conv.relprop(R) + R = self.first_conv.relprop(R) + return R \ No newline at end of file diff --git a/lib/layers/concrete_gate.py b/lib/layers/concrete_gate.py new file mode 100644 index 0000000..b5d5a01 --- /dev/null +++ b/lib/layers/concrete_gate.py @@ -0,0 +1,97 @@ +import tensorflow as tf +from warnings import warn +import lib + + +class ConcreteGate: + """ + A gate made of stretched concrete distribution (using experimental Stretchable Concrete™) + Can be applied to sparsify neural network activations or weights. + Example usage: https://gist.github.com/justheuristic/1118a14a798b2b6d47789f7e6f511abd + :param shape: shape of gate variable. can be broadcasted. + e.g. if you want to apply gate to tensor [batch, length, units] over units axis, + your shape should be [1, 1, units] + :param temperature: concrete sigmoid temperature, should be in (0, 1] range + lower values yield better approximation to actual discrete gate but train longer + :param stretch_limits: min and max value of gate before it is clipped to [0, 1] + min value should be negative in order to compute l0 penalty as in https://arxiv.org/pdf/1712.01312.pdf + however, you can also use tf.nn.sigmoid(log_a) as regularizer if min, max = 0, 1 + :param l0_penalty: coefficient on the regularizer that minimizes l0 norm of gated value + :param l2_penalty: coefficient on the regularizer that minimizes l2 norm of gated value + :param eps: a small additive value used to avoid NaNs + :param hard: if True, gates are binarized to {0, 1} but backprop is still performed as if they were concrete + :param local_rep: if True, samples a different gumbel noise tensor for each sample in batch, + by default, noise is sampled using shape param as size. + + """ + + def __init__(self, name, shape, temperature=0.33, stretch_limits=(-0.1, 1.1), + l0_penalty=0.0, l2_penalty=0.0, eps=1e-6, hard=False, local_rep=False): + self.name = name + self.temperature, self.stretch_limits, self.eps = temperature, stretch_limits, eps + self.l0_penalty, self.l2_penalty = l0_penalty, l2_penalty + self.hard, self.local_rep = hard, local_rep + with tf.variable_scope(name): + self.log_a = lib.ops.get_model_variable("log_a", shape=shape) + + def __call__(self, values, is_train=None, axis=None, reg_collection=tf.GraphKeys.REGULARIZATION_LOSSES): + """ applies gate to values, if is_train, adds regularizer to reg_collection """ + is_train = lib.layers.basic.is_dropout_enabled() if is_train is None else is_train + gates = self.get_gates(is_train, shape=tf.shape(values) if self.local_rep else None) + + if self.l0_penalty != 0 or self.l2_penalty != 0: + reg = self.get_penalty(values=values, axis=axis) + tf.add_to_collection(reg_collection, tf.identity(reg, name='concrete_gate_reg')) + return values * gates + + def get_gates(self, is_train, shape=None): + """ samples gate activations in [0, 1] interval """ + low, high = self.stretch_limits + with tf.name_scope(self.name): + if is_train: + shape = tf.shape(self.log_a) if shape is None else shape + noise = tf.random_uniform(shape, self.eps, 1.0 - self.eps) + concrete = tf.nn.sigmoid((tf.log(noise) - tf.log(1 - noise) + self.log_a) / self.temperature) + else: + concrete = tf.nn.sigmoid(self.log_a) + + stretched_concrete = concrete * (high - low) + low + clipped_concrete = tf.clip_by_value(stretched_concrete, 0, 1) + if self.hard: + hard_concrete = tf.to_float(tf.greater(clipped_concrete, 0.5)) + clipped_concrete = clipped_concrete + tf.stop_gradient(hard_concrete - clipped_concrete) + return clipped_concrete + + def get_penalty(self, values=None, axis=None): + """ + Computes l0 and l2 penalties. For l2 penalty one must also provide the sparsified values + (usually activations or weights) before they are multiplied by the gate + Returns the regularizer value that should to be MINIMIZED (negative logprior) + """ + if self.l0_penalty == self.l2_penalty == 0: + warn("get_penalty() is called with both penalties set to 0") + low, high = self.stretch_limits + assert low < 0.0, "p_gate_closed can be computed only if lower stretch limit is negative" + with tf.name_scope(self.name): + # compute p(gate_is_closed) = cdf(stretched_sigmoid < 0) + p_open = tf.nn.sigmoid(self.log_a - self.temperature * tf.log(-low / high)) + p_open = tf.clip_by_value(p_open, self.eps, 1.0 - self.eps) + + total_reg = 0.0 + if self.l0_penalty != 0: + if values != None and self.local_rep: + p_open += tf.zeros_like(values) # broadcast shape to account for values + l0_reg = self.l0_penalty * tf.reduce_sum(p_open, axis=axis) + total_reg += tf.reduce_mean(l0_reg) + + if self.l2_penalty != 0: + assert values is not None + l2_reg = 0.5 * self.l2_penalty * p_open * tf.reduce_sum(values ** 2, axis=axis) + total_reg += tf.reduce_mean(l2_reg) + + return total_reg + + def get_sparsity_rate(self, is_train=False): + """ Computes the fraction of gates which are now active (non-zero) """ + is_nonzero = tf.not_equal(self.get_gates(is_train), 0.0) + return tf.reduce_mean(tf.to_float(is_nonzero)) \ No newline at end of file diff --git a/lib/layers/lrp.py b/lib/layers/lrp.py new file mode 100644 index 0000000..566da1c --- /dev/null +++ b/lib/layers/lrp.py @@ -0,0 +1,58 @@ +import tensorflow as tf +from ..ops import record_activations as rec + + +class LRP: + """ Helper class for layerwise relevance propagation """ + alpha = 1.0 + beta = 0.0 + eps = 1e-7 + crop_function = abs + + @classmethod + def relprop(cls, f_positive, f_negative, output_relevance, *inps): + """ + computes input relevance given output_relevance using z+ rule + works for linear layers, convolutions, poolings, etc. + notation from DOI:10.1371/journal.pone.0130140, Eq 60 + :param f_positive: forward function with positive weights (if any) and no nonlinearities + :param f_negative: forward function with negative weights and no nonlinearities + if there's no weights, set f_negative to None. Only used for alpha-beta LRP + :param output_relevance: relevance w.r.t. layer output + :param inps: a list of layer inputs + """ + assert len(inps) > 0, "please provide at least one input" + with rec.do_not_record(): + alpha, beta, eps = cls.alpha, cls.beta, cls.eps + inps = [inp + eps for inp in inps] + + # ouput relevance: [*dims, out_size] + z_positive = f_positive(*inps) + s_positive = cls.alpha * output_relevance / z_positive # [*dims, out_size] + positive_relevances = tf.gradients(z_positive, inps, grad_ys=s_positive) + # ^-- list of [*dims, inp_size] + + if cls.beta != 0 and f_negative is not None: + z_negative = f_negative(*inps) + s_negative = -cls.beta * output_relevance / z_negative # [*dims, out_size] + negative_relevances = tf.gradients(z_negative, inps, grad_ys=s_negative) + # ^-- list of [*dims, inp_size] + else: + negative_relevances = [0.0] * len(inps) + + inp_relevances = [ + inp * (rel_pos + rel_neg) + for inp, rel_pos, rel_neg in zip(inps, positive_relevances, negative_relevances) + ] + + return cls.rescale(output_relevance, *inp_relevances) + + + @classmethod + def rescale(cls, reference, *inputs, axis=None): + inputs = [cls.crop_function(inp) for inp in inputs] + ref_scale = tf.reduce_sum(reference, axis=axis, keep_dims=axis is not None) + inp_scales = [tf.reduce_sum(inp, axis=axis, keep_dims=axis is not None) for inp in inputs] + total_inp_scale = sum(inp_scales) + cls.eps + inputs = [inp * (ref_scale / total_inp_scale) for inp in inputs] + return inputs[0] if len(inputs) == 1 else inputs diff --git a/lib/meta.py b/lib/meta.py new file mode 100644 index 0000000..5d57e60 --- /dev/null +++ b/lib/meta.py @@ -0,0 +1,46 @@ +import tensorflow as tf +import sys + +from collections import namedtuple +from contextlib import contextmanager + +## Collection keys + +# Collection of tensors representing layer activations in network +ACTIVATIONS = tf.GraphKeys.ACTIVATIONS + +# Collection of Attention objects +ATTENTIONS = "attentions" +SUMMARIES_ZOO = "summaries_zoo" +PARAMS_SUMMARIES = "params_summaries" + + +Attention = namedtuple('Attention', ['name', 'weights', 'logits', 'mask']) + + +def get_indexed_collection(coll, scope, root_scope=None): + if root_scope is None: + root_scope = tf.contrib.framework.get_name_scope() + + full_scope = root_scope + '/' + scope + + def normalize_name(n): + n = n[len(full_scope)+1:] + if n.endswith(':0'): + n = n[:-2] + if n.endswith('/'): + n = n[:-1] + return n + + return dict((normalize_name(t.name), t) for t in tf.get_collection(coll, full_scope + '/.*')) + + +@contextmanager +def lock_collections(collections): + collection_states = [tf.get_collection(coll) for coll in collections] + yield + for coll, old_coll_state in zip(collections, collection_states): + new_coll_state = tf.get_collection_ref(coll) + if old_coll_state != new_coll_state: + print("! Changes in collection %s will be ignored!" % coll, flush=True, file=sys.stderr) + new_coll_state[:] = old_coll_state # Replace collection state with old one diff --git a/lib/ops/__init__.py b/lib/ops/__init__.py new file mode 100644 index 0000000..ba74a64 --- /dev/null +++ b/lib/ops/__init__.py @@ -0,0 +1,2 @@ +from . import basic, mpi, sliced_argmax, devices, record_activations +from .basic import * \ No newline at end of file diff --git a/lib/ops/basic.py b/lib/ops/basic.py new file mode 100644 index 0000000..63b5e84 --- /dev/null +++ b/lib/ops/basic.py @@ -0,0 +1,164 @@ +# Basic TF operations +import threading +from contextlib import contextmanager + +import tensorflow as tf +import hashlib +from copy import copy + + +def get_seed_from_name(name): + full_name = '/'.join([tf.get_variable_scope().name, name]) + return int(hashlib.md5(full_name.encode()).hexdigest()[:8], 16) + + +def default_initializer(seed, dtype): + scope_initializer = tf.get_variable_scope().initializer + if scope_initializer is not None: + return scope_initializer + try: + return tf.initializers.glorot_uniform(seed, dtype) + except: + return tf.glorot_uniform_initializer(seed, dtype) + + +def get_model_variable(name, **kwargs): + """ Get variable from MODEL_VARIABLES collection with initializer seeded from its name, not id """ + + if kwargs.get('initializer') is None: + kwargs['initializer'] = default_initializer(seed=get_seed_from_name(name), dtype=kwargs.get('dtype', tf.float32)) + elif hasattr(kwargs['initializer'], 'seed') and kwargs['initializer'].seed is None: + kwargs['initializer'] = copy(kwargs['initializer']) + kwargs['initializer'].seed = get_seed_from_name(name) + + return tf.contrib.framework.model_variable(name, **kwargs) + + +def dot(x, y): + """ + x: [..., a] + y: [a, ...] + ------------- + Ret: [..., ...] + """ + x_ndim = x.get_shape().ndims + y_ndim = y.get_shape().ndims + etc_x = tf.slice(tf.shape(x), [0], [x_ndim-1]) + etc_y = tf.slice(tf.shape(y), [1], [-1]) + a = tf.shape(y)[0] + + # Reshape forth. + if x_ndim != 2: + x = tf.reshape(x, [-1, a]) + if y_ndim != 2: + y = tf.reshape(y, [a, -1]) + + # Compute + ret = tf.matmul(x, y) + + # Reshape back. + if x_ndim != 2 or y_ndim != 2: + ret = tf.reshape(ret, tf.concat([etc_x, etc_y], 0)) + + return ret + + +def sequence_mask(lengths, dtype, maxlen=None): + """ + WARNING: THis func produces Time-major tensor + lengths: [batch_size] + ------- + out: [maxlen, batch_size] + """ + lengths = tf.cast(lengths, tf.int32) + if maxlen is not None: + maxlen = tf.cast(maxlen, tf.int32) + return tf.transpose(tf.sequence_mask(lengths, dtype=dtype, maxlen=maxlen)) + + +def infer_length(seq, eos=1, time_major=False): + """ + compute length given output indices and eos code + :param seq: tf matrix [time,batch] if time_major else [batch,time] + :param eos: integer index of end-of-sentence token + :returns: lengths, int32 vector of [batch_size] + """ + axis = 0 if time_major else 1 + is_eos = tf.cast(tf.equal(seq, eos), 'int32') + count_eos = tf.cumsum(is_eos, axis=axis, exclusive=True) + lengths = tf.reduce_sum(tf.cast(tf.equal(count_eos, 0), 'int32'), axis=axis) + return lengths + + +def infer_mask(seq, eos=1, time_major=False, dtype=tf.bool): + """ + compute mask + :param seq: tf matrix [time,batch] if time_major else [batch,time] + :param eos: integer index of end-of-sentence token + :returns: mask, matrix of same shape as seq and of given dtype (bool by default) + """ + lengths = infer_length(seq, eos=eos, time_major=time_major) + mask_fn = sequence_mask if time_major else tf.sequence_mask + maxlen = tf.shape(seq)[0 if time_major else 1] + return mask_fn(lengths, dtype=dtype, maxlen=maxlen) + + +def dropout(x, keep_prob, *args, **kwargs): + """This is a hack to save memory if there is no dropout""" + if keep_prob >= 1: + return x + return tf.nn.dropout(x, keep_prob, *args, **kwargs) + + +def group(*ops): + """ + Like tf.group(), but returns tf.constant(0) instead of tf.no_op(), + which makes it suitable for use in tf.cond(). + """ + with tf.control_dependencies(ops): + return tf.constant(0) + + +def select_values_over_last_axis(values, indices): + """ + Auxiliary function to select logits corresponding to chosen tokens. + :param values: logits for all actions: float32[batch,tick,action] + :param indices: action ids int32[batch,tick] + :returns: values selected for the given actions: float[batch,tick] + """ + assert values.shape.ndims == 3 and indices.shape.ndims == 2 + batch_size, seq_len = tf.shape(indices)[0], tf.shape(indices)[1] + + time_i, batch_i = tf.meshgrid(tf.range(0, seq_len, dtype=indices.dtype), + tf.range(0, batch_size, dtype=indices.dtype)) + + indices_nd = tf.stack([batch_i, time_i, indices], axis=-1) + + return tf.gather_nd(values, indices_nd) + + +def nop(x): + return x + + +def kl_divergence_with_logits(p_logits, q_logits): + return tf.reduce_sum(tf.nn.softmax(p_logits) * (tf.nn.log_softmax(p_logits) - tf.nn.log_softmax(q_logits)), axis=-1) + + +_tls = threading.local() + + +def is_dropout_enabled(): + if not hasattr(_tls, 'dropout_enabled'): + _tls.dropout_enabled = True + return _tls.dropout_enabled + + +@contextmanager +def dropout_scope(enabled): + was_enabled = is_dropout_enabled() + _tls.dropout_enabled = enabled + try: + yield + finally: + _tls.dropout_enabled = was_enabled \ No newline at end of file diff --git a/lib/ops/devices.py b/lib/ops/devices.py new file mode 100644 index 0000000..08936bf --- /dev/null +++ b/lib/ops/devices.py @@ -0,0 +1,15 @@ +import tensorflow as tf + + +def list_devices(session=None): + if session is None: + session = session or tf.get_default_session() + return session.list_devices() + + +def list_gpu_devices(session=None): + return [x for x in list_devices(session) if x.device_type == 'GPU'] + + +def have_gpu(): + return len(list_gpu_devices()) != 0 diff --git a/lib/ops/mpi/__init__.py b/lib/ops/mpi/__init__.py new file mode 100644 index 0000000..2ed8327 --- /dev/null +++ b/lib/ops/mpi/__init__.py @@ -0,0 +1,372 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# pylint: disable=g-short-docstring-punctuation +"""## Communicating Between Processes with MPI + +TensorFlow natively provides inter-device communication through send and +receive ops and inter-node communication through Distributed TensorFlow, based +on the same send and receive abstractions. On HPC clusters where Infiniband or +other high-speed node interconnects are available, these can end up being +insufficient for synchronous data-parallel training (without asynchronous +gradient descent). This module implements a variety of MPI ops which can take +advantage of hardware-specific MPI libraries for efficient communication. + +In order to use this module, TensorFlow must be built with an MPI library, +which can be provided to the `./configure` script at build time. As a user of +TensorFlow, you will need to build TensorFlow yourself to select the MPI +library to use; to do so, follow the [instructions for building TensorFlow from +source](https://www.tensorflow.org/get_started/os_setup#installing_from_sources). + +### Utility Ops + +In addition to reductions and gathers, this module provides utility operations +for detecting the running MPI configuration. + +Example: + +```python +from tensorflow.contrib import mpi + +# Use `mpi.Session` instead of `tf.Session` +with mpi.Session() as session: + rank = session.run(mpi.rank()) + print("My MPI Rank:", rank) + + if rank == 0: + print("MPI Size:", session.run(mpi.size())) +``` + +@@rank +@@size + +### Ring Allreduce and Allgather + +When summing or averaging tensors across many processes, communication can +easily become a bottleneck. A naive implementation will send all the tensor +values to the same process, perform the reduction, and then broadcast the +values back to all other processes, effectively creating a synchronous +parameter server in one process. However, the process responsible for +performing the reduction will have to receive and send a massive amount of data +which scales with the number of processes *and* the number of parameters in the +model. + +Instead of centralizing the reduction and having one primary reducer, we can +implement a distributed allreduce or allgather. A bandwidth-optimal allreduce +will end up sending 2(N - 1) values for every value in the input tensor, +and can be implemented with a ring allreduce [1]. (Intuitively, a linear reduce +requires at least (N - 1) sends between the different nodes, and a broadcast of +the result also requires (N - 1) sends, for a total of 2 (N - 1); these two +steps cannot be combined in a clever way to reduce the number of required +sends.) This module implements bandwidth-optimal ring allreduce and ring +allgather operations using MPI; by choosing a hardware-appropriate MPI +implementation (such as OpenMPI with CUDA-IPC support), you can train large +models with synchronous gradient descent with minimal communication overhead. + +In addition to the `allreduce` and `allgather` functions, a convenience +`DistributedOptimizer` wrapper is provided to simplify using these functions +for reducing model gradients. + +Example: + +```python +import tensorflow as tf +from tensorflow.contrib import mpi + +# Construct a simple linear regression model to optimize +W = tf.get_variable("W", shape=[20, 1], dtype=tf.float32) +B = tf.get_variable("B", shape=[1, 1], dtype=tf.float32) +inputs = tf.placeholder("Inputs", shape=[None, 20]) +outputs = tf.placeholder("Outputs", shape=[None, 1]) +loss = tf.nn.l2_loss(tf.matmul(inputs, W) + B - outputs) + +# Training using MPI allreduce with DistributedOptimizer +optimizer = mpi.DistributedOptimizer(tf.train.AdamOptimizer()) +train = optimizer.minimize(loss) + +# Average loss over all ranks, for printing. +# Do not pass this to an optimizer! +avg_loss = mpi.allreduce(loss) + +# On different ranks, feed different input data. +with mpi.Session() as session: + rank = session.run(mpi.rank()) + batch_inputs, batch_outputs = construct_batch_for_rank(rank) + feed_dict = {inputs: batch_inputs, outputs: batch_outputs} + _, l = session.run([train, avg_loss], feed_dict=feed_dict) + print("Average Loss:", l) +``` + +[1] Patarasuk, Pitch and Yuan, Xin. "Bandwidth Optimal All-reduce Algorithms +for Clusters of Workstations". + +@@Session +@@DistributedOptimizer +@@allreduce +@@allgather +""" + +import tensorflow as tf + +import threading +import importlib +import os + + +_provider = None + + +def get_provider(): + global _provider + if _provider is None: + set_provider('horovod' if os.getenv('OMPI_COMM_WORLD_SIZE') is not None else 'dummy') + return _provider + + +def set_provider(provider, force=False): + global _provider + if _provider is not None and not force: + raise RuntimeError("%r already set as provider" % _provider) + _provider = importlib.import_module('lib.ops.mpi.%s_provider' % provider) + + +def is_master(): + """ + Helper function to identify master + """ + mpi_rank = os.getenv('OMPI_COMM_WORLD_RANK') + return mpi_rank is None or mpi_rank == '0' + + +def is_distributed(): + """ + Helper function to identify if we are in distributed mode + """ + mpi_size = os.getenv('OMPI_COMM_WORLD_SIZE') + return mpi_size is not None and int(mpi_size) > 1 + + +class Session(tf.Session): + """A class for running TensorFlow operations, with copies of the same graph + running distributed across different MPI nodes. + + The primary difference between `tf.Session` and `tf.contrib.mpi.Session` is + that the MPI `Session` ensures that the `Session` options are correct for + use with `tf.contrib.mpi`, and initializes MPI immediately upon the start + of the session. + """ + + def __init__(self, gpu_group=None, gpu_group_size=1, target='', graph=None, config=None): + """Creates a new TensorFlow MPI session. + + Unlike a normal `tf.Session`, an MPI Session may only use a single GPU, + which must be specified in advance before the session is initialized. + In addition, it only uses a single graph evaluation thread, and + initializes MPI immediately upon starting. + + If no `graph` argument is specified when constructing the session, + the default graph will be launched in the session. If you are + using more than one graph (created with `tf.Graph()` in the same + process, you will have to use different sessions for each graph, + but each graph can be used in multiple sessions. In this case, it + is often clearer to pass the graph to be launched explicitly to + the session constructor. + + Args: + gpu: (Optional.) The GPU index to use, or None for CPU only MPI. + graph: (Optional.) The `Graph` to be launched (described above). + config: (Optional.) A `ConfigProto` protocol buffer with configuration + options for the session. + """ + if config is None: + config = tf.ConfigProto() + + if gpu_group is not None: + config.gpu_options.visible_device_list = ','.join(str(gpu_group*gpu_group_size + d) for d in range(gpu_group_size)) + + super(Session, self).__init__(target, graph, config=config) + + # Initialize MPI on the relevant device. + with self.as_default(): + self.run(init()) + + # Setup finalize status and lock to prevent double finalize call + self._mpi_finalized = False + self._mpi_finalize_lock = threading.Lock() + + def close(self): + with self._mpi_finalize_lock: + if not self._mpi_finalized: + # Finalize MPI on the relevant device + self.run(finalize()) + self._mpi_finalized = True + + super(Session, self).close() + + +############################################################################### +# +# TensorFlow MPI operations +# +############################################################################### + + +def size(name=None): + """An op which returns the number of MPI processes. + + This is equivalent to running `MPI_Comm_size(MPI_COMM_WORLD, ...)` to get the + size of the global communicator. + + Returns: + An integer scalar containing the number of MPI processes. + """ + return get_provider().size(name) + + +def rank(name=None): + """An op which returns the MPI rank of the calling process. + + This is equivalent to running `MPI_Comm_rank(MPI_COMM_WORLD, ...)` to get the + rank of the current process in the global communicator. + + Returns: + An integer scalar with the MPI rank of the calling process. + """ + return get_provider().rank(name) + + +def local_rank(name=None): + """An op which returns the local MPI rank of the calling process, within the + node that it is running on. For example, if there are seven processes running + on a node, their local ranks will be zero through six, inclusive. + + This is equivalent to running `MPI_Comm_rank(...)` on a new communicator + which only includes processes on the same node. + + Returns: + An integer scalar with the local MPI rank of the calling process. + """ + return get_provider().local_rank(name=name) + + +def init(name=None): + """An op which initializes MPI on the device on which it is run. + + All future MPI ops must be run on the same device that the `init` op was run + on. + """ + return get_provider().init(name) + + +def finalize(name=None): + """An op which finalizes MPI on the device on which it is run. + + No future MPI ops must be run on the same device that the `finalize` op was run + on. + """ + return get_provider().finalize(name=name) + + +def allreduce(tensor, average=True, name=None): + """Perform an MPI allreduce on a tf.Tensor or tf.IndexedSlices. + + Arguments: + tensor: tf.Tensor, tf.Variable, or tf.IndexedSlices to reduce. + The shape of the input must be identical across all ranks. + average: If True, computes the average over all ranks. + Otherwise, computes the sum over all ranks. + + This function performs a bandwidth-optimal ring allreduce on the input + tensor. If the input is an tf.IndexedSlices, the function instead does an + allgather on the values and the indices, effectively doing an allreduce on + the represented tensor. + """ + return get_provider().allreduce(tensor, average, name) + + +def allgather(tensor, name=None): + """An op which concatenates the input tensor with the same input tensor on + all other MPI processes. + + The concatenation is done on the first dimension, so the input tensors on the + different processes must have the same rank and shape, except for the first + dimension, which is allowed to be different. + + Returns: + A tensor of the same type as `tensor`, concatenated on dimension zero + across all processes. The shape is identical to the input shape, except for + the first dimension, which may be greater and is the sum of all first + dimensions of the tensors in different MPI processes. + """ + return get_provider().allgather(tensor, name) + + +def broadcast(tensor, name=None): + """Broadcasts value of given tensor from coordinator node to all the others. + + Returns: + Result of broadcast, same shape as `tensor` + """ + return get_provider().broadcast(tensor, name=name) + + +def broadcast_var(ref, allow_uninitialized=False, name=None): + """Broadcasts value of given variable from coordinator node to all the others. + + Returns: + A mutable `tensor`, same as `ref` + """ + return get_provider().broadcast_var(ref, allow_uninitialized=allow_uninitialized, name=name) + + +############################################################################### +# +# Specific MPI operations on Python objects +# +############################################################################### + + +def broadcast_obj(obj, name=None): + """ + Returns: + Broadcasted object, same as input + """ + return get_provider().broadcast_obj(obj, name) + + +def gather_obj(obj, name=None): + """Gathers given Python object from all workers on the coordinator + + Returns: + Gathered object on the coordinator (on all other workers None) + """ + return get_provider().gather_obj(obj, name) + + +def scatter_obj(obj_array, name=None): + """Scatters given array of Python objects to all workers from the coordinator + + Returns: + Object on each worker + """ + return get_provider().scatter_obj(obj_array, name) + + +def allgather_obj(obj, name=None): + """Performs ALLGATHER on the given Python object + + Returns: + Gathered object on all workers + """ + return get_provider().allgather_obj(obj, name) diff --git a/lib/ops/mpi/dummy_provider.py b/lib/ops/mpi/dummy_provider.py new file mode 100644 index 0000000..5faaf77 --- /dev/null +++ b/lib/ops/mpi/dummy_provider.py @@ -0,0 +1,66 @@ +import tensorflow as tf + +############################################################################### +# +# TensorFlow MPI operations +# +############################################################################### + + +def size(name=None): + return tf.constant(1, name=name) + + +def rank(name=None): + return tf.constant(0, name=name) + + +def local_rank(name=None): + return tf.constant(0, name=name) + + +def init(name=None): + return tf.no_op(name=name) + + +def finalize(name=None): + return tf.no_op(name=name) + + +def allreduce(tensor, average=True, name=None): + return tf.stop_gradient(tensor, name=name) # Stop gradient propagation, as in distributed mode + + +def allgather(tensor, name=None): + return tf.stop_gradient(tensor, name=name) + + +def broadcast(tensor, name=None): + return tf.stop_gradient(tensor, name=name) + + +def broadcast_var(ref, allow_uninitialized=False, name=None): + return ref + + +############################################################################### +# +# Specific MPI operations on Python objects +# +############################################################################### + + +def broadcast_obj(obj, name=None): + return obj + + +def gather_obj(obj, name=None): + return [obj] + + +def scatter_obj(obj_array, name=None): + return obj_array[0] + + +def allgather_obj(obj, name=None): + return [obj] diff --git a/lib/ops/mpi/horovod_provider.py b/lib/ops/mpi/horovod_provider.py new file mode 100644 index 0000000..2623bd8 --- /dev/null +++ b/lib/ops/mpi/horovod_provider.py @@ -0,0 +1,152 @@ +import tensorflow as tf +import horovod.tensorflow as hvd +import pickle +import os +import threading + + +############################################################################### +# +# Horovod MPI operations +# +############################################################################### + + +def size(name=None): + return tf.constant(int(os.getenv('OMPI_COMM_WORLD_SIZE', 1)), name=name, dtype=tf.int32) + + +def rank(name=None): + return tf.constant(int(os.getenv('OMPI_COMM_WORLD_RANK', 0)), name=name, dtype=tf.int32) + + +def local_rank(name=None): + return tf.constant(int(os.getenv('OMPI_COMM_WORLD_LOCAL_RANK', 0)), name=name, dtype=tf.int32) + + +def init(name=None): + hvd.init() + return tf.no_op(name=name) + + +def finalize(name=None): + return tf.no_op(name=name) + + +def allreduce(tensor, average=True, name=None): + return hvd.allreduce(tensor, average=average) + + +def allgather(tensor, name=None): + return hvd.allgather(tensor) + + +def broadcast(tensor, name=None): + return hvd.broadcast(tensor, root_rank=0) + + +def broadcast_var(ref, allow_uninitialized=False, name=None): + if allow_uninitialized: + raise RuntimeError("allow_uninitialized is not supported in Horovod implementation") + return tf.assign(ref, broadcast(ref)) + + +############################################################################### +# +# Specific MPI operations on Python objects +# +############################################################################### + + +def broadcast_obj(obj, name=None): + if name is None: + name = 'broadcast_obj' + return allgather_obj(obj, name)[0] + + +def gather_obj(obj, name=None): + if name is None: + name = 'gather_obj' + + res = allgather_obj(obj, name) + if int(os.getenv('OMPI_COMM_WORLD_RANK', 0)) == 0: + return res + else: + return None + + +def scatter_obj(inps, name=None): + if name is None: + name = 'scatter_obj' + + if int(os.getenv('OMPI_COMM_WORLD_RANK', 0)) == 0: + assert len(inps) == int(os.getenv('OMPI_COMM_WORLD_SIZE', 1)) + else: + inps = None + + outs = allgather_obj(inps, name) + return outs[0][int(os.getenv('OMPI_COMM_WORLD_RANK', 0))] + + +def allgather_obj(obj, name=None): + if name is None: + name = 'allgather_obj' + + encoded = _encode_obj(obj) + encoded_size = len(encoded) + + graph_ops = _get_graph_ops(name) + + sizes, encoded_res = tf.get_default_session().run([graph_ops.allgather_obj_size_result, graph_ops.allgather_obj_result], feed_dict={ + graph_ops.allgather_obj_size_inp: [encoded_size], + graph_ops.allgather_obj_inp: encoded + }) + + res = [] + pos = 0 + for sz in sizes: + res.append(_decode_obj(encoded_res[pos:pos+sz])) + pos += sz + return res + + +## Implementation details + +class _GraphOps: + def __init__(self, name): + self.name = name + + with tf.name_scope("horovod_python_ops/" + name): + self.allgather_obj_size_inp = tf.placeholder(name="allgather_obj_size", dtype=tf.int32, shape=[None]) + self.allgather_obj_inp = tf.placeholder(name="allgather_obj", dtype=tf.uint8, shape=[None]) + + self.allgather_obj_size_result = hvd.allgather(self.allgather_obj_size_inp) + self.allgather_obj_result = hvd.allgather(self.allgather_obj_inp) + + +_graph_ops_collection = "HOROVOD_GRAPH_OPS" +_graph_ops_lock = threading.Lock() + + +def _encode_obj(obj): + return list(pickle.dumps(obj)) + + +def _decode_obj(data): + return pickle.loads(bytes(data)) + + +def _get_graph_ops(name): + """ + Returns lazy-initialized hash of graph operations required to implement allgather_obj/scatter_obj. + These operations stored in graph collection to avoid binding parallelism to specific graph + """ + + found = tf.get_collection(_graph_ops_collection, name) + if len(found) > 0: + return found[0] + + with _graph_ops_lock: + ops = _GraphOps(name) + tf.add_to_collection(_graph_ops_collection, ops) + return ops diff --git a/lib/ops/record_activations.py b/lib/ops/record_activations.py new file mode 100644 index 0000000..05ca2c5 --- /dev/null +++ b/lib/ops/record_activations.py @@ -0,0 +1,93 @@ +from warnings import warn +from collections import defaultdict +from contextlib import contextmanager +import tensorflow as tf + +# Idea: we need to store layer activations to do things like relevance propagation, +# let's build a single-use collection that one can store layer-wise activations in +# Here's how it should work: +# with record_activations() as saved_activations: +# y = model(x) # saves activations in... saved_activations +# x_rel = model.relprop(y) # uses activations stored on forward pass +# +# print('btw, activation tensors are', activations) +# note: why not just use tf collections? because they are global and you can never be sure +# what's left in there since previous run + +# this will be a dictionary: { layer name -> a dict of saved activations } +RECORDED_ACTIVATIONS = None +WARN_IF_NO_COLLECTION = False + + +@contextmanager +def recording_activations(existing_state_dict=None, subscope_key=None): + """ A special context that allows you to store any forward pass activations """ + assert isinstance(existing_state_dict, (dict, type(None))) + global RECORDED_ACTIVATIONS + prev_collection = RECORDED_ACTIVATIONS + RECORDED_ACTIVATIONS = existing_state_dict or defaultdict(dict) + if subscope_key: + assert is_recorded() and existing_state_dict is None + prev_collection[subscope_key] = RECORDED_ACTIVATIONS + + try: + yield RECORDED_ACTIVATIONS + finally: + RECORDED_ACTIVATIONS = prev_collection + + +@contextmanager +def do_not_record(): + """ Temporarily disables recording activations within context """ + global RECORDED_ACTIVATIONS + prev_collection = RECORDED_ACTIVATIONS + RECORDED_ACTIVATIONS = None + try: + yield + finally: + RECORDED_ACTIVATIONS = prev_collection + + +def is_recorded(): + return RECORDED_ACTIVATIONS is not None + + +def save_activation(key, value, scope=None, overwrite=False): + """ Saves value in current recorded activations (if it exists) under current name scope """ + scope = scope or tf.get_variable_scope().name or tf.contrib.framework.get_name_scope() + if is_recorded(): + if scope in RECORDED_ACTIVATIONS and key in RECORDED_ACTIVATIONS[scope] and not overwrite: + raise ValueError('Recorded activations already contain key "{}" for scope "{}". ' + 'Make sure you run your network only once inside recording_activations context. ' + 'If a layer is called multiple times, make sure each call happens in a separate ' + ' tf.name_scope .'.format(key, scope)) + + RECORDED_ACTIVATIONS[scope][key] = value + elif WARN_IF_NO_COLLECTION: + warn('Tried to save under key "{}" in scope "{}" without recording_activations context. ' + 'As the fox says, the context is important'.format(key, scope)) + + +def save_activations(**kwargs): + """ convenience function to save multiple activations. see save_activation """ + scope, overwrite = kwargs.pop('scope', None), kwargs.pop('overwrite', False) + assert isinstance(scope, (str, type(None))) + assert isinstance(overwrite, bool) + for key, value in kwargs.items(): + save_activation(key, value, scope=scope, overwrite=overwrite) + + +def get_activation(key, scope=None): + """ gets one activation from current scope or freaks out if there isn't any """ + scope = scope or tf.get_variable_scope().name or tf.contrib.framework.get_name_scope() + assert is_recorded(), "can't get activations if used outside recording_activations context." + assert scope in RECORDED_ACTIVATIONS, 'no saved activations in scope "{}". Is scope name correct?'.format(scope) + assert key in RECORDED_ACTIVATIONS[scope], 'no saved activation for "{}" in scope "{}". Existing keys: {}'.format( + key, scope, list(RECORDED_ACTIVATIONS[scope].keys()) + ) + return RECORDED_ACTIVATIONS[scope][key] + + +def get_activations(*keys, scope=None): + """ convenience function to get multiple activations from current scope, see get_activation """ + return [get_activation(key, scope=scope) for key in keys] diff --git a/lib/ops/sliced_argmax.py b/lib/ops/sliced_argmax.py new file mode 100644 index 0000000..c0409b6 --- /dev/null +++ b/lib/ops/sliced_argmax.py @@ -0,0 +1,168 @@ +import numpy as np +import tensorflow as tf + + +def hypo_to_batch_index(n_hypos, slices): + """ + Computes index in batch (input sequence index) for each hypothesis given slices. + :param n_hypos: number of hypotheses (tf int scalar) + :param slices: indices of first hypo for each input in batch + It should guaranteed that + - slices[0]==0 (first hypothesis starts at index 0), otherwise output[:slices[0]] will be -1 + - if batch[i] is terminated, then batch[i]==batch[i+1] + """ + is_next_sent_at_t = tf.bincount(slices, minlength=n_hypos, maxlength=n_hypos) + hypo_to_index = tf.cumsum(is_next_sent_at_t) - 1 + return hypo_to_index + + +def sliced_argmax_naive(logits, slices, k): + """ + Computes top-k of values in each slice. + :param values: matrix of shape [m,n] + :param slices: vector of shape [m] containing start indices for each slice. + :param k: take this many elements with largest values from each slice + :returns: batch_scores,batch_indices: + - batch_scores[m,k] - top-beam_size values from logP corresponding to + - batch_indices[m,k] - indices of batch_scores in each respective slice (first value in each slice has index 0!) + + For any slice contains less than k elements, batch_scores would be padded with -inf, batch_indices - with -1 + If values.shape[1] != 1, batch_indices will still be 1-dimensional, satisfying the following property: + - batch_scores,batch_indices = sliced_argmax(values,slices,k) + - start, end = slices[i], slices[i+1] + - tf.equals(batch_scores == tf.reshape(values[start:end,:],[-1])[batch_indices]) #this is True for all indices + + Examples + -------- + >>> logp = tf.constant(np.array([[1, 2, 3, 4, 5, 6], + [6, 5, 4, 3, 2, 1]],'float32').T) + >>> slices = tf.constant([0,2,5]) + >>> best_scores, best_indices = sliced_argmax(logp,slices,tf.constant(4)) + >>> print('scores:\n%s\nindices:\n%s'%(best_scores.eval(), best_indices.eval())) + scores: + [[ 6. 5. 2. 1.] + [ 5. 4. 4. 3.] + [ 6. 1. -inf -inf]] + indices: + [[ 1 3 2 0] + [ 4 1 2 3] + [ 0 1 -1 -1]] + """ + + assert logits.shape.ndims == 2, "logits must be [batch*beam, num_tokens]" + assert slices.shape.ndims == 1, "slices must be 1d indices" + n_slices, n_hypos, voc_size = tf.shape(slices)[0], tf.shape(logits)[0], tf.shape(logits)[1] + slices_incl = tf.concat([slices, [n_hypos]], axis=0) + offsets = slices_incl[1:] - slices_incl[:-1] + slice_indices = hypo_to_batch_index(n_hypos, slices) # [n_hypos], index of slice the value belongs to + + # step 1: flatten logits[n_hypos, voc_size] into [n_slices, max_slice_length * voc_size] + # by putting all logits within slice on the same row and padding with -inf + flat_shape = [n_slices, (tf.reduce_max(offsets)) * voc_size] + flat_row_index = tf.reshape(tf.tile(slice_indices[:, None], [1, voc_size]), [-1]) + flat_col_index = tf.range(n_hypos * voc_size) - tf.gather(slices_incl * voc_size, flat_row_index) + flat_index_2d = tf.stack([flat_row_index, flat_col_index], axis=1) + mask = tf.less(tf.range(flat_shape[1]), (offsets * voc_size)[:, None]) + flat_logits = tf.where(mask, + tf.scatter_nd(flat_index_2d, tf.reshape(logits, [-1]), flat_shape), + tf.fill(flat_shape, -float('inf')) + ) # shape: [n_slices, max_slice_length * voc_size] + + flat_indices = tf.where(mask, + tf.scatter_nd(flat_index_2d, flat_col_index, flat_shape), + tf.fill(flat_shape, -1) + ) # shape: [n_slices, max_slice_length * voc_size] + + # step 2: top-k for each slice and gather respectrive indices + sliced_top_k = tf.nn.top_k(flat_logits, k=k) + original_values = sliced_top_k.values + + original_indices_flat = tf.gather_nd(flat_indices, + tf.stack([tf.range(n_slices * k) // k, + tf.reshape(sliced_top_k.indices, [-1])], axis=1)) + original_indices = tf.reshape(original_indices_flat, tf.shape(original_values)) + + # set shapes + out_shape = (logits.shape[0], k if isinstance(k, int) else None) + original_values.set_shape(out_shape) + original_indices.set_shape(out_shape) + return original_values, original_indices + + +def sliced_argmax(logits, slices, k, staged=None): + """ + Computes top-k of values in each slice. + :param values: matrix of shape [m,n] + :param slices: vector of shape [m] containing start indices for each slice. + :param k: take this many elements with largest values from each slice + :param staged: if True, computes sliced argmax in two stages: + (1) select top-k for each row and + (2) global top-k among all rows in slice + if False, runs second stage only + if None (default), defaults to True unless logits.shape[1] / k < 10 + :returns: batch_scores,batch_indices: + - batch_scores[m,k] - top-beam_size values from logP corresponding to + - batch_indices[m,k] - indices of batch_scores in each respective slice (first value in each slice has index 0!) + + For any slice contains less than k elements, batch_scores would be padded with -inf, batch_indices - with -1 + If values.shape[1] != 1, batch_indices will still be 1-dimensional, satisfying the following property: + - batch_scores,batch_indices = sliced_argmax(values,slices,k) + - start, end = slices[i], slices[i+1] + - tf.equals(batch_scores == tf.reshape(values[start:end,:],[-1])[batch_indices]) #this is True for all indices + + Examples + -------- + >>> logp = tf.constant(np.array([[1, 2, 3, 4, 5, 6], + [6, 5, 4, 3, 2, 1]],'float32').T) + >>> slices = tf.constant([0,2,5]) + >>> best_scores, best_indices = sliced_argmax(logp,slices,tf.constant(4)) + >>> print('scores:\n%s\nindices:\n%s'%(best_scores.eval(), best_indices.eval())) + scores: + [[ 6. 5. 2. 1.] + [ 5. 4. 4. 3.] + [ 6. 1. -inf -inf]] + indices: + [[ 1 3 2 0] + [ 4 1 2 3] + [ 0 1 -1 -1]] + """ + + assert logits.shape.ndims == 2, "logits must be [batch*beam, num_tokens]" + assert slices.shape.ndims == 1, "slices must be 1d indices" + if staged is None: + staged = (logits.shape[1].value is None) or (float(logits.shape[1].value) / k >= 10.0) + + if staged: + # two-step process: (1) select top-k for each row and (2) global top-k among all rows in slice + # this version is slightly slower but a lot more memory-efficient + logits_topk = tf.nn.top_k(logits, k=k) # [n_hypos, k] + best_values, best_indices_in_top = sliced_argmax_naive(logits_topk.values, slices, k=k) + + best_hypo_ix = tf.where(tf.not_equal(best_indices_in_top, -1), + best_indices_in_top // k + slices[:, None], + best_indices_in_top) + + best_token_ix_in_top = tf.where(tf.not_equal(best_indices_in_top, -1), + best_indices_in_top % k, + best_indices_in_top) + + best_token_indices_original = tf.gather_nd( + logits_topk.indices, + tf.maximum(0, tf.reshape(tf.stack([best_hypo_ix, best_token_ix_in_top], axis=-1), [-1, 2])) + ) + best_token_indices_original = tf.where(tf.not_equal(tf.reshape(best_hypo_ix, [-1]), -1), + best_token_indices_original, + tf.fill(tf.shape(best_token_indices_original), -1)) + + best_token_indices_original = tf.reshape(best_token_indices_original, + tf.shape(best_token_ix_in_top)) + best_hypo_ix_within_slice = tf.where( + tf.not_equal(best_indices_in_top, -1), + best_indices_in_top // k, + tf.zeros_like(best_indices_in_top, dtype=best_indices_in_top.dtype)) + # ^-- use 0 cuz best_token_indices_original is already -1 and they are added + + best_indices_original = best_token_indices_original + best_hypo_ix_within_slice * tf.shape(logits)[1] + return best_values, best_indices_original + else: + return sliced_argmax_naive(logits, slices, k) diff --git a/lib/session.py b/lib/session.py new file mode 100644 index 0000000..1dbc1c5 --- /dev/null +++ b/lib/session.py @@ -0,0 +1,181 @@ +import tensorflow as tf +from tensorflow.python import ops +import lib +import sys +import os +import threading +from contextlib import contextmanager +from tensorflow.python.framework import * +from tensorflow.contrib.tfprof import * +from tensorflow.python.client import timeline, session +from collections import namedtuple + +# tfprof-oriented Session object. +# More about tfprof: +# https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/tfprof +# +# !! Attention !! For using need to append +# /usr/local/cuda-8.0/extras/CUPTI/lib64 to $LD_LIBRARY_PATH + + +PROFILE_SUPER_VERBOSE = 666 + +_tls = threading.local() +def get_profile_level(): + if not hasattr(_tls, 'profile_level'): + _tls.profile_level = PROFILE_SUPER_VERBOSE # Never profile most of sess.run + return _tls.profile_level + +def set_profile_level(level): + _tls.profile_level = level + +@contextmanager +def profile_scope(level=1): + prev_level = get_profile_level() + _tls.profile_level = level + try: + yield + finally: + _tls.profile_level = prev_level + + +MemTimelineRecord = namedtuple('MemTimelineRecord', ['ts', 'node_name', 'bytes_in_use', 'live_bytes']) + + +class SessionWrapper(session.SessionInterface): + + def __init__(self, session): + self._sess = session + + @property + def graph(self): + return self._sess.graph + + @property + def sess_str(self): + return self._sess.sess_str + + def run(self, *a, **kwa): + return self._sess.run(*a, **kwa) + + def partial_run_setup(self, *a, **kwa): + raise RuntimeError("Not supported in session wrapper") + + def partial_run(self, *a, **kwa): + raise RuntimeError("Not supported in session wrapper") + + def make_callable(self, *a, **kwa): + raise RuntimeError("Not supported in session wrapper") + + def as_default(self): + return ops.default_session(self) + + def __getattr__(self, attr): + return getattr(self._sess, attr) + + def __enter__(self): + if self._default_session_context_manager is None: + self._default_session_context_manager = self.as_default() + return self._default_session_context_manager.__enter__() + + def __exit__(self, *exc): + self._default_session_context_manager.__exit__(*exc) + + def __del__(self): + self._sess.__del__() + + +class ProfilableSessionWrapper(SessionWrapper): + def __init__(self, session, log_dir, skip_first_nruns=0, profile_level=0): + super(ProfilableSessionWrapper, self).__init__(session) + + self.run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + self.run_metadata = tf.RunMetadata() + self.run_counter = 0 + self.nruns_threshold = skip_first_nruns + self.profile_level = profile_level + + self.log_dir = log_dir + os.makedirs(log_dir, exist_ok=True) + + self.op_log = None + tf.profiler.write_op_log( + tf.get_default_graph(), + log_dir=log_dir, + op_log=self.op_log, + run_meta=self.run_metadata + ) + + def _write_log(self): + print("* --------------------------------------", file=sys.stderr) + print("* RUN: %d" % self.run_counter, file=sys.stderr) + + # 1. Fetch memory usage and timing stat + time_stat_options = model_analyzer.PRINT_ALL_TIMING_MEMORY + time_stat_options['output'] = 'file:outfile=%s/time_stat.run_%d.txt' % (self.log_dir, self.run_counter) + time_stat_options['select'] = ['device', 'micros', 'bytes'] + time_stat_options['order_by'] = 'micros' + tf.profiler.profile( + tf.get_default_graph(), + run_meta=self.run_metadata, + op_log=self.op_log, + options=time_stat_options + ) + + # 2. Create timeline.json file. It can be load in chrome://tracing + time_data = timeline.Timeline(self.run_metadata.step_stats) + trace = time_data.generate_chrome_trace_format(show_memory=True) + timeline_fname = '%s/timeline.run_%d.json' % (self.log_dir, self.run_counter) + with open(timeline_fname, 'w') as f: + f.write(trace) + + # 3. Get peak memory + mem_timelines = self._build_memory_timelines() + peak_memory = self._compute_peak_memory(mem_timelines) + print("Peak memory: %s" % str(peak_memory), file=sys.stderr) + + # 4. Print memory timelines + for allocator, tl in mem_timelines.items(): + memory_fname = '%s/memory.%s.run_%d.txt' % (self.log_dir, allocator, self.run_counter) + with open(memory_fname, 'w') as f: + print("ts,node_name,bytes_in_use,live_bytes", file=f) + for r in tl: + print("%d,%s,%d,%d" % (r.ts, r.node_name, r.bytes_in_use, r.live_bytes), file=f) + + def run(self, fetches, feed_dict=None, options=None, run_metadata=None): + do_profile = self.run_counter >= self.nruns_threshold and self.profile_level >= get_profile_level() + result = super(ProfilableSessionWrapper, self).run( + fetches, feed_dict, + options=self.run_options if do_profile else None, + run_metadata=self.run_metadata if do_profile else None + ) + # For earch invocation of `run()` or `eval()` methods dump log to new file + if do_profile and lib.ops.mpi.is_master(): + self._write_log() + self.run_counter += 1 + return result + + def _compute_peak_memory(self, mem_timelines): + res = {} + for k, tl in mem_timelines.items(): + res[k] = max([r.bytes_in_use for r in tl]) + return res + + def _build_memory_timelines(self): + timelines = {} + + for dev in self.run_metadata.step_stats.dev_stats: + for node in dev.node_stats: + ts = node.all_start_micros + for mem in node.memory: + if mem.allocator_name not in timelines: + timelines[mem.allocator_name] = [] + timelines[mem.allocator_name].append(MemTimelineRecord(ts, node.node_name, mem.allocator_bytes_in_use, mem.live_bytes)) + + for tl in timelines.values(): + tl.sort() + + return timelines + + def _simplify_device_name(self, device_name): + return '/' + device_name.split('device:')[1] diff --git a/lib/task/__init__.py b/lib/task/__init__.py new file mode 100644 index 0000000..a6c3bca --- /dev/null +++ b/lib/task/__init__.py @@ -0,0 +1 @@ +from . import seq2seq diff --git a/lib/task/seq2seq/__init__.py b/lib/task/seq2seq/__init__.py new file mode 100644 index 0000000..6590432 --- /dev/null +++ b/lib/task/seq2seq/__init__.py @@ -0,0 +1 @@ +from . import inference, problems, models, data, bleu, summary, tickers, voc \ No newline at end of file diff --git a/lib/task/seq2seq/bleu.py b/lib/task/seq2seq/bleu.py new file mode 100644 index 0000000..c91af9b --- /dev/null +++ b/lib/task/seq2seq/bleu.py @@ -0,0 +1,266 @@ +#!/usr/bin/env python3 +# coding: utf-8 + +import argparse +from collections import Counter, namedtuple +import math +import os.path +import sys +import numpy as np + +sys.path += [os.path.dirname(sys.argv[0])] + +from .strutils import tokenize, al_num, all_chars__punct_tokens, all_chars__punct_tokens__foldcase +from .strutils import al_num__foldcase, chinese_tok, split_by_char_tok, equal_to_framework + +SEED = 51 +BleuResult = namedtuple('BleuResult', 'BLEU brevity_penalty ratio hyp_len ref_len BLEU_for_ngrams') + + +def best_match_length(references, cand, verbose=False): + spl_cand_length = len(cand) + diff = sys.maxsize + for ref in references: + spl_ref_length = len(ref) + if not spl_ref_length: + continue + if spl_ref_length == spl_cand_length: + return spl_ref_length + elif abs(diff) == abs(spl_cand_length - spl_ref_length): + diff = max(diff, spl_cand_length - spl_ref_length) + elif abs(diff) > abs(spl_cand_length - spl_ref_length): + diff = spl_cand_length - spl_ref_length + best_len = max(spl_cand_length - diff, 0) + if verbose and not best_len: + print('WARNING: empty reference: ', repr((references, cand)), file=sys.stderr) + return best_len + + +def brev_penalty(cand_length, best_match_length): + if cand_length > best_match_length: + return 1 + else: + return math.exp(1 - float(best_match_length) / float(cand_length)) + + +def split_into_ngrams(text, n): + if n <= 0: + raise ValueError('n should be a positive number!') + return [tuple(text[i:i+n]) for i in range(len(text) - n + 1)] + + +def compute_length_for_n(text, n_for_ngram): + ''' + # split into words and count: + # count - n + ''' + unigram_count = len(text) + if n_for_ngram > unigram_count: + return 0 + else: + return unigram_count - n_for_ngram + 1 + + +def mod_precision_for_n(refs, cand, n, smoothed=False): + cand_counter = Counter(split_into_ngrams(cand, n)) + ref_counters = [Counter(split_into_ngrams(ref, n)) for ref in refs] + total_sum = 0 + for ngram, count_in_cand in cand_counter.items(): + max_count_in_refs = max(counter[ngram] for counter in ref_counters) + total_sum += min(max_count_in_refs, count_in_cand) + if smoothed and n > 1: + return total_sum + 1, compute_length_for_n(cand, n) + 1 + return total_sum, compute_length_for_n(cand, n) + + +def logarithm(x): + if x == 0: + return -sys.maxsize - 1 + else: + return math.log(x) + + +def print_summary(bleu_vals): + bleu_mean, bleu_std = np.mean(bleu_vals), np.std(bleu_vals) + summary_string = ("Mean BLEU: %.4f; 95%% CI: [%.4f, %.4f]; std=%.4f" % + (bleu_mean, bleu_mean - 1.96 * bleu_std, bleu_mean + 1.96 * bleu_std, bleu_std)) + print(summary_string) + + +class Bleu(object): + def __init__(self, normalize_func=None, smoothed=False, cached=False, language=None, verbose=False): + self.cand_len = 0 + self.best_ref_len = 0 + self.brevity_penalty = 0 + self.mod_precision = [[0, 0], [0, 0], [0, 0], [0, 0]] + self.normalize_func = normalize_func + self.smoothed = smoothed + self.cached = cached + self.language = language + if cached: + self.cand_len_vals = [] + self.best_ref_len_vals = [] + self.mod_precision_vals = [] + self.verbose = verbose + + def process_next(self, cand, refs, **kwargs): + if self.normalize_func is not None: + cand = tokenize(self.normalize_func(cand, self.language)) + refs = [tokenize(self.normalize_func(ref, self.language)) for ref in refs] + else: + cand = tokenize(cand) + refs = [tokenize(ref) for ref in refs] + self.last__cand_len = compute_length_for_n(cand, 1) + self.cand_len += self.last__cand_len + self.last__best_ref_len = best_match_length(refs, cand, verbose=self.verbose) + self.best_ref_len += self.last__best_ref_len + self.last_mp = [] + for i in range(4): + self.last_mp.append(mod_precision_for_n(refs, cand, i + 1, smoothed=self.smoothed)) + self.mod_precision[i][0] += self.last_mp[i][0] + self.mod_precision[i][1] += self.last_mp[i][1] + + if self.cached: + self.cand_len_vals.append(self.last__cand_len) + self.best_ref_len_vals.append(self.last__best_ref_len) + self.mod_precision_vals.append(self.last_mp) + + def _compute_bleu(self, cand_len, best_ref_len, mod_precision, sentence_level=False): + brevity_penalty = brev_penalty(cand_len, best_ref_len) + bleu_for_ngram = [0, 0, 0, 0] + for i in range(4): + if mod_precision[i][0] > 0.0 and mod_precision[i][1] > 0.0 : + bleu_for_ngram[i] = round(float(mod_precision[i][0]) / float(mod_precision[i][1]), 4) + else: + bleu_for_ngram[i] = 0.0 + average = 0 + for i in range(4): + if sentence_level: + nonzero = mod_precision[i][1] > 0.0 + else: + nonzero = mod_precision[i][0] > 0.0 and mod_precision[i][1] > 0.0 + if not nonzero: + average += 0.25 * (-sys.maxsize) + if nonzero: + average += 0.25 * logarithm(float(mod_precision[i][0]) / float(mod_precision[i][1])) + total_bleu = round(brevity_penalty * math.exp(average), 4) + return BleuResult(total_bleu, brevity_penalty, round(float(cand_len) / float(best_ref_len), 4), cand_len, best_ref_len, bleu_for_ngram) + + def result_for_last(self): + return self._compute_bleu(self.last__cand_len, self.last__best_ref_len, self.last_mp, True) + + def total(self): + return self._compute_bleu(self.cand_len, self.best_ref_len, self.mod_precision) + + def bootstrap_sample(self, n_times=1000, seed=None): + rng = np.random.RandomState(seed) + if not self.cached: + return None + bleu_vals = [] + for i in range(n_times): + inds = rng.randint(0, len(self.cand_len_vals), len(self.cand_len_vals)) + cand_len = sum([self.cand_len_vals[i] for i in inds]) + best_ref_len = sum([self.best_ref_len_vals[i] for i in inds]) + mod_precision = sum([np.array(self.mod_precision_vals[i]) for i in inds]) + bleu_vals.append(self._compute_bleu(cand_len, best_ref_len, mod_precision)[0]) + return np.array(bleu_vals) + + +if __name__ == '__main__': + t_options = {'simple': al_num__foldcase, + 'case-sensitive': al_num, + 'punctuation': all_chars__punct_tokens__foldcase, + 'c-s-punctuation': all_chars__punct_tokens, + 'ch': chinese_tok, + 'split-by-char': split_by_char_tok, + 'framework': equal_to_framework} + parser = argparse.ArgumentParser() + parser.add_argument('-t', '--tokenization', help='''Tokenization options: + default - split text by spaces + simple = alphanumerics only, + case-sensitive = with small letters, + punctuation = with punctuation marks as separate tokens, + c-s-punctuation = case-sensitive + punctuation, + split-by-char = set space between all characters, + framework = lang-specific replacements + unicode category tokenization''', + choices=t_options.keys()) + parser.add_argument('-c', '--candidate', type=int, nargs='+', help='Hypothesis column number.', required=True) + parser.add_argument('-r', '--reference', help='Reference column number (range or int)') + parser.add_argument('--all', help='Bleu scores for all queries.', action='store_true') + parser.add_argument('-s', '--smoothed', action='store_true', default=False, help='Use to compute smoothed BLEU') + parser.add_argument('-l', '--language', help='Dst-side language') + parser.add_argument('--bootstrap-sampling-n', type=int, + help='Run bootstrap sampling n times for BLEU CI estimate.', default=0) + parser.add_argument('--compare', help='Compare Bleu scores for two MT systems', action='store_true') + args = parser.parse_args() + + if args.compare and len(args.candidate) != 2: + raise AssertionError('It should specify 2 hypothesis columns if `--compare` flag used') + if args.compare and args.all: + raise AssertionError('Could not evaluate BLEU score for each query if `--compare` flag used') + + if ':' in args.reference: + r_start, r_end = args.reference.split(':') + reference = slice(int(r_start), int(r_end) if len(r_end) > 0 else None) + else: + reference = int(args.reference) + + bleu_opts = { + 'normalize_func': t_options[args.tokenization] if args.tokenization else None, + 'smoothed': args.smoothed, + 'cached': bool(args.bootstrap_sampling_n) or args.compare, + 'language': args.language, + } + + b_first = Bleu(verbose=True, **bleu_opts) + if args.compare: + b_second = Bleu(verbose=True, **bleu_opts) + + for i, line in enumerate(sys.stdin): # for candidate and set of references in corpus compute process_next + line = line.rstrip('\n') + if not line: + continue + text_data = line.rstrip().split('\t') + refs = [text_data[reference]] if isinstance(reference, int) else text_data[reference] + refs = [ref for ref in refs if ref] + if not refs: + print('Error: no data found in {} column, line {}'.format(args.reference, i + 1), file=sys.stderr) + continue + if len(text_data) < 2: + text_data += [''] + cand_first = text_data[args.candidate[0]] + b_first.process_next(cand_first, refs) + if args.compare: + cand_second = text_data[args.candidate[1]] + b_second.process_next(cand_second, refs) + + #if not cand: + # print >> sys.stderr, 'Error: no data found in %d column, line %i' % (args.r, i + 1) + # sys.exit(1) + if args.all: + print(i, b_first.result_for_last()[0]) + + if not args.compare: + print(b_first.total()) + if args.bootstrap_sampling_n: + bleu_vals = b_first.bootstrap_sample(args.bootstrap_sampling_n, seed=SEED) + print_summary(bleu_vals) + else: + sampling_n = args.bootstrap_sampling_n if args.bootstrap_sampling_n > 0 else 1000 + + print("---\nFirst system stats:" ) + print(b_first.total()) + bleu_vals_first = b_first.bootstrap_sample(sampling_n, seed=SEED) + print_summary(bleu_vals_first) + + print("---\nSecond system stats:" ) + print(b_second.total()) + bleu_vals_second = b_second.bootstrap_sample(sampling_n, seed=SEED) + print_summary(bleu_vals_second) + + delta = bleu_vals_first - bleu_vals_second + bootstrap_p_value = np.mean(delta > 0) + print("---\nSystem %d is better. Significance test results:" % (1 if bootstrap_p_value > 0.5 else 2)) + print("Paired boostrap p-value = %.3f" % min(bootstrap_p_value, 1 - bootstrap_p_value)) + + diff --git a/lib/task/seq2seq/data.py b/lib/task/seq2seq/data.py new file mode 100644 index 0000000..c22e73a --- /dev/null +++ b/lib/task/seq2seq/data.py @@ -0,0 +1,324 @@ +import sys +import random +from sortedcontainers import SortedList +import numpy as np +import math +import itertools +import tensorflow as tf + +from lib.data import pad_seq_list + + +def srclen(item): + return item[0].count(' ') + 1 + + +def dstlen(item): + return item[1].count(' ') + 1 + + +def maxlen(item): + return max(srclen(item), dstlen(item)) + + +def sumlen(item): + return srclen(item) + dstlen(item) + + +def form_batches(data, batch_size): + seq = iter(data) + done = False + while not done: + batch = [] + for _ in range(batch_size): + try: + batch.append(next(seq)) + except StopIteration: + done = True + if batch: + yield batch + + +def locally_sorted_by_len(seq, window, weight_func=maxlen, alterate=False): + reverse = False + for batch in form_batches(seq, window): + batch = sorted(batch, key=weight_func, reverse=reverse) + for x in batch: + yield x + if alterate: + reverse = not reverse + + +def form_adaptive_batches(data, batch_len, batch_size_max=0): + seq = iter(data) + prev = [] + max_len = 0 + done = False + while not done: + batch = prev + try: + while True: + item = next(seq) + max_len = max(max_len, maxlen(item)) + if (len(batch) + 1) * max_len > batch_len or (batch_size_max and len(batch) >= batch_size_max): + prev, max_len = [item], maxlen(item) + break + batch.append(item) + except StopIteration: + done = True + if batch: + yield batch + + +def form_adaptive_batches_windowed(data, weight_func=maxlen, max_size=5000, split_len=10000, batch_size_max=0): + rng = random.Random(42) + buf = [] + last_chunk = [] + reverse = False + for p in data: + if len(buf) >= split_len: + # Last chunk may contain fewer sentences than others - let's return in to the miller + buf += last_chunk + + buf = sorted(buf, key=weight_func, reverse=reverse) + chunks = list(form_adaptive_batches(buf, max_size, batch_size_max=batch_size_max)) + + last_chunk = chunks.pop() + buf = [] + + reverse = not reverse + + rng.shuffle(chunks) + for chunk in chunks: + yield chunk + buf.append(p) + + buf += last_chunk + buf = sorted(buf, key=weight_func, reverse=reverse) + chunks = list(form_adaptive_batches(buf, max_size, batch_size_max=batch_size_max)) + rng.shuffle(chunks) + for chunk in chunks: + yield chunk + + +def batch_cost(x): + return len(x) + 2 * (max(len(i[0]) for i in x) + max(len(i[1]) for i in x)) + + +def load_parallel(src, dst, cycle=False): + # Load data. + for i in itertools.count(): + if i > 0 and not cycle: + break + data = zip(open(src), open(dst)) + for l, r in data: + yield l.rstrip('\n'), r.rstrip('\n') + + +def filter_by_len(data, max_srclen=sys.maxsize, max_dstlen=sys.maxsize, batch_len=None): + def item_ok(item): + ok = (srclen(item) <= max_srclen and dstlen(item) <= max_dstlen) + if batch_len is not None: + return ok and maxlen(item) <= batch_len + return ok + + return filter(item_ok, data) + + +# ======= Random block reader ===================== +def _read_file_part(fd, file_size, part_id, nparts): + begin_pos = (part_id * file_size) // nparts + end_pos = ((part_id + 1) * file_size) // nparts + fd.seek(begin_pos) + + if part_id > 0: + _prefix = fd.readline() # skip first line + + current_pos = fd.tell() # get offset + if current_pos < end_pos: + body = fd.readlines(end_pos - current_pos) + elif current_pos == end_pos and end_pos < file_size: + body = [fd.readline()] + else: + return [] + + return body + + +def _grouper(data, n=5): + pool = [] + for item in data: + if len(pool) == n: + yield pool + pool = [] + pool.append(item) + yield pool + + +def random_block_reader(fname, part_size=64 * 1024, parallel=5, infinite=False, seed=42, encoding='utf-8'): + fd = open(fname, 'rb') + fd.seek(0, 2) + file_size = fd.tell() + + nparts = math.ceil(file_size / part_size) + rng = np.random.RandomState(seed) + if infinite: + part_ids = (rng.randint(0, nparts) for _ in itertools.count()) + else: + part_ids = np.arange(nparts) + rng.shuffle(part_ids) + + for group in _grouper(part_ids, parallel): + lines = [] + for part_id in group: + part_lines = _read_file_part(fd, file_size, part_id, nparts) + lines += part_lines + + rng.shuffle(lines) + + for line in lines: + yield line.decode(encoding).rstrip('\n') + + fd.close() + + +# ======= Adaptive batches with sorted data ======= +class FastSplitter2d: + def __init__(self, max_size=5000, chunk_count=5): + self.max_size = max_size + self.max_x = 0 + self.points = SortedList(key=lambda p: -p[1]) + self.chunk_count = chunk_count + + def add_to_pack(self, p): + self.max_x = max(self.max_x, p[0]) + new_pos = self.points.bisect_right(p) + self.points.insert(new_pos, p) + + offset = 0 + bs_vec = [] + while offset < len(self.points): + bs = self.max_size // (self.max_x + self.points[offset][1]) + bs = min(len(self.points) - offset, bs) + bs_vec.append(bs) + offset += bs + + return new_pos, bs_vec + + def make_chunk_gen(self, points): + prev_bs_vec = [0] + for p in sorted(list(points), key=lambda p: p[0], reverse=True): + new_pos, bs_vec = self.add_to_pack(p) + + if len(bs_vec) > len(prev_bs_vec): + if len(prev_bs_vec) >= self.chunk_count: + self.points.pop(new_pos) + offset = 0 + for sz in prev_bs_vec: + yield self.points[offset:offset + sz] + offset += sz + self.points.clear() + self.points.add(p) + prev_bs_vec = [1] + self.max_x = p[0] + prev_bs_vec = bs_vec + offset = 0 + for sz in prev_bs_vec: + yield self.points[offset:offset + sz] + offset += sz + + +def _in_conv(item, lead_inp_len=False): + x_len = item[0].count(' ') + 1 + y_len = item[1].count(' ') + 1 + if lead_inp_len: + return item.__class__((x_len, y_len)) + item + else: + return item.__class__((y_len, x_len)) + item + + +def _out_chunk_conv(chunk): + return [item[2:] for item in chunk] + + +def form_adaptive_batches_split2d(data, max_size=5000, split_len=10000, chunk_count=5, lead_inp_len=False): + rng = random.Random(42) + buf = [] + for p in data: + if len(buf) >= split_len: + splitter = FastSplitter2d(max_size=max_size, chunk_count=chunk_count) + chunks = list(splitter.make_chunk_gen(buf)) + rng.shuffle(chunks) + for chunk in chunks: + if len(chunk) == 0: + print("SPLIT2D: empty chunk", file=sys.stderr) + continue + yield _out_chunk_conv(chunk) + buf = [] + + buf.append(_in_conv(p, lead_inp_len=lead_inp_len)) + + splitter = FastSplitter2d(max_size=max_size, chunk_count=chunk_count) + chunks = list(splitter.make_chunk_gen(buf)) + rng.shuffle(chunks) + for chunk in chunks: + if len(chunk) == 0: + print("SPLIT2D: empty chunk", file=sys.stderr) + continue + yield _out_chunk_conv(chunk) + + +## ============================================================================ +# Integration + +def words_from_line(line, voc, bos=0, eos=1): + line = line.rstrip('\n') + words = [token for token in line.split(' ') if token] + return voc.words([voc.bos]) * bos + words + voc.words([voc.eos]) * eos + + +def words_from_ids(ids, voc): + return [ + word if (id not in [voc.bos, voc.eos]) else None + for id, word in zip(ids, voc.words(ids)) + ] + + +def lines2ids(lines, voc, **kwargs): + # Read as-is, without padding. + ids_all = [] + for line in lines: + words = words_from_line(line, voc, **kwargs) + ids = voc.ids(words) + ids_all.append(ids) + + # Pad and transpose. + ids_all, ids_len = pad_seq_list(ids_all, voc.eos) + return ids_all, ids_len + + +def make_batch_data(batch, inp_voc, out_voc, force_bos=False, **kwargs): + inp_lines, out_lines = zip(*batch) + inp, inp_len = lines2ids(inp_lines, inp_voc, bos=int(force_bos)) + out, out_len = lines2ids(out_lines, out_voc, bos=int(force_bos)) + + batch_data = dict( + inp=np.array(inp, dtype=np.int32), + inp_len=np.array(inp_len, dtype=np.int32), + out=np.array(out, dtype=np.int32), + out_len=np.array(out_len, dtype=np.int32)) + + return batch_data + + +def make_batch_placeholder(batch_data): + batch_placeholder = { + k: tf.placeholder(v.dtype, [None] * len(v.shape)) + for k, v in batch_data.items()} + return batch_placeholder + + +class BatchIndexer: + pass + + diff --git a/lib/task/seq2seq/inference.py b/lib/task/seq2seq/inference.py new file mode 100644 index 0000000..407b796 --- /dev/null +++ b/lib/task/seq2seq/inference.py @@ -0,0 +1,1046 @@ +import sys +from collections import namedtuple +from warnings import warn + +import tensorflow as tf + +import lib.util +from lib.ops import infer_length, infer_mask +from lib.ops.sliced_argmax import sliced_argmax +from lib.util import nested_map, is_scalar +import numpy as np + + +def translate_lines(lines, translator, model, out_voc, replace_unk=False, unbpe=False, dumper=None): + """ + tokenize, translate and detokenize strings using specified model and translator + :param lines: an iterable of strings + :type translator: something that can .translate_batch(batch_dict) -> out, attnP, ... + :param model: a model from lib.task.seq2seq.models.ModelBase + :param out_voc: destination language dictionary + :param replace_unk: if True, forbids sampling UNK from the model + :param unbpe: if True, concatenates bpe subword units together + :return: yields one translation line at a time + """ + batch = [(l, "") for l in lines] + batch_data = model.make_feed_dict(batch, add_inp_words=True) + kwargs = {} + + if dumper is not None: + kwargs['batch_dumpers'] = dumper.create_batch_dumpers(batch) + + out_ids, attnP = translator.translate_batch(batch_data, **kwargs)[:2] + + for i in range(len(out_ids)): + ids = list(out_ids[i]) + words = out_voc.words(ids) + words = [w for w, out_id in zip(words, ids) if out_id not in [out_voc.bos, out_voc.eos]] + + if replace_unk: + where = [(w and '_UNK_' in w) for w in words] + if any(w for w in where): + inp_words = batch_data['inp_words'][i][:batch_data['inp_len'][i]] + + # select attention weights for non-BOS/EOS tokens, shape=[num_outputs, num_inputs] + attns = np.array([a for a, out_id in zip(attnP[i], ids) + if out_id not in [out_voc.bos, out_voc.eos]])[:len(words), :len(inp_words)] + + # forbid attns to special tokens if there are normal tokens in inp + inp_mask = np.array([w not in ['_BOS_', '_EOS_'] for w in inp_words]) + attns = np.where(inp_mask[None, :], attns, -np.inf) + + words = copy_argmax(inp_words, words, attns, where) + + out_line = " ".join(words) + if unbpe: + out_line = out_line.replace('` ', '') + yield out_line + + +def copy_argmax(inp, out, attnP, where): + """ + inp: [ninp] + out: [nout] + attnP: [nout, ninp] + where: [nout] + """ + # Check shapes. + if len(inp) != attnP.shape[1]: + msg = 'len(inp) is %i, but attnP.shape[1] is %i' + raise ValueError(msg % (len(inp), attnP.shape[1])) + if len(out) != attnP.shape[0]: + msg = 'len(out) is %i, but attnP.shape[0] is %i' + raise ValueError(msg % (len(out), attnP.shape[0])) + + # Copy in every requested position. + new_out = [] + for o in range(len(out)): + # Output as-is. + if not where[o]: + new_out.append(out[o]) + continue + + # Copy from input. + i = np.argmax(attnP[o]) + new_out.append(inp[i]) + + return new_out + + +class TranslateModel: + + def __init__(self, name, inp_voc, out_voc, loss, **hp): + """ Each model must have name, vocabularies and a hyperparameter dict """ + self.name = name + self.inp_voc = inp_voc + self.out_voc = out_voc + self.loss = loss + self.hp = hp + + def encode(self, batch, **flags): + """ + Encodes symbolic input and returns initial state of decode + :param batch: { + inp: int32 matrix [batch,time] or whatever your model can encode + inp_len: int vector [batch_size] + } + -------------------------------------------------- + :returns: dec_state, nested structure of tensors, batch-major + """ + raise NotImplementedError() + + def decode(self, dec_state, words, **flags): + """ + Performs decode step on given words. + + dec_state: nested structure of tensors, batch-major + words: int vector [batch_size] + ------------------------------------------------------ + :returns: new_dec_state, nested structure of tensors, batch-major + """ + raise NotImplementedError() + + def shuffle(self, dec_state, hypo_indices): + """ + Selects hypotheses from model decoder state by given indices. + :param dec_state: a nested structure of tensors representing model state + :param hypo_indices: int32 vector of indices to select + :returns: dec state elements for given flat_indices only + """ + return nested_map(lambda x: tf.gather(x, hypo_indices), dec_state) + + def switch(self, condition, state_on_true, state_on_false): + """ + Composes a new stack.best_dec_state out of new dec state when new_is_better and old dec state otherwise + :param condition: a boolean condition vector of shape [batch_size] + """ + return nested_map(lambda x, y: tf.where(condition, x, y), state_on_true, state_on_false) + + def sample(self, dec_state, base_scores, slices, k, sampling_strategy='greedy', sampling_temperature=None, **flags): + """ + Samples top-K new words for each hypothesis from a beam. + Decoder states and base_scores of hypotheses for different inputs are concatenated like this: + [x0_hypo0, x0_hypo1, ..., x0_hypoN, x1_hypo0, ..., x1_hypoN, ..., xM_hypoN + + :param dec_state: nested structure of tensors, batch-major + :param base_scores: [batch_size], log-probabilities of hypotheses in dec_state with additive penalties applied + :param slices: start indices of each input + :param k: [], int, how many hypotheses to sample per input + :returns: best_hypos, words, scores, + best_hypos: in-beam hypothesis index for each sampled token, [batch_size / slice_size, k], int + words: new tokens for each hypo, [batch_size / slice_size, k], int + scores: log P(words | best_hypos), [batch_size / slice_size, k], float32 + """ + rdo = self.get_rdo(dec_state) + if isinstance(rdo, (tuple, list)) or lib.util.is_namedtuple(rdo): + logits = self.loss.rdo_to_logits__predict(*rdo) + else: + logits = self.loss.rdo_to_logits__predict(rdo) + + n_hypos, voc_size = tf.shape(logits)[0], tf.shape(logits)[1] + batch_size = tf.shape(slices)[0] + + if sampling_strategy == 'random': + if sampling_temperature is not None: + logits /= sampling_temperature + + logp = tf.nn.log_softmax(logits, 1) + + best_hypos = tf.range(0, n_hypos)[:, None] + + best_words = tf.cast(tf.multinomial(logp, k), tf.int32) + best_words_flat = (tf.range(0, batch_size) * voc_size)[:, None] + best_words + + best_delta_scores = tf.gather(tf.reshape(logp, [-1]), best_words_flat) + + elif sampling_strategy == 'greedy': + logp = tf.nn.log_softmax(logits, 1) + base_scores[:, None] + best_scores, best_indices = sliced_argmax(logp, slices, k) + + # If best_hypos == -1, best_scores == -inf, set best_hypos to 0 to avoid runtime IndexError + best_hypos = tf.where(tf.not_equal(best_indices, -1), + tf.floordiv(best_indices, voc_size) + slices[:, None], + tf.fill(tf.shape(best_indices), -1)) + best_words = tf.where(tf.not_equal(best_indices, -1), + tf.mod(best_indices, voc_size), + tf.fill(tf.shape(best_indices), -1)) + + best_delta_scores = best_scores - tf.gather(base_scores, tf.maximum(0, best_hypos)) + else: + raise ValueError("sampling_strategy must be in ['random','greedy']") + + return (best_hypos, best_words, best_delta_scores) + + def get_rdo(self, dec_state): + if hasattr(dec_state, 'rdo'): + return dec_state.rdo + raise NotImplementedError() + + def get_attnP(self, dec_state): + """ + Returns attnP + + dec_state: [..., batch_size, ...] + --------------------------------- + Ret: attnP + attnP: [batch_size, ninp] + """ + if hasattr(dec_state, 'attnP'): + return dec_state.attnP + raise NotImplementedError() + + +class GreedyDecoder: + """ + Inference that encodes input sequence, then iteratively samples and decodes output sequence. + :type model: lib.task.seq2seq.inference.translate_model.TranslateModel + :param batch: a dictionary that contains symbolic tensor {'inp': input token ids, shape [batch_size,time]} + :param max_len: maximum length of output sequence, symbolic or numeric integer + if scalar, sets global length; if vector[batch_size], sets length for each input; + if None, defaults to 2*inp_len + 3 + :param force_bos: if True, forces zero-th output to be model.out_voc.bos. Otherwise lets model decide. + :param force_eos: if True, any token past initial EOS is guaranteed to be EOS + :param get_tracked_outputs: callback that returns whatever tensor(s) you want to track on each time-step + :param crop_last_step: if True, does not perform additional decode __after__ last eos + ensures all tensors have equal time axis + :param back_prop: see tf.while_loop back_prop param + :param swap_memory: see tf.while_loop swap_memory param + :param **flags: you can add any amount of tags that encode and decode understands. + e.g. greedy=True or is_train=True + + """ + + Stack = namedtuple('Stack', + ['out', 'out_len', 'scores', 'finished', 'dec_state', 'attnP', 'tracked']) + + def __init__(self, model, batch_placeholder, max_len=None, force_bos=False, force_eos=True, + get_tracked_outputs=lambda dec_state: [], crop_last_step=True, + back_prop=True, swap_memory=False, **flags): + self.batch_placeholder = batch_placeholder + self.get_tracked_outputs = get_tracked_outputs + + inp_len = batch_placeholder.get('inp_len', infer_length(batch_placeholder['inp'], model.out_voc.eos)) + max_len = max_len if max_len is not None else (2 * inp_len + 3) + + first_stack = self.create_initial_stack(model, batch_placeholder, force_bos=force_bos, **flags) + shape_invariants = nested_map(lambda v: tf.TensorShape([None for _ in v.shape]), first_stack) + + # Actual decoding + def should_continue_translating(*stack): + stack = self.Stack(*stack) + return tf.reduce_any(tf.less(stack.out_len, max_len)) & tf.reduce_any(~stack.finished) + + def inference_step(*stack): + stack = self.Stack(*stack) + return self.greedy_step(model, stack, **flags) + + final_stack = tf.while_loop( + cond=should_continue_translating, + body=inference_step, + loop_vars=first_stack, + shape_invariants=shape_invariants, + swap_memory=swap_memory, + back_prop=back_prop, + ) + + outputs, _, scores, _, dec_states, attnP, tracked_outputs = final_stack + if crop_last_step: + attnP = attnP[:, :-1] + tracked_outputs = nested_map(lambda out: out[:, :-1], tracked_outputs) + + if force_eos: + out_mask = infer_mask(outputs, model.out_voc.eos) + outputs = tf.where(out_mask, outputs, tf.fill(tf.shape(outputs), model.out_voc.eos)) + + self.best_out = outputs + self.best_attnP = attnP + self.best_scores = scores + self.dec_states = dec_states + self.tracked_outputs = tracked_outputs + + def translate_batch(self, batch_data, **optional_feed): + """ + Translates NUMERIC batch of data + :param batch_data: dict {'inp':np.array int32[batch,time]} + :optional_feed: any additional values to be fed into graph. e.g. if you used placeholder for max_len at __init__ + :return: best hypotheses' outputs[batch, out_len] and attnP[batch, out_len, inp_len] + """ + feed_dict = {placeholder: batch_data[k] for k, placeholder in self.batch_placeholder.items()} + for k, v in optional_feed.items(): + feed_dict[k] = v + + out_ids, attnP = tf.get_default_session().run( + [self.best_out, self.best_attnP], + feed_dict=feed_dict) + + return out_ids, attnP + + def create_initial_stack(self, model, batch_placeholder, force_bos=False, **flags): + inp = batch_placeholder['inp'] + batch_size = tf.shape(inp)[0] + + initial_state = model.encode(batch_placeholder, **flags) + initial_attnP = model.get_attnP(initial_state)[:, None] + initial_tracked = nested_map(lambda x: x[:, None], self.get_tracked_outputs(initial_state)) + + if force_bos: + initial_outputs = tf.cast(tf.fill((batch_size, 1), model.out_voc.bos), inp.dtype) + initial_state = model.decode(initial_state, initial_outputs[:, 0], **flags) + second_attnP = model.get_attnP(initial_state)[:, None] + initial_attnP = tf.concat([initial_attnP, second_attnP], axis=1) + initial_tracked = nested_map(lambda x, y: tf.concat([x, y[:, None]], axis=1), + initial_tracked, + self.get_tracked_outputs(initial_state),) + else: + initial_outputs = tf.zeros((batch_size, 0), dtype=inp.dtype) + + initial_scores = tf.zeros([batch_size], dtype='float32') + initial_finished = tf.zeros_like([batch_size], dtype='bool') + initial_len = tf.shape(initial_outputs)[1] + + return self.Stack(initial_outputs, initial_len, initial_scores, initial_finished, + initial_state, initial_attnP, initial_tracked) + + def greedy_step(self, model, stack, **flags): + """ + :type model: lib.task.seq2seq.inference.translate_model.TranslateModel + :param stack: beam search stack + :return: new beam search stack + """ + out_seq, out_len, scores, finished, dec_states, attnP, tracked = stack + + # 1. sample + batch_size = tf.shape(out_seq)[0] + phony_slices = tf.range(batch_size) + _, new_outputs, logp_next = model.sample(dec_states, scores, phony_slices, k=1, **flags) + + out_seq = tf.concat([out_seq, new_outputs], axis=1) + scores = scores + logp_next[:, 0] * tf.cast(~finished, 'float32') + is_eos = tf.equal(new_outputs[:, 0], model.out_voc.eos) + finished = tf.logical_or(finished, is_eos) + + # 2. decode + new_states = model.decode(dec_states, new_outputs[:, 0], **flags) + attnP = tf.concat([attnP, model.get_attnP(new_states)[:, None]], axis=1) + tracked = nested_map(lambda seq, new: tf.concat([seq, new[:, None]], axis=1), + tracked, self.get_tracked_outputs(new_states) + ) + return self.Stack(out_seq, out_len + 1, scores, finished, new_states, attnP, tracked) + + +class BeamSearchDecoder: + """ + Performs ingraph beam search for given input sequences (inp) + Supports custom penalizing, pruning against best score and best score in beam (via beam_spread) + :param model: something that implements TranslateModel + :param batch_placeholder: whatever model can .encode, + by default should be {'inp': int32 matrix [batch_size x time]} + :param max_len: maximum hypothesis length to consider, symbolic or numeric integer + if scalar, sets global length; if vector[batch_size], sets length for each input; + if None, defaults to 2*inp_len + 3; float('inf') means unlimited + :param min_len: minimum valid output length. None means min_len=inp_len // 4 - 1; Same type as min_len + :param beam_size: maximum number of hypotheses that can pass from one beam search step to another. + The rest is pruned. + :param beam_spread: maximum difference in score between a hypothesis and current best hypothesis. + Anything below that is pruned. + :param force_bos: if True, forces zero-th output to be model.out_voc.bos. Otherwise lets model decide. + :param if_no_eos: if 'last', will return unfinished hypos if there are no finished hypos by max_len + elif 'initial', returns empty hypothesis + :param back_prop: see tf.while_loop back_prop param + :param swap_memory: see tf.while_loop swap_memory param + + :param **flags: whatever else you want to feed into model. This will be passed to encode, decode, etc. + is_train - if True (default), enables dropouts and similar training-only stuff + sampling_strategy - if "random", samples hypotheses proportionally to softmax(logits) + otherwise(default) - takes top K hypotheses + sampling_temperature - if sampling_strategy == "random", + performs sampling ~ softmax(logits/sampling_temperature) + + """ + Stack = namedtuple('Stack', [ + # per hypo values + 'out', # [batch_size x beam_size, nout], int + 'scores', # [batch_size x beam_size ] + 'raw_scores', # [batch_size x beam_size ] + 'attnP', # [batch_size x beam_size, nout+1, ninp] + 'dec_state', # TranslateModel DecState nested structure of [batch_size x beam_size, ...] + + # per beam values + 'slices', # indices of first hypo for each sentence [batch_size ] + 'out_len', # total (maximum) length of a stack [], int + 'best_out', # [batch_size, nout], int, padded with EOS + 'best_scores', # [batch_size] + 'best_raw_scores', # [batch_size] + 'best_attnP', # [batch_size, nout+1, ninp], padded with EOS + 'best_dec_state', # TranslateModel DecState; nested structure of [batch_size, ...] + + # Auxilary data for extension classes. + 'ext' # Dict[StackExtType, StackExtType()] + ]) + + def __init__(self, model, batch_placeholder, min_len=None, max_len=None, + beam_size=12, beam_spread=3, beam_spread_raw=None, force_bos=False, + if_no_eos='last', back_prop=True, swap_memory=False, **flags + ): + assert if_no_eos in ['last', 'initial'] + assert np.isfinite(beam_spread) or max_len != float('inf'), "Must set maximum length if beam_spread is infinite" + # initialize fields + self.batch_placeholder = batch_placeholder + inp_len = batch_placeholder.get('inp_len', infer_length(batch_placeholder['inp'], model.out_voc.eos)) + self.min_len = min_len if min_len is not None else inp_len // 4 - 1 + self.max_len = max_len if max_len is not None else 2 * inp_len + 3 + self.beam_size, self.beam_spread = beam_size, beam_spread + if beam_spread_raw is None: + self.beam_spread_raw = beam_spread + else: + self.beam_spread_raw = beam_spread_raw + self.force_bos, self.if_no_eos = force_bos, if_no_eos + + # actual beam search + first_stack = self.create_initial_stack(model, batch_placeholder, force_bos=force_bos, **flags) + shape_invariants = nested_map(lambda v: tf.TensorShape([None for _ in v.shape]), first_stack) + + def should_continue_translating(*stack): + stack = self.Stack(*stack) + should_continue = self.should_continue_translating(model, stack) + return tf.reduce_any(should_continue) + + def expand_hypos(*stack): + return self.beam_search_step(model, self.Stack(*stack), **flags) + + last_stack = tf.while_loop( + cond=should_continue_translating, + body=expand_hypos, + loop_vars=first_stack, + shape_invariants=shape_invariants, + back_prop=back_prop, + swap_memory=swap_memory, + ) + + # crop unnecessary EOSes that occur if no hypothesis is updated on several last steps + actual_length = infer_length(last_stack.best_out, model.out_voc.eos) + max_length = tf.reduce_max(actual_length) + last_stack = last_stack._replace(best_out=last_stack.best_out[:, :max_length]) + + self.best_out = last_stack.best_out + self.best_attnP = last_stack.best_attnP + self.best_scores = last_stack.best_scores + self.best_raw_scores = last_stack.best_raw_scores + self.best_state = last_stack.best_dec_state + + def translate_batch(self, batch_data, **optional_feed): + """ + Translates NUMERIC batch of data + :param batch_data: dict {'inp':np.array int32[batch,time]} + :optional_feed: any additional values to be fed into graph. e.g. if you used placeholder for max_len at __init__ + :return: best hypotheses' outputs[batch, out_len] and attnP[batch, out_len, inp_len] + """ + feed_dict = {placeholder: batch_data[k] for k, placeholder in self.batch_placeholder.items()} + for k, v in optional_feed.items(): + feed_dict[k] = v + + out_ids, attnP = tf.get_default_session().run( + [self.best_out, self.best_attnP], + feed_dict=feed_dict) + + return out_ids, attnP + + def create_initial_stack(self, model, batch, **flags): + """ + Creates initial stack for beam search by encoding inp and optionally forcing BOS as first output + :type model: lib.task.seq2seq.inference.TranslateModel + :param batch: model inputs - whatever model can eat for self.encode(batch,**tags) + :param force_bos: if True, forces zero-th output to be model.out_voc.bos. Otherwise lets model decide. + """ + + dec_state = dec_state_0 = model.encode(batch, **flags) + attnP_0 = model.get_attnP(dec_state_0) + batch_size = tf.shape(attnP_0)[0] + + out_len = tf.constant(0, 'int32') + out = tf.zeros(shape=(batch_size, 0), dtype=tf.int32) # [batch_size, nout = 0] + + if self.force_bos: + bos = tf.fill(value=model.out_voc.bos, dims=(batch_size,)) + dec_state = dec_state_1 = model.decode(dec_state_0, bos, **flags) + attnP_1 = model.get_attnP(dec_state_1) + attnP = tf.stack([attnP_0, attnP_1], axis=1) # [batch_size, 2, ninp] + out_len += 1 + out = tf.concat([out, bos[:, None]], axis=1) + + else: + attnP = attnP_0[:, None, :] # [batch_size, 1, ninp] + + slices = tf.range(0, batch_size) + empty_out = tf.fill(value=model.out_voc.eos, dims=(batch_size, tf.shape(out)[1])) + + # Create stack. + stack = self.Stack( + out=out, + scores=tf.zeros(shape=(batch_size,)), + raw_scores=tf.zeros(shape=(batch_size,)), + attnP=attnP, + dec_state=dec_state, + slices=slices, + out_len=out_len, + best_out=empty_out, + best_scores=tf.fill(value=-float('inf'), dims=(batch_size,)), + best_raw_scores=tf.fill(value=-float('inf'), dims=(batch_size,)), + best_attnP=attnP, + best_dec_state=dec_state, + ext={} + ) + + return stack + + def should_continue_translating(self, model, stack): + """ + Returns a bool vector for all hypotheses where True means hypo should be kept, 0 means it should be dropped. + A hypothesis is dropped if it is either finished or pruned by beam_spread or by beam_size + Note: this function assumes hypotheses for each input sample are sorted by scores(best first)!!! + """ + + # drop finished hypotheses + should_keep = tf.logical_not( + tf.reduce_any(tf.equal(stack.out, model.out_voc.eos), axis=-1)) # [batch_size x beam_size] + + n_hypos = tf.shape(stack.out)[0] + batch_size = tf.shape(stack.best_out)[0] + batch_indices = hypo_to_batch_index(n_hypos, stack.slices) + + # prune by length + if self.max_len is not None: + within_max_length = tf.less_equal(stack.out_len, self.max_len) + + # if we're given one max_len per each sentence, repeat it for each batch + if not is_scalar(self.max_len): + within_max_length = tf.gather(within_max_length, batch_indices) + + should_keep = tf.logical_and( + should_keep, + within_max_length, + ) + + # prune by beam spread + if self.beam_spread is not None: + best_scores_for_hypos = tf.gather(stack.best_scores, batch_indices) + pruned_by_spread = tf.less(stack.scores + self.beam_spread, best_scores_for_hypos) + should_keep = tf.logical_and(should_keep, tf.logical_not(pruned_by_spread)) + + if self.beam_spread_raw: + best_raw_scores_for_hypos = tf.gather(stack.best_raw_scores, batch_indices) + pruned_by_raw_spread = tf.less(stack.raw_scores + self.beam_spread_raw, best_raw_scores_for_hypos) + should_keep = tf.logical_and(should_keep, + tf.logical_not(pruned_by_raw_spread)) + + + # pruning anything exceeding beam_size + if self.beam_size is not None: + # This code will use a toy example to explain itself: slices=[0,2,5,5,8], n_hypos=10, beam_size=2 + # should_keep = [1,1,1,0,1,1,1,1,0,1] (two hypotheses have been pruned/finished) + + # 1. compute index of each surviving hypothesis globally over full batch, [0,1,2,3,3,4,5,6,7,7] + survived_hypo_id = tf.cumsum(tf.cast(should_keep, 'int32'), exclusive=True) + # 2. compute number of surviving hypotheses for each batch sample, [2,2,3,1] + survived_hypos_per_input = tf.bincount(batch_indices, weights=tf.cast(should_keep, 'int32'), + minlength=batch_size, maxlength=batch_size) + # 3. compute the equivalent of slices for hypotheses excluding pruned: [0,2,4,4,7] + slices_exc_pruned = tf.cumsum(survived_hypos_per_input, exclusive=True) + # 4. compute index of surviving hypothesis within one sample (for each sample) + # index of input sentence in batch: inp0 /inp_1\ /inp_2\, /inp_3\ + # index of hypothesis within input: [0, 1, 0, 1, 1, 0, 1, 2, 0, 0, 1] + # 'e' = pruned earlier, 'x' - pruned now: 'e' 'x' 'e' + beam_index = survived_hypo_id - tf.gather(slices_exc_pruned, batch_indices) + + # 5. prune hypotheses with index exceeding beam_size + pruned_by_beam_size = tf.greater_equal(beam_index, self.beam_size) + should_keep = tf.logical_and(should_keep, tf.logical_not(pruned_by_beam_size)) + + return should_keep + + def beam_search_step_expand_hypos(self, model, stack, **flags): + """ + Performs one step of beam search decoding. Samples new hypothesis to stack. + :type model: lib.task.seq2seq.inference.TranslateModel + :type stack: BeamSearchDecoder.BeamSearchStack + """ + + # Prune + # - Against best completed hypo + # - Against best hypo in beam + # - EOS translations + # - Against beam size + + should_keep = self.should_continue_translating(model, stack) + + hypo_indices = tf.where(should_keep)[:, 0] + stack = self.shuffle_beam(model, stack, hypo_indices) + + # Compute penalties, if any + base_scores = self.compute_base_scores(model, stack, **flags) + + # Get top-beam_size new hypotheses for each input. + # Note: we assume sample returns hypo_indices from highest score to lowest, therefore hypotheses + # are automatically sorted by score within each slice. + hypo_indices, words, delta_raw_scores = model.sample(stack.dec_state, base_scores, stack.slices, + self.beam_size, **flags + ) + + # hypo_indices, words and delta_raw_scores may contain -1/-1/-inf triples for non-available hypotheses. + # This can only happen if for some input there were 0 surviving hypotheses OR beam_size > n_hypos*vocab_size + # In either case, we want to prune such hypotheses + valid_indices = tf.where(tf.not_equal(tf.reshape(hypo_indices, [-1]), -1))[:, 0] + hypo_indices = tf.gather(tf.reshape(hypo_indices, [-1]), valid_indices) + words = tf.gather(tf.reshape(words, [-1]), valid_indices) + delta_raw_scores = tf.gather(tf.reshape(delta_raw_scores, [-1]), valid_indices) + + stack = self.shuffle_beam(model, stack, hypo_indices) + dec_state = model.decode(stack.dec_state, words, **flags) + step_attnP = model.get_attnP(dec_state) + # step_attnP shape: [batch_size * beam_size, ninp] + + # collect stats for the next step + attnP = tf.concat([stack.attnP, step_attnP[:, None, :]], axis=1) # [batch * beam_size, nout, ninp] + out = tf.concat([stack.out, words[..., None]], axis=-1) + out_len = stack.out_len + 1 + + raw_scores = stack.raw_scores + delta_raw_scores + + return stack._replace( + out=out, + raw_scores=raw_scores, + attnP=attnP, + out_len=out_len, + dec_state=dec_state, + ) + + def beam_search_step_update_best(self, model, stack, maintain_best_state=False, **flags): + """ + Performs one step of beam search decoding. Removes hypothesis from (beam_size ** 2) stack. + :type model: lib.task.seq2seq.inference.TranslateModel + :type stack: BeamSearchDecoder.BeamSearchStack + """ + + # Compute sample id for each hypo in stack + n_hypos = tf.shape(stack.out)[0] + batch_indices = hypo_to_batch_index(n_hypos, stack.slices) + + # Mark finished hypos + finished = tf.equal(stack.out[:, -1], model.out_voc.eos) + + if self.min_len is not None: + below_min_length = tf.less(stack.out_len, self.min_len) + if not is_scalar(self.min_len): + below_min_length = tf.gather(below_min_length, batch_indices) + + finished = tf.logical_and(finished, tf.logical_not(below_min_length)) + + if self.if_no_eos == 'last': + # No hypos finished with EOS, but len == max_len, consider unfinished hypos + reached_max_length = tf.equal(stack.out_len, self.max_len) + if not is_scalar(self.max_len): + reached_max_length = tf.gather(reached_max_length, batch_indices) + + have_best_out = tf.reduce_any(tf.not_equal(stack.best_out, model.out_voc.eos), 1) + no_finished_alternatives = tf.gather(tf.logical_not(have_best_out), batch_indices) + allow_unfinished_hypo = tf.logical_and(reached_max_length, no_finished_alternatives) + + finished = tf.logical_or(finished, allow_unfinished_hypo) + + # select best finished hypo for each input in batch (if any) + finished_scores = tf.where(finished, stack.scores, tf.fill(tf.shape(stack.scores), -float('inf'))) + best_scores, best_indices = sliced_argmax(finished_scores[:, None], stack.slices, 1) + best_scores, best_indices = best_scores[:, 0], stack.slices + best_indices[:, 0] + best_indices = tf.clip_by_value(best_indices, 0, tf.shape(stack.out)[0] - 1) + + stack_is_nonempty = tf.not_equal(tf.shape(stack.out)[0], 0) + + # take the better one of new best hypotheses or previously existing ones + new_is_better = tf.greater(best_scores, stack.best_scores) + best_scores = tf.where(new_is_better, best_scores, stack.best_scores) + + new_best_raw_scores = tf.cond(stack_is_nonempty, + lambda:tf.gather(stack.raw_scores, best_indices), + lambda:stack.best_raw_scores) + + best_raw_scores = tf.where(new_is_better, new_best_raw_scores, stack.best_raw_scores) + + + batch_size = tf.shape(stack.best_out)[0] + eos_pad = tf.fill(value=model.out_voc.eos, dims=(batch_size, 1)) + padded_best_out = tf.concat([stack.best_out, eos_pad], axis=1) + new_out = tf.cond(stack_is_nonempty, + lambda: tf.gather(stack.out, best_indices), + lambda: tf.gather(padded_best_out, best_indices) # dummy out, best indices are zeros + ) + best_out = tf.where(new_is_better, new_out, padded_best_out) + + zero_attnP = tf.zeros_like(stack.best_attnP[:, :1, :]) + padded_best_attnP = tf.concat([stack.best_attnP, zero_attnP], axis=1) + new_attnP = tf.cond(stack_is_nonempty, + lambda: tf.gather(stack.attnP, best_indices), + lambda: tf.gather(padded_best_attnP, best_indices), # dummy attnP, best indices are zeros + ) + best_attnP = tf.where(new_is_better, new_attnP, padded_best_attnP) + + # if better translation is reached, update it's state too + best_dec_state = stack.best_dec_state + if maintain_best_state: + new_best_dec_state = model.shuffle(stack.dec_state, best_indices) + best_dec_state = model.switch(new_is_better, new_best_dec_state, stack.best_dec_state) + + return stack._replace( + best_out=best_out, + best_scores=best_scores, + best_attnP=best_attnP, + best_raw_scores=best_raw_scores, + best_dec_state=best_dec_state, + ) + + def beam_search_step(self, model, stack, **flags): + stack = self.beam_search_step_expand_hypos(model, stack, **flags) + stack = stack._replace( + scores=self.compute_scores(model, stack, **flags) + ) + is_beam_not_empty = tf.not_equal(tf.shape(stack.raw_scores)[0], 0) + return self.beam_search_step_update_best(model, stack, **flags) + + def compute_scores(self, model, stack, **flags): + """ + Compute hypothesis scores given beam search stack. Applies any penalties necessary. + For quick prototyping, you can store whatever penalties you need in stack.dec_state + :type model: lib.task.seq2seq.inference.TranslateModel + :type stack: BeamSearchDecoder.BeamSearchStack + :return: float32 vector (one score per hypo) + """ + return stack.raw_scores + + def compute_base_scores(self, model, stack, **flags): + """ + Compute hypothesis scores to be used as base_scores for model.sample. + This is usually same as compute_scores but scaled to the magnitude of log-probabilities + :type model: lib.task.seq2seq.inference.TranslateModel + :type stack: BeamSearchDecoder.BeamSearchStack + :return: float32 vector (one score per hypo) + """ + return self.compute_scores(model, stack, **flags) + + def shuffle_beam(self, model, stack, flat_indices): + """ + Selects hypotheses by index from entire BeamSearchStack + Note: this method assumes that both stack and flat_indices are sorted by sample index + (i.e. first are indices for input0 are, then indices for input1, then 2, ... then input[batch_size-1] + """ + n_hypos = tf.shape(stack.out)[0] + batch_size = tf.shape(stack.best_out)[0] + + # compute new slices: + # step 1: get index of inptut sequence (in batch) for each hypothesis in flat_indices + sample_ids_for_slices = tf.gather(hypo_to_batch_index(n_hypos, stack.slices), flat_indices) + # step 2: compute how many hypos per flat_indices + n_hypos_per_sample = tf.bincount(sample_ids_for_slices, minlength=batch_size, maxlength=batch_size) + # step 3: infer slice start indices + new_slices = tf.cumsum(n_hypos_per_sample, exclusive=True) + + # shuffle everything else + return stack._replace( + out=tf.gather(stack.out, flat_indices), + scores=tf.gather(stack.scores, flat_indices), + raw_scores=tf.gather(stack.raw_scores, flat_indices), + attnP=tf.gather(stack.attnP, flat_indices), + dec_state=model.shuffle(stack.dec_state, flat_indices), + ext=nested_map(lambda x: tf.gather(x, flat_indices), stack.ext), + slices=new_slices, + ) + + +class PenalizedBeamSearchDecoder(BeamSearchDecoder): + """ + Performs ingraph beam search for given input sequences (inp) + Implements length and coverage penalties + """ + PenalizedExt = namedtuple('PenalizedExt', [ + 'attnP_sum', # [batch_size x beam_size, ninp] + ]) + + def beam_search_step_expand_hypos(self, model, stack, **flags): + new_stack = super().beam_search_step_expand_hypos(model, stack, **flags) + new_stack_ext = new_stack.ext[self.PenalizedExt] + + step_attnP = model.get_attnP(new_stack.dec_state) + # step_attnP shape: [batch_size * beam_size, ninp] + + new_stack.ext[self.PenalizedExt] = new_stack_ext._replace( + attnP_sum=new_stack_ext.attnP_sum + step_attnP) + return new_stack + + def create_initial_stack(self, model, batch, **flags): + stack = super().create_initial_stack(model, batch, **flags) + stack.ext[self.PenalizedExt] = self.PenalizedExt( + attnP_sum=tf.reduce_sum(stack.attnP, axis=1)) + return stack + + def compute_scores(self, model, stack, len_alpha=1, attn_beta=0, **flags): + """ + Computes scores after length and coverage penalty + :param len_alpha: coefficient for length penalty, score / ( [5 + len(output_sequence)] / 6) ^ len_alpha + :param attn_beta: coefficient for coverage penalty (additive) + attn_beta * sum_i {log min(1.0, sum_j {attention_p[x_i,y_j] } )} + :return: float32 vector (one score per hypo) + """ + stack_ext = stack.ext[self.PenalizedExt] + + if attn_beta: + warn("whenever attn_beta !=0, this code works as in http://bit.ly/2ziK5a8," + "which may or may not be correct depending on your definition.") + + scores = stack.raw_scores + if len_alpha: + length_penalty = tf.pow((1. + tf.to_float(stack.out_len) / 6.), len_alpha) + scores /= length_penalty + if attn_beta: + times_translated = tf.minimum(stack_ext.attnP_sum, 1) + coverage_penalty = tf.reduce_sum( + tf.log(times_translated + sys.float_info.epsilon), + axis=-1) * attn_beta + scores += coverage_penalty + return scores + + def compute_base_scores(self, model, stack, len_alpha=1, **flags): + """ + Compute hypothesis scores to be used as base_scores for model.sample + :return: float32 vector (one score per hypo) + """ + scores = self.compute_scores(model, stack, len_alpha=len_alpha, **flags) + if len_alpha: + length_penalty = tf.pow((1. + tf.to_float(stack.out_len) / 6.), len_alpha) + scores *= length_penalty + return scores + + +def get_words_attnP(step_attnP, inp_words_mask, slices, src_word_attn_aggregation='max'): + # Helper function to extract word-level alignment aggregation on src. + # For parameter description see AlignmentPenaltyBeamSearchDecoder.AlignmentPenaltyExt + + def _get_words_attnP(step_attnP, inp_words_mask, slices): + max_words_len = np.max(np.sum(inp_words_mask, axis=1)) + words_attnP = np.zeros((step_attnP.shape[0], max_words_len)) + slices = slices.tolist() + [step_attnP.shape[0]] + for words_mask, (b, e) in zip(inp_words_mask, + zip(slices[:-1], slices[1:])): + words_ind = np.where(words_mask)[0].tolist() + [len(words_mask)] + for i, (wb, we) in enumerate(zip(words_ind[:-1], words_ind[1:])): + if src_word_attn_aggregation == 'max': + words_attnP[b:e, i] = np.max(step_attnP[b:e, wb:we], axis=1) + elif src_word_attn_aggregation == 'sum': + words_attnP[b:e, i] = np.sum(step_attnP[b:e, wb:we], axis=1) + else: + raise ValueError('Unknown src_word_attn_aggregation mode: %s' % src_word_attn_aggregation) + return words_attnP.astype(np.float32) + + words_attnP = tf.py_func(_get_words_attnP, [step_attnP, inp_words_mask, slices], tf.float32, stateful=False) + words_attnP.set_shape([None, None]) + return tf.stop_gradient(words_attnP) + + +class AlignmentPenaltyBeamSearchDecoder(BeamSearchDecoder): + AlignmentPenaltyExt = namedtuple('AlignmentPenaltyExt', [ + 'attnP_aggregated_src', # [batch_size x beam_size, ninp|ninp_words] + 'attnP_aggregated_dst', # [batch_size x beam_size, nout] + + 'inp_words_mask', # Does bpe token start a new word? [batch_size, ninp], bool + ]) + + def __init__(self, *args, + len_alpha=1, + attn_beta=0, src_attn_aggregation='max', + src_word_attn_aggregation=None, + dst_attn_beta=0, dst_attn_aggregation='max', + **kwargs ): + # We need to initialize them all to create initial stack + self.len_alpha = len_alpha + self.attn_beta = attn_beta + self.src_attn_aggregation = src_attn_aggregation + self.src_word_attn_aggregation = src_word_attn_aggregation + self.dst_attn_beta = dst_attn_beta + self.dst_attn_aggregation = dst_attn_aggregation + super().__init__(*args, **kwargs) + + def beam_search_step_expand_hypos(self, model, stack, **flags): + stack = super().beam_search_step_expand_hypos(model, stack, **flags) + stack_ext = stack.ext[self.AlignmentPenaltyExt] + + step_attnP = model.get_attnP(stack.dec_state) + # step_attnP shape: [batch_size * beam_size, ninp] + + # updating attnP_aggregated_src + step_attnP_word = step_attnP + if self.src_word_attn_aggregation: + step_attnP_word = get_words_attnP( + step_attnP_word, stack_ext.inp_words_mask, + stack.slices, self.src_word_attn_aggregation) + + max_words_num = tf.shape(stack_ext.attnP_aggregated_src)[1] + paddings = max_words_num - tf.shape(step_attnP_word)[1] + step_attnP_word = tf.pad(step_attnP_word, [[0, 0], [0, paddings]]) + + if self.attn_beta: + if self.src_attn_aggregation == 'max': + attnP_aggregated_src = tf.maximum(stack_ext.attnP_aggregated_src, + step_attnP_word) + elif self.src_attn_aggregation == 'sum': + attnP_aggregated_src = stack_ext.attnP_aggregated_src + step_attnP_word + else: + raise ValueError + else: + attnP_aggregated_src = stack_ext.attnP_aggregated_src + + # updating attnP_aggregated_dst + if self.dst_attn_beta: + if self.dst_attn_aggregation == 'max': + dst_attnP_aggregated = tf.reduce_max(step_attnP_word, axis=-1)[:, None] + elif self.dst_attn_aggregation == 'sum': + dst_attnP_aggregated = tf.reduce_sum(step_attnP_word, axis=-1)[:, None] + else: + raise ValueError + + attnP_aggregated_dst = tf.concat( + [stack_ext.attnP_aggregated_dst, dst_attnP_aggregated], + axis=1) + else: + attnP_aggregated_dst = stack_ext.attnP_aggregated_dst + + stack.ext[self.AlignmentPenaltyExt] = stack_ext._replace( + attnP_aggregated_src=attnP_aggregated_src, + attnP_aggregated_dst=attnP_aggregated_dst) + return stack + + def create_initial_stack(self, model, batch, **flags): + stack = super().create_initial_stack(model, batch, **flags) + + words_attnP = tf.squeeze(stack.attnP, axis=1) + + # Calc inp_words_mask and aggregate data. + if self.src_word_attn_aggregation: + def is_new_word(inp_words): + return np.array([[not v.startswith(b'`') for v in l] for l in inp_words]) + + inp_words_mask = tf.py_func(is_new_word, [batch['inp_words']], bool, stateful=False) + inp_words_mask.set_shape(batch['inp_words'].shape) + inp_words_mask = tf.stop_gradient(inp_words_mask) + + words_attnP = get_words_attnP( + words_attnP, inp_words_mask, stack.slices, + self.src_word_attn_aggregation) + else: + inp_words_mask = tf.fill(tf.shape(batch['inp']), 1.0) + + + if self.attn_beta: + if self.src_attn_aggregation in ('max', 'sum'): + attnP_aggregated_src = words_attnP + else: + raise ValueError + else: + attnP_aggregated_src = tf.fill(tf.shape(batch['inp']), 0.0) + + # Calc attnP_aggregated_dst + if self.dst_attn_beta: + if self.dst_attn_aggregation == 'max': + attnP_aggregated_dst = tf.reduce_max(words_attnP, axis=-1) + elif self.dst_attn_aggregation == 'sum': + attnP_aggregated_dst = tf.reduce_sum(words_attnP, axis=-1) + else: + raise ValueError + else: + attnP_aggregated_dst = tf.fill((tf.shape(batch['inp'])[0],), 0.0) + attnP_aggregated_dst = attnP_aggregated_dst[:, None] + + stack.ext[self.AlignmentPenaltyExt] = self.AlignmentPenaltyExt( + attnP_aggregated_src=attnP_aggregated_src, + attnP_aggregated_dst=attnP_aggregated_dst, + inp_words_mask=inp_words_mask + ) + return stack + + def compute_scores(self, model, stack, **flags): + """ + Computes scores after length and coverage penalty + :param len_alpha: coefficient for length penalty, score / ( [5 + len(output_sequence)] / 6) ^ len_alpha + :param attn_beta: coefficient for coverage penalty (additive) + attn_beta * sum_i {log min(1.0, {src_attn_aggregation}_j {attention_p[x_i,y_j] } )} + :param src_attn_aggregation: aggregation for src coverage penalty. + Possible values are 'max', 'sum'. + :param src_word_attn_aggregation: should we aggregate src coverage penalty by words? + Possible values are None/max/sum. + :param dst_attn_beta: coefficient for coverage penalty on dst side: + attn_beta * sum_j {log min(1.0, {dst_attn_aggregation}_i {attention_p[x_i,y_j] } )} + :param dst_attn_aggregation: aggregation for dst coverage penalty. + Possible values are 'max', 'sum'. + :return: float32 vector (one score per hypo) + """ + + stack_ext = stack.ext[self.AlignmentPenaltyExt] + + scores = stack.raw_scores + if self.len_alpha: + length_penalty = tf.pow((1. + tf.to_float(stack.out_len) / 6.), self.len_alpha) + scores /= length_penalty + if self.attn_beta: + coverage_penalty = tf.reduce_sum( + tf.log(tf.minimum(stack_ext.attnP_aggregated_src, 1) + sys.float_info.epsilon), + axis=-1) + scores += coverage_penalty * self.attn_beta + if self.dst_attn_beta: + coverage_penalty = tf.reduce_sum( + tf.log(tf.minimum(stack_ext.attnP_aggregated_dst, 1) + sys.float_info.epsilon), + axis=-1) + scores += coverage_penalty * self.dst_attn_beta + + return scores + + def compute_base_scores(self, model, stack, **flags): + """ + Compute hypothesis scores to be used as base_scores for model.sample + :return: float32 vector (one score per hypo) + """ + scores = self.compute_scores(model, stack, **flags) + if self.len_alpha: + length_penalty = tf.pow((1. + tf.to_float(stack.out_len) / 6.), self.len_alpha) + scores *= length_penalty + return scores + + +def hypo_to_batch_index(n_hypos, slices): + """ + Computes index in batch (input sequence index) for each hypothesis given slices. + :param n_hypos: number of hypotheses (tf int scalar) + :param slices: indices of first hypo for each input in batch + + It should guaranteed that + - slices[0]==0 (first hypothesis starts at index 0), otherwise output[:slices[0]] will be -1 + - if batch[i] is terminated, then batch[i]==batch[i+1] + """ + is_next_sent_at_t = tf.bincount(slices, minlength=n_hypos, maxlength=n_hypos) + hypo_to_index = tf.cumsum(is_next_sent_at_t) - 1 + return hypo_to_index diff --git a/lib/task/seq2seq/models/__init__.py b/lib/task/seq2seq/models/__init__.py new file mode 100644 index 0000000..7a93f7e --- /dev/null +++ b/lib/task/seq2seq/models/__init__.py @@ -0,0 +1,95 @@ +from ..inference import translate_lines +from lib.task.seq2seq.inference import TranslateModel, GreedyDecoder, PenalizedBeamSearchDecoder +from ..data import make_batch_data, make_batch_placeholder +from functools import lru_cache +from itertools import chain, islice + + +class ModelBase: + def encode_decode(self, batch, is_train): + """ Encode input sequence and decode rdo for output sequence """ + raise NotImplementedError() + + def _get_batch_sample(self): + return [("i saw a cat", "i write the code")] + + def make_feed_dict(self, batch, **kwargs): + batch_data = make_batch_data(batch, self.inp_voc, self.out_voc, force_bos=self.hp.get('force_bos', True), **kwargs) + return batch_data + + +class TranslateModelBase(TranslateModel, ModelBase): + """ + A base class that most seq2seq models depend on. + Must have following fields: name, inp_voc, out_voc, loss + """ + def translate_lines(self, lines, ingraph=True, ingraph_mode='beam_search', + unbpe=True, batch_size=None, dumper=None, **flags): + """ Translate multiple lines with the model """ + if ingraph: + translator = self.get_ingraph_translator(mode=ingraph_mode, back_prop=False, **flags) + else: + translator = self.get_translator(**flags) + + replace_unk = flags.get('replace', self.hp.get('replace', False)) + + if batch_size is None: + lines_batched = [lines] + else: + lines = iter(lines) + lines_batched = list(iter(lambda: tuple(islice(lines, batch_size)), ())) + + outputs = (translate_lines(batch_lines, translator, self, self.out_voc, replace_unk, unbpe, dumper=dumper) + for batch_lines in lines_batched) + + return list(chain(*outputs)) + + def predict(self): + self.get_predictor().main() + + @lru_cache() + def get_ingraph_translator(self, mode='beam_search', **flags): + """ + Creates a symbolic translation graph on a batch of placeholders. + Used to translate numeric data. + :param mode: 'greedy', 'sample', or 'beam_search' + :param flags: anything else you want to pass to decoder, encode, decode, sample, etc. + :return: a class with .best_out, .best_scores containing symbolic tensors for translations + """ + batch_data_sample = self.make_feed_dict(self._get_batch_sample()) + batch_placeholder = make_batch_placeholder(batch_data_sample) + return self.symbolic_translate(batch_placeholder, mode, **flags) + + def symbolic_translate(self, batch_placeholder, mode='beam_search', **flags): + """ + A function that takes a dict of symbolic inputs and outputs symolic translations + :param batch_placeholder: a dict of symbolic inputs {'inp':int32[batch, time]} + :param mode: str: 'greedy', 'sample', 'beam_search' or a decoder class + :param flags: anything else you want to pass to decoder, encode, decode, sample, etc. + :return: a class with .best_out, .best_scores containing symbolic tensors for translations + """ + flags = dict(self.hp, **flags) + + if mode in ('greedy', 'sample'): + flags['sampling_strategy'] = 'random' if mode == 'sample' else 'greedy' + return GreedyDecoder( + model=self.get_translate_model(), + batch_placeholder=batch_placeholder, + **flags + ) + elif mode == 'beam_search': + return PenalizedBeamSearchDecoder( + model=self.get_translate_model(), + batch_placeholder=batch_placeholder, + **flags + ) + elif callable(mode): + return mode(self.get_translate_model(), batch_placeholder, **flags) + else: + raise ValueError("Invalid mode : %s" % mode) + + def get_translate_model(self): + if hasattr(self, 'translate_model'): + return self.translate_model + + return self diff --git a/lib/task/seq2seq/models/transformer.py b/lib/task/seq2seq/models/transformer.py new file mode 100644 index 0000000..4c7015b --- /dev/null +++ b/lib/task/seq2seq/models/transformer.py @@ -0,0 +1,623 @@ +#!/usr/bin/env python3 +from lib.layers import * +from lib.ops import * +from ..models import TranslateModelBase, TranslateModel +from ..data import * +from collections import namedtuple + + +class Transformer: + def __init__( + self, name, + inp_voc, out_voc, + *_args, + emb_size=None, hid_size=512, + key_size=None, value_size=None, + inner_hid_size=None, # DEPRECATED. Left for compatibility with older experiments + ff_size=None, + num_heads=8, num_layers=6, + attn_dropout=0.0, attn_value_dropout=0.0, relu_dropout=0.0, res_dropout=0.1, + share_emb=False, inp_emb_bias=False, rescale_emb=False, + dst_reverse=False, dst_rand_offset=False, summarize_preactivations=False, + res_steps='ldan', normalize_out=False, multihead_attn_format='v1', + emb_inp_device='', emb_out_device='', + **_kwargs + ): + + if isinstance(ff_size, str): + ff_size = [int(i) for i in ff_size.split(':')] + + if _args: + raise Exception("Unexpected positional arguments") + + emb_size = emb_size if emb_size else hid_size + key_size = key_size if key_size else hid_size + value_size = value_size if value_size else hid_size + if key_size % num_heads != 0: + raise Exception("Bad number of heads") + if value_size % num_heads != 0: + raise Exception("Bad number of heads") + + self.name = name + self.num_layers_enc = num_layers + self.num_layers_dec = num_layers + self.res_dropout = res_dropout + self.emb_size = emb_size + self.hid_size = hid_size + self.rescale_emb = rescale_emb + self.summarize_preactivations = summarize_preactivations + self.dst_reverse = dst_reverse + self.dst_rand_offset = dst_rand_offset + self.normalize_out = normalize_out + + with tf.variable_scope(name): + max_voc_size = max(inp_voc.size(), out_voc.size()) + + self.emb_inp = Embedding( + 'emb_inp', max_voc_size if share_emb else inp_voc.size(), emb_size, + initializer=tf.random_normal_initializer(0, emb_size ** -.5), + device=emb_inp_device) + + self.emb_out = Embedding( + 'emb_out', max_voc_size if share_emb else out_voc.size(), emb_size, + matrix=self.emb_inp.mat if share_emb else None, + initializer=tf.random_normal_initializer(0, emb_size ** -.5), + device=emb_out_device) + + self.emb_inp_bias = 0 + if inp_emb_bias: + self.emb_inp_bias = get_model_variable('emb_inp_bias', shape=[1, 1, emb_size]) + + def get_layer_params(layer_prefix, layer_idx): + layer_name = '%s-%i' % (layer_prefix, layer_idx) + inp_out_size = emb_size if layer_idx == 0 else hid_size + return layer_name, inp_out_size + + def attn_layer(layer_prefix, layer_idx, **kwargs): + layer_name, inp_out_size = get_layer_params(layer_prefix, layer_idx) + return ResidualLayerWrapper( + layer_name, + MultiHeadAttn( + layer_name, + inp_size=inp_out_size, + key_depth=key_size, + value_depth=value_size, + output_depth=hid_size, + num_heads=num_heads, + attn_dropout=attn_dropout, + attn_value_dropout=attn_value_dropout, + **kwargs), + inp_size=inp_out_size, + out_size=inp_out_size, + steps=res_steps, + dropout=res_dropout) + + def ffn_layer(layer_prefix, layer_idx, ffn_hid_size): + layer_name, inp_out_size = get_layer_params(layer_prefix, layer_idx) + return ResidualLayerWrapper( + layer_name, + FFN( + layer_name, + inp_size=inp_out_size, + hid_size=ffn_hid_size, + out_size=hid_size, + relu_dropout=relu_dropout), + inp_size=inp_out_size, + out_size=hid_size, + steps=res_steps, + dropout=res_dropout) + + # Encoder/decoder layer params + enc_ffn_hid_size = ff_size if ff_size else (inner_hid_size if inner_hid_size else hid_size) + dec_ffn_hid_size = ff_size if ff_size else hid_size + dec_enc_attn_format = 'use_kv' if multihead_attn_format == 'v1' else 'combined' + + # Encoder Layers + self.enc_attn = [attn_layer('enc_attn', i) for i in range(self.num_layers_enc)] + + self.enc_ffn = [ffn_layer('enc_ffn', i, enc_ffn_hid_size) for i in range(self.num_layers_enc)] + + if self.normalize_out: + self.enc_out_norm = LayerNorm('enc_out_norm', + inp_size=emb_size if self.num_layers_enc == 0 else hid_size) + + # Decoder layers + self.dec_attn = [attn_layer('dec_attn', i) for i in range(self.num_layers_dec)] + self.dec_enc_attn = [attn_layer('dec_enc_attn', i, _format=dec_enc_attn_format) for i in + range(self.num_layers_dec)] + + self.dec_ffn = [ffn_layer('dec_ffn', i, dec_ffn_hid_size) for i in range(self.num_layers_dec)] + + if self.normalize_out: + self.dec_out_norm = LayerNorm('dec_out_norm', + inp_size=emb_size if self.num_layers_dec == 0 else hid_size) + + def encode(self, inp, inp_len, is_train): + with dropout_scope(is_train), tf.name_scope('mod_enc') as scope: + + # Embeddings + emb_inp = self.emb_inp(inp) # [batch_size * ninp * emb_dim] + if self.rescale_emb: + emb_inp *= self.emb_size ** .5 + emb_inp += self.emb_inp_bias + + # Prepare decoder + enc_attn_mask = self._make_enc_attn_mask(inp, inp_len) # [batch_size * 1 * 1 * ninp] + + enc_inp = self._add_timing_signal(emb_inp) + + # Apply dropouts + if is_dropout_enabled(): + enc_inp = tf.nn.dropout(enc_inp, 1.0 - self.res_dropout) + + tf.add_to_collection("LayerEmbeddings", enc_inp) + + # Encoder + for layer in range(self.num_layers_enc): + enc_inp = self.enc_attn[layer](enc_inp, enc_attn_mask) + enc_inp = self.enc_ffn[layer](enc_inp, summarize_preactivations=self.summarize_preactivations) + tf.add_to_collection("LayerEmbeddings", enc_inp) + + if self.normalize_out: + enc_inp = self.enc_out_norm(enc_inp) + + tf.add_to_collection(lib.meta.ACTIVATIONS, tf.identity(enc_inp, name=scope)) + + return enc_inp, enc_attn_mask + + def decode(self, out, out_len, out_reverse, enc_out, enc_attn_mask, is_train): + with dropout_scope(is_train), tf.name_scope('mod_dec') as scope: + # Embeddings + emb_out = self.emb_out(out) # [batch_size * nout * emb_dim] + if self.rescale_emb: + emb_out *= self.emb_size ** .5 + + # Shift right; drop embedding for last word + emb_out = tf.pad(emb_out, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] + + # Prepare decoder + dec_attn_mask = self._make_dec_attn_mask(out) # [1 * 1 * nout * nout] + + offset = 'random' if self.dst_rand_offset else 0 + dec_inp = self._add_timing_signal(emb_out, offset=offset, inp_reverse=out_reverse) + # Apply dropouts + if is_dropout_enabled(): + dec_inp = dropout(dec_inp, 1.0 - self.res_dropout) + + # bypass info from Encoder to avoid None gradients for num_layers_dec == 0 + if self.num_layers_dec == 0: + inp_mask = tf.squeeze(tf.transpose(enc_attn_mask, perm=[3, 1, 2, 0]), 3) + dec_inp += tf.reduce_mean(enc_out * inp_mask, axis=[0, 1], keep_dims=True) + + # Decoder + for layer in range(self.num_layers_dec): + dec_inp = self.dec_attn[layer](dec_inp, dec_attn_mask) + dec_inp = self.dec_enc_attn[layer](dec_inp, enc_attn_mask, enc_out) + dec_inp = self.dec_ffn[layer](dec_inp, summarize_preactivations=self.summarize_preactivations) + + if self.normalize_out: + dec_inp = self.dec_out_norm(dec_inp) + + tf.add_to_collection(lib.meta.ACTIVATIONS, tf.identity(dec_inp, name=scope)) + + return dec_inp + + def relprop_decode(self, R): + """ propagates relevances from rdo to output embeddings and encoder state """ + R_enc = 0.0 + R_enc_scale = 0.0 + for layer in range(self.num_layers_dec)[::-1]: + R = self.dec_ffn[layer].relprop(R) + + relevance_dict = self.dec_enc_attn[layer].relprop(R, main_key='query_inp') + R = relevance_dict['query_inp'] + R_enc += relevance_dict['kv_inp'] + R_enc_scale += tf.reduce_sum(abs(relevance_dict['kv_inp'])) + + R = self.dec_attn[layer].relprop(R) + + # shift left: compensate for right shift + R = LRP.rescale(R, tf.pad(R, [[0, 0], [0, 1], [0, 0]])[:, 1:, :]) + return {'emb_out': R, 'enc_out': R_enc * R_enc_scale / tf.reduce_sum(abs(R_enc))} + + def relprop_encode(self, R): + """ propagates relevances from enc_out to emb_inp """ + for layer in range(self.num_layers_enc)[::-1]: + R = self.enc_ffn[layer].relprop(R) + R = self.enc_attn[layer].relprop(R) + return R + + def relprop_encode_decode(self, R): + """ propagates relevances from rdo to input and optput embeddings """ + relevances = self.relprop_decode(R) + relevances['emb_inp'] = self.relprop_encode(relevances['enc_out']) + return relevances + + def _make_enc_attn_mask(self, inp, inp_len, dtype=tf.float32): + """ + inp = [batch_size * ninp] + inp_len = [batch_size] + + attn_mask = [batch_size * 1 * 1 * ninp] + """ + with tf.variable_scope("make_enc_attn_mask"): + inp_mask = tf.sequence_mask(inp_len, dtype=dtype, maxlen=tf.shape(inp)[1]) + + attn_mask = inp_mask[:, None, None, :] + return attn_mask + + def _make_dec_attn_mask(self, out, dtype=tf.float32): + """ + out = [baatch_size * nout] + + attn_mask = [1 * 1 * nout * nout] + """ + with tf.variable_scope("make_dec_attn_mask"): + length = tf.shape(out)[1] + lower_triangle = tf.matrix_band_part(tf.ones([length, length], dtype=dtype), -1, 0) + attn_mask = tf.reshape(lower_triangle, [1, 1, length, length]) + return attn_mask + + def _add_timing_signal(self, inp, min_timescale=1.0, max_timescale=1.0e4, offset=0, inp_reverse=None): + """ + inp: (batch_size * ninp * hid_dim) + :param offset: add this number to all character positions. + if offset == 'random', picks this number uniformly from [-32000,32000] integers + :type offset: number, tf.Tensor or 'random' + """ + with tf.variable_scope("add_timing_signal"): + ninp = tf.shape(inp)[1] + hid_size = tf.shape(inp)[2] + + position = tf.to_float(tf.range(ninp))[None, :, None] + + if offset == 'random': + BIG_LEN = 32000 + offset = tf.random_uniform(tf.shape(position), minval=-BIG_LEN, maxval=BIG_LEN, dtype=tf.int32) + + # force broadcasting over batch axis + if isinstance(offset * 1, tf.Tensor): # multiply by 1 to also select variables, special generators, etc. + assert offset.shape.ndims in (0, 1, 2) + new_shape = [tf.shape(offset)[i] for i in range(offset.shape.ndims)] + new_shape += [1] * (3 - len(new_shape)) + offset = tf.reshape(offset, new_shape) + + position += tf.to_float(offset) + + if inp_reverse is not None: + position = tf.multiply( + position, + tf.where( + tf.equal(inp_reverse, 0), + tf.ones_like(inp_reverse, dtype=tf.float32), + -1.0 * tf.ones_like(inp_reverse, dtype=tf.float32) + )[:, None, None] # (batch_size * ninp * dim) + ) + num_timescales = hid_size // 2 + log_timescale_increment = ( + math.log(float(max_timescale) / float(min_timescale)) / + (tf.to_float(num_timescales) - 1)) + inv_timescales = min_timescale * tf.exp( + tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) + + # scaled_time: [ninp * hid_dim] + scaled_time = position * inv_timescales[None, None, :] + signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=-1) + signal = tf.pad(signal, [[0, 0], [0, 0], [0, tf.mod(hid_size, 2)]]) + return inp + signal + + +# ============================================================================ +# Transformer model + +class Model(TranslateModelBase): + + def __init__(self, name, inp_voc, out_voc, **hp): + self.name = name + self.inp_voc = inp_voc + self.out_voc = out_voc + self.hp = hp + + # Parameters + self.transformer = Transformer(name, inp_voc, out_voc, **hp) + + projection_matrix = None + if hp.get('dwwt', False): + projection_matrix = tf.transpose(self.transformer.emb_out.mat) + + self.loss = LossXent( + hp.get('loss_name', 'loss_xent_lm'), + hp['hid_size'], + out_voc, + hp, + matrix=projection_matrix, + bias=None if hp.get("loss_bias", False) else 0) + + inference_mode = hp.get("inference_mode", "fast") + if inference_mode == 'fast': + self.translate_model = TranslateModelFast(self.name, self.transformer, self.loss, self.inp_voc, + self.out_voc) + elif inference_mode == 'lazy': + self.translate_model = TranslateModelLazy(self.name, self.transformer, self.loss, self.inp_voc, + self.out_voc) + else: + raise NotImplementedError("inference_mode %s is not supported" % inference_mode) + + # Train interface + def encode_decode(self, batch, is_train, score_info=False): + inp = batch['inp'] # [batch_size * ninp] + out = batch['out'] # [batch_size * nout] + inp_len = batch.get('inp_len', infer_length(inp, self.inp_voc.eos, time_major=False)) # [batch] + out_len = batch.get('out_len', infer_length(out, self.out_voc.eos, time_major=False)) # [batch] + + out_reverse = tf.zeros_like(inp_len) # batch['out_reverse'] + + # rdo: [batch_size * nout * hid_dim] + enc_out, enc_attn_mask = self.transformer.encode(inp, inp_len, is_train) + rdo = self.transformer.decode(out, out_len, out_reverse, enc_out, enc_attn_mask, is_train) + + return rdo + + def make_feed_dict(self, batch, **kwargs): + feed_dict = make_batch_data(batch, self.inp_voc, self.out_voc, + force_bos=self.hp.get('force_bos', False), + **kwargs) + return feed_dict + + # ======== TranslateModel for Inference ============ + def encode(self, batch, **flags): + """ + :param batch: a dict of {string:symbolic tensor} that model understands. + By default it should accept {'inp': int32 matrix[batch,time]} + :return: initial decoder state + """ + return self.translate_model.encode(batch, **flags) + + def decode(self, dec_state, words=None, **flags): + """ + Performs decoding step given words and previous state. + :param words: previous output tokens, int32[batch_size]. if None, uses zero embeddings (first step) + :returns: next state + """ + return self.translate_model.decode(dec_state, words, **flags) + + def sample(self, dec_state, base_scores, slices, k, **kwargs): + return self.translate_model.sample(dec_state, base_scores, slices, k, **kwargs) + + def get_rdo(self, dec_state, **kwargs): + return self.translate_model.get_rdo(dec_state, **kwargs) + + def get_attnP(self, dec_state, **kwargs): + return self.translate_model.get_attnP(dec_state, **kwargs) + + +class ScopedModel(Model): + + def __init__(self, name, inp_voc, out_voc, **hp): + with tf.variable_scope(name): + super(ScopedModel, self).__init__(name, inp_voc, out_voc, **hp) + + def encode_decode(self, *args, **kwargs): + with tf.name_scope(self.name): + return super(ScopedModel, self).encode_decode(*args, **kwargs) + + def encode(self, *args, **kwargs): + with tf.name_scope(self.name): + return super(ScopedModel, self).encode(*args, **kwargs) + + def decode(self, *args, **kwargs): + with tf.name_scope(self.name): + return super(ScopedModel, self).decode(*args, **kwargs) + + def sample(self, *args, **kwargs): + with tf.name_scope(self.name): + return super(ScopedModel, self).sample(*args, **kwargs) + + +# ============================================================================ +# Transformer inference + +class TranslateModelFast(TranslateModel): + DecState = namedtuple("transformer_state", ['enc_out', 'enc_attn_mask', 'attnP', 'rdo', 'out_seq', 'offset', + 'emb', 'dec_layers', 'dec_enc_kv', 'dec_dec_kv']) + + def __init__(self, name, transformer, loss, inp_voc, out_voc): + """ + A translation model that performs quick (n^2) inference for transformer + with manual implementation of 1-step decoding + """ + self.name = name + self.transformer = transformer + self.loss = loss + self.inp_voc = inp_voc + self.out_voc = out_voc + + def encode(self, batch, is_train=False, **kwargs): + """ + :param batch: a dict containing 'inp':int32[batch_size * ninp] and optionally inp_len:int32[batch_size] + :param is_train: if True, enables dropouts + """ + inp = batch['inp'] + inp_len = batch.get('inp_len', infer_length(inp, self.inp_voc.eos, time_major=False)) + with dropout_scope(is_train), tf.name_scope(self.transformer.name): + # Encode. + enc_out, enc_attn_mask = self.transformer.encode(inp, inp_len, is_train=False) + + # Decoder dummy input/output + ninp = tf.shape(inp)[1] + batch_size = tf.shape(inp)[0] + hid_size = tf.shape(enc_out)[-1] + out_seq = tf.zeros([batch_size, 0], dtype=inp.dtype) + rdo = tf.zeros([batch_size, hid_size], dtype=enc_out.dtype) + + attnP = tf.ones([batch_size, ninp]) / tf.to_float(inp_len)[:, None] + + offset = tf.zeros((batch_size,)) + if self.transformer.dst_rand_offset: + BIG_LEN = 32000 + random_offset = tf.random_uniform(tf.shape(offset), minval=-BIG_LEN, maxval=BIG_LEN, dtype=tf.int32) + offset += tf.to_float(random_offset) + + trans = self.transformer + empty_emb = tf.zeros([batch_size, 0, trans.emb_size]) + empty_dec_layers = [tf.zeros([batch_size, 0, trans.hid_size])] * trans.num_layers_dec + input_layers = [empty_emb] + empty_dec_layers[:-1] + + # prepare kv parts for all decoder attention layers. Note: we do not preprocess enc_out + # for each layer because ResidualLayerWrapper only preprocesses first input (query) + dec_enc_kv = [layer.kv_conv(enc_out) + for i, layer in enumerate(trans.dec_enc_attn)] + dec_dec_kv = [layer.kv_conv(layer.preprocess(input_layers[i])) + for i, layer in enumerate(trans.dec_attn)] + + new_state = self.DecState(enc_out, enc_attn_mask, attnP, rdo, out_seq, offset, + empty_emb, empty_dec_layers, dec_enc_kv, dec_dec_kv) + + # perform initial decode (instead of force_bos) with zero embeddings + new_state = self.decode(new_state, is_train=is_train) + return new_state + + def decode(self, dec_state, words=None, is_train=False, **kwargs): + """ + Performs decoding step given words and previous state. + Returns next state. + + :param words: previous output tokens, int32[batch_size]. if None, uses zero embeddings (first step) + :param is_train: if True, enables dropouts + """ + trans = self.transformer + enc_out, enc_attn_mask, attnP, rdo, out_seq, offset, prev_emb = dec_state[:7] + prev_dec_layers = dec_state.dec_layers + dec_enc_kv = dec_state.dec_enc_kv + dec_dec_kv = dec_state.dec_dec_kv + + batch_size = tf.shape(rdo)[0] + if words is not None: + out_seq = tf.concat([out_seq, tf.expand_dims(words, 1)], 1) + + with dropout_scope(is_train), tf.name_scope(trans.name): + # Embeddings + if words is None: + # initial step: words are None + emb_out = tf.zeros((batch_size, 1, trans.emb_size)) + else: + emb_out = trans.emb_out(words[:, None]) # [batch_size * 1 * emb_dim] + if trans.rescale_emb: + emb_out *= trans.emb_size ** .5 + + # Prepare decoder + dec_inp_t = trans._add_timing_signal(emb_out, offset=offset) + # Apply dropouts + if is_dropout_enabled(): + dec_inp_t = tf.nn.dropout(dec_inp_t, 1.0 - trans.res_dropout) + + # bypass info from Encoder to avoid None gradients for num_layers_dec == 0 + if trans.num_layers_dec == 0: + inp_mask = tf.squeeze(tf.transpose(enc_attn_mask, perm=[3, 1, 2, 0]), 3) + dec_inp_t += tf.reduce_mean(enc_out * inp_mask, axis=[0, 1], keep_dims=True) + + # Decoder + new_emb = tf.concat([prev_emb, dec_inp_t], axis=1) + _out = tf.pad(out_seq, [(0, 0), (0, 1)]) + dec_attn_mask = trans._make_dec_attn_mask(_out)[:, :, -1:, :] # [1, 1, n_q=1, n_kv] + + new_dec_layers = [] + new_dec_dec_kv = [] + + for layer in range(trans.num_layers_dec): + # multi-head self-attention: use only the newest time-step as query, + # but all time-steps up to newest one as keys/values + next_dec_kv = trans.dec_attn[layer].kv_conv(trans.dec_attn[layer].preprocess(dec_inp_t)) + new_dec_dec_kv.append(tf.concat([dec_dec_kv[layer], next_dec_kv], axis=1)) + dec_inp_t = trans.dec_attn[layer](dec_inp_t, dec_attn_mask, kv=new_dec_dec_kv[layer]) + + dec_inp_t = trans.dec_enc_attn[layer](dec_inp_t, enc_attn_mask, kv=dec_enc_kv[layer]) + dec_inp_t = trans.dec_ffn[layer](dec_inp_t, summarize_preactivations=trans.summarize_preactivations) + + new_dec_inp = tf.concat([prev_dec_layers[layer], dec_inp_t], axis=1) + new_dec_layers.append(new_dec_inp) + + if trans.normalize_out: + dec_inp_t = trans.dec_out_norm(dec_inp_t) + + rdo = dec_inp_t[:, -1] + + new_state = self.DecState(enc_out, enc_attn_mask, attnP, rdo, out_seq, offset + 1, + new_emb, new_dec_layers, dec_enc_kv, new_dec_dec_kv) + return new_state + + def get_rdo(self, dec_state, **kwargs): + return dec_state.rdo, dec_state.out_seq + + def get_attnP(self, dec_state, **kwargs): + return dec_state.attnP + + +class TranslateModelLazy(TranslateModel): + def __init__(self, name, transformer, loss, inp_voc, out_voc): + """ + Automatically implements O(n^3) decoding by using trans.decode + """ + self.name = name + self.transformer = transformer + self.loss = loss + self.inp_voc = inp_voc + self.out_voc = out_voc + + def encode(self, batch, is_train=False, **kwargs): + """ + :param batch: a dict of placeholders + inp: [batch_size * ninp] + inp_len; [batch_size] + """ + inp = batch['inp'] + inp_len = batch['inp_len'] + with dropout_scope(is_train), tf.name_scope(self.transformer.name): + # Encode. + enc_out, enc_attn_mask = self.transformer.encode(inp, inp_len, is_train=False) + + # Decoder dummy input/output + ninp = tf.shape(inp)[1] + batch_size = tf.shape(inp)[0] + hid_size = tf.shape(enc_out)[-1] + out_seq = tf.zeros([batch_size, 0], dtype=inp.dtype) + rdo = tf.zeros([batch_size, hid_size], dtype=enc_out.dtype) + + attnP = tf.ones([batch_size, ninp]) / tf.to_float(inp_len)[:, None] + + return self._decode_impl((enc_out, enc_attn_mask, attnP, out_seq, rdo), **kwargs) + + def decode(self, dec_state, words, **kwargs): + """ + Performs decoding step given words + + words: [batch_size] + """ + with tf.name_scope(self.transformer.name): + (enc_out, enc_attn_mask, attnP, prev_out_seq, rdo) = dec_state + out_seq = tf.concat([prev_out_seq, tf.expand_dims(words, 1)], 1) + return self._decode_impl((enc_out, enc_attn_mask, attnP, out_seq, rdo), **kwargs) + + def _decode_impl(self, dec_state, is_train=False, **kwargs): + (enc_out, enc_attn_mask, attnP, out_seq, rdo) = dec_state + + with dropout_scope(is_train): + out = tf.pad(out_seq, [(0, 0), (0, 1)]) + out_len = tf.fill(dims=(tf.shape(out)[0],), value=tf.shape(out_seq)[1]) + out_reverse = tf.zeros_like(out_len) # batch['out_reverse'] + dec_out = self.transformer.decode(out, out_len, out_reverse, enc_out, enc_attn_mask, is_train=False) + rdo = dec_out[:, -1, :] # [batch_size * hid_dim] + + attnP = enc_attn_mask[:, 0, 0, :] # [batch_size * ninp ] + attnP /= tf.reduce_sum(attnP, axis=1, keep_dims=True) + + return (enc_out, enc_attn_mask, attnP, out_seq, rdo) + + def get_rdo(self, dec_state, **kwargs): + rdo = dec_state[4] + out = dec_state[3] + return rdo, out + + def get_attnP(self, dec_state, **kwargs): + return dec_state[2] + diff --git a/lib/task/seq2seq/models/transformer_head_gates.py b/lib/task/seq2seq/models/transformer_head_gates.py new file mode 100644 index 0000000..1f0f63c --- /dev/null +++ b/lib/task/seq2seq/models/transformer_head_gates.py @@ -0,0 +1,651 @@ +#!/usr/bin/env python3 +from lib.layers import * +from lib.ops import * +from ..models import TranslateModelBase, TranslateModel +from ..data import * +from collections import namedtuple + + +class Transformer: + def __init__( + self, name, + inp_voc, out_voc, + *_args, + emb_size=None, hid_size=512, + key_size=None, value_size=None, + inner_hid_size=None, # DEPRECATED. Left for compatibility with older experiments + ff_size=None, + num_heads=8, num_layers=6, + attn_dropout=0.0, attn_value_dropout=0.0, relu_dropout=0.0, res_dropout=0.1, + share_emb=False, inp_emb_bias=False, rescale_emb=False, + dst_reverse=False, dst_rand_offset=False, summarize_preactivations=False, + res_steps='nlda', normalize_out=False, multihead_attn_format='v1', + emb_inp_device='', emb_out_device='', + concrete_heads={}, # any subset of {enc-self, dec-self, dec-enc} + alive_heads={}, # {enc-self: [[1,1,1,0,1,0,0,0], [0,0,0,1,0,1,0,1], ..., [0,1,0,1,0,0,0,0]], + # dec-self: [...], + # dec-enc: [...]} + num_layers_enc=0, + num_layers_dec=0, + **_kwargs + ): + + for attn_type in ['enc-self', 'dec-self', 'dec-enc']: + assert not (attn_type in concrete_heads and attn_type in alive_heads),\ + "'{}' is passed as both with trainable concrete gates heads and fixed gates".format(attn_type) + + if isinstance(ff_size, str): + ff_size = [int(i) for i in ff_size.split(':')] + + if _args: + raise Exception("Unexpected positional arguments") + + emb_size = emb_size if emb_size else hid_size + key_size = key_size if key_size else hid_size + value_size = value_size if value_size else hid_size + if key_size % num_heads != 0: + raise Exception("Bad number of heads") + if value_size % num_heads != 0: + raise Exception("Bad number of heads") + + self.name = name + self.num_layers_enc = num_layers if num_layers_enc == 0 else num_layers_enc + self.num_layers_dec = num_layers if num_layers_dec == 0 else num_layers_dec + self.res_dropout = res_dropout + self.emb_size = emb_size + self.hid_size = hid_size + self.rescale_emb = rescale_emb + self.summarize_preactivations = summarize_preactivations + self.dst_reverse = dst_reverse + self.dst_rand_offset = dst_rand_offset + self.normalize_out = normalize_out + + with tf.variable_scope(name): + max_voc_size = max(inp_voc.size(), out_voc.size()) + + self.emb_inp = Embedding( + 'emb_inp', max_voc_size if share_emb else inp_voc.size(), emb_size, + initializer=tf.random_normal_initializer(0, emb_size ** -.5), + device=emb_inp_device) + + self.emb_out = Embedding( + 'emb_out', max_voc_size if share_emb else out_voc.size(), emb_size, + matrix=self.emb_inp.mat if share_emb else None, + initializer=tf.random_normal_initializer(0, emb_size ** -.5), + device=emb_out_device) + + self.emb_inp_bias = 0 + if inp_emb_bias: + self.emb_inp_bias = get_model_variable('emb_inp_bias', shape=[1, 1, emb_size]) + + def get_layer_params(layer_prefix, layer_idx): + layer_name = '%s-%i' % (layer_prefix, layer_idx) + inp_out_size = emb_size if layer_idx == 0 else hid_size + return layer_name, inp_out_size + + def attn_layer(layer_prefix, layer_idx, **kwargs): + layer_name, inp_out_size = get_layer_params(layer_prefix, layer_idx) + return ResidualLayerWrapper( + layer_name, + MultiHeadAttn( + layer_name, + inp_size=inp_out_size, + key_depth=key_size, + value_depth=value_size, + output_depth=hid_size, + num_heads=num_heads, + attn_dropout=attn_dropout, + attn_value_dropout=attn_value_dropout, + **kwargs), + inp_size=inp_out_size, + out_size=inp_out_size, + steps=res_steps, + dropout=res_dropout) + + def attn_layer_concrete_heads(layer_prefix, layer_idx, **kwargs): + layer_name, inp_out_size = get_layer_params(layer_prefix, layer_idx) + return ResidualLayerWrapper( + layer_name, + MultiHeadAttnConcrete( + layer_name, + inp_size=inp_out_size, + key_depth=key_size, + value_depth=value_size, + output_depth=hid_size, + num_heads=num_heads, + attn_dropout=attn_dropout, + attn_value_dropout=attn_value_dropout, + **kwargs), + inp_size=inp_out_size, + out_size=inp_out_size, + steps=res_steps, + dropout=res_dropout) + + def attn_layer_fixed_alive_heads(layer_prefix, layer_idx, head_gate, **kwargs): + layer_name, inp_out_size = get_layer_params(layer_prefix, layer_idx) + return ResidualLayerWrapper( + layer_name, + MultiHeadAttnFixedAliveHeads( + layer_name, + inp_size=inp_out_size, + key_depth=key_size, + value_depth=value_size, + output_depth=hid_size, + num_heads=num_heads, + attn_dropout=attn_dropout, + attn_value_dropout=attn_value_dropout, + head_gate=head_gate, + **kwargs), + inp_size=inp_out_size, + out_size=inp_out_size, + steps=res_steps, + dropout=res_dropout) + + def ffn_layer(layer_prefix, layer_idx, ffn_hid_size): + layer_name, inp_out_size = get_layer_params(layer_prefix, layer_idx) + return ResidualLayerWrapper( + layer_name, + FFN( + layer_name, + inp_size=inp_out_size, + hid_size=ffn_hid_size, + out_size=hid_size, + relu_dropout=relu_dropout), + inp_size=inp_out_size, + out_size=hid_size, + steps=res_steps, + dropout=res_dropout) + + # Encoder/decoder layer params + enc_ffn_hid_size = ff_size if ff_size else (inner_hid_size if inner_hid_size else hid_size) + dec_ffn_hid_size = ff_size if ff_size else hid_size + dec_enc_attn_format = 'use_kv' if multihead_attn_format == 'v1' else 'combined' + + # Encoder Layers + self.enc_attn = [attn_layer_concrete_heads('enc_attn', i) if 'enc-self' in concrete_heads else + attn_layer('enc_attn', i) if not 'enc-self' in alive_heads else + attn_layer_fixed_alive_heads('enc_attn', i, alive_heads['enc-self'][i]) + for i in range(self.num_layers_enc)] + + self.enc_ffn = [ffn_layer('enc_ffn', i, enc_ffn_hid_size) for i in range(self.num_layers_enc)] + + if self.normalize_out: + self.enc_out_norm = LayerNorm('enc_out_norm', + inp_size=emb_size if self.num_layers_enc == 0 else hid_size) + + # Decoder layers + self.dec_attn = [attn_layer_concrete_heads('dec_attn', i) if 'dec-self' in concrete_heads else + attn_layer('dec_attn', i) if not 'dec-self' in alive_heads else + attn_layer_fixed_alive_heads('dec_attn', i, alive_heads['dec-self'][i]) + for i in range(self.num_layers_dec)] + + self.dec_enc_attn = [attn_layer_concrete_heads('dec_enc_attn', i, _format=dec_enc_attn_format) \ + if 'dec-enc' in concrete_heads else \ + attn_layer('dec_enc_attn', i, _format=dec_enc_attn_format) if \ + not 'dec-enc' in alive_heads else \ + attn_layer_fixed_alive_heads('dec_enc_attn', i, alive_heads['dec-enc'][i], _format=dec_enc_attn_format) + for i in range(self.num_layers_enc)] + + self.dec_ffn = [ffn_layer('dec_ffn', i, dec_ffn_hid_size) for i in range(self.num_layers_dec)] + + if self.normalize_out: + self.dec_out_norm = LayerNorm('dec_out_norm', + inp_size=emb_size if self.num_layers_dec == 0 else hid_size) + + def encode(self, inp, inp_len, is_train): + with dropout_scope(is_train), tf.name_scope('mod_enc') as scope: + + # Embeddings + emb_inp = self.emb_inp(inp) # [batch_size * ninp * emb_dim] + if self.rescale_emb: + emb_inp *= self.emb_size ** .5 + emb_inp += self.emb_inp_bias + + # Prepare decoder + enc_attn_mask = self._make_enc_attn_mask(inp, inp_len) # [batch_size * 1 * 1 * ninp] + + enc_inp = self._add_timing_signal(emb_inp) + + # Apply dropouts + if is_dropout_enabled(): + enc_inp = tf.nn.dropout(enc_inp, 1.0 - self.res_dropout) + + # Encoder + for layer in range(self.num_layers_enc): + enc_inp = self.enc_attn[layer](enc_inp, enc_attn_mask) + enc_inp = self.enc_ffn[layer](enc_inp, summarize_preactivations=self.summarize_preactivations) + + if self.normalize_out: + enc_inp = self.enc_out_norm(enc_inp) + + tf.add_to_collection(lib.meta.ACTIVATIONS, tf.identity(enc_inp, name=scope)) + + return enc_inp, enc_attn_mask + + def decode(self, out, out_len, out_reverse, enc_out, enc_attn_mask, is_train): + with dropout_scope(is_train), tf.name_scope('mod_dec') as scope: + # Embeddings + emb_out = self.emb_out(out) # [batch_size * nout * emb_dim] + if self.rescale_emb: + emb_out *= self.emb_size ** .5 + + # Shift right; drop embedding for last word + emb_out = tf.pad(emb_out, [[0, 0], [1, 0], [0, 0]])[:, :-1, :] + + # Prepare decoder + dec_attn_mask = self._make_dec_attn_mask(out) # [1 * 1 * nout * nout] + + offset = 'random' if self.dst_rand_offset else 0 + dec_inp = self._add_timing_signal(emb_out, offset=offset, inp_reverse=out_reverse) + # Apply dropouts + if is_dropout_enabled(): + dec_inp = tf.nn.dropout(dec_inp, 1.0 - self.res_dropout) + + # bypass info from Encoder to avoid None gradients for num_layers_dec == 0 + if self.num_layers_dec == 0: + inp_mask = tf.squeeze(tf.transpose(enc_attn_mask, perm=[3, 1, 2, 0]), 3) + dec_inp += tf.reduce_mean(enc_out * inp_mask, axis=[0, 1], keep_dims=True) + + # Decoder + for layer in range(self.num_layers_dec): + dec_inp = self.dec_attn[layer](dec_inp, dec_attn_mask) + dec_inp = self.dec_enc_attn[layer](dec_inp, enc_attn_mask, enc_out) + dec_inp = self.dec_ffn[layer](dec_inp, summarize_preactivations=self.summarize_preactivations) + + if self.normalize_out: + dec_inp = self.dec_out_norm(dec_inp) + + tf.add_to_collection(lib.meta.ACTIVATIONS, tf.identity(dec_inp, name=scope)) + + return dec_inp + + def _make_enc_attn_mask(self, inp, inp_len, dtype=tf.float32): + """ + inp = [batch_size * ninp] + inp_len = [batch_size] + + attn_mask = [batch_size * 1 * 1 * ninp] + """ + with tf.variable_scope("make_enc_attn_mask"): + inp_mask = tf.sequence_mask(inp_len, dtype=dtype, maxlen=tf.shape(inp)[1]) + + attn_mask = inp_mask[:, None, None, :] + return attn_mask + + def _make_dec_attn_mask(self, out, dtype=tf.float32): + """ + out = [baatch_size * nout] + + attn_mask = [1 * 1 * nout * nout] + """ + with tf.variable_scope("make_dec_attn_mask"): + length = tf.shape(out)[1] + lower_triangle = tf.matrix_band_part(tf.ones([length, length], dtype=dtype), -1, 0) + attn_mask = tf.reshape(lower_triangle, [1, 1, length, length]) + return attn_mask + + def _add_timing_signal(self, inp, min_timescale=1.0, max_timescale=1.0e4, offset=0, inp_reverse=None): + """ + inp: (batch_size * ninp * hid_dim) + :param offset: add this number to all character positions. + if offset == 'random', picks this number uniformly from [-32000,32000] integers + :type offset: number, tf.Tensor or 'random' + """ + with tf.variable_scope("add_timing_signal"): + ninp = tf.shape(inp)[1] + hid_size = tf.shape(inp)[2] + + position = tf.to_float(tf.range(ninp))[None, :, None] + + if offset == 'random': + BIG_LEN = 32000 + offset = tf.random_uniform(tf.shape(position), minval=-BIG_LEN, maxval=BIG_LEN, dtype=tf.int32) + + # force broadcasting over batch axis + if isinstance(offset * 1, tf.Tensor): # multiply by 1 to also select variables, special generators, etc. + assert offset.shape.ndims in (0, 1, 2) + new_shape = [tf.shape(offset)[i] for i in range(offset.shape.ndims)] + new_shape += [1] * (3 - len(new_shape)) + offset = tf.reshape(offset, new_shape) + + position += tf.to_float(offset) + + if inp_reverse is not None: + position = tf.multiply( + position, + tf.where( + tf.equal(inp_reverse, 0), + tf.ones_like(inp_reverse, dtype=tf.float32), + -1.0 * tf.ones_like(inp_reverse, dtype=tf.float32) + )[:, None, None] # (batch_size * ninp * dim) + ) + num_timescales = hid_size // 2 + log_timescale_increment = ( + math.log(float(max_timescale) / float(min_timescale)) / + (tf.to_float(num_timescales) - 1)) + inv_timescales = min_timescale * tf.exp( + tf.to_float(tf.range(num_timescales)) * -log_timescale_increment) + + # scaled_time: [ninp * hid_dim] + scaled_time = position * inv_timescales[None, None, :] + signal = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)], axis=-1) + signal = tf.pad(signal, [[0, 0], [0, 0], [0, tf.mod(hid_size, 2)]]) + return inp + signal + + +# ============================================================================ +# Transformer model + +class Model(TranslateModelBase): + + def __init__(self, name, inp_voc, out_voc, **hp): + self.name = name + self.inp_voc = inp_voc + self.out_voc = out_voc + self.hp = hp + + # Parameters + self.transformer = Transformer(name, inp_voc, out_voc, **hp) + + projection_matrix = None + if hp.get('dwwt', False): + projection_matrix = tf.transpose(self.transformer.emb_out.mat) + + self.loss = LossXent( + hp.get('loss_name', 'loss_xent_lm'), + hp['hid_size'], + out_voc, + hp, + matrix=projection_matrix, + bias=None if hp.get("loss_bias", False) else 0) + + inference_mode = hp.get("inference_mode", "fast") + if inference_mode == 'fast': + self.translate_model = TranslateModelFast(self.name, self.transformer, self.loss, self.inp_voc, + self.out_voc) + elif inference_mode == 'lazy': + self.translate_model = TranslateModelLazy(self.name, self.transformer, self.loss, self.inp_voc, + self.out_voc) + else: + raise NotImplementedError("inference_mode %s is not supported" % inference_mode) + + # Train interface + def encode_decode(self, batch, is_train, score_info=False): + inp = batch['inp'] # [batch_size * ninp] + out = batch['out'] # [batch_size * nout] + inp_len = batch.get('inp_len', infer_length(inp, self.inp_voc.eos, time_major=False)) # [batch] + out_len = batch.get('out_len', infer_length(out, self.out_voc.eos, time_major=False)) # [batch] + + out_reverse = tf.zeros_like(inp_len) # batch['out_reverse'] + + # rdo: [batch_size * nout * hid_dim] + enc_out, enc_attn_mask = self.transformer.encode(inp, inp_len, is_train) + rdo = self.transformer.decode(out, out_len, out_reverse, enc_out, enc_attn_mask, is_train) + + return rdo + + def make_feed_dict(self, batch, **kwargs): + feed_dict = make_batch_data(batch, self.inp_voc, self.out_voc, + force_bos=self.hp.get('force_bos', False), + **kwargs) + return feed_dict + + + + # ======== TranslateModel for Inference ============ + def encode(self, batch, **flags): + """ + :param batch: a dict of {string:symbolic tensor} that model understands. + By default it should accept {'inp': int32 matrix[batch,time]} + :return: initial decoder state + """ + return self.translate_model.encode(batch, **flags) + + def decode(self, dec_state, words=None, **flags): + """ + Performs decoding step given words and previous state. + :param words: previous output tokens, int32[batch_size]. if None, uses zero embeddings (first step) + :returns: next state + """ + return self.translate_model.decode(dec_state, words, **flags) + + def sample(self, dec_state, base_scores, slices, k, **kwargs): + return self.translate_model.sample(dec_state, base_scores, slices, k, **kwargs) + + def get_rdo(self, dec_state, **kwargs): + return self.translate_model.get_rdo(dec_state, **kwargs) + + def get_attnP(self, dec_state, **kwargs): + return self.translate_model.get_attnP(dec_state, **kwargs) + + +class ScopedModel(Model): + + def __init__(self, name, inp_voc, out_voc, **hp): + with tf.variable_scope(name): + super(ScopedModel, self).__init__(name, inp_voc, out_voc, **hp) + + def encode_decode(self, *args, **kwargs): + with tf.name_scope(self.name): + return super(ScopedModel, self).encode_decode(*args, **kwargs) + + def encode(self, *args, **kwargs): + with tf.name_scope(self.name): + return super(ScopedModel, self).encode(*args, **kwargs) + + def decode(self, *args, **kwargs): + with tf.name_scope(self.name): + return super(ScopedModel, self).decode(*args, **kwargs) + + def sample(self, *args, **kwargs): + with tf.name_scope(self.name): + return super(ScopedModel, self).sample(*args, **kwargs) + + +# ============================================================================ +# Transformer inference + +class TranslateModelFast(TranslateModel): + DecState = namedtuple("transformer_state", ['enc_out', 'enc_attn_mask', 'attnP', 'rdo', 'out_seq', 'offset', + 'emb', 'dec_layers', 'dec_enc_kv', 'dec_dec_kv']) + + def __init__(self, name, transformer, loss, inp_voc, out_voc): + """ + A translation model that performs quick (n^2) inference for transformer + with manual implementation of 1-step decoding + """ + self.name = name + self.transformer = transformer + self.loss = loss + self.inp_voc = inp_voc + self.out_voc = out_voc + + def encode(self, batch, is_train=False, **kwargs): + """ + :param batch: a dict containing 'inp':int32[batch_size * ninp] and optionally inp_len:int32[batch_size] + :param is_train: if True, enables dropouts + """ + inp = batch['inp'] + inp_len = batch.get('inp_len', infer_length(inp, self.inp_voc.eos, time_major=False)) + with dropout_scope(is_train), tf.name_scope(self.transformer.name): + # Encode. + enc_out, enc_attn_mask = self.transformer.encode(inp, inp_len, is_train=False) + + # Decoder dummy input/output + ninp = tf.shape(inp)[1] + batch_size = tf.shape(inp)[0] + hid_size = tf.shape(enc_out)[-1] + out_seq = tf.zeros([batch_size, 0], dtype=inp.dtype) + rdo = tf.zeros([batch_size, hid_size], dtype=enc_out.dtype) + + attnP = tf.ones([batch_size, ninp]) / tf.to_float(inp_len)[:, None] + + offset = tf.zeros((batch_size,)) + if self.transformer.dst_rand_offset: + BIG_LEN = 32000 + random_offset = tf.random_uniform(tf.shape(offset), minval=-BIG_LEN, maxval=BIG_LEN, dtype=tf.int32) + offset += tf.to_float(random_offset) + + trans = self.transformer + empty_emb = tf.zeros([batch_size, 0, trans.emb_size]) + empty_dec_layers = [tf.zeros([batch_size, 0, trans.hid_size])] * trans.num_layers_dec + input_layers = [empty_emb] + empty_dec_layers[:-1] + + # prepare kv parts for all decoder attention layers. Note: we do not preprocess enc_out + # for each layer because ResidualLayerWrapper only preprocesses first input (query) + dec_enc_kv = [layer.kv_conv(enc_out) + for i, layer in enumerate(trans.dec_enc_attn)] + dec_dec_kv = [layer.kv_conv(layer.preprocess(input_layers[i])) + for i, layer in enumerate(trans.dec_attn)] + + new_state = self.DecState(enc_out, enc_attn_mask, attnP, rdo, out_seq, offset, + empty_emb, empty_dec_layers, dec_enc_kv, dec_dec_kv) + + # perform initial decode (instead of force_bos) with zero embeddings + new_state = self.decode(new_state, is_train=is_train) + return new_state + + def decode(self, dec_state, words=None, is_train=False, **kwargs): + """ + Performs decoding step given words and previous state. + Returns next state. + + :param words: previous output tokens, int32[batch_size]. if None, uses zero embeddings (first step) + :param is_train: if True, enables dropouts + """ + trans = self.transformer + enc_out, enc_attn_mask, attnP, rdo, out_seq, offset, prev_emb = dec_state[:7] + prev_dec_layers = dec_state.dec_layers + dec_enc_kv = dec_state.dec_enc_kv + dec_dec_kv = dec_state.dec_dec_kv + + batch_size = tf.shape(rdo)[0] + if words is not None: + out_seq = tf.concat([out_seq, tf.expand_dims(words, 1)], 1) + + with dropout_scope(is_train), tf.name_scope(trans.name): + # Embeddings + if words is None: + # initial step: words are None + emb_out = tf.zeros((batch_size, 1, trans.emb_size)) + else: + emb_out = trans.emb_out(words[:, None]) # [batch_size * 1 * emb_dim] + if trans.rescale_emb: + emb_out *= trans.emb_size ** .5 + + # Prepare decoder + dec_inp_t = trans._add_timing_signal(emb_out, offset=offset) + # Apply dropouts + if is_dropout_enabled(): + dec_inp_t = tf.nn.dropout(dec_inp_t, 1.0 - trans.res_dropout) + + # bypass info from Encoder to avoid None gradients for num_layers_dec == 0 + if trans.num_layers_dec == 0: + inp_mask = tf.squeeze(tf.transpose(enc_attn_mask, perm=[3, 1, 2, 0]), 3) + dec_inp_t += tf.reduce_mean(enc_out * inp_mask, axis=[0, 1], keep_dims=True) + + # Decoder + new_emb = tf.concat([prev_emb, dec_inp_t], axis=1) + _out = tf.pad(out_seq, [(0, 0), (0, 1)]) + dec_attn_mask = trans._make_dec_attn_mask(_out)[:, :, -1:, :] # [1, 1, n_q=1, n_kv] + + new_dec_layers = [] + new_dec_dec_kv = [] + + for layer in range(trans.num_layers_dec): + # multi-head self-attention: use only the newest time-step as query, + # but all time-steps up to newest one as keys/values + next_dec_kv = trans.dec_attn[layer].kv_conv(trans.dec_attn[layer].preprocess(dec_inp_t)) + new_dec_dec_kv.append(tf.concat([dec_dec_kv[layer], next_dec_kv], axis=1)) + dec_inp_t = trans.dec_attn[layer](dec_inp_t, dec_attn_mask, kv=new_dec_dec_kv[layer]) + + dec_inp_t = trans.dec_enc_attn[layer](dec_inp_t, enc_attn_mask, kv=dec_enc_kv[layer]) + dec_inp_t = trans.dec_ffn[layer](dec_inp_t, summarize_preactivations=trans.summarize_preactivations) + + new_dec_inp = tf.concat([prev_dec_layers[layer], dec_inp_t], axis=1) + new_dec_layers.append(new_dec_inp) + + if trans.normalize_out: + dec_inp_t = trans.dec_out_norm(dec_inp_t) + + rdo = dec_inp_t[:, -1] + + new_state = self.DecState(enc_out, enc_attn_mask, attnP, rdo, out_seq, offset + 1, + new_emb, new_dec_layers, dec_enc_kv, new_dec_dec_kv) + return new_state + + def get_rdo(self, dec_state, **kwargs): + return dec_state.rdo, dec_state.out_seq + + def get_attnP(self, dec_state, **kwargs): + return dec_state.attnP + + +class TranslateModelLazy(TranslateModel): + def __init__(self, name, transformer, loss, inp_voc, out_voc): + """ + Automatically implements O(n^3) decoding by using trans.decode + """ + self.name = name + self.transformer = transformer + self.loss = loss + self.inp_voc = inp_voc + self.out_voc = out_voc + + def encode(self, batch, is_train=False, **kwargs): + """ + :param batch: a dict of placeholders + inp: [batch_size * ninp] + inp_len; [batch_size] + """ + inp = batch['inp'] + inp_len = batch['inp_len'] + with dropout_scope(is_train), tf.name_scope(self.transformer.name): + # Encode. + enc_out, enc_attn_mask = self.transformer.encode(inp, inp_len, is_train=False) + + # Decoder dummy input/output + ninp = tf.shape(inp)[1] + batch_size = tf.shape(inp)[0] + hid_size = tf.shape(enc_out)[-1] + out_seq = tf.zeros([batch_size, 0], dtype=inp.dtype) + rdo = tf.zeros([batch_size, hid_size], dtype=enc_out.dtype) + + attnP = tf.ones([batch_size, ninp]) / tf.to_float(inp_len)[:, None] + + return self._decode_impl((enc_out, enc_attn_mask, attnP, out_seq, rdo), **kwargs) + + def decode(self, dec_state, words, **kwargs): + """ + Performs decoding step given words + + words: [batch_size] + """ + with tf.name_scope(self.transformer.name): + (enc_out, enc_attn_mask, attnP, prev_out_seq, rdo) = dec_state + out_seq = tf.concat([prev_out_seq, tf.expand_dims(words, 1)], 1) + return self._decode_impl((enc_out, enc_attn_mask, attnP, out_seq, rdo), **kwargs) + + def _decode_impl(self, dec_state, is_train=False, **kwargs): + (enc_out, enc_attn_mask, attnP, out_seq, rdo) = dec_state + + with dropout_scope(is_train): + out = tf.pad(out_seq, [(0, 0), (0, 1)]) + out_len = tf.fill(dims=(tf.shape(out)[0],), value=tf.shape(out_seq)[1]) + out_reverse = tf.zeros_like(out_len) # batch['out_reverse'] + dec_out = self.transformer.decode(out, out_len, out_reverse, enc_out, enc_attn_mask, is_train=False) + rdo = dec_out[:, -1, :] # [batch_size * hid_dim] + + attnP = enc_attn_mask[:, 0, 0, :] # [batch_size * ninp ] + attnP /= tf.reduce_sum(attnP, axis=1, keep_dims=True) + + return (enc_out, enc_attn_mask, attnP, out_seq, rdo) + + def get_rdo(self, dec_state, **kwargs): + rdo = dec_state[4] + out = dec_state[3] + return rdo, out + + def get_attnP(self, dec_state, **kwargs): + return dec_state[2] + diff --git a/lib/task/seq2seq/problems/__init__.py b/lib/task/seq2seq/problems/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/task/seq2seq/problems/concrete.py b/lib/task/seq2seq/problems/concrete.py new file mode 100644 index 0000000..4298799 --- /dev/null +++ b/lib/task/seq2seq/problems/concrete.py @@ -0,0 +1,131 @@ + +from ..summary import * +from lib.layers.basic import * +from lib.train.problem import Problem +from lib.task.seq2seq.problems.default import word_dropout + + +class ConcreteProblem(Problem): + def __init__(self, models, dump_dir=None, dump_first_n=None, sum_loss=False, use_small_batch_multiplier=False, + inp_word_dropout=0, out_word_dropout=0, word_dropout_method='unk', concrete_coef=1., + ): + assert len(models) == 1 + + self.models = models + self.model = list(self.models.values())[0] + + self.inp_voc = self.model.inp_voc + self.out_voc = self.model.out_voc + + self.dump_dir = dump_dir + self.dump_first_n = dump_first_n + self.sum_loss = sum_loss + self.use_small_batch_multiplier = use_small_batch_multiplier + + self.inp_word_dropout = inp_word_dropout + self.out_word_dropout = out_word_dropout + self.word_dropout_method = word_dropout_method + + # ======================== for concrete gates ========================================= + self.concrete_coef = concrete_coef + # ======================================================================================== + + if self.use_small_batch_multiplier: + self.max_batch_size_var = tf.get_variable("max_batch_size", shape=[], initializer=tf.ones_initializer(), + trainable=False) + + def _make_encdec_batch(self, batch, is_train): + encdec_batch = copy(batch) + + if is_train and self.inp_word_dropout > 0: + encdec_batch['inp'] = word_dropout(encdec_batch['inp'], encdec_batch['inp_len'], self.inp_word_dropout, + self.word_dropout_method, self.model.inp_voc) + + if is_train and self.out_word_dropout > 0: + encdec_batch['out'] = word_dropout(encdec_batch['out'], encdec_batch['out_len'], self.out_word_dropout, + self.word_dropout_method, self.model.out_voc) + + return encdec_batch + + def batch_counters(self, batch, is_train): + if hasattr(self.model, 'batch_counters'): + return self.model.batch_counters(batch, is_train) + + # ======================== for concrete gates ========================================= + tf.get_default_graph().clear_collection("CONCRETE") + tf.get_default_graph().clear_collection(tf.GraphKeys.REGULARIZATION_LOSSES) + + rdo = self.model.encode_decode(self._make_encdec_batch(batch, is_train), is_train) + + sparsity_rate = tf.reduce_mean(tf.get_collection("CONCRETE")) + concrete_reg = tf.reduce_mean(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + # ======================================================================================== + + with lib.layers.basic.dropout_scope(is_train): + logits = self.model.loss.rdo_to_logits(rdo, batch['out'], + batch['out_len']) # [batch_size * nout * ovoc_size] + loss_values = self.model.loss.logits2loss(logits, batch['out'], batch['out_len']) + # loss_values /= math.log(2.0) # TODO: move to loss or to model + + if self.dump_dir: + dump_map = batch + + loss_values = tf_dump( + loss_values, + dump_map, + self.dump_dir + '/batch_dump_{}.npz', + first_n=self.dump_first_n) + + counters = dict( + loss=tf.reduce_sum(loss_values), + out_len=tf.to_float(tf.reduce_sum(batch['out_len'])), + # ======================== for concrete gates ========================================= + sparsity_rate=sparsity_rate, + concrete_reg=concrete_reg, + # ======================================================================================== + ) + append_counters_common_metrics(counters, logits, batch['out'], batch['out_len'], is_train) + append_counters_xent(counters, loss_values, batch['out_len']) + append_counters_io(counters, batch['inp'], batch['out'], batch['inp_len'], batch['out_len']) + return counters + + def loss_multibatch(self, counters, is_train): + if self.sum_loss: + value = tf.reduce_sum(counters['loss']) + else: + value = tf.reduce_sum(counters['loss']) / tf.reduce_sum(counters['out_len']) + + if self.use_small_batch_multiplier and is_train: + batch_size = tf.reduce_sum(counters['out_len']) + max_batch_size = tf.maximum(self.max_batch_size_var, batch_size) + with tf.control_dependencies([tf.assign(self.max_batch_size_var, max_batch_size)]): + small_batch_multiplier = batch_size / max_batch_size + value = value * small_batch_multiplier + + # ======================== for concrete gates ========================================= + value += self.concrete_coef * tf.reduce_mean(counters['concrete_reg']) + # ======================================================================================== + + return value + + def summary_multibatch(self, counters, prefix, is_train): + res = [] + # ======================== for concrete gates ========================================= + res.append(tf.summary.scalar(prefix + "/concrete_reg", tf.reduce_mean(counters['concrete_reg']))) + res.append(tf.summary.scalar(prefix + "/sparsity_rate", tf.reduce_mean(counters['sparsity_rate']))) + # ======================================================================================== + + res += summarize_common_metrics(counters, prefix) + res += summarize_xent(counters, prefix) + res += summarize_io(counters, prefix) + return res + + def params_summary(self): + if hasattr(self.model, 'params_summary'): + return self.model.params_summary() + return [] + + def make_feed_dict(self, batch, **kwargs): + return self.model.make_feed_dict(batch, **kwargs) + + diff --git a/lib/task/seq2seq/problems/default.py b/lib/task/seq2seq/problems/default.py new file mode 100644 index 0000000..04ad923 --- /dev/null +++ b/lib/task/seq2seq/problems/default.py @@ -0,0 +1,122 @@ +from ..summary import * +from lib.layers.basic import * +from lib.train.problem import Problem + + +def word_dropout(inp, inp_len, dropout, method, voc): + inp_shape = tf.shape(inp) + + border = tf.fill([inp_shape[0], 1], False) + + mask = tf.sequence_mask(inp_len - 2, inp_shape[1] - 2) + mask = tf.concat((border, mask, border), axis=1) + mask = tf.logical_and(mask, tf.random_uniform(inp_shape) < dropout) + + if method == 'unk': + replacement = tf.fill(inp_shape, tf.cast(voc._unk, inp.dtype)) + elif method == 'random_word': + replacement = tf.random_uniform(inp_shape, minval=max(voc.bos, voc.eos, voc._unk)+1, maxval=voc.size(), dtype=inp.dtype) + else: + raise ValueError("Unknown word dropout method: %r" % method) + + return tf.where(mask, replacement, inp) + + +class DefaultProblem(Problem): + + def __init__(self, models, dump_dir=None, dump_first_n=None, sum_loss=False, use_small_batch_multiplier=False, + inp_word_dropout=0, out_word_dropout=0, word_dropout_method='unk', + ): + assert len(models) == 1 + + self.models = models + self.model = list(self.models.values())[0] + + self.inp_voc = self.model.inp_voc + self.out_voc = self.model.out_voc + + self.dump_dir = dump_dir + self.dump_first_n = dump_first_n + self.sum_loss = sum_loss + self.use_small_batch_multiplier = use_small_batch_multiplier + + self.inp_word_dropout = inp_word_dropout + self.out_word_dropout = out_word_dropout + self.word_dropout_method = word_dropout_method + + if self.use_small_batch_multiplier: + self.max_batch_size_var = tf.get_variable("max_batch_size", shape=[], initializer=tf.ones_initializer(), trainable=False) + + def _make_encdec_batch(self, batch, is_train): + encdec_batch = copy(batch) + + if is_train and self.inp_word_dropout > 0: + encdec_batch['inp'] = word_dropout(encdec_batch['inp'], encdec_batch['inp_len'], self.inp_word_dropout, self.word_dropout_method, self.model.inp_voc) + + if is_train and self.out_word_dropout > 0: + encdec_batch['out'] = word_dropout(encdec_batch['out'], encdec_batch['out_len'], self.out_word_dropout, self.word_dropout_method, self.model.out_voc) + + return encdec_batch + + def batch_counters(self, batch, is_train): + if hasattr(self.model, 'batch_counters'): + return self.model.batch_counters(batch, is_train) + + rdo = self.model.encode_decode(self._make_encdec_batch(batch, is_train), is_train) + + with dropout_scope(is_train): + logits = self.model.loss.rdo_to_logits(rdo, batch['out'], batch['out_len']) # [batch_size * nout * ovoc_size] + loss_values = self.model.loss.logits2loss(logits, batch['out'], batch['out_len']) + + counters = dict( + loss=tf.reduce_sum(loss_values), + out_len=tf.to_float(tf.reduce_sum(batch['out_len'])), + ) + append_counters_common_metrics(counters, logits, batch['out'], batch['out_len'], is_train) + append_counters_xent(counters, loss_values, batch['out_len']) + append_counters_io(counters, batch['inp'], batch['out'], batch['inp_len'], batch['out_len']) + return counters + + def get_xent(self, batch, is_train): + if hasattr(self.model, 'batch_counters'): + return self.model.batch_counters(batch, is_train) + + rdo = self.model.encode_decode(self._make_encdec_batch(batch, is_train), is_train) + + with dropout_scope(is_train): + logits = self.model.loss.rdo_to_logits(rdo, batch['out'], + batch['out_len']) # [batch_size * nout * ovoc_size] + loss_values = self.model.loss.logits2loss(logits, batch['out'], batch['out_len']) + + return loss_values + + def loss_multibatch(self, counters, is_train): + if self.sum_loss: + value = tf.reduce_sum(counters['loss']) + else: + value = tf.reduce_sum(counters['loss']) / tf.reduce_sum(counters['out_len']) + + if self.use_small_batch_multiplier and is_train: + batch_size = tf.reduce_sum(counters['out_len']) + max_batch_size = tf.maximum(self.max_batch_size_var, batch_size) + with tf.control_dependencies([tf.assign(self.max_batch_size_var, max_batch_size)]): + small_batch_multiplier = batch_size / max_batch_size + value = value * small_batch_multiplier + + return value + + def summary_multibatch(self, counters, prefix, is_train): + res = [] + res += summarize_common_metrics(counters, prefix) + res += summarize_xent(counters, prefix) + res += summarize_io(counters, prefix) + return res + + def params_summary(self): + if hasattr(self.model, 'params_summary'): + return self.model.params_summary() + + return [] + + def make_feed_dict(self, batch, **kwargs): + return self.model.make_feed_dict(batch, **kwargs) diff --git a/lib/task/seq2seq/strutils.py b/lib/task/seq2seq/strutils.py new file mode 100644 index 0000000..f9c5e81 --- /dev/null +++ b/lib/task/seq2seq/strutils.py @@ -0,0 +1,134 @@ +# coding: utf-8 + +from codecs import iterdecode +import re +import sys +import unicodedata + + +def normalize_table_lang(text, lang=None): + """According to normalization done in framework""" + if lang == 'ru': + # replace capital and small letters IO -> IE + return text.replace(u'\u0401', u'\u0415').replace(u'\u0451', u'\u0435') + elif lang == 'ro': + # replace capital and small letters S and T with cedilla -> comma below + return text.replace(u'\u015F', u'\u0219').replace(u'\u015E', + u'\u0218').replace(u'\u0163', u'\u021b').replace(u'\u0162', u'\u021a') + elif lang == 'tr': + # replace capital and small letters with circumflex + return text.replace(u'\u00C2', u'\u0041').replace(u'\u00E2', + u'\u0061').replace(u'\u00CE', u'\u0049').replace(u'\u00EE', + u'\u0069').replace(u'\u00DB', u'\u0055').replace(u'\u00FB', u'\u0075') + else: + return text + + +def unicode_category_tokenize(text, lang=None): + import regex + re_for_split = regex.compile( + u'(?u)[\p{Punctuation}\p{Separator}\p{Other}\p{Sm}\p{So}\p{Sc}]+') + return u' '.join(tok for tok in re_for_split.split(text) if tok) + + +def chinese_tok(text, lang=None): + """ставит между всеми символами пробелы""" + # '''from meteor_ext import make_tmp_file, get_random_filename + # import os + # + # tmpfile = make_tmp_file(pre=get_random_filename()) + # tmpfile.write(text.encode('utf-8')) + # args = ['/place/framework/metrics/stanford-segmenter/segment.sh', 'ctb', tmpfile.name, encoding, '0'] + # p = Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + # out_data, err_data = p.communicate() + # tmpfile.close() + # os.unlink(tmpfile.name) + # return out_data''' + # from itertools import cycle + return ' '.join(text) + +def split_by_char_tok(text, lang=None): + return ' '.join(text) + +def tokenize(text): + return text.split() + +def lower(text, lang=None): + return text.lower() + +def upper(text, lang=None): + return text.upper() + +def foldcase(text, lang=None): + """приводит текст к одному регистру (верхнему)""" + # folds case according to language + # TODO: set locale by lang so that some letters are folded correctly + # (e.g. turkish i without dot) + return upper(text) + +def join_tokens(text): + return u' '.join(tokenize(text)) + + +def separate_punctuation(text, lang=None): + """отделяет пунктуацию и символы (Po/So/Ps/Pe/Sc/-) двумя пробелами, после ' ставит один пробел""" + new_chars = [] + for character in text: + if character == u"'": + new_chars.append(character + u' ') + elif unicodedata.category(character) in ('Po', 'So', 'Ps', 'Pe', 'Sc')\ + or character == u"-": + new_chars.append(u' ' + character + u' ') + else: + new_chars.append(character) + return "".join(new_chars) + + +def alphanum(text, lang=None): + """заменяет все не-alphanumeric (\W, Unicode) на пробел""" + #TODO: do not remove currency signs + non_alphanum = re.compile(u'\W', re.UNICODE) + text = non_alphanum.sub(' ', text) + return text + + +def func_chain(*funcs): + """Returns a function that chains parameter functions""" + def result_func(text, lang=None): + result = text + for func in funcs: + result = func(result, lang) + return result + return result_func + + +def normalize_space(u_text, lang=None): + """стирает пробельные символы, заменяя их на один пробел""" + return ' '.join(u_text.split()) + +def xlines(fileobj, encoding='utf_8_sig', keepends=False): + for line in iterdecode(fileobj, encoding): + if not keepends: + line = line.rstrip('\r\n') + yield line + +# only alphanumeric characters are kept +al_num = func_chain(alphanum, normalize_space) +# only alphanumeric characters are kept, the rest is case-folded +al_num__foldcase = func_chain(foldcase, alphanum, normalize_space) +all_chars__as_is = func_chain() +# all characters are folded in case +all_chars__foldcase = func_chain(foldcase, normalize_space) +# punctuation becomes separate tokens +all_chars__punct_tokens = func_chain(separate_punctuation, normalize_space) +# punctuation becomes separate tokens, all characters are folded in case +all_chars__punct_tokens__foldcase = func_chain(foldcase, separate_punctuation, normalize_space) +# as is in eval framework +equal_to_framework = func_chain(normalize_table_lang, foldcase, unicode_category_tokenize) + +if __name__ == '__main__': + funcs = {'-s': al_num__foldcase, '-p': all_chars__punct_tokens__foldcase, + '-cs': al_num, '-csp': all_chars__punct_tokens} + a = sys.argv[1] + for line in xlines(sys.stdin): + print(u''.join(map(funcs[a], line.split('\t')))) \ No newline at end of file diff --git a/lib/task/seq2seq/summary.py b/lib/task/seq2seq/summary.py new file mode 100644 index 0000000..fa31885 --- /dev/null +++ b/lib/task/seq2seq/summary.py @@ -0,0 +1,159 @@ +import tensorflow as tf +from ...ops.basic import select_values_over_last_axis + + +def append_counters_accuracy(counters, logits, out, out_len): + with tf.variable_scope("summary_accuracy"): + predictions = tf.argmax(logits, axis=2) + acc_values = predictions2accuracy(predictions, out, out_len) + acc_top5_values = logits2accuracy_top_k(logits, out, out_len, k=5) + acc_per_seq_values = predictions2accuracy_per_sequence(predictions, out, out_len) + + node = dict( + accuracy=tf.reduce_sum(acc_values), + accuracy_top5=tf.reduce_sum(acc_top5_values), + accuracy_per_sequence=tf.reduce_sum(acc_per_seq_values), + out_len=tf.to_float(tf.reduce_sum(out_len)), + seqs=tf.to_float(tf.shape(out_len)[0]), + ) + + _append_counters(counters, "summarize_accuracy", node) + + +def append_counters_common_metrics(counters, logits, out, out_len, is_train): + append_counters_accuracy(counters, logits, out, out_len) + + +def append_counters_xent(counters, xent_values, out_len): + with tf.variable_scope("summary_xent"): + node = dict( + xent=tf.reduce_sum(xent_values), + out_len=tf.to_float(tf.reduce_sum(out_len)), + ) + _append_counters(counters, "summarize_xent", node) + + +def append_counters_io(counters, inp, out, inp_len, out_len): + with tf.variable_scope("summary_io"): + node = dict( + batch_size=tf.to_float(tf.shape(inp))[0], + inp_len=tf.to_float(tf.reduce_sum(inp_len)), + out_len=tf.to_float(tf.reduce_sum(out_len)), + ninp=tf.to_float(tf.shape(inp)[1]), + nout=tf.to_float(tf.shape(out)[1]), + ) + _append_counters(counters, "summarize_io", node) + + +def summarize_accuracy(counters, prefix): + node = counters['summarize_accuracy'] + summaries = [ + tf.summary.scalar("%s_metrics/Acc" % prefix, tf.reduce_sum(node['accuracy']) / tf.reduce_sum(node['out_len'])), + tf.summary.scalar("%s_metrics/AccTop5" % prefix, tf.reduce_sum(node['accuracy_top5']) / tf.reduce_sum(node['out_len'])), + tf.summary.scalar("%s_metrics/AccPerSeq" % prefix, tf.reduce_sum(node['accuracy_per_sequence']) / tf.reduce_sum(node['seqs'])), + ] + return summaries + + +def summarize_common_metrics(counters, prefix): + return summarize_accuracy(counters, prefix) + + +def summarize_xent(counters, prefix): + node = counters['summarize_xent'] + return [ + tf.summary.scalar("%s_metrics/Xent" % prefix, tf.reduce_sum(node['xent']) / tf.reduce_sum(node['out_len'])), + ] + + +def summarize_io(counters, prefix): + node = counters['summarize_io'] + return [ + tf.summary.scalar("%s_IO/BatchSize" % prefix, tf.reduce_sum(node['batch_size'])), + tf.summary.scalar("%s_IO/InpLenAvg" % prefix, tf.reduce_sum(node['inp_len']) / tf.reduce_sum(node['batch_size'])), + tf.summary.scalar("%s_IO/OutLenAvg" % prefix, tf.reduce_sum(node['out_len']) / tf.reduce_sum(node['batch_size'])), + tf.summary.scalar("%s_IO/InpLenSum" % prefix, tf.reduce_sum(node['inp_len'])), + tf.summary.scalar("%s_IO/OutLenSum" % prefix, tf.reduce_sum(node['out_len'])), + + tf.summary.scalar( + "%s_IO/InpNoPadRate" % prefix, + tf.reduce_sum(node['inp_len']) / tf.reduce_sum(node['ninp'] * node['batch_size'])), + tf.summary.scalar( + "%s_IO/OutNoPadRate" % prefix, + tf.reduce_sum(node['out_len']) / tf.reduce_sum(node['nout'] * node['batch_size'])), + ] + + +def _append_counters(counters, key, value): + if isinstance(counters, dict): + if key in counters: + raise Exception('Duplicate key "{}" in counters'.format(key)) + counters[key] = value + else: + raise Exception('Unexpected type: {}. Counters should be dict'.format(counters.__class__.__name__)) + + +def logits2accuracy(logits, out, out_len, dtype=tf.float32): + """ + logits : [batch_size * nout * voc_size] + out : [batch_size * nout] + out_len: [batch_size] + + results: [batch_size * nout] + """ + predictions = tf.argmax(logits, axis=2) + return predictions2accuracy(predictions, out, out_len, dtype=dtype) + + +def predictions2accuracy(predictions, out, out_len, dtype=tf.float32): + """ + predictions: [batch_size * nout] + out : [batch_size * nout] + out_len: [batch_size] + + results: [batch_size * nout] + """ + out_equals = tf.equal(tf.cast(predictions, dtype=out.dtype), out) + out_mask = tf.sequence_mask(out_len, dtype=dtype, maxlen=tf.shape(out)[1]) + acc_values = tf.cast(out_equals, dtype=dtype) * out_mask + + return acc_values + + +def logits2accuracy_top_k(logits, out, out_len, k, dtype=tf.float32): + """ + logits: [batch_size * nout * ntokens] + out : [batch_size * nout] + out_len: [batch_size] + + results: [batch_size * nout] + """ + out_logits = select_values_over_last_axis(logits, tf.to_int32(out)) + out_logits = tf.expand_dims(out_logits, axis=-1) + + greater_mask = tf.greater(logits, out_logits) + greater_ranks = tf.reduce_sum(tf.to_int32(greater_mask), axis=-1) + hit_mask = greater_ranks < k + out_mask = tf.sequence_mask(out_len, dtype=dtype, maxlen=tf.shape(out)[1]) + acc_values = tf.to_float(hit_mask) * out_mask + + return acc_values + + +def predictions2accuracy_per_sequence(predictions, out, out_len, dtype=tf.float32): + """ + predictions: [batch_size * nout] + out: [batch_size * nout] + out_len: [batch_size] + + results: [batch_size] + """ + not_correct = tf.not_equal(tf.cast(predictions, dtype=out.dtype), out) + out_mask = tf.sequence_mask(out_len, dtype=dtype, maxlen=tf.shape(out)[1]) + correct_seq = 1.0 - tf.minimum(1.0, tf.reduce_sum(tf.cast(not_correct, dtype=dtype) * out_mask, axis=1)) + return tf.cast(correct_seq, dtype=dtype) + + + + + diff --git a/lib/task/seq2seq/tickers.py b/lib/task/seq2seq/tickers.py new file mode 100644 index 0000000..0b158bf --- /dev/null +++ b/lib/task/seq2seq/tickers.py @@ -0,0 +1,107 @@ +import os +import sys +import tensorflow as tf + +from ...train.tickers import DistributedTicker, _IsItTimeYet +import lib +from .bleu import Bleu + +# - TranslateTicker - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + +def unbpe(sent): + return sent.replace(' `', '') + + +class TranslateTicker(DistributedTicker): + """ + - Translate devset once in a while. + - Print BLEU to stderr after each translation. + """ + def __init__(self, model_name, devset, name='Dev', every_steps=None, every_minutes=None, initial=False, folder=None, + suffix=None, device=None, parallel=True): + self.model_name = model_name + self.devset = devset + self.every_steps = every_steps + self.every_minutes = every_minutes + self.folder = folder + self.name = name + self.initial = initial + self.device = device + self.parallel = parallel + self.suffix = suffix if suffix is not None else model_name + if self.suffix: # add underscore if we add suffix + self.suffix = '_' + self.suffix + + def on_started(self, context): + self.devset_batches = list(self.devset) + self.context = context + self.model = context.get_model(self.model_name) + + self.bleu = tf.placeholder(tf.float32) + self.translations = tf.placeholder(tf.string, shape=[None]) + + self.summary = tf.summary.merge([ + tf.summary.scalar(("%s/BLEU" % self.name) + self.suffix, self.bleu), + tf.summary.text(("%s/Translations" % self.name) + self.suffix, self.translations)]) + + self.is_it_time_yet = _IsItTimeYet( + context, self.every_steps, self.every_minutes) + + # Score devset after initialization if option passed (and we are not loading some non-init checkpoint) + if self.initial and context.get_global_step() == 0: + self._score() + + def after_train_batch(self, ingraph_result): + if self.is_it_time_yet(): + self._score() + + def _score(self): + if lib.ops.mpi.is_master(): + print('Translating', end='', file=sys.stderr, flush=True) + + translations = None + + if self.parallel or lib.ops.mpi.is_master(): + translations = [] + with tf.device(self.device) if self.device is not None else lib.util.nop_ctx(): + for batch in self.devset_batches: + trans = self.model.translate_lines([line[0] for line in batch]) + for index in range(len(batch)): + src = unbpe(batch[index][0]) + ethalon = unbpe(batch[index][1]) + translations.append(src + '\t' + ethalon + '\t' + unbpe(trans[index])) + + if self.parallel: + translations = lib.ops.mpi.gather_obj(translations) + if translations is not None: + translations = [x for t in translations for x in t] + + if translations is not None: + # compute BLEU only on the master + + if self.folder is not None: + global_step = self.context.get_global_step() + self._dump_translations( + translations, + fname='translations{}_{}.txt'.format(self.suffix, global_step) + ) + + bleu = Bleu() + for translation in translations: + src, ethalon, trans = translation.split('\t') + bleu.process_next(trans, [ethalon]) + bleu_value = 100 * (bleu.total()[0]) + + print('BLEU %f' % bleu_value, file=sys.stderr, flush=True) + + summary = tf.get_default_session().run(self.summary, feed_dict={self.bleu: bleu_value, + self.translations: translations}) + + self.context.get_summary_writer().add_summary(summary, self.context.get_global_step()) + + def _dump_translations(self, translations, fname): + if not os.path.isdir(self.folder): + os.mkdir(self.folder) + fout = open(os.path.join(self.folder, fname), 'w') + for translation in translations: + print(translation, file=fout) diff --git a/lib/task/seq2seq/voc.py b/lib/task/seq2seq/voc.py new file mode 100644 index 0000000..1ec3aed --- /dev/null +++ b/lib/task/seq2seq/voc.py @@ -0,0 +1,105 @@ +import collections +import sys + + +class BaseVoc: + @property + def bos(self): + raise NotImplementedError() + + @property + def eos(self): + raise NotImplementedError() + + def ids(self, words): + raise NotImplementedError() + + def words(self, ids): + raise NotImplementedError() + + def size(self): + raise NotImplementedError() + + +class Voc: + @property + def bos(self): + return 0 + + @property + def eos(self): + return 1 + + @property + def _unk(self): + return 2 + + def ids(self, words): + if isinstance(words, (list, tuple)): + return [self.ids(word) for word in words] + return self._voc.get(words, self._unk) + + def words(self, ids): + if isinstance(ids, (list, tuple)): + return [self.words(id) for id in ids] + return self._ivoc[ids] + + def size(self): + return self._size + + @staticmethod + def compile(corpus_filename, max_words, index=0): + # Accumulate frequencies. + freqs = collections.defaultdict(int) + with open(corpus_filename) as corpus: + for line in corpus: + line = line.strip('\n') + if not line: + continue + for word in line.split(' '): + freqs[word.split('|||')[index]] += 1 + + # Sort by frequency. + freq_and_word = lambda item: item[::-1] + most_frequent = sorted(freqs.items(), key=freq_and_word, reverse=True) + + # Create voc. + obj = Voc() + voc = { '_BOS_': obj.bos, '_EOS_': obj.eos } + id = 3 + total_covered_freq = 0 + for word, freq in most_frequent[:max_words]: + voc[word] = id + id += 1 + total_covered_freq += freq + + # Report coverage. + total_freq = sum(freqs.values()) + msg = 'Voc %r: %i words, %.3f%% coverage' % ( + corpus_filename, + id, + total_covered_freq * 100 / total_freq, + ) + print(msg, file=sys.stderr, flush=True) + + # Return. + obj.__setstate__((voc,)) + return obj + + def __getstate__(self): + return self._voc, + + def __setstate__(self, state): + # Load direct vocabulary. + self._voc, = state + + # Fill inverse vocabulary. + self._ivoc = {} + for k, v in self._voc.items(): + self._ivoc[v] = k + self._ivoc[self.bos] = '_BOS_' + self._ivoc[self.eos] = '_EOS_' + self._ivoc[self._unk] = '_UNK_' + + # Compute size + self._size = max(self._voc.values()) + 1 diff --git a/lib/tools/__init__.py b/lib/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/lib/tools/apply_bpe.py b/lib/tools/apply_bpe.py new file mode 100755 index 0000000..497daa8 --- /dev/null +++ b/lib/tools/apply_bpe.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python3 +import argparse +import sys +import numpy as np +from collections import defaultdict + + +class BPEizer: + def __init__(self, path, separator=' `'): + """ + A tool that converts tokenized strings into BPE units given bpe rules + Works by iteratively merging subword pairs with lowest priority, starting from individual characters + :param path: path to a file with bpe merging rules. Either from subword_nmt or yandex internal bpe tool + subword_nmt: file should start with #version: {some version} header and contain "{left_part right_part}" rules + yandex internal: file shoud contain lines with "{left_part}\t{right_part}\t{priority}" + :param separator: a string that will separates segments of a word; + Note: subword_nmt's default separator is "@@ " (mind the space) + + Usage: + >>> bpeizer = BPEizer(path='./data/ru.bpe.voc') + >>> bpeizer.bpeize_token('транспонировали') + 'тран `сп `он `ир `овали' + >>> bpeizer(['тридцать три треугольных матрицы транспонировали - транспонировали', ', да не вытранспонировали !']) + ['тридцать три треуголь `ных мат `рицы тран `сп `он `ир `овали - тран `сп `он `ир `овали', + ', да не выт `ран `сп `он `ир `овали !'] + """ + self.bpe_rules = defaultdict(lambda: float('inf')) + self.separator = separator + + if self.is_yandex_bpe(path): + self.mode = 'yandex' + self.begin, self.end = '^$' + for left, right, index in map(str.split, open(path)): + self.bpe_rules[left, right] = int(index) + + elif self.is_rsenrich_bpe(path): + self.mode = 'rsenrich' + self.begin, self.end = '', '' + f_rules = open(path) + f_rules.readline() + for i, (left, right) in enumerate(map(str.split, f_rules)): + self.bpe_rules[left, right] = i + else: + raise NotImplementedError("BPE rules header is compatible with neither subword_nmt nor yandex bpe") + + self.escape_chars = {self.begin: chr(0x110000 - 2), self.end: chr(0x110000 - 1)} + self.unescape_chars = {v: k for k, v in self.escape_chars.items()} + + def bpeize_token(self, chars: str): + """ split a single token (str) into bpe units """ + tokens = [self.begin] + [self.escape_chars.get(c, c) for c in chars] + [self.end] + if self.mode == 'rsenrich': + last = tokens.pop() + tokens[-1] += last # automatically merge with previous token + + while len(tokens) > 1: + # find first bpe rule to match + bpe_rule_priorities = [self.bpe_rules[prev, cur] for prev, cur in zip(tokens[:-1], tokens[1:])] + + chosen_ix = np.argmin(bpe_rule_priorities) + if bpe_rule_priorities[chosen_ix] == float('inf'): + break # this is the end of the road, afro samurai! + + # apply it + tokens = tokens[:chosen_ix] + [tokens[chosen_ix] + tokens[chosen_ix + 1]] + tokens[chosen_ix + 2:] + + assert tokens[0].startswith(self.begin) and tokens[-1].endswith(self.end) + tokens[0] = tokens[0][len(self.begin):] + tokens[-1] = tokens[-1][:-len(self.end)] + tokens = [''.join([self.unescape_chars.get(c, c) for c in bpe]) + for bpe in tokens if len(bpe) != 0] + return self.separator.join(filter(len, tokens)) + + def bpeize_line(self, line: str): + """ converts a tokenized line into a bpe-ized line """ + return ' '.join(map(self.bpeize_token, line.split())) + + def __call__(self, text): + if isinstance(text, (list, tuple)): + return list(map(self, text)) + elif isinstance(text, str): + return self.bpeize_line(text) + else: + raise ValueError("Expected string or list/tuple of strings but found {}".format(type(text))) + + @staticmethod + def is_rsenrich_bpe(bpe_rules_path): + """ Check if bpe rules were learned by https://github.com/rsennrich/subword-nmt """ + header = open(bpe_rules_path).readline() + return header.startswith('#version:') + + @staticmethod + def is_yandex_bpe(bpe_rules_path): + """ Check if bpe rules were learned by internal Yandex tool """ + try: + header = open(bpe_rules_path).readline() + l, r, i = header.split('\t') # check if this line contains 3 tabs + return True + except: + return False + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--bpe_rules', required=True) + args = parser.parse_args() + + bpeizer = BPEizer(args.bpe_rules) + for l in sys.stdin: + print(bpeizer.bpeize_line(l)) diff --git a/lib/tools/average_npz.py b/lib/tools/average_npz.py new file mode 100755 index 0000000..e325a4c --- /dev/null +++ b/lib/tools/average_npz.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 + +""" +Averaging NPZ files +""" + +import argparse +import os +import numpy as np + + +def get_last_checkpoints(folder, num_checkpoints): + labels = [] + for fname in os.listdir(folder): + if not fname.startswith('model-') or not fname.endswith('.npz'): + continue + label = fname[len('model-'):-len('.npz')] + if not label.isdigit(): + continue + labels.append(int(label)) + labels = sorted(labels, reverse=True)[0:num_checkpoints] + + files = [] + for label in labels: + filename = os.path.join(folder, 'model-%d.npz' % label) + files += [filename] + return files + + +def average_npzs(files): + out = {} + for filename in files: + model = np.load(filename) + for var in model: + if var in out: + out[var] += model[var] + else: + out[var] = model[var] + for var in out: + out[var] /= len(files) + + return out + + +def _parse_args(): + p = argparse.ArgumentParser() + p.add_argument('--oname','-O', required=True, help='output file name') + p.add_argument('--ncheckpoints', '-n', type=int, help='number of checkpoints to use') + p.add_argument('--folder', type=str, help='path to checkpoints') + p.add_argument('files', nargs='*') + + args = p.parse_args() + if (args.folder is None) != (args.ncheckpoints is None): + raise Exception("--folder and --ncheckpoints should be specified togather") + + if (args.folder is not None) and len(args.files): + raise Exception("Use one of two modes:\n