In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset, Dataset

from datasets import load_dataset

import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

from nltk import word_tokenize
import numpy as np
import random
import pickle
import os
import errno

In [2]:
# constant
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 64
MAX_SEQ_LEN = 37
# paths
COCO_DATASET = '../../TextGAN/dataset/image_coco.txt'
JIGSAW_DATASET = ['../../detox/emnlp2021/data/train/train_normal',
                  '../../detox/emnlp2021/data/train/train_toxic']

### First Reproduce RelGAN then adapt with Jigsaw toxic data

In [3]:
# build vocabulary
def load_file(path):    

    with open(path, 'r') as r:
        lines = r.read().split('\n')

    return lines

def build_vocab(sentences):
    '''
    :params sentences: list of merged sentences
    '''
    # tokenize sentences
    tokens = []
    for s in sentences:
        tokens += word_tokenize(s)

    # map id to token
    vocab_index_word = dict(enumerate(set(tokens), 3))
    vocab_index_word[0] = '<PAD>'
    vocab_index_word[1] = '<BOS>'
    vocab_index_word[2] = '<EOS>'

    # map back token to id
    vocab_word_index = {w: i for i, w in vocab_index_word.items()}

    return vocab_index_word, vocab_word_index

def encode(sentences, vocab, padding=True, max_seq_len='max'):
    '''
    if padding True, then return tensor with size len(sentences) * max_seq_len
    if padding False, then return list of encoded tokens
    '''
    seq_len = max_seq_len
    if max_seq_len == 'max':
        seq_len = max([len(x) for x in sentences])

    # encode sentences
    encoded = []
    for s in sentences:
        tokens = word_tokenize(s)
        tokens = [vocab[token] for token in tokens]
        encoded.append(tokens)

    if padding:
        # create empty tensor with size len(sentences) * max_seq_len
        features = torch.zeros((len(encoded), seq_len)).long()
        # now fill tensor
        for i, row in enumerate(encoded):            
            features[i, :len(row)] = torch.tensor(row)[:seq_len]
    else:
        features = encoded
    
    return features

def prepare_generator_dataset(samples):
    inp = torch.zeros_like(samples).long()
    target = samples
    inp[:, 0] = 1   # the <BOS> token
    inp[:, 1:] = target[:, :-1]

    return inp, target

def prepare_discriminator_dataset(samples_pos, samples_neg):
    inp = torch.cat([samples_pos, samples_neg], dim=0).long().detach()
    target = torch.ones(len(inp)).float()
    target[:len(samples_pos)] = 0

    # shuffle
    perm = torch.randperm(len(inp))
    inp = inp[perm]
    target = target[perm]

    return inp, target

def read_data(samples, samples_neg=None):
    
    if isinstance(samples, str):
        samples = load_file(COCO_DATASET)
        vocab_i2w, vocab_w2i = build_vocab(samples)
        samples = encode(samples, vocab_w2i, max_seq_len=MAX_SEQ_LEN)
    
    if samples_neg == None:
        inp, tgt = prepare_generator_dataset(samples)
    else:

        if isinstance(samples_neg, str):
            samples_neg = load_file(COCO_DATASET)
            vocab_i2w, vocab_w2i = build_vocab(samples_neg)
            samples_neg = encode(samples_neg, vocab_w2i, max_seq_len=MAX_SEQ_LEN)

        inp, tgt = prepare_discriminator_dataset(samples, samples_neg)

    all_data = [{'input': i, 'target': t} for i, t in zip(inp, tgt)]

    return all_data


In [4]:
vocab_w2i, vocab_i2w = build_vocab(load_file(COCO_DATASET))
len(vocab_w2i)

4659

In [5]:
# dataloaders
class GANDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return len(self.data)

def get_dataloader(samples, samples_neg=None, batch_size=BATCH_SIZE, shuffle=True):
    if samples_neg == None:
        return DataLoader(
            GANDataset(read_data(samples)),
            batch_size=batch_size, 
            shuffle=shuffle
        )
    else:
        return DataLoader(
            GANDataset(read_data(samples, samples_neg)),
            batch_size=batch_size, 
            shuffle=shuffle
        )

def get_random_batch(dataloader):
    idx = random.randint(0, len(dataloader) - 1)
    return list(dataloader)[idx]

