In [1]:
from theano import tensor
from toolz import merge
import os

import numpy

from fuel.datasets import IterableDataset
from fuel.transformers import Merge
from fuel.streams import DataStream

from blocks.bricks import (Tanh, Maxout, Linear, FeedforwardSequence,
                           Bias, Initializable, MLP)
from blocks.bricks.attention import SequenceContentAttention
from blocks.bricks.base import application
from blocks.bricks.lookup import LookupTable
from blocks.bricks.parallel import Fork
from blocks.bricks.recurrent import GatedRecurrent, Bidirectional
from blocks.bricks.sequence_generators import (
    LookupFeedback, Readout, SoftmaxEmitter,
    SequenceGenerator)
from blocks.roles import add_role, WEIGHT
from blocks.utils import shared_floatx_nans

from machine_translation.models import MinRiskSequenceGenerator

from picklable_itertools.extras import equizip


Using gpu device 0: GeForce GTX TITAN X (CNMeM is disabled, cuDNN 4007)


In [2]:
import logging

logging.basicConfig()
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

In [3]:
# create an NMT decoder which has access to image features via the target-side initial state

# IDEA: subclass attention recurrent, and add one more context
# -- could directly push the context onto attention_recurrent.context_names?

# it's ok to add directly to the contexts of the recurrent transition, since that's what will be using them anyway,
# TEST 1: what happens when we directly add the image features to the kwargs that we pass to sequence_generator.cost?
# note this is similar to IMT, since we're trying to modify the decoder initial state

# The kwargs do get passed through to the recurrent transition, so this should work

# AttentionRecurrent gets created in the SequenceGenerator init(), which then calls BaseSequenceGenerator
# Subclass SequenceGenerator

In [4]:
# add one more source for the images

# get the MT datastream in the standard way, then add the new source using Merge
# -- the problem with this is all the operations we do on the stream beforehand

# as long as the arrays fit in memory, we should be able to use iterable dataset

TRAIN_IMAGE_FEATURES = '/media/1tb_drive/multilingual-multimodal/flickr30k/img_features/f30k-translational-newsplits/train.npz'
DEV_IMAGE_FEATURES = '/media/1tb_drive/multilingual-multimodal/flickr30k/img_features/f30k-translational-newsplits/dev.npz'
TEST_IMAGE_FEATURES = '/media/1tb_drive/multilingual-multimodal/flickr30k/img_features/f30k-translational-newsplits/test.npz'

In [5]:
# the prototype config for NMT experiment with images

In [14]:
BASEDIR = '/media/1tb_drive/multilingual-multimodal/flickr30k/train/processed/BERTHA-TEST_Adam_wmt-multimodal_internal_data_dropout'+\
          '0.3_ff_noiseFalse_search_model_en2es_vocab20000_emb300_rec800_batch15/'
#best_bleu_model_1455464992_BLEU31.61.npz

