In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset

import pytorch_lightning as pl

In [None]:
DEFAULT_LSTM_CONFIG = {
    'vocab_size': 32008,
    'embedding_size': 128,
    'embedding_factor_size': 300,
    'hidden_size': 1024,
    'n_layers': 3
}

DEFAULT_DISCRIMINATOR_CONFIG = {
    'encoder_hidden_size': 1024,
    'hidden_size': 512
}

DEFAULT_GENERATOR_CONFIG = {
    'encoder_hidden_size': 768,
    'hidden_size': 256,
    'max_sequence_length': 256
}

In [None]:
class LSTM_LM(nn.Module):

    def __init__(self, config={}):
        self.config = DEFAULT_LSTM_CONFIG
        self.config.update(config)
        
        self.vocab_size = self.config['vocab_size']
        self.embedding_size = self.config['embedding_size']
        self.embedding_factor_size = self.config['embedding_factor_size']
        self.hidden_size = self.config['hidden_size']
        self.n_layers = self.config['n_layers']
        
        self.embedding = nn.Embedding(self.vocab_size, self.embedding_size)
        self.embedding_linear = nn.Linear(self.embedding_size, self.embedding_factor_size)
        self.rnn = nn.LSTM(self.embedding_factor_size, self.hidden_size,
            batch_first=True,
            bidirectional=True, 
            n_layers=self.n_layers
        )
        
    def forward(self, tokens):
        x = self.embedding(tokens)
        x = self.embedding_linear(tokens)

        x[x == 0] = -1e9
        x = torch.max(x, 0)[0]
        if x.ndimension() == 3:
            x = x.squeeze(0)
            assert x.ndimension() == 2

        return x

In [None]:
class LMDiscriminatorHead(nn.Module):

    def __init__(self, config={}):
        
        self.config = DEFAULT_DISCRIMINATOR_CONFIG
        self.config.update(config)

        self.encoder_hidden_size = self.config['encoder_hidden_size']
        self.hidden_size = self.config['hidden_size']
        
        self.classifier = nn.Sequential(
            nn.Linear(self.encoder_hidden_size * 2, self.hidden_size),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Linear(self.hidden_size, 1)
        )

    def forward(self, x):
        return self.classifier(x)

In [None]:
class LMGeneratorHead(nn.Module):

    def __init__(self, config={}):
        
        self.config = DEFAULT_GENERATOR_CONFIG
        self.config.update(config)

        self.encoder_hidden_size = self.config['encoder_hidden_size']
        self.hidden_size = self.config['hidden_size']
        self.max_sequence_length = self.config['max_sequence_length']

        self.classifier = nn.Sequential(
            nn.Linear(self.encoder_hidden_size * 2, self.hidden_size),
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.Linear(self.hidden_size, self.max_sequence_length)
        )

    def forward(self, x):
        return self.classifier(x)