In [6]:
genloader = get_dataloader(COCO_DATASET)
disloader = get_dataloader(COCO_DATASET, COCO_DATASET)

In [7]:
# network architecture (relational memory)
class RelationalMemory(nn.Module):
    """
    Constructs a `RelationalMemory` object.
    This class is same as the RMC from relational_rnn_models.py, but without language modeling-specific variables.
    Args:
      mem_slots: The total number of memory slots to use.
      head_size: The size of an attention head.
      input_size: The size of input per step. i.e. the dimension of each input vector
      num_heads: The number of attention heads to use. Defaults to 1.
      num_blocks: Number of times to compute attention per time step. Defaults
        to 1.
      forget_bias: Bias to use for the forget gate, assuming we are using
        some form of gating. Defaults to 1.
      input_bias: Bias to use for the input gate, assuming we are using
        some form of gating. Defaults to 0.
      gate_style: Whether to use per-element gating ('unit'),
        per-memory slot gating ('memory'), or no gating at all (None).
        Defaults to `unit`.
      attention_mlp_layers: Number of layers to use in the post-attention
        MLP. Defaults to 2.
      key_size: Size of vector to use for key & query vectors in the attention
        computation. Defaults to None, in which case we use `head_size`.
      name: Name of the module.

      # NEW flag for this class
      return_all_outputs: Whether the model returns outputs for each step (like seq2seq) or only the final output.
    Raises:
      ValueError: gate_style not one of [None, 'memory', 'unit'].
      ValueError: num_blocks is < 1.
      ValueError: attention_mlp_layers is < 1.
    """

    def __init__(self, mem_slots, head_size, input_size, num_heads=1, num_blocks=1, forget_bias=1., input_bias=0.,
                 gate_style='unit', attention_mlp_layers=2, key_size=None, return_all_outputs=False):
        super(RelationalMemory, self).__init__()

        ########## generic parameters for RMC ##########
        self.mem_slots = mem_slots
        self.head_size = head_size
        self.num_heads = num_heads
        self.mem_size = self.head_size * self.num_heads

        # a new fixed params needed for pytorch port of RMC
        # +1 is the concatenated input per time step : we do self-attention with the concatenated memory & input
        # so if the mem_slots = 1, this value is 2
        self.mem_slots_plus_input = self.mem_slots + 1

        if num_blocks < 1:
            raise ValueError('num_blocks must be >=1. Got: {}.'.format(num_blocks))
        self.num_blocks = num_blocks

        if gate_style not in ['unit', 'memory', None]:
            raise ValueError(
                'gate_style must be one of [\'unit\', \'memory\', None]. got: '
                '{}.'.format(gate_style))
        self.gate_style = gate_style

        if attention_mlp_layers < 1:
            raise ValueError('attention_mlp_layers must be >= 1. Got: {}.'.format(
                attention_mlp_layers))
        self.attention_mlp_layers = attention_mlp_layers

        self.key_size = key_size if key_size else self.head_size

        ########## parameters for multihead attention ##########
        # value_size is same as head_size
        self.value_size = self.head_size
        # total size for query-key-value
        self.qkv_size = 2 * self.key_size + self.value_size
        self.total_qkv_size = self.qkv_size * self.num_heads  # denoted as F

        # each head has qkv_sized linear projector
        # just using one big param is more efficient, rather than this line
        # self.qkv_projector = [nn.Parameter(torch.randn((self.qkv_size, self.qkv_size))) for _ in range(self.num_heads)]
        self.qkv_projector = nn.Linear(self.mem_size, self.total_qkv_size)
        self.qkv_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.total_qkv_size])

        # used for attend_over_memory function
        self.attention_mlp = nn.ModuleList([nn.Linear(self.mem_size, self.mem_size)] * self.attention_mlp_layers)
        self.attended_memory_layernorm = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size])
        self.attended_memory_layernorm2 = nn.LayerNorm([self.mem_slots_plus_input, self.mem_size])

        ########## parameters for initial embedded input projection ##########
        self.input_size = input_size
        self.input_projector = nn.Linear(self.input_size, self.mem_size)

        ########## parameters for gating ##########
        self.num_gates = 2 * self.calculate_gate_size()
        self.input_gate_projector = nn.Linear(self.mem_size, self.num_gates)
        self.memory_gate_projector = nn.Linear(self.mem_size, self.num_gates)
        # trainable scalar gate bias tensors
        self.forget_bias = nn.Parameter(torch.tensor(forget_bias, dtype=torch.float32))
        self.input_bias = nn.Parameter(torch.tensor(input_bias, dtype=torch.float32))

        ########## number of outputs returned #####
        self.return_all_outputs = return_all_outputs

    def repackage_hidden(self, h):
        """Wraps hidden states in new Tensors, to detach them from their history."""
        # needed for truncated BPTT, called at every batch forward pass
        if isinstance(h, torch.Tensor):
            return h.detach()
        else:
            return tuple(self.repackage_hidden(v) for v in h)

    def initial_state(self, batch_size, trainable=False):
        """
        Creates the initial memory.
        We should ensure each row of the memory is initialized to be unique,
        so initialize the matrix to be the identity. We then pad or truncate
        as necessary so that init_state is of size
        (batch_size, self.mem_slots, self.mem_size).
        Args:
          batch_size: The size of the batch.
          trainable: Whether the initial state is trainable. This is always True.
        Returns:
          init_state: A truncated or padded matrix of size
            (batch_size, self.mem_slots, self.mem_size).
        """
        init_state = torch.stack([torch.eye(self.mem_slots) for _ in range(batch_size)])

        # pad the matrix with zeros
        if self.mem_size > self.mem_slots:
            difference = self.mem_size - self.mem_slots
            pad = torch.zeros((batch_size, self.mem_slots, difference))
            init_state = torch.cat([init_state, pad], -1)

        # truncation. take the first 'self.mem_size' components
        elif self.mem_size < self.mem_slots:
            init_state = init_state[:, :, :self.mem_size]

        return init_state

    def multihead_attention(self, memory):
        """
        Perform multi-head attention from 'Attention is All You Need'.
        Implementation of the attention mechanism from
        https://arxiv.org/abs/1706.03762.
        Args:
          memory: Memory tensor to perform attention on.
        Returns:
          new_memory: New memory tensor.
        """

        # First, a simple linear projection is used to construct queries
        qkv = self.qkv_projector(memory)
        # apply layernorm for every dim except the batch dim
        qkv = self.qkv_layernorm(qkv)

        # mem_slots needs to be dynamically computed since mem_slots got concatenated with inputs
        # example: self.mem_slots=10 and seq_length is 3, and then mem_slots is 10 + 1 = 11 for each 3 step forward pass
        # this is the same as self.mem_slots_plus_input, but defined to keep the sonnet implementation code style
        mem_slots = memory.shape[1]  # denoted as N

        # split the qkv to multiple heads H
        # [B, N, F] => [B, N, H, F/H]
        qkv_reshape = qkv.view(qkv.shape[0], mem_slots, self.num_heads, self.qkv_size)

        # [B, N, H, F/H] => [B, H, N, F/H]
        qkv_transpose = qkv_reshape.permute(0, 2, 1, 3)

        # [B, H, N, key_size], [B, H, N, key_size], [B, H, N, value_size]
        q, k, v = torch.split(qkv_transpose, [self.key_size, self.key_size, self.value_size], -1)

        # scale q with d_k, the dimensionality of the key vectors
        q = q * (self.key_size ** -0.5)

        # make it [B, H, N, N]
        dot_product = torch.matmul(q, k.permute(0, 1, 3, 2))
        weights = F.softmax(dot_product, dim=-1)

        # output is [B, H, N, V]
        output = torch.matmul(weights, v)

        # [B, H, N, V] => [B, N, H, V] => [B, N, H*V]
        output_transpose = output.permute(0, 2, 1, 3).contiguous()
        new_memory = output_transpose.view((output_transpose.shape[0], output_transpose.shape[1], -1))

        return new_memory

    @property
    def state_size(self):
        return [self.mem_slots, self.mem_size]

    @property
    def output_size(self):
        return self.mem_slots * self.mem_size

    def calculate_gate_size(self):
        """
        Calculate the gate size from the gate_style.
        Returns:
          The per sample, per head parameter size of each gate.
        """
        if self.gate_style == 'unit':
            return self.mem_size
        elif self.gate_style == 'memory':
            return 1
        else:  # self.gate_style == None
            return 0

    def create_gates(self, inputs, memory):
        """
        Create input and forget gates for this step using `inputs` and `memory`.
        Args:
          inputs: Tensor input.
          memory: The current state of memory.
        Returns:
          input_gate: A LSTM-like insert gate.
          forget_gate: A LSTM-like forget gate.
        """
        # We'll create the input and forget gates at once. Hence, calculate double
        # the gate size.

        # equation 8: since there is no output gate, h is just a tanh'ed m
        memory = torch.tanh(memory)

        # sonnet uses this, but i think it assumes time step of 1 for all cases
        # if inputs is (B, T, features) where T > 1, this gets incorrect
        # inputs = inputs.view(inputs.shape[0], -1)

        # fixed implementation
        if len(inputs.shape) == 3:
            if inputs.shape[1] > 1:
                raise ValueError(
                    "input seq length is larger than 1. create_gate function is meant to be called for each step, with input seq length of 1")
            inputs = inputs.view(inputs.shape[0], -1)
            # matmul for equation 4 and 5
            # there is no output gate, so equation 6 is not implemented
            gate_inputs = self.input_gate_projector(inputs)
            gate_inputs = gate_inputs.unsqueeze(dim=1)
            gate_memory = self.memory_gate_projector(memory)
        else:
            raise ValueError("input shape of create_gate function is 2, expects 3")

        # this completes the equation 4 and 5
        gates = gate_memory + gate_inputs
        gates = torch.split(gates, split_size_or_sections=int(gates.shape[2] / 2), dim=2)
        input_gate, forget_gate = gates
        assert input_gate.shape[2] == forget_gate.shape[2]

        # to be used for equation 7
        input_gate = torch.sigmoid(input_gate + self.input_bias)
        forget_gate = torch.sigmoid(forget_gate + self.forget_bias)

        return input_gate, forget_gate

    def attend_over_memory(self, memory):
        """
        Perform multiheaded attention over `memory`.
            Args:
              memory: Current relational memory.
            Returns:
              The attended-over memory.
        """
        for _ in range(self.num_blocks):
            attended_memory = self.multihead_attention(memory)

            # Add a skip connection to the multiheaded attention's input.
            memory = self.attended_memory_layernorm(memory + attended_memory)

            # add a skip connection to the attention_mlp's input.
            attention_mlp = memory
            for i, l in enumerate(self.attention_mlp):
                attention_mlp = self.attention_mlp[i](attention_mlp)
                attention_mlp = F.relu(attention_mlp)
            memory = self.attended_memory_layernorm2(memory + attention_mlp)

        return memory

    def forward_step(self, inputs, memory, treat_input_as_matrix=False):
        """
        Forward step of the relational memory core.
        Args:
          inputs: Tensor input.
          memory: Memory output from the previous time step.
          treat_input_as_matrix: Optional, whether to treat `input` as a sequence
            of matrices. Default to False, in which case the input is flattened
            into a vector.
        Returns:
          output: This time step's output.
          next_memory: The next version of memory to use.
        """

        if treat_input_as_matrix:
            # keep (Batch, Seq, ...) dim (0, 1), flatten starting from dim 2
            inputs = inputs.view(inputs.shape[0], inputs.shape[1], -1)
            # apply linear layer for dim 2
            inputs_reshape = self.input_projector(inputs)
        else:
            # keep (Batch, ...) dim (0), flatten starting from dim 1
            inputs = inputs.view(inputs.shape[0], -1)
            # apply linear layer for dim 1
            inputs = self.input_projector(inputs)
            # unsqueeze the time step to dim 1
            inputs_reshape = inputs.unsqueeze(dim=1)

        memory_plus_input = torch.cat([memory, inputs_reshape], dim=1)
        next_memory = self.attend_over_memory(memory_plus_input)

        # cut out the concatenated input vectors from the original memory slots
        n = inputs_reshape.shape[1]
        next_memory = next_memory[:, :-n, :]

        if self.gate_style == 'unit' or self.gate_style == 'memory':
            # these gates are sigmoid-applied ones for equation 7
            input_gate, forget_gate = self.create_gates(inputs_reshape, memory)
            # equation 7 calculation
            next_memory = input_gate * torch.tanh(next_memory)
            next_memory += forget_gate * memory

        output = next_memory.view(next_memory.shape[0], -1)

        return output, next_memory

    def forward(self, inputs, memory, treat_input_as_matrix=False):
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.

        # for loop implementation of (entire) recurrent forward pass of the model
        # inputs is batch first [batch, seq], and output logit per step is [batch, vocab]
        # so the concatenated logits are [seq * batch, vocab]

        # targets are flattened [seq, batch] => [seq * batch], so the dimension is correct

        # memory = self.repackage_hidden(memory)
        logit = 0
        logits = []
        # shape[1] is seq_lenth T
        for idx_step in range(inputs.shape[1]):
            logit, memory = self.forward_step(inputs[:, idx_step], memory)
            logits.append(logit.unsqueeze(1))
        logits = torch.cat(logits, dim=1)

        if self.return_all_outputs:
            return logits, memory
        else:
            return logit.unsqueeze(1), memory