exp_config = {
    'src_vocab_size': 20000,
    'trg_vocab_size': 20000,
    'enc_embed': 300,
    'dec_embed': 300,
    'enc_nhids': 800,
    'dec_nhids': 800,
    'src_vocab': os.path.join(BASEDIR, 'vocab.en-de.en.pkl'),
    'trg_vocab': os.path.join(BASEDIR, 'vocab.en-de.de.pkl'),
    'src_data': os.path.join(BASEDIR, 'training_data/train.en.tok.shuf'),
    'trg_data': os.path.join(BASEDIR, 'training_data/train.de.tok.shuf'),
    'unk_id':1,
    # Bleu script that will be used (moses multi-perl in this case)
    'bleu_script': '/home/chris/projects/neural_mt/test_data/sample_experiment/tiny_demo_dataset/multi-bleu.perl',

    # Optimization related ----------------------------------------------------
    # Batch size
    'batch_size': 8,
    # This many batches will be read ahead and sorted
    'sort_k_batches': 2,
    # Optimization step rule
    'step_rule': 'AdaDelta',
    # Gradient clipping threshold
    'step_clipping': 1.,
    # Std of weight initialization
    'weight_scale': 0.01,
    'seq_len': 40,
    # Beam-size
    'beam_size': 10,

    # Maximum number of updates
    'finish_after': 1000000,

    # Reload model from files if exist
    'reload': False,

    # Save model after this many updates
    'save_freq': 500,

    # Show samples from model after this many updates
    'sampling_freq': 1000,

    # Show this many samples at each sampling
    'hook_samples': 5,

    # Validate bleu after this many updates
    'bleu_val_freq': 10,
    # Normalize cost according to sequence length after beam-search
    'normalized_bleu': True,
    
    'saveto': '/media/1tb_drive/test_min_risk_model_save',
    'model_save_directory': 'test_image_context_features_model_save',
    
    # Validation set source file
    'val_set': '/media/1tb_drive/multilingual-multimodal/flickr30k/train/processed/dev.en.tok',

    # Validation set gold file
    'val_set_grndtruth': '/media/1tb_drive/multilingual-multimodal/flickr30k/train/processed/dev.de.tok',

    # Print validation output to file
    'output_val_set': True,

    # Validation output file
    'val_set_out': '/media/1tb_drive/test_min_risk_model_save/validation_out.txt',
    'val_burn_in': 0,

    #     'saved_parameters': '/media/1tb_drive/multilingual-multimodal/flickr30k/train/processed/BERTHA-TEST_wmt-multimodal_internal_data_dropout0.3_ff_noiseFalse_search_model_en2es_vocab20000_emb300_rec800_batch15/best_bleu_model_1455410311_BLEU30.38.npz',

    # NEW PARAMS FOR ADDING CONTEXT FEATURES
    'context_features': '/media/1tb_drive/multilingual-multimodal/flickr30k/img_features/f30k-translational-newsplits/train.npz', 
    'val_context_features': '/media/1tb_drive/multilingual-multimodal/flickr30k/img_features/f30k-translational-newsplits/dev.npz'
    
    # NEW PARAM FOR MIN RISK
#     'n_samples': 100

}


In [21]:
from machine_translation.stream import _ensure_special_tokens, _length, PaddingWithEOS, _oov_to_unk, _too_long

def get_tr_stream_with_context_features(src_vocab, trg_vocab, src_data, trg_data, context_features,
                  src_vocab_size=30000, trg_vocab_size=30000, unk_id=1,
                  seq_len=50, batch_size=80, sort_k_batches=12, **kwargs):
    """Prepares the training data stream."""

    def _get_np_array(filename):
        return numpy.load(filename)['arr_0']
    
    # Load dictionaries and ensure special tokens exist
    src_vocab = _ensure_special_tokens(
        src_vocab if isinstance(src_vocab, dict)
        else cPickle.load(open(src_vocab)),
        bos_idx=0, eos_idx=src_vocab_size - 1, unk_idx=unk_id)
    trg_vocab = _ensure_special_tokens(
        trg_vocab if isinstance(trg_vocab, dict) else
        cPickle.load(open(trg_vocab)),
        bos_idx=0, eos_idx=trg_vocab_size - 1, unk_idx=unk_id)

    # Get text files from both source and target
    src_dataset = TextFile([src_data], src_vocab, None)
    trg_dataset = TextFile([trg_data], trg_vocab, None)

    # Merge them to get a source, target pair
    stream = Merge([src_dataset.get_example_stream(),
                    trg_dataset.get_example_stream()],
                   ('source', 'target'))

    # Filter sequences that are too long
    stream = Filter(stream,
                    predicate=_too_long(seq_len=seq_len))
    
  
    # Replace out of vocabulary tokens with unk token
    # TODO: doesn't the TextFile stream do this anyway?
    stream = Mapping(stream,
                     _oov_to_unk(src_vocab_size=src_vocab_size,
                                 trg_vocab_size=trg_vocab_size,
                                 unk_id=unk_id))

    # now add the source with the image features
    # create the image datastream (iterate over a file line-by-line)
    train_features = _get_np_array(context_features)
    train_feature_dataset = IterableDataset(train_features)
    train_image_stream = DataStream(train_feature_dataset)

    stream = Merge([stream, train_image_stream], ('source', 'target', 'initial_contexts'))
    
    # Build a batched version of stream to read k batches ahead
    stream = Batch(stream,
                   iteration_scheme=ConstantScheme(
                       batch_size*sort_k_batches))

    # Sort all samples in the read-ahead batch
    stream = Mapping(stream, SortMapping(_length))

    # Convert it into a stream again
    stream = Unpack(stream)

    # Construct batches from the stream with specified batch size
    stream = Batch(
        stream, iteration_scheme=ConstantScheme(batch_size))

    # Pad sequences that are short
    masked_stream = PaddingWithEOS(
        stream, [src_vocab_size - 1, trg_vocab_size - 1], mask_sources=('source', 'target'))

    return masked_stream


