### 在mnist数据集下的条件GAN
- 条件GAN

- 实现

本例使用mnist数据集，在GAN的基础上将标签y进行one-hot编码，之后和数据样本进行**连接**作为输入，判别器输入为真实样本，每个样本后面加上对应的标签y的one-hot编码，生成器的输入为随机采样样本，每个样本后附加标签y的one-hot编码，以这种方式实现上述的条件概率分布。

- 改进

CGAN_mnist-02.ipynp 将y和原始输入先映射成高维向量A, B，在将A,B进行连接，以此实现条件GAN。

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

import torch.utils.data
from torch import nn, optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelBinarizer
import imageio

In [None]:
project_root = os.path.realpath('.')
print(project_root)
os.chdir(project_root)

no_cuda = False
cuda_available = not no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
BATCH_SIZE = 64
EPOCH = 100
SEED = 8

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda_available else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./MNIST_data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./MNIST_data', train=False, transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)

In [None]:
n_classes = 10 #class number
g_input_size = 100 + n_classes    # Random noise dimension coming into generator, per output vector
g_output_size = 784   # size of generated output vector

d_input_size = 784 + n_classes  # Minibatch size - cardinality of distributions

d_output_size = 1    # Single dimension for 'real' vs. 'fake'


d_learning_rate = 2e-4  # 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)

print_interval = 200

d_steps = 1  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(g_input_size, 256),
            nn.LeakyReLU(0.2),
            
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            
            nn.Linear(1024, g_output_size),
            nn.Tanh()
        )
        
    
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    
    def forward(self, x):
        x = x.view(x.size(0), g_input_size)
        out = self.model(x)
        return out

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(d_input_size, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
    def forward(self, x):
        x = x.view(x.size(0), d_input_size)
        out = self.model(x)
        out = out.view(out.size(0), -1)
        return out

In [None]:

D  = Discriminator().to(device)
G = Generator().to(device)
print(D)
print(G)

In [None]:
criterion = nn.BCELoss()
lr = 0.0002
d_optimizer = torch.optim.Adam(D.parameters(), lr=lr)
g_optimizer = torch.optim.Adam(G.parameters(), lr=lr)

In [None]:
lb = LabelBinarizer()
lb.fit(list(range(0,n_classes)))
   #将标签进行one-hot编码
def to_categrical(y: torch.FloatTensor):
    y_n = y.numpy()
    
    y_one_hot = lb.transform(y_n)
    floatTensor = torch.FloatTensor(y_one_hot)
    return floatTensor

#样本和one-hot标签进行连接，以此作为条件生成
def concanate_data_label(x, y):
#     print(f'dimenion y: {y.shape} {x.shape}')
    y_one_hot = to_categrical(y)
#     print(y_one_hot.shape)
    con = torch.cat((x, y_one_hot), 1)
#     print(con.shape)
    
    return con.to(device)


D_losses = []
G_losses = []    
def train(epoch):
    D.train()
    G.train()
    
    D_losses.clear()
    G_losses.clear()
    for batch_idx, (data, label) in enumerate(train_loader):
        D.zero_grad()
#         print(f'+++++{data.shape}')
        data_flatten = data.view(data.size(0), d_input_size-n_classes)
#         print(f'----flatten data shape: {data_flatten.shape}')
        real_data = concanate_data_label(data_flatten, label)
        
        for d_index in range(d_steps):
            # 1. Train D on real+fake
            #  1A: Train D on real
            d_real_decision = D(Variable(real_data))
#             if d_real_decision.shape[0] != BATCH_SIZE:
#                 print(d_real_decision.shape)
            d_real_error = criterion(d_real_decision, Variable(torch.ones(data.shape[0],1)).to(device))  # ones = true
            d_real_error.backward() # compute/store gradients, but don't change params
            
            #  1B: Train D on fake note: labels are used as a constration
            fake_sample = concanate_data_label(torch.randn(data.shape[0], g_input_size-n_classes), label)
#             print(f'fake sample shape: {fake_sample.shape}')
            d_gen_input = Variable(fake_sample)
                
            d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels
            
            d_fake_data_con = concanate_data_label(d_fake_data, label)
            d_fake_decision = D(d_fake_data_con)
            d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(data.shape[0], 1)).to(device))  # zeros = fake
            d_fake_error.backward()
            d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()
            
            d_error = d_real_error + d_fake_error
            D_losses.append(d_error.data)

        for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
            G.zero_grad()
            #sample label as before
            fake_sample = concanate_data_label(torch.randn(data.shape[0], g_input_size-n_classes), label)
            gen_input = Variable(fake_sample)
            g_fake_data = G(gen_input)
            g_fake_data_con = concanate_data_label(g_fake_data, label)
            
            dg_fake_decision = D(g_fake_data_con)
            g_error = criterion(dg_fake_decision, Variable(torch.ones(data.shape[0], 1)).to(device))  # we want to fool, so pretend it's all genuine

            g_error.backward()
            g_optimizer.step()  # Only optimizes G's parameters
            
            G_losses.append(g_error.data)
        
            
    print('[%d/%d]: D(x): %.3f, D(G(z)): %.3f' % (
        epoch , EPOCH,np.mean(D_losses), np.mean(G_losses)))

In [None]:
def test(epoch):
    G.eval()
    

In [None]:
def get_fixed_sample():
    '''10x10 matrix,row 0-9 represent label 0-9'''
    sample = torch.randn(100, g_input_size) #100x110
    label = torch.arange(n_classes)

    for i in range(100):
        s = torch.FloatTensor([i % n_classes]) #s: 0-9
        s_c = to_categrical(s)
        sample[i, sample.shape[1]-10:] = s_c
    
    return sample.to(device)


# print(fixed_sample.shape)

In [None]:
des_path = os.path.join(project_root, 'cgan_results/')
if not os.path.exists(des_path):
    os.makedirs(des_path, exist_ok=True)

import math,  itertools
from IPython import display

size_figure_grid = 10
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(8, 8)) #10 rows 10 columns
for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
    ax[i,j].get_xaxis().set_visible(False)
    ax[i,j].get_yaxis().set_visible(False)



for epoch in range(1, EPOCH + 1):
    display.clear_output(wait=True)
    train(epoch)
#     test(epoch)
    with torch.no_grad():
        fixed_sample = get_fixed_sample()
        sample = G(fixed_sample).cpu()
       
        
        save_image(sample.view(100, 1, 28, 28),
                   f'{des_path}epoch_{epoch}.png', nrow=10)
        
        for k in range(100):
            i = k//10
            j = k%10
            ax[i,j].cla()
            ax[i,j].imshow(sample[k,:].data.cpu().numpy().reshape(28, 28),cmap='Greys')
        
        display.display(plt.gcf())

In [None]:
# sample_old = G(fixed_sample).cpu()
# print(sample_old.shape)
# sample_new = sample_old[:, 0:sample_old.shape[1]-10]
# print(sample_new.shape)
# v = sample_new.view(sample.size(0), 1, 28, 28)
# v.shape
# ax[1,1].imshow(sample_new[1,:].data.cpu().numpy().reshape(28, 28),cmap='Greys')

In [None]:
images = []
for i in range(1, 38):
    one_image = f'{des_path}epoch_{i}.png'
    images.append(imageio.imread(one_image))
imageio.mimsave(f'{des_path}cgan.gif', images, fps=5)