# ########## DEBUG: unit test code ##########
# input_size = 32
# seq_length = 20
# batch_size = 32
# num_tokens = 5000
# model = RelationalMemory(mem_slots=1, head_size=512, input_size=input_size, num_heads=2)
# model_memory = model.initial_state(batch_size=batch_size)
#
# # random input
# random_input = torch.randn((32, seq_length, input_size))
# # random targets
# random_targets = torch.randn((32, seq_length, input_size))
#
# # take a one step forward
# logit, next_memory = model(random_input, model_memory)
# print(next_memory.shape)
# print(logit.shape)


In [8]:
# network architecture (LSTM/RMC Generator)
class RelGAN(nn.Module):
    def __init__(self, mem_slots, num_heads, head_size, embedding_dim, hidden_dim, vocab_size, max_seq_len, padding_idx, device=DEVICE):    
        super(RelGAN, self).__init__()

        # ---------------- #
        # model properties #
        # ---------------- #
        self.name = 'RelGAN'
        self.vocab_size = vocab_size
        self.embed_size = embedding_dim
        self.hidden_size = hidden_dim
        self.max_seq_len = max_seq_len
        self.padding_idx = padding_idx
        self.mem_slots = mem_slots
        self.num_heads = num_heads
        self.head_size = head_size
        self.temperature = 1.0
        self.theta = None
        self.device = device

        # ------------ #
        # model layers #
        # ------------ #
        # LSTM
        # self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        # self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        # self.lstm2out = nn.Linear(hidden_dim, vocab_size)

        # RMC
        self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=padding_idx)
        self.hidden_size = mem_slots * num_heads * head_size
        self.lstm = RelationalMemory(mem_slots=mem_slots, head_size=head_size, input_size=embedding_dim,
                                     num_heads=num_heads, return_all_outputs=True)
        self.lstm2out = nn.Linear(self.hidden_size, self.vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)

        self.to(self.device)


    def forward(self, x, h, need_hidden=False):
        emb = self.embeddings(x)
        if len(x.size()) == 1:
            emb = emb.unsqueeze(1)  # batch_size * 1 * embedding_dim        
        out, hidden = self.lstm(emb, h)  # out: batch_size * seq_len * hidden_dim        
        out = out.contiguous().view(-1, self.hidden_size)  # out: (batch_size * len) * hidden_dim        
        out = self.lstm2out(out)  # batch_size * seq_len * vocab_size        
        out = self.temperature * out  # temperature
        pred = self.softmax(out)        

        if need_hidden:
            return pred, hidden
        else:
            return pred


    def step(self, x, h):
        '''
        RelGAN step forward
        :param inp: [batch_size]
        :param hidden: memory size
        :return: pred, hidden, next_token, next_token_onehot, next_o
            - pred: batch_size * vocab_size, use for adversarial training backward
            - hidden: next hidden
            - next_token: [batch_size], next sentence token
            - next_token_onehot: batch_size * vocab_size, not used yet
            - next_o: batch_size * vocab_size, not used yet
        '''        
        emb = self.embeddings(x).unsqueeze(1)
        out, hidden = self.lstm(emb, h)
        gumbel_t = self.add_gumbel(self.lstm2out(out.squeeze(1)))
        next_token = torch.argmax(gumbel_t, dim=1).detach()
        # next_token_onehot = F.one_hot(next_token, cfg.vocab_size).float()  # not used yet
        next_token_onehot = None

        pred = F.softmax(gumbel_t * self.temperature, dim=-1)  # batch_size * vocab_size
        # next_o = torch.sum(next_token_onehot * pred, dim=1)  # not used yet
        next_o = None

        return pred, hidden, next_token, next_token_onehot, next_o


    def sample(self, num_samples, batch_size, one_hot=False, start_letter=1):
        """
        Sample from RelGAN Generator
        - one_hot: if return pred of RelGAN, used for adversarial training
        :return:
            - all_preds: batch_size * seq_len * vocab_size, only use for a batch
            - samples: all samples
        """
        num_batch = num_samples // batch_size + 1 if num_samples != batch_size else 1
        samples = torch.zeros(num_batch * batch_size, self.max_seq_len, device=self.device).long()
        if one_hot:
            all_preds = torch.zeros(batch_size, self.max_seq_len, self.vocab_size, device=self.device)

        for b in range(num_batch):
            hidden = self.init_hidden(batch_size).to(self.device)
            inp = torch.LongTensor([start_letter] * batch_size).to(self.device)            

            for i in range(self.max_seq_len):
                pred, hidden, next_token, _, _ = self.step(inp, hidden)
                samples[b * batch_size:(b + 1) * batch_size, i] = next_token
                if one_hot:
                    all_preds[:, i] = pred
                inp = next_token
        samples = samples[:num_samples]  # num_samples * seq_len

        if one_hot:
            return all_preds  # batch_size * seq_len * vocab_size
        return samples


    def init_hidden(self, batch_size=BATCH_SIZE):
        """init RMC memory"""
        memory = self.lstm.initial_state(batch_size)
        memory = self.lstm.repackage_hidden(memory)  # detch memory at first        
        return memory.to(self.device)

    
    def add_gumbel(self, o_t, eps=1e-10):
        """Add o_t by a vector sampled from Gumbel(0,1)"""
        u = torch.zeros(o_t.size(), device=self.device)
        
        u.uniform_(0, 1)
        g_t = -torch.log(-torch.log(u + eps) + eps)
        gumbel_t = o_t + g_t
        return gumbel_t