# Remember that the BleuValidator does hackish stuff to get target set information from the main_loop data_stream
# using all kwargs here makes it more clear that this function is always called with get_dev_stream(**config_dict)
def get_dev_stream_with_context_features(val_context_features=None, val_set=None, src_vocab=None,
                                         src_vocab_size=30000, unk_id=1, **kwargs):
    """Setup development set stream if necessary."""
    dev_stream = None
    if val_set is not None and src_vocab is not None:
        src_vocab = _ensure_special_tokens(
            src_vocab if isinstance(src_vocab, dict) else
            cPickle.load(open(src_vocab)),
            bos_idx=0, eos_idx=src_vocab_size - 1, unk_idx=unk_id)
        
        # TODO: how is the dev dataset used without the context features?
        dev_dataset = TextFile([val_set], src_vocab, None)
        
        dev_stream = DataStream(dev_dataset)
    return dev_stream


In [22]:
# WORKING: how do BLEU validator and beam search need to be modified to account for the new context? 
# TODO: How does sampling change?

In [23]:
# setting up the experiment
    
# args = parser.parse_args()
# arg_dict = vars(args)
# configuration_file = arg_dict['exp_config']
# mode = arg_dict['mode']

mode = 'train'
logger.info('Running Neural Machine Translation in mode: {}'.format(mode))
# config_obj = configurations.get_config(configuration_file)
config_obj = exp_config

# add the config file name into config_obj
# config_obj['config_file'] = configuration_file
# logger.info("Model Configuration:\n{}".format(pprint.pformat(config_obj)))

train_stream = get_tr_stream_with_context_features(**config_obj)
dev_stream = get_dev_stream_with_context_features(**config_obj)


# bokeh = True

# if mode == 'train':
    # Get data streams and call main
#     main(config_obj, get_tr_stream(**config_obj),
#          get_dev_stream(**config_obj), bokeh)

INFO:__main__:Running Neural Machine Translation in mode: train


In [25]:
b = next(train_stream.get_epoch_iterator())

In [26]:
[i.shape for i in b]

[(8, 16), (8, 16), (8, 13), (8, 13), (8, 4096)]

In [24]:
train_stream.sources

('source', 'source_mask', 'target', 'target_mask', 'initial_contexts')

In [None]:
class GRUInitialStateWithInitialStateContext(GatedRecurrent):
    """Gated Recurrent with special initial state.

    Initial state of Gated Recurrent is set by an MLP that conditions on the
    last hidden state of the bidirectional encoder, applies an affine
    transformation followed by a tanh non-linearity to set initial state.

    """
    def __init__(self, attended_dim, context_dim, **kwargs):
        super(GRUInitialState, self).__init__(**kwargs)
        self.attended_dim = attended_dim
        self.attended_dim = attended_dim
        self.context_dim = context_dim

        self.initial_transformer = MLP(activations=[Tanh(),Tanh(),Tanh()],
                                       dims=[attended_dim + context_dim, 1000, 500, self.dim],
                                       name='state_initializer')
        self.children.append(self.initial_transformer)
  
    # WORKING: add the images as another context to the recurrent transition
    # THINKING: how to best combine the image info with the source info?
    @application
    def initial_states(self, batch_size, *args, **kwargs):
        attended = kwargs['attended']
        context = kwargs['initial_state_context']
        attended_reverse_final_state = attended[0, :, -self.attended_dim:]
        concat_attended_and_context = tensor.concatenate([attended_reverse_final_state, context], axis=1)
        initial_state = self.initial_transformer.apply(
            attended[0, :, -self.attended_dim:])
        return initial_state

    def _allocate(self):
        self.parameters.append(shared_floatx_nans((self.dim, self.dim),
                               name='state_to_state'))
        self.parameters.append(shared_floatx_nans((self.dim, 2 * self.dim),
                               name='state_to_gates'))
        for i in range(2):
            if self.parameters[i]:
                add_role(self.parameters[i], WEIGHT)

