In [2]:
CUDA_LAUNCH_BLOCKING = "1"

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import IMDB
from torchtext.vocab import build_vocab_from_iterator
from torchtext.data.functional import to_map_style_dataset

import time
import os

import numpy as np

from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer

In [3]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 2
MIN_FREQ = 10

UNK_IDX = 0
BOS_IDX = 1
EOS_IDX = 2
PAD_IDX = 3
SPEC_TOKENS = ['<UNK>', '<BOS>', '<EOS>', '<PAD>']

lr = 1e-3
lr_decay_every = 1000000
epochs = 10
log_interval = 10

In [4]:
class WordDataset:
    def __init__(self):
        self.tokenizer = get_tokenizer('basic_english')

        train_dataset, test_dataset = iter(IMDB(split=('train', 'test')))
        train_dataset, test_dataset = to_map_style_dataset(train_dataset), to_map_style_dataset(test_dataset)

        self.vocab = build_vocab_from_iterator(self.build_vocab([train_dataset, test_dataset]), specials=SPEC_TOKENS)
        self.vocab.set_default_index(self.vocab['<UNK>'])
        self.vocab_length = len(self.vocab.get_itos())

        self.vectorizer = CountVectorizer(vocabulary=self.vocab.get_itos(), tokenizer=self.tokenizer)

        self.text_transform = lambda x: self.vocab(self.tokenizer(x))

        self.train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=self.vectorize_batch)
        self.test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, collate_fn=self.vectorize_batch)

    def build_vocab(self, datasets):
        for dataset in datasets:
            for _, text in dataset:
                yield self.tokenizer(text)

    def vectorize_batch(self, batch):
        label_list, text_list, offsets = [], [], []
        for Y, X in batch:
          label_list.append(self.text_transform(Y))
          tmp_X = torch.tensor(self.text_transform(X), dtype=torch.int64)
          text_list.append(torch.cat([torch.tensor([BOS_IDX]), tmp_X, torch.tensor([EOS_IDX])]))
        label_list = torch.tensor(label_list, dtype=torch.int64)
        text_list = torch.cat(text_list)
        return label_list.to(DEVICE), text_list.to(DEVICE)

    def text_translate(self, x):
        return ' '.join([self.vocab.get_itos()[i] for i in x])

In [9]:
class Generator(nn.Module):
  def __init__(self, emb_dim, h_dim, vocab_len, max_seq_len):
        super(Generator, self).__init__()

        self.h_dim = h_dim
        self.emb_dim = emb_dim
        self.max_seq_len = max_seq_len

        # Word embeddings layer
        self.embedding = nn.Embedding(vocab_len, emb_dim, padding_idx=PAD_IDX)

        # Generator
        self.lstm = nn.LSTM(emb_dim, h_dim, batch_first=True)
        self.out = nn.Linear(h_dim, vocab_len)
        self.softmax = nn.LogSoftmax(dim=-1)

  def forward(self, text, hidden):
        emb = self.embedding(text)
        if len(text.size()) == 1:
            emb = emb.unsqueeze(1)

        out, hidden = self.lstm(emb, hidden)
        out = out.contiguous().view(-1, self.h_dim)
        out = self.out(out)
        
        pred = self.softmax(out)
        
        return pred

In [6]:
class Discriminator(nn.Module):
  def __init__(self, h_dim, c_dim, emb_dim, vocab_len, dropout=0.3):
    super(Discriminator, self).__init__()
    
    self.h_dim = h_dim
    self.emb_dim = emb_dim

    # Word embeddings layer
    self.embedding = nn.Embedding(vocab_len, emb_dim, padding_idx=PAD_IDX)

    # Disriminator
    self.gru = nn.GRU(emb_dim, h_dim, num_layers=2, bidirectional=True, dropout=dropout)
    self.hidden = nn.Linear(4 * h_dim, c_dim)
    self.out = nn.Linear(c_dim, 2)
    self.dropout = nn.Dropout(dropout)

  def forward(self, inputs):
    label, text = inputs

    label_output = self.embedding(label)
    text_output = self.embedding(text)
    
    concat = torch.cat((text_output, label_output), dim=0)

    hidden = torch.zeros(4, concat.size(0), self.h_dim).to(DEVICE)

    _, hidden = self.gru(concat.unsqueeze(0), hidden)
    hidden = hidden.permute(1, 0, 2).contiguous()

    out = self.hidden(hidden.view(-1, 4 * self.h_dim))

    feature = torch.tanh(out)

    pred = self.out(self.dropout(feature))
    
    # pred = self.model(concat)
    
    return pred


In [7]:
dataset = WordDataset()

100%|██████████| 84.1M/84.1M [00:24<00:00, 3.45MB/s]


=== TRAINING ===

In [10]:
discriminator = Discriminator(
    vocab_len=dataset.vocab_length,
    emb_dim=64,
    h_dim=64,
    c_dim=2
).to(DEVICE)

generator = Generator(
    vocab_len=dataset.vocab_length,
    max_seq_len=15,
    emb_dim=64,
    h_dim=64
).to(DEVICE)

generator_optimizer = optim.Adam(generator.parameters(), lr=lr)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=lr)

for e in range(epochs):
    discriminator.train()
    generator.train()

    discriminator_loss_list = []
    generator_loss_list = []

    for idx, (labels, texts) in enumerate(dataset.train_loader):
      print('L:', labels.shape)
      print('T:', texts.shape)

      real_label = torch.ones(texts.size(0) + BATCH_SIZE, 2).to(DEVICE)
      fake_label = torch.zeros(texts.size(0) + BATCH_SIZE, 2).to(DEVICE)

      noise = torch.randn([texts.size(0) + BATCH_SIZE, 100]).to(DEVICE)
      conditional = torch.randint(6, 7, (texts.size(0) + BATCH_SIZE,)).to(DEVICE)

      # ==== FAKE ====
      discriminator_optimizer.zero_grad()

      discriminator_real_loss = F.binary_cross_entropy(discriminator((labels.view(-1), texts)), real_label)
      
      noise_vector = torch.randn(texts.size(0), 64, device=DEVICE)
      noise_vector = noise_vector.to(DEVICE)

      generated_text = generator((noise_vector, labels))

      output = discriminator((generated_text.detach(), labels))
      
      discriminator_fake_loss = F.binary_cross_entropy(output,  fake_label)

      discriminator_total_loss = (discriminator_real_loss + discriminator_fake_loss) / 2
      
      discriminator_loss_list.append(discriminator_total_loss)
      
      discriminator_total_loss.backward()
      
      discriminator_optimizer.step()

      # ==== REAL ====
      generator_optimizer.zero_grad()
      
      generator_loss = F.binary_cross_entropy(discriminator((generated_text, labels)), real_label)
      
      generator_loss_list.append(generator_loss)
      
      generator_loss.backward()
      
      generator_optimizer.step()

L: torch.Size([2, 1])
T: torch.Size([575])


RuntimeError: ignored

In [None]:
!pip install torchtext==0.11.0