In [9]:
MEM_SLOTS = 1
NUM_HEADS = 2
HEAD_SIZE = 256
GEN_EMBED_DIM = 32
GEN_HIDDEN_DIM = 32

In [10]:
gen = RelGAN(MEM_SLOTS, NUM_HEADS, HEAD_SIZE, embedding_dim=GEN_EMBED_DIM, 
             hidden_dim=GEN_HIDDEN_DIM, vocab_size=len(vocab_w2i), 
             max_seq_len=MAX_SEQ_LEN, padding_idx=0)

In [17]:
# network architecture discriminator
class CNNDiscriminator(nn.Module):
    def __init__(self, embed_dim, max_seq_len, num_rep, vocab_size, padding_idx,
                 filter_sizes, num_filters, dropout=0.25, device=DEVICE):
    # def __init__(self, embed_dim, vocab_size, filter_sizes, num_filters, padding_idx, dropout=0.2):
        super(CNNDiscriminator, self).__init__()

        # ---------------- #
        # model properties #
        # ---------------- #
        self.name = 'CNNDiscriminator'
        self.embed_size = embed_dim
        self.max_seq_len = max_seq_len
        self.feature_dim = sum(num_filters)
        self.emb_dim_single = int(embed_dim / num_rep)
        self.vocab_size = vocab_size
        self.padding_idx = padding_idx        
        
        # ------------ #
        # model layers #
        # ------------ #
        self.embeddings = nn.Linear(vocab_size, embed_dim, bias=False)        
        self.convs = nn.ModuleList([
            nn.Conv2d(1, n, (f, self.emb_dim_single), stride=(1, self.emb_dim_single)) for (n, f) in 
            zip(num_filters, filter_sizes)
        ])
        self.highway = nn.Linear(self.feature_dim, self.feature_dim)
        self.feature2out = nn.Linear(self.feature_dim, 100)
        self.out2logits = nn.Linear(100, 1)
        self.dropout = nn.Dropout(dropout)

        self.to(device)

    def forward(self, x):
        emb = self.embeddings(x).unsqueeze(1)

        convs = [F.relu(conv(emb)) for conv in self.convs]
        pools = [F.max_pool2d(con, (con.size(2), 1)).squeeze(2) for con in convs]
        pred = torch.cat(pools, 1)
        pred = pred.permute(0, 2, 1).contiguous().view(-1, self.feature_dim)
        highway = self.highway(pred)
        pred = torch.sigmoid(highway) * F.relu(highway) + (1. - torch.sigmoid(highway)) * pred

        pred = self.feature2out(self.dropout(pred))
        logits = self.out2logits(pred).squeeze(1)

        return logits

