#TransGAN

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class Attention(nn.Module):
    def __init__(self, D, heads=8):
        super().__init__()
        self.D = D
        self.heads = heads

        assert (D % heads == 0), "Embedding size should be divisble by number of heads"
        self.head_dim = self.D // heads

        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.H = nn.Linear(self.D, self.D)

    def forward(self, Q, K, V, mask):
        batch_size = Q.shape[0]
        q_len, k_len, v_len = Q.shape[1], K.shape[1], V.shape[1]

        Q = Q.reshape(batch_size, q_len, self.heads, self.head_dim)
        K = K.reshape(batch_size, k_len, self.heads, self.head_dim)
        V = V.reshape(batch_size, v_len, self.heads, self.head_dim)

        # performing batch-wise matrix multiplication
        raw_scores = torch.einsum("bqhd,bkhd->bhqk", [Q, K])

        # shut off triangular matrix with very small value
        scores = raw_scores.masked_fill(mask == 0, -np.inf) if mask else raw_scores

        attn = torch.softmax(scores / np.sqrt(self.D), dim=3)
        attn_output = torch.einsum("bhql,blhd->bqhd", [attn, V])
        attn_output = attn_output.reshape(batch_size, q_len, self.D)

        output = self.H(attn_output)

        return output

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, D, heads, p, fwd_exp):
        super().__init__()
        self.mha = Attention(D, heads)
        self.drop_prob = p
        self.n1 = nn.LayerNorm(D)
        self.n2 = nn.LayerNorm(D)
        self.mlp = nn.Sequential(
            nn.Linear(D, fwd_exp*D),
            nn.ReLU(),
            nn.Linear(fwd_exp*D, D),
        )
        self.dropout = nn.Dropout(p)

    def forward(self, Q, K, V, mask):
        attn = self.mha(Q, K, V, mask)

        """
        Layer normalisation with residual connections
        """
        x = self.n1(attn + Q)
        x = self.dropout(x)
        forward = self.mlp(x)
        x = self.n2(forward + x)
        out = self.dropout(x)

        return out

class MLP(nn.Module):
    def __init__(self, noise_w, noise_h, channels):
        super().__init__()
        self.l1 = nn.Linear(
                    noise_w*noise_h*channels, 
                    (8*8)*noise_w*noise_h*channels, 
                    bias=False
                )

    def forward(self, x):
        out = self.l1(x)
        return out

class Embedding(nn.Module):
    def __init__(self, emb_w, emb_h, channels):
        super().__init__()
        self.l1 = nn.Linear(
                    emb_w*emb_h*channels, 
                    (8*8)*emb_w*emb_h*channels, 
                    bias=False
                )

    def forward(self, x):
        out = self.l1(x)
        return out

class PixelShuffle(nn.Module):
    def __init__(self):
        super().__init__()
        pass

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = MLP(32, 32, 1)
        self.emb = Embedding(384, 1, 1)
        
        # stage 1
        self.s1_enc = nn.ModuleList([
                        EncoderBlock(1024*8*8)
                        for _ in range(5)
                    ])

        # stage 2
        self.s2_pix_shuffle = PixelShuffle()
        self.s2_enc = nn.ModuleList([
                        EncoderBlock(256*16*16)
                        for _ in range (4)
                    ])

        # stage 3
        self.s3_pix_shuffle = PixelShuffle()
        self.s3_enc = nn.ModuleList([
                        EncoderBlock(64*32*32)
                        for _ in range(2)
                    ])

        # stage 4
        self.linear = nn.Linear(32*32*64, 32*32*3)

    def forward(self, noise, embedding):
        x = self.mlp(noise)
        embedding = self.emb(embedding)
        x = torch.cat([x,embedding],1)
        for layer in self.s1_enc:
            x = layer(x)
        
        x = self.s2_pix_shuffle(x)
        for layer in self.s2_enc:
            x = layer(x)

        x - self.s3_pix_shuffle(x)
        for layer in self.s3_enc:
            x = layer(x)

        text = self.linear(x)

        return text

