In [None]:
!pip install torchtext==0.7
!pip install torch==1.6
!pip install protobuf==3.20.0

In [None]:
# 1.9 torch
# 0.10 torchtext

In [None]:
import torch.nn as nn
import mlflow
import torch
import os
from torch.autograd import Variable, grad
from torch import optim
from torchtext.data import Field, BucketIterator
from torchtext.datasets import IMDB
from mlflow import log_metric, log_param
from tqdm.notebook import tqdm

In [None]:
class ResBlock(nn.Module):
    def __init__(self, dim):
        super(ResBlock, self).__init__()
        self.res_block = nn.Sequential(
            nn.ReLU(True),
            nn.Conv1d(dim, dim, 5, padding=2),  # nn.Linear(DIM, DIM),
            nn.ReLU(True),
            nn.Conv1d(dim, dim, 5, padding=2),  # nn.Linear(DIM, DIM),
        )

    def forward(self, input):
        output = self.res_block(input)
        return input + (0.3 * output)


class Generator(nn.Module):
    def __init__(self, dim, seq_len, vocab_size):
        super(Generator, self).__init__()
        self.dim = dim
        self.seq_len = seq_len

        self.fc1 = nn.Linear(128, dim * seq_len)
        self.block = nn.Sequential(
            ResBlock(dim),
            ResBlock(dim),
            ResBlock(dim),
            ResBlock(dim),
            ResBlock(dim),
        )
        self.conv1 = nn.Conv1d(dim, vocab_size, 1)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, noise):
        batch_size = noise.size(0)
        output = self.fc1(noise)
        # (BATCH_SIZE, DIM, SEQ_LEN)
        output = output.view(-1, self.dim, self.seq_len)
        output = self.block(output)
        output = self.conv1(output)
        output = output.transpose(1, 2)
        shape = output.size()
        output = output.contiguous()
        output = output.view(batch_size * self.seq_len, -1)
        output = self.softmax(output)
        # (BATCH_SIZE, SEQ_LEN, len(charmap))
        output = output.view(shape)
        return output.view(shape)


class Discriminator(nn.Module):
    def __init__(self, dim, seq_len, vocab_size):
        super(Discriminator, self).__init__()
        self.dim = dim
        self.seq_len = seq_len

        self.block = nn.Sequential(
            ResBlock(dim),
            ResBlock(dim),
            ResBlock(dim),
            ResBlock(dim),
            ResBlock(dim),
        )
        self.conv1d = nn.Conv1d(vocab_size, dim, 1)
        self.linear = nn.Linear(seq_len * dim, 1)

    def forward(self, input):
        # (BATCH_SIZE, VOCAB_SIZE, SEQ_LEN)
        output = input.transpose(1, 2)
        output = self.conv1d(output)
        output = self.block(output)
        output = output.view(-1, self.seq_len * self.dim)
        output = self.linear(output)
        return output

In [None]:
def penalize_grad(D, real, fake, batch_size, lamb, use_cuda=True):
    """
    lamb: lambda
    """
    alpha = torch.rand(batch_size, 1, 1).expand(real.size())
    if use_cuda:
        alpha = alpha.cuda()
    interpolates = alpha * real + ((1 - alpha) * fake)
    if use_cuda:
        interpolates = interpolates.cuda()
    interpolates = Variable(interpolates, requires_grad=True)
    d_interpolates = D(interpolates)
    ones = torch.ones(d_interpolates.size())
    if use_cuda:
        ones = ones.cuda()
    gradients = grad(outputs=d_interpolates, inputs=interpolates,
                     grad_outputs=ones, create_graph=True,
                     retain_graph=True, only_inputs=True)[0]
    grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lamb
    return grad_penalty


def train_discriminator(D, G, optim_D, real, lamb, batch_size, use_cuda=True):
    D.zero_grad()

    # train with real
    d_real = D(real)
    d_real = d_real.mean()
    d_real.backward(mone)

    # train with fake
    noise = torch.randn(batch_size, 128)
    if use_cuda:
        noise = noise.cuda()
    noise = noise  # freeze G
    fake = G(noise)
    fake = Variable(fake.data)
    inputv = fake
    d_fake = D(inputv)
    d_fake = d_fake.mean()
    d_fake.backward(one)

    grad_penalty = penalize_grad(D, real.data, fake.data,
                                 batch_size, lamb, use_cuda)
    grad_penalty.backward()

    d_loss = d_fake - d_real + grad_penalty
    wasserstein = d_real - d_fake
    optim_D.step()
    return d_loss, wasserstein