In [18]:
DIS_EMBED_DIM = 64
NUM_REP = 64

filter_sizes = [2, 3, 4, 5]
num_filters = [300, 300, 300, 300]

In [19]:
disc = CNNDiscriminator(DIS_EMBED_DIM, MAX_SEQ_LEN, NUM_REP, len(vocab_w2i), padding_idx=2,
                        filter_sizes=filter_sizes, num_filters=num_filters)

In [14]:
# pretraining generator
gen_criterion = nn.NLLLoss()
gen_optim = Adam(gen.parameters(), lr=0.01)
gen_pretrain_epochs = 150
gen_clip_norm = 5.0
gen_loss_min = torch.inf

gen.to(DEVICE)

epochloop = tqdm(range(gen_pretrain_epochs), position=0, desc='Training...', leave=True)

for e in epochloop:
    
    gen.train()
    train_loss = 0

    for i, batch in enumerate(genloader):
        feature, target = batch['input'].to(DEVICE), batch['target'].to(DEVICE)
        hidden = gen.init_hidden(batch_size=len(feature)).to(DEVICE)    # use length of feature as batch_size to handle different last batch total item

        pred = gen(feature, hidden)

        loss = gen_criterion(pred, target.view(-1))
        
        gen_optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(gen.parameters(), gen_clip_norm)
        gen_optim.step()

        train_loss += loss.item()

        epochloop.set_postfix_str(f'Batch: {i+1}/{len(genloader)} | Loss: {train_loss/(i+1):.3f}')

    avg_loss = train_loss / len(genloader)

    print(f'Epoch: {e+1}/{gen_pretrain_epochs} | Loss: {avg_loss}')
    if avg_loss <= gen_loss_min:
        torch.save(gen.state_dict(), f'../models/gan/pretrain_{gen.name}.pt')
        gen_loss_min = avg_loss
    else:
        print(f'[WARN] Loss didn\'t improving ({gen_loss_min:.4f} --> {avg_loss:.4f})')