class Discriminator(nn.Module):
    def __init__(self, max_len_word,out_num):
        super().__init__()

        self.l1 = nn.Linear(max_len_word*256, (8*8+1)*384)
        self.s2_enc = nn.ModuleList([
                        EncoderBlock((8*8+1)*284)
                        for _ in range(7)
                    ])

        self.classification_head = nn.Linear(1*384, out_num)

    def forward(self, text):
        x = self.l1(text)
        for layer in self.s2_enc:
            x = layer(x)

        logits = self.classification_head(x)
        pred = F.softmax(logits)
        embedding = x
        return pred,embedding

In [None]:
# Embedding训练
def train_classifier(trainloader, D, D_optimizer, loss_func, device):

    # set train mode
    D.train()
    
    D_total_loss = 0
    
    
    for i, x in enumerate(trainloader):
        # real label and fake label
        y_real = torch.ones(x.size(0), 1).to(device).float()
        # batch_size个真实数据
        x = x.to(device)

        # update D network
        # D optimizer zero grads
        D_optimizer.zero_grad()
        
        # D real loss from real images
        d_real = D(x)[0]
        d_real_loss = loss_func(d_real.float(), y_real)
        d_real_loss.backward()
        D_optimizer.step()

        D_total_loss += d_loss.item()
    
    return D_total_loss / len(trainloader)

In [None]:
# Embedding 训练
max_len_word = 32

# Adam lr and betas
learning_rate = 1e-4
betas = (0.5, 0.999)

batch_size = 10

device = torch.device('cpu')
train_value = Variable(torch.from_numpy('''训练数据'''))
trainloader = torch.utils.data.DataLoader(train_value, batch_size=batch_size, shuffle=True)
bceloss = nn.BCELoss().to(device).double()

D = Discriminator(max_len_word,2).to(device).double()
# G and D optimizer, use Adam or SGD
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas)

n_epochs = 1800

d_loss_hist = []
train_classifier(trainloader, D, D_optimizer, loss_func, device)

In [None]:
# GAN训练
def train(trainloader, G, D, G_optimizer, D_optimizer, loss_func, device):

    # set train mode
    D.train()
    G.train()
    
    D_total_loss = 0
    G_total_loss = 0
    
    
    for i, x in enumerate(trainloader):
        # real label and fake label
        y_real = torch.ones(x.size(0), 1).to(device).float()
        y_fake = torch.zeros(x.size(0), 1).to(device).float()
        # batch_size个真实数据
        x = x.to(device)

        # update D network
        # D optimizer zero grads
        D_optimizer.zero_grad()
        
        # D real loss from real images
        d_real = D(x)[0]
        d_real_loss = loss_func(d_real.float(), y_real)
        
        # D fake loss from fake images generated by G
        g_z = G()
        d_fake = D(g_z)[0]
        d_fake_loss = loss_func(d_fake.float(), y_fake)
        
        # D backward and step
        d_loss = d_real_loss + d_fake_loss
        d_loss.backward()
        D_optimizer.step()

        # update G network
        # G optimizer zero grads
        G_optimizer.zero_grad()
        
        # G loss
        g_z = G()
        d_fake = D(g_z)[0]
        g_loss = loss_func(d_fake.float(), y_real)
        
        # G backward and step
        g_loss.backward()
        G_optimizer.step()
        
        D_total_loss += d_loss.item()
        G_total_loss += g_loss.item()
    
    return D_total_loss / len(trainloader), G_total_loss / len(trainloader)

In [None]:
# cTransGAN 训练
max_len_word = 32

# Adam lr and betas
learning_rate = 1e-4
betas = (0.5, 0.999)

batch_size = 10

device = torch.device('cpu')
train_value = Variable(torch.from_numpy('''训练数据'''))
trainloader = torch.utils.data.DataLoader(train_value, batch_size=batch_size, shuffle=True)
bceloss = nn.BCELoss().to(device).double()

# G and D model
G = Generator().to(device).double()
D = Discriminator(max_len_word,2).to(device).double()
# G and D optimizer, use Adam or SGD
G_optimizer = optim.Adam(G.parameters(), lr=learning_rate, betas=betas)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas)

n_epochs = 1800

d_loss_hist = []
g_loss_hist = []
train(trainloader, G, D, G_optimizer, D_optimizer, loss_func, device, z_dim)