In [1]:
!pip install -q parlai
!pip install -q subword_nmt

In [2]:
import torch.nn as nn
import torch.nn.functional as F
import parlai.core.torch_generator_agent as tga
from parlai.core.agents import register_agent, Agent
from parlai.scripts.train_model import TrainModel
from parlai.scripts.display_model import DisplayModel

In [3]:
class Encoder(nn.Module):
    def __init__(self, embeddings, hidden_size):
        super().__init__()
        self.embeddings = embeddings
        self.lstm = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True,
        )

    def forward(self, input_tokens):
        embedded = self.embeddings(input_tokens)
        _output, hidden = self.lstm(embedded)
        return hidden

class Decoder(nn.Module):
    def __init__(self, embeddings, hidden_size):
        super().__init__()
        self.embeddings = embeddings
        self.lstm = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=1,
            batch_first=True,
        )

    def forward(self, input, encoder_state, incr_state=None):
        embedded = self.embeddings(input)
        # previous hidden state of decoder
        if incr_state is None:
            state = encoder_state
        else:
            state = incr_state
        output, incr_state = self.lstm(embedded, state)
        return output, incr_state

In [4]:
class EncoderDecoder(tga.TorchGeneratorModel):
    def __init__(self, dictionary, hidden_size=512):
        super().__init__(
            padding_idx=dictionary[dictionary.null_token],
            start_idx=dictionary[dictionary.start_token],
            end_idx=dictionary[dictionary.end_token],
            unknown_idx=dictionary[dictionary.unk_token],
        )
        self.embeddings = nn.Embedding(len(dictionary), hidden_size)
        self.encoder = Encoder(self.embeddings, hidden_size)
        self.decoder = Decoder(self.embeddings, hidden_size)

    def output(self, decoder_output):
        return F.linear(decoder_output, self.embeddings.weight)

    def reorder_encoder_states(self, encoder_states, indices):
        h, c = encoder_states
        return h[:, indices, :], c[:, indices, :]

    def reorder_decoder_incremental_state(self, incr_state, indices):
        h, c = incr_state
        return h[:, indices, :], c[:, indices, :]

@register_agent("seq2seq_milestone2")
class S2S_MileStone2(tga.TorchGeneratorAgent):
    def build_model(self):
        model = EncoderDecoder(self.dict, 512)
        self._copy_embeddings(model.embeddings.weight, self.opt['embedding_type'])
        return model

In [None]:
TrainModel.main(
    model='seq2seq_milestone2',
    model_file='seq2seq_milestone2/model',
    task='personachat:both_revised',
    embedding_type='fasttext_cc',
    batchsize=7,
    max_train_time=3600,
    dynamic_batching='full',
    truncate = 256,
    metrics = 'ppl,bleu,rouge',
)

In [None]:
DisplayModel.main(
    task='personachat:both_revised',
    model_file='seq2seq_milestone2/model',
    num_examples=35,
)