In [1]:
from allennlp.models.encoder_decoders.composed_seq2seq import ComposedSeq2Seq
from typing import Dict, Tuple

from allennlp.models.encoder_decoders.simple_seq2seq import SimpleSeq2Seq
 
import torch
import torch.nn as nn

from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.attention.additive_attention import AdditiveAttention
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding


from typing import Dict
import csv

from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import LabelField, TextField, Field, MetadataField
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers.token import Token
from allennlp.data.tokenizers.character_tokenizer import CharacterTokenizer
from allennlp.data.tokenizers import Tokenizer
from allennlp.data.iterators import BucketIterator, BasicIterator

from allennlp.models.encoder_decoders.simple_seq2seq import SimpleSeq2Seq

import torch
import torch.nn as nn

from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.attention.additive_attention import AdditiveAttention
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.data.vocabulary import Vocabulary
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.training.trainer import Trainer

In [2]:
class LovelyModel(SimpleSeq2Seq):

    def _prepare_output_projections(self,
                                    last_predictions: torch.Tensor,
                                    state: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        encoder_outputs = state["encoder_outputs"]
        source_mask = state["source_mask"]
        decoder_hidden = state["decoder_hidden"]
        decoder_context = state["decoder_context"]
        embedded_input = self._source_embedder._token_embedders['tokens'](last_predictions)
        if self._attention:
            attended_input = self._prepare_attended_input(decoder_hidden, encoder_outputs, source_mask)
            decoder_input = torch.cat((attended_input, embedded_input), -1)
        else:
            decoder_input = embedded_input
        decoder_hidden, decoder_context = self._decoder_cell(
                decoder_input,
                (decoder_hidden, decoder_context))

        state["decoder_hidden"] = decoder_hidden
        state["decoder_context"] = decoder_context

        output_projections = self._output_projection_layer(decoder_hidden)
        return output_projections, state

    def forward(self,  # type: ignore
                source_tokens: Dict[str, torch.LongTensor],
                target_tokens: Dict[str, torch.LongTensor] = None, **kwargs):
        del kwargs
        return super().forward(source_tokens, target_tokens)


In [3]:
def get_baseline_model(vocab: Vocabulary) -> SimpleSeq2Seq:
    emb_dim = 64
    hidden_dim = 32
    token_embedding = Embedding(
        num_embeddings=vocab.get_vocab_size('tokens'),
        embedding_dim=emb_dim
    )

    word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})
    lstm = PytorchSeq2SeqWrapper(nn.LSTM(emb_dim, hidden_dim, batch_first=True))

    model = LovelyModel(
        vocab=vocab,
        source_embedder=word_embeddings,
        encoder=lstm,
        max_decoding_steps=20,
        # attention=AdditiveAttention(vector_dim=hidden_dim, matrix_dim=hidden_dim)
    )

    return model

In [4]:
class MyReader(DatasetReader):

    def _read(self, file_path):
        with open(cached_path(file_path), "r") as file:
            for line in file:
                yield self.text_to_instance(line.strip())

    def text_to_instance(
        self,
        text: str
    ) -> Instance:
        fields: Dict[str, Field] = {}
        tokenized = text.split()
        fields["source_tokens"] = TextField([Token(word) for word in tokenized], {"tokens": SingleIdTokenIndexer()})
        fields["target_tokens"] = fields["source_tokens"]
        return Instance(fields)

In [5]:
train_path = '/Users/fursovia/Documents/texar/examples/text_style_transfer/data/insurance_cropped/insurance.train.text'
test_path = '/Users/fursovia/Documents/texar/examples/text_style_transfer/data/insurance_cropped/insurance.test.text'



In [6]:
reader = MyReader()

train_dataset = reader.read(train_path)
test_dataset = reader.read(test_path)

266051it [00:07, 35026.94it/s]
266051it [00:08, 32787.92it/s]


In [7]:
vocab = Vocabulary.from_instances(train_dataset + test_dataset)

100%|██████████| 532102/532102 [00:05<00:00, 97137.39it/s] 


In [8]:
iterator = BasicIterator(batch_size=256)
iterator.index_with(vocab)

In [9]:
model = get_baseline_model(vocab)

In [10]:
model

LovelyModel(
  (_source_embedder): BasicTextFieldEmbedder(
    (token_embedder_tokens): Embedding()
  )
  (_encoder): PytorchSeq2SeqWrapper(
    (_module): LSTM(64, 32, batch_first=True)
  )
  (_target_embedder): Embedding()
  (_decoder_cell): LSTMCell(64, 32)
  (_output_projection_layer): Linear(in_features=32, out_features=2110, bias=True)
)

In [11]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [12]:
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_dataset,
    validation_dataset=test_dataset,
    patience=2,
    num_epochs=10
)

In [13]:
results = trainer.train()

loss: 4.8792 ||: 100%|██████████| 1040/1040 [05:50<00:00,  2.97it/s]
BLEU: 0.0006, loss: 3.8495 ||:   7%|▋         | 76/1040 [01:20<17:44,  1.10s/it]

KeyboardInterrupt: 