Training...:   0%|          | 0/3 [00:00<?, ?it/s]

Epoch: 1/3 | Loss: 1.7589457589349928
Epoch: 2/3 | Loss: 1.1059965510277232
Epoch: 3/3 | Loss: 0.9907671475106743


In [15]:
# helpers
def get_losses(d_out_real, d_out_fake, loss_type='JS'):
    bce_loss = nn.BCEWithLogitsLoss()

    if loss_type == 'standard':
        d_loss_real = bce_loss(d_out_real, torch.ones_like(d_out_real))
        d_loss_fake = bce_loss(d_out_fake, torch.zeros_like(d_out_fake))
        d_loss = d_loss_real + d_loss_fake

        g_loss = bce_loss(d_out_fake, torch.ones_like(d_out_fake))

    elif loss_type == 'JS':
        d_loss_real = bce_loss(d_out_real, torch.ones_like(d_out_real))
        d_loss_fake = bce_loss(d_out_fake, torch.zeros_like(d_out_fake))
        d_loss = d_loss_real + d_loss_fake

        g_loss = -d_loss_fake

    elif loss_type == 'KL':
        d_loss_real = bce_loss(d_out_real, torch.ones_like(d_out_real))
        d_loss_fake = bce_loss(d_out_fake, torch.zeros_like(d_out_fake))
        d_loss = d_loss_real + d_loss_fake

        g_loss = torch.mean(-d_out_fake)

    elif loss_type == 'hinge':
        d_loss_real = torch.mean(nn.ReLU(1.0 - d_out_real))
        d_loss_fake = torch.mean(nn.ReLU(1,0 + d_out_fake))
        d_loss = d_loss_real + d_loss_fake

        g_loss = -torch.mean(d_out_fake)

    elif loss_type == 'rsgan':
        d_loss = bce_loss(d_out_real - d_out_fake, torch.ones_like(d_out_real))
        g_loss = bce_loss(d_out_fake - d_out_real, torch.ones_like(d_out_fake))

    return g_loss, d_loss