In [None]:
class InitialContextDecoder(Initializable):
    """
    Decoder which incorporates context features into the target-side initial state

    Parameters:
    -----------
    vocab_size: int
    embedding_dim: int
    representation_dim: int
    theano_seed: int
    loss_function: str : {'cross_entropy'(default) | 'min_risk'}

    """

    def __init__(*args, **kwargs):
        super(InitialContextDecoder, self).__init__(*args, **kwargs)   

    @application(inputs=['representation', 'source_sentence_mask',
                         'target_sentence_mask', 'target_sentence', 'initial_state_context'],
                 outputs=['cost'])
    def cost(self, representation, source_sentence_mask,
             target_sentence, target_sentence_mask, initial_state_context):

        source_sentence_mask = source_sentence_mask.T
        target_sentence = target_sentence.T
        target_sentence_mask = target_sentence_mask.T

        # Get the cost matrix
        cost = self.sequence_generator.cost_matrix(**{
            'mask': target_sentence_mask,
            'outputs': target_sentence,
            'attended': representation,
            'attended_mask': source_sentence_mask,
            'initial_state_context': initial_state_context}
        )

        return (cost * target_sentence_mask).sum() / \
            target_sentence_mask.shape[1]

    # Note: this requires the decoder to be using sequence_generator which implements expected cost
#     @application(inputs=['representation', 'source_sentence_mask',
#                          'target_samples_mask', 'target_samples', 'scores'],
#                  outputs=['cost'])
#     def expected_cost(self, representation, source_sentence_mask, target_samples, target_samples_mask, scores,
#                       **kwargs):
#         return self.sequence_generator.expected_cost(representation,
#                                                      source_sentence_mask,
#                                                      target_samples, target_samples_mask, scores, **kwargs)


#     @application
#     def generate(self, source_sentence, representation, **kwargs):
#         return self.sequence_generator.generate(
#             n_steps=2 * source_sentence.shape[1],
#             batch_size=source_sentence.shape[0],
#             attended=representation,
#             attended_mask=tensor.ones(source_sentence.shape).T,
#             **kwargs)


In [None]:
import logging

import os
import shutil
from collections import Counter
from theano import tensor
from toolz import merge
import numpy
import pickle
from subprocess import Popen, PIPE
import codecs

from blocks.algorithms import (GradientDescent, StepClipping,
                               CompositeRule, Adam, AdaDelta)
from blocks.extensions import FinishAfter, Printing, Timing
from blocks.extensions.monitoring import TrainingDataMonitoring
from blocks.filter import VariableFilter
from blocks.graph import ComputationGraph, apply_noise, apply_dropout
from blocks.initialization import IsotropicGaussian, Orthogonal, Constant
from blocks.main_loop import MainLoop
from blocks.model import Model
from blocks.select import Selector
from blocks.search import BeamSearch
from blocks_extras.extensions.plot import Plot

from machine_translation.checkpoint import CheckpointNMT, LoadNMT
from machine_translation.model import BidirectionalEncoder, Decoder
from machine_translation.sampling import BleuValidator, Sampler, SamplingBase
from machine_translation.stream import (get_tr_stream, get_dev_stream,
                                        _ensure_special_tokens)

