<a href="https://colab.research.google.com/github/dohyeongkim97/papers/blob/master/gan%26cgan_for_text_augmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

lstm based GAN(generator & discriminator)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

# e.g. of sample data. expressed by simple int
# for use, text_tokenized into int sequence dataset needed
real_data = np.random.randint(0, 10, (100, 10))
vocab_size = 10  # token size

# generator model
class Generator(nn.Module):
    def __init__(self, vocab_size, embedding_dim=10, hidden_dim=20):
        super(Generator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, noise):
        embedded = self.embedding(noise)
        lstm_out, _ = self.lstm(embedded)
        output = self.fc(lstm_out)
        return output

# discriminator model by sequence
class Discriminator(nn.Module):
    def __init__(self, vocab_size, embedding_dim=10, hidden_dim=20):
        super(Discriminator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, sequence):
        embedded = self.embedding(sequence)
        lstm_out, _ = self.lstm(embedded)
        lstm_out = lstm_out[:, -1, :]  # use only  final sequence
        output = self.sigmoid(self.fc(lstm_out))
        return output

# reset hyper_parametre and model
embedding_dim = 10
hidden_dim = 20
G = Generator(vocab_size, embedding_dim, hidden_dim)
D = Discriminator(vocab_size, embedding_dim, hidden_dim)

# loss and optimisation function
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=0.001)
optimizer_D = optim.Adam(D.parameters(), lr=0.001)

# gan train loup
num_epochs = 1000
batch_size = 16

for epoch in range(num_epochs):
    # data preparence
    real_labels = torch.ones(batch_size, 1)
    fake_labels = torch.zeros(batch_size, 1)
    real_data_batch = torch.tensor(real_data[np.random.randint(0, len(real_data), batch_size)], dtype=torch.long)

    # generator train
    noise = torch.randint(0, vocab_size, (batch_size, 10), dtype=torch.long)  # random sequence input
    generated_data = G(noise).argmax(dim=-1)  # sequence by generator
    fake_output = D(generated_data)  # evaluation by discriminator
    loss_G = criterion(fake_output, real_labels)  # calc loss which generate tries to cheat

    optimizer_G.zero_grad()
    loss_G.backward()
    optimizer_G.step()

    # discriminator train
    real_output = D(real_data_batch)
    fake_output = D(generated_data.detach())  # discriminator train by detach generated data
    loss_D_real = criterion(real_output, real_labels)
    loss_D_fake = criterion(fake_output, fake_labels)
    loss_D = (loss_D_real + loss_D_fake) / 2

    optimizer_D.zero_grad()
    loss_D.backward()
    optimizer_D.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch}/{num_epochs} - Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")

transformer based GAN(generator & disriminator)

In [None]:
# hyper_parametre set up
vocab_size = 10
embedding_dim = 16
hidden_dim = 32
seq_length = 10

# generator: Transformer based sequence generating
class TransformerGenerator(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(TransformerGenerator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.transformer = nn.Transformer(d_model=embedding_dim, nhead=4, num_encoder_layers=2, num_decoder_layers=2)
        self.fc = nn.Linear(embedding_dim, vocab_size)

    def forward(self, noise):
        embedded = self.embedding(noise)
        embedded = embedded.permute(1, 0, 2)  # (batch, seq, embed) -> (seq, batch, embed)
        transformer_out = self.transformer(embedded, embedded)
        output = self.fc(transformer_out.permute(1, 0, 2))
        return output

# discriminator: Transformer based sequence discriminating
class TransformerDiscriminator(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(TransformerDiscriminator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.transformer = nn.Transformer(d_model=embedding_dim, nhead=4, num_encoder_layers=2)
        self.fc = nn.Linear(embedding_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, sequence):
        embedded = self.embedding(sequence).permute(1, 0, 2)
        transformer_out = self.transformer(embedded, embedded)
        transformer_out = transformer_out[-1, :, :]  # only use the last output
        output = self.sigmoid(self.fc(transformer_out))
        return output

# reset model
G = TransformerGenerator(vocab_size, embedding_dim, hidden_dim)
D = TransformerDiscriminator(vocab_size, embedding_dim, hidden_dim)

CGAN(Conditional GAN)

In [None]:
vocab_size = 10  # numbers of words types
embedding_dim = 16
hidden_dim = 32
seq_length = 10  # seq length

# Generator: Transformer based, label input contains
class ConditionalTransformerGenerator(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(ConditionalTransformerGenerator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.label_embedding = nn.Embedding(2, embedding_dim)  # binary label(0, 1) embedding
        self.transformer = nn.Transformer(d_model=embedding_dim, nhead=4, num_encoder_layers=2, num_decoder_layers=2)
        self.fc = nn.Linear(embedding_dim, vocab_size)

    def forward(self, noise, labels):
        embedded = self.embedding(noise)  # sequence embeddinig
        label_embedded = self.label_embedding(labels).unsqueeze(1)  # label embedding, dimention fittinig
        conditional_input = embedded + label_embedded  # generating conditional input by embedding summation
        transformer_out = self.transformer(conditional_input.permute(1, 0, 2), conditional_input.permute(1, 0, 2))
        output = self.fc(transformer_out.permute(1, 0, 2))
        return output

# Discriminator: Transformer based, label input contains
class ConditionalTransformerDiscriminator(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super(ConditionalTransformerDiscriminator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.label_embedding = nn.Embedding(2, embedding_dim)
        self.transformer = nn.Transformer(d_model=embedding_dim, nhead=4, num_encoder_layers=2)
        self.fc = nn.Linear(embedding_dim, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, sequence, labels):
        embedded = self.embedding(sequence).permute(1, 0, 2)
        label_embedded = self.label_embedding(labels).unsqueeze(1).expand(-1, seq_length, -1).permute(1, 0, 2)
        conditional_input = embedded + label_embedded
        transformer_out = self.transformer(conditional_input, conditional_input)
        output = self.sigmoid(self.fc(transformer_out[-1, :, :]))  # use last output
        return output

# model reset
G = ConditionalTransformerGenerator(vocab_size, embedding_dim, hidden_dim)
D = ConditionalTransformerDiscriminator(vocab_size, embedding_dim, hidden_dim)

# example train data
real_data = torch.randint(0, vocab_size, (32, seq_length), dtype=torch.long)
labels = torch.randint(0, 2, (32,), dtype=torch.long)

# example for train loop
optimizer_G = optim.Adam(G.parameters(), lr=0.001)
optimizer_D = optim.Adam(D.parameters(), lr=0.001)
criterion = nn.BCELoss()

# Generator& Discriminator train
for epoch in range(1000):
    # real data training
    real_labels = torch.ones(32, 1)
    fake_labels = torch.zeros(32, 1)

    # 1. discriminator train
    noise = torch.randint(0, vocab_size, (32, seq_length), dtype=torch.long)
    fake_data = G(noise, labels).argmax(dim=-1)  # generated false data
    real_output = D(real_data, labels)
    fake_output = D(fake_data.detach(), labels)

    loss_D = (criterion(real_output, real_labels) + criterion(fake_output, fake_labels)) / 2
    optimizer_D.zero_grad()
    loss_D.backward()
    optimizer_D.step()

    # 2. generator train
    fake_output = D(fake_data, labels)
    loss_G = criterion(fake_output, real_labels)  # loss that generator tries to cheat

    optimizer_G.zero_grad()
    loss_G.backward()
    optimizer_G.step()

    if epoch % 100 == 0:
        print(f"Epoch {epoch} - Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")