def get_fixed_temperature(temper, i, N, adapt='exp'):
    """A function to set up different temperature control policies"""
    N = 5000

    if adapt == 'no':
        temper_var_np = 1.0  # no increase, origin: temper
    elif adapt == 'lin':
        temper_var_np = 1 + i / (N - 1) * (temper - 1)  # linear increase
    elif adapt == 'exp':
        temper_var_np = temper ** (i / N)  # exponential increase
    elif adapt == 'log':
        temper_var_np = 1 + (temper - 1) / np.log(N) * np.log(i + 1)  # logarithm increase
    elif adapt == 'sigmoid':
        temper_var_np = (temper - 1) * 1 / (1 + np.exp((N / 2 - i) * 20 / N)) + 1  # sigmoid increase
    elif adapt == 'quad':
        temper_var_np = (temper - 1) / (N - 1) ** 2 * i ** 2 + 1
    elif adapt == 'sqrt':
        temper_var_np = (temper - 1) / np.sqrt(N - 1) * np.sqrt(i) + 1
    else:
        raise Exception("Unknown adapt type!")

    return temper_var_np

In [None]:
# adversarial training
adv_train_epoch = 2000
adv_gen_step = 1
adv_disc_step = 5
gen_adv_optim = Adam(gen.parameters(), lr=1e-4)
disc_adv_optim = Adam(disc.parameters(), lr=1e-4)
adv_clip_norm = 5.0
loss_type = 'rsgan'