try:
    from blocks_extras.extensions.plot import Plot
    BOKEH_AVAILABLE = True
except ImportError:
    BOKEH_AVAILABLE = False

logger = logging.getLogger(__name__)


def main(config, tr_stream, dev_stream, use_bokeh=False):

    # Create Theano variables
    logger.info('Creating theano variables')
    source_sentence = tensor.lmatrix('source')
    source_sentence_mask = tensor.matrix('source_mask')
    target_sentence = tensor.lmatrix('target')
    target_sentence_mask = tensor.matrix('target_mask')
    initial_context = tensor.matrix('initial_context')
    
    sampling_input = tensor.lmatrix('input')

    # Construct model
    logger.info('Building RNN encoder-decoder')
    encoder = BidirectionalEncoder(
        config['src_vocab_size'], config['enc_embed'], config['enc_nhids'])
    decoder = Decoder(
        config['trg_vocab_size'], config['dec_embed'], config['dec_nhids'],
        config['enc_nhids'] * 2)
    cost = decoder.cost(
        encoder.apply(source_sentence, source_sentence_mask),
        source_sentence_mask, target_sentence, target_sentence_mask)

    logger.info('Creating computational graph')
    cg = ComputationGraph(cost)

    # Initialize model
    logger.info('Initializing model')
    encoder.weights_init = decoder.weights_init = IsotropicGaussian(
        config['weight_scale'])
    encoder.biases_init = decoder.biases_init = Constant(0)
    encoder.push_initialization_config()
    decoder.push_initialization_config()
    encoder.bidir.prototype.weights_init = Orthogonal()
    decoder.transition.weights_init = Orthogonal()
    encoder.initialize()
    decoder.initialize()

    # apply dropout for regularization
    if config['dropout'] < 1.0:
        # dropout is applied to the output of maxout in ghog
        # this is the probability of dropping out, so you probably want to make it <=0.5
        logger.info('Applying dropout')
        dropout_inputs = [x for x in cg.intermediary_variables
                          if x.name == 'maxout_apply_output']
        cg = apply_dropout(cg, dropout_inputs, config['dropout'])

    # Apply weight noise for regularization
    if config['weight_noise_ff'] > 0.0:
        logger.info('Applying weight noise to ff layers')
        enc_params = Selector(encoder.lookup).get_parameters().values()
        enc_params += Selector(encoder.fwd_fork).get_parameters().values()
        enc_params += Selector(encoder.back_fork).get_parameters().values()
        dec_params = Selector(
            decoder.sequence_generator.readout).get_parameters().values()
        dec_params += Selector(
            decoder.sequence_generator.fork).get_parameters().values()
        dec_params += Selector(decoder.transition.initial_transformer).get_parameters().values()
        cg = apply_noise(cg, enc_params+dec_params, config['weight_noise_ff'])

    # TODO: weight noise for recurrent params isn't currently implemented -- see config['weight_noise_rec']
    # Print shapes
    shapes = [param.get_value().shape for param in cg.parameters]
    logger.info("Parameter shapes: ")
    for shape, count in Counter(shapes).most_common():
        logger.info('    {:15}: {}'.format(shape, count))
    logger.info("Total number of parameters: {}".format(len(shapes)))

    # Print parameter names
    enc_dec_param_dict = merge(Selector(encoder).get_parameters(),
                               Selector(decoder).get_parameters())
    logger.info("Parameter names: ")
    for name, value in enc_dec_param_dict.items():
        logger.info('    {:15}: {}'.format(value.get_value().shape, name))
    logger.info("Total number of parameters: {}"
                .format(len(enc_dec_param_dict)))

    # Set up training model
    logger.info("Building model")
    training_model = Model(cost)

    # create the training directory, and copy this config there if directory doesn't exist
    if not os.path.isdir(config['saveto']):
        os.makedirs(config['saveto'])
        shutil.copy(config['config_file'], config['saveto'])

    # Set extensions
    logger.info("Initializing extensions")
    extensions = [
        FinishAfter(after_n_batches=config['finish_after']),
        TrainingDataMonitoring([cost], after_batch=True),
        Printing(after_batch=True),
        CheckpointNMT(config['saveto'],
                      every_n_batches=config['save_freq'])
    ]


    # Set up beam search and sampling computation graphs if necessary

    if config['hook_samples'] >= 1 or config['bleu_script'] is not None:
        logger.info("Building sampling model")
        sampling_representation = encoder.apply(
            sampling_input, tensor.ones(sampling_input.shape))
        # TODO: the generated output actually contains several more values, ipdb to see what they are
        generated = decoder.generate(sampling_input, sampling_representation)
        search_model = Model(generated)
        _, samples = VariableFilter(
            bricks=[decoder.sequence_generator], name="outputs")(
                ComputationGraph(generated[1]))  # generated[1] is next_outputs

    # Add sampling
    if config['hook_samples'] >= 1:
        logger.info("Building sampler")
        extensions.append(
            Sampler(model=search_model, data_stream=tr_stream,
                    hook_samples=config['hook_samples'],
                    every_n_batches=config['sampling_freq'],
                    src_vocab_size=config['src_vocab_size']))

    # Add early stopping based on bleu
    if config['bleu_script'] is not None:
        logger.info("Building bleu validator")
        extensions.append(
            BleuValidator(sampling_input, samples=samples, config=config,
                          model=search_model, data_stream=dev_stream,
                          normalize=config['normalized_bleu'],
                          every_n_batches=config['bleu_val_freq']))

    # Reload model if necessary
    if config['reload']:
        extensions.append(LoadNMT(config['saveto']))

    # Plot cost in bokeh if necessary
    if use_bokeh and BOKEH_AVAILABLE:
        extensions.append(
            Plot(config['model_save_directory'], channels=[['decoder_cost_cost'], ['validation_set_bleu_score']],
                 every_n_batches=10))

    # Set up training algorithm
    logger.info("Initializing training algorithm")
    # if there is dropout or random noise, we need to use the output of the modified graph
    if config['dropout'] < 1.0 or config['weight_noise_ff'] > 0.0:
        algorithm = GradientDescent(
            cost=cg.outputs[0], parameters=cg.parameters,
            step_rule=CompositeRule([StepClipping(config['step_clipping']),
                                     eval(config['step_rule'])()])
        )
    else:
        algorithm = GradientDescent(
            cost=cost, parameters=cg.parameters,
            step_rule=CompositeRule([StepClipping(config['step_clipping']),
                                     eval(config['step_rule'])()])
        )

    # enrich the logged information
    extensions.append(
        Timing(every_n_batches=100)
    )

    # Initialize main loop
    logger.info("Initializing main loop")
    main_loop = MainLoop(
        model=training_model,
        algorithm=algorithm,
        data_stream=tr_stream,
        extensions=extensions
    )

    # Train!
    main_loop.run()