def train_generator(D, G, optim_G, batch_size, use_cuda=True):
    G.zero_grad()
    noise = torch.randn(batch_size, 128)
    if use_cuda:
        noise = noise.cuda()
    noisev = Variable(noise)
    fake = G(noisev)
    g = D(fake)
    g = g.mean()
    g.backward(mone)
    g_loss = -g
    optim_G.step()
    return g_loss

In [None]:
def to_onehot(index, vocab_size):
    batch_size, seq_len = index.size(0), index.size(1)
    onehot = torch.FloatTensor(batch_size, seq_len, vocab_size).zero_()
    onehot.scatter_(2, index.data.cpu().unsqueeze(2), 1)
    return onehot


def sample(G, TEXT, batch_size, seq_len, vocab_size, use_cuda=True):
    noise = torch.randn(batch_size, 128)
    if use_cuda:
        noise = noise.cuda()
    noisev = noise
    with torch.no_grad():
        samples = G(noisev)
    samples = samples.view(-1, seq_len, vocab_size)
    _, argmax = torch.max(samples, 2)
    argmax = argmax.cpu().data
    decoded_samples = []
    for i in range(len(argmax)):
        decoded = "".join([TEXT.vocab.itos[s] for s in argmax[i]])
        decoded_samples.append(decoded)
    return decoded_samples

In [None]:
batchs=500000
critic_iters=5
batch_size=8
seq_len=280
lamb=10
lr=1e-4

In [None]:
use_cuda = torch.cuda.is_available()


# load datasets

print("[!] preparing dataset...")
TEXT = Field(lower=True, fix_length=seq_len,
             tokenize=list, batch_first=True)
LABEL = Field(sequential=False)
train_data, test_data = IMDB.splits(TEXT, LABEL)
TEXT.build_vocab(train_data)
LABEL.build_vocab(train_data)
train_iter, test_iter = BucketIterator.splits(
        (train_data, test_data), batch_size=batch_size, repeat=True)
vocab_size = len(TEXT.vocab)
print("[TRAIN]:%d (dataset:%d)\t[TEST]:%d (dataset:%d)\t[VOCAB]:%d"
      % (len(train_iter), len(train_iter.dataset),
         len(test_iter), len(test_iter.dataset), vocab_size))

# instantiate models
G = Generator(dim=512, seq_len=seq_len, vocab_size=vocab_size)
D = Discriminator(dim=512, seq_len=seq_len, vocab_size=vocab_size)
optim_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.9))
optim_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.9))

global one, mone
one = torch.tensor(1, dtype=torch.float)
mone = one * -1
if use_cuda:
    G, D = G.cuda(), D.cuda()
    one, mone = one.cuda(), mone.cuda()

train_iter = iter(train_iter)
batch_size = batch_size

with mlflow.start_run():
    log_param("batch_size", batch_size)
    log_param("initial_lr", lr)
    log_param("max_epochs", batchs)
    for b in tqdm(range(1, batchs+1)):
        # (1) Update D network
        for p in D.parameters():  # reset requires_grad
            p.requires_grad = True
        for iter_d in range(critic_iters):  # CRITIC_ITERS
            batch = next(train_iter)
            text, label = batch.text, batch.label
            text = to_onehot(text, vocab_size)
            if use_cuda:
                text = text.cuda()
            real = text
            d_loss, wasserstein = train_discriminator(
                    D, G, optim_D, real, lamb, batch_size, use_cuda)
        # (2) Update G network
        for p in D.parameters():
            p.requires_grad = False  # to avoid computation
        g_loss = train_generator(D, G, optim_G, batch_size, use_cuda)

        try:
            mlflow.log_metric('Discriminator Loss', d_loss.data.item())
            mlflow.log_metric('Generator Loss', g_loss.data.item())
            mlflow.log_metric('Wasserstein distance', wasserstein.data.item())
        except:
            pass

        samples = sample(G, TEXT, 1, seq_len, vocab_size, use_cuda)

        if b % 500 == 0 and b > 1:
            samples = sample(G, TEXT, 1, seq_len, vocab_size, use_cuda)
            with open('text_samples.txt', 'a') as f:
                f.write(f' ------------------------------------------------- \n epoch #{b}  D:{d_loss.data.item()}  G:{g_loss.data.item()}  W:{wasserstein.data.item()}\n sample: {samples[0]}\n ------------------------------------------------- \n')

        if b % 5000 == 0 and b > 1:
            print("[!] saving model")
            if not os.path.isdir(".save"):
                os.makedirs(".save")
            torch.save(G.state_dict(), './.save/wgan_g_%d.pt' % (b))
            torch.save(D.state_dict(), './.save/wgan_d_%d.pt' % (b))