adv_epochloop = tqdm(range(adv_train_epoch))

for e in adv_epochloop:
    # adv train generator
    adv_gen_loss = 0
    for i in range(adv_gen_step):
        # get random real sample
        real_sample = get_random_batch(genloader)['target'].to(DEVICE)
        gen_sample = gen.sample(BATCH_SIZE, BATCH_SIZE, one_hot=True).to(DEVICE)
        real_sample = F.one_hot(real_sample, len(vocab_w2i)).float()

        # train
        d_out_real = disc(real_sample)
        d_out_fake = disc(gen_sample)
        g_loss, _ = get_losses(d_out_real, d_out_fake, loss_type=loss_type)

        # optimize
        gen_adv_optim.zero_grad()
        g_loss.backward()
        torch.nn.utils.clip_grad_norm_(gen.parameters(), adv_clip_norm)
        adv_gen_loss += g_loss.item()

        adv_epochloop.set_description(f'Generator step: {i+1}/{adv_gen_step} | g_loss: {(g_loss/(i+1)):.3f}')        

    adv_gen_loss = adv_gen_loss / adv_gen_step if adv_gen_step != 0 else 0

    # adv train discriminator
    adv_disc_loss = 0
    for i in range(adv_disc_step):
        # get random real sample
        real_sample = get_random_batch(genloader)['target'].to(DEVICE)
        gen_sample = gen.sample(BATCH_SIZE, BATCH_SIZE, one_hot=True).to(DEVICE)
        real_sample = F.one_hot(real_sample, len(vocab_w2i)).float()

        # train
        d_out_real = disc(real_sample)
        d_out_fake = disc(gen_sample)
        _, d_loss = get_losses(d_out_real, d_out_fake, loss_type=loss_type)

        # optimize
        disc_adv_optim.zero_grad()
        d_loss.backward()
        torch.nn.utils.clip_grad(disc.parameters(), adv_clip_norm)
        adv_disc_loss += d_loss.item()

        adv_epochloop.set_description(f'Discriminator step: {i+1}/{adv_disc_step} | d_loss: {(d_loss/(i+1)):.3f}')

    adv_disc_loss = adv_disc_loss / adv_disc_step if adv_disc_step != 0 else 0

    # update generator temperature
    gen.temperature = get_fixed_temperature(1, e, adv_train_epoch, adapt='exp')
    
    print(f'[ADV] Epoch: {e}/{adv_train_epoch} | g_loss: {adv_gen_loss:.4f}, d_loss: {adv_disc_loss:.4f}, temperature: {gen.temperature}')