In [None]:
def create_model(encoder, decoder):

    # Create Theano variables
    logger.info('Creating theano variables')
    source_sentence = tensor.lmatrix('source')
    source_sentence_mask = tensor.matrix('source_mask')
    
    

    initial_context = tensor.matrix('initial_context')
    
    representation, source_sentence_mask,
             target_sentence, target_sentence_mask, initial_state_context

    # the name is important to make sure pre-trained params get loaded correctly
    # decoder.name = 'decoder'

    # This is the part that is different for the MinimumRiskSequenceGenerator
    cost = initial_context_decoder.cost(
        encoder.apply(source_sentence, source_sentence_mask),
        source_sentence_mask, samples, samples_mask, scores)


    return cost

In [12]:
x = next(train_image_stream.get_epoch_iterator())

In [15]:
type(x[0])
x[0].shape

(4096,)

In [None]:


theano_sample_func = sample_model.get_theano_function()

# close over the sampling func and the trg_vocab to standardize the interface
# TODO: actually this should be a callable class with params (sampling_func, trg_vocab)
# TODO: we may be able to make this function faster by passing multiple sources for sampling at the same damn time
# TODO: or by avoiding the for loop somehow
def sampling_func(source_seq, num_samples=1):

    def _get_true_length(seqs, vocab):
        try:
            lens = []
            for r in seqs.tolist():
                lens.append(r.index(vocab['</S>']) + 1)
            return lens
        except ValueError:
            return [seqs.shape[1] for _ in range(seqs.shape[0])]

    # samples = []
    # for _ in range(num_samples):
        # outputs of self.sampling_fn = outputs of sequence_generator.generate: next_states + [next_outputs] +
        #                 list(next_glimpses.values()) + [next_costs])
        # _1, outputs, _2, _3, costs = theano_sample_func(source_seq[None, :])
        # if we are generating a single sample, the length of the output will be len(source_seq)*2
        # see decoder.generate
        # the output is a [seq_len, 1] array
        # outputs = outputs.reshape(outputs.shape[0])
        # outputs = outputs[:_get_true_length(outputs, trg_vocab)]
        # samples.append(outputs)

    inputs = numpy.tile(source_seq[None, :], (num_samples, 1))
    # the output is [seq_len, batch]
    _1, outputs, _2, _3, costs = theano_sample_func(inputs)
    outputs = outputs.T

    # TODO: this step could be avoided by computing the samples mask in a different way
    lens = _get_true_length(outputs, trg_vocab)
    samples = [s[:l] for s,l in zip(outputs.tolist(), lens)]

    return samples


src_stream = get_textfile_stream(source_file=exp_config['src_data'], src_vocab=exp_config['src_vocab'],
                                         src_vocab_size=exp_config['src_vocab_size'])

# test_source_stream.sources = ('sources',)
trg_stream = get_textfile_stream(source_file=exp_config['trg_data'], src_vocab=exp_config['trg_vocab'],
                                         src_vocab_size=exp_config['trg_vocab_size'])

# Merge them to get a source, target pair
training_stream = Merge([src_stream,
                         trg_stream],
                         ('source', 'target'))

# Filter sequences that are too long
training_stream = Filter(training_stream,
                         predicate=_too_long(seq_len=exp_config['seq_len']))

# sampling_transformer = MTSampleStreamTransformer(sampling_func, fake_score, num_samples=5)
sampling_transformer = MTSampleStreamTransformer(sampling_func, sentence_level_bleu, num_samples=exp_config['n_samples'])

training_stream = Mapping(training_stream, sampling_transformer, add_sources=('samples', 'scores'))


class FlattenSamples(Transformer):
    """Adds padding to variable-length sequences.

    When your batches consist of variable-length sequences, use this class
    to equalize lengths by adding zero-padding. To distinguish between
    data and padding masks can be produced. For each data source that is
    masked, a new source will be added. This source will have the name of
    the original source with the suffix ``_mask`` (e.g. ``features_mask``).

    Elements of incoming batches will be treated as numpy arrays (i.e.
    using `numpy.asarray`). If they have more than one dimension,
    all dimensions except length, that is the first one, must be equal.

    Parameters
    ----------
    data_stream : :class:`AbstractDataStream` instance
        The data stream to wrap

    """
    def __init__(self, data_stream, **kwargs):
        if data_stream.produces_examples:
            raise ValueError('the wrapped data stream must produce batches of '
                             'examples, not examples')
        super(FlattenSamples, self).__init__(
            data_stream, produces_examples=False, **kwargs)

#         if mask_dtype is None:
#             self.mask_dtype = config.floatX
#         else:
#             self.mask_dtype = mask_dtype

    @property
    def sources(self):
        return self.data_stream.sources
#         sources = []
#         for source in self.data_stream.sources:
#             sources.append(source)
#             if source in self.mask_sources:
#                 sources.append(source + '_mask')
#         return tuple(sources)

    def transform_batch(self, batch):
        batch_with_flattened_samples = []
        for i, (source, source_batch) in enumerate(
                zip(self.data_stream.sources, batch)):
#             if source not in self.mask_sources:
#                 batch_with_masks.append(source_batch)
#                 continue
            if source == 'samples':
                flattened_samples = []
                for ins in source_batch:
                    for sample in ins:
                        flattened_samples.append(sample)
                batch_with_flattened_samples.append(flattened_samples)
            else:
                batch_with_flattened_samples.append(source_batch)

        return tuple(batch_with_flattened_samples)


class CopySourceNTimes(Transformer):
    """Duplicate the source N times to match the number of samples

    We need this transformer because the attention model expects one source sequence for each
    target sequence, but in the sampling case there are effectively (instances*sample_size) target sequences

    Parameters
    ----------
    data_stream : :class:`AbstractDataStream` instance
        The data stream to wrap
    n_samples : int -- the number of samples that were generated for each source sequence

    """
    def __init__(self, data_stream, n_samples=5, **kwargs):
        if data_stream.produces_examples:
            raise ValueError('the wrapped data stream must produce batches of '
                             'examples, not examples')
        self.n_samples = n_samples

        super(CopySourceNTimes, self).__init__(
            data_stream, produces_examples=False, **kwargs)


    @property
    def sources(self):
        return self.data_stream.sources

    def transform_batch(self, batch):
        batch_with_expanded_source = []
        for i, (source, source_batch) in enumerate(
                zip(self.data_stream.sources, batch)):
            if source == 'source':
#                 copy each source seqoyuence self.n_samples times, but keep the tensor 2d

                expanded_source = []
                for ins in source_batch:
                    expanded_source.extend([ins for _ in range(self.n_samples)])

                batch_with_expanded_source.append(expanded_source)
            else:
                batch_with_expanded_source.append(source_batch)

        return tuple(batch_with_expanded_source)



# Replace out of vocabulary tokens with unk token
# training_stream = Mapping(training_stream,
#                  _oov_to_unk(src_vocab_size=exp_config['src_vocab_size'],
#                              trg_vocab_size=exp_config['trg_vocab_size'],
#                              unk_id=exp_config['unk_id']))

# Build a batched version of stream to read k batches ahead
training_stream = Batch(training_stream,
               iteration_scheme=ConstantScheme(
                   exp_config['batch_size']*exp_config['sort_k_batches']))

# Sort all samples in the read-ahead batch
training_stream = Mapping(training_stream, SortMapping(_length))

# Convert it into a stream again
training_stream = Unpack(training_stream)

# Construct batches from the stream with specified batch size
training_stream = Batch(
    training_stream, iteration_scheme=ConstantScheme(exp_config['batch_size']))

# Pad sequences that are short
# IDEA: add a transformer which flattens the target samples before we add the mask
flat_sample_stream = FlattenSamples(training_stream)

expanded_source_stream = CopySourceNTimes(flat_sample_stream, n_samples=exp_config['n_samples'])

# TODO: some sources can be excluded from the padding Op, but since blocks matches sources with input variable
# TODO: names, it's not critical
masked_stream = PaddingWithEOS(
    expanded_source_stream, [exp_config['src_vocab_size'] - 1, exp_config['trg_vocab_size'] - 1])


def create_model(encoder, decoder):

    # Create Theano variables
    logger.info('Creating theano variables')
    source_sentence = tensor.lmatrix('source')
    source_sentence_mask = tensor.matrix('source_mask')

#     target_samples = tensor.tensor3('samples').astype('int64')
#     target_samples_mask = tensor.tensor3('target_samples_mask').astype('int64')
    samples = tensor.lmatrix('samples')
    samples_mask = tensor.matrix('samples_mask')

    # scores is (batch, samples)
    scores = tensor.matrix('scores')
    # We don't need a scores mask because there should be the same number of scores for each instance
    # num samples is a hyperparameter of the model

    # the name is important to make sure pre-trained params get loaded correctly
#     decoder.name = 'decoder'

    # This is the part that is different for the MinimumRiskSequenceGenerator
    cost = decoder.expected_cost(
        encoder.apply(source_sentence, source_sentence_mask),
        source_sentence_mask, samples, samples_mask, scores)


    return cost