In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

import torch.utils.data
from torch import nn, optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelBinarizer
import imageio

In [None]:
project_root = os.path.realpath('.')
print(project_root)
os.chdir(project_root)

no_cuda = False
cuda_available = not no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if cuda_available else "cpu")
BATCH_SIZE = 64
EPOCH = 100
SEED = 8

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda_available else {}
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./MNIST_data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./MNIST_data', train=False, transform=transforms.ToTensor()),
    batch_size=BATCH_SIZE, shuffle=True, **kwargs)

In [None]:
n_classes = 10
# Model params
g_input_size = 100     # Random noise dimension coming into generator, per output vector

g_output_size = 784    # size of generated output vector

d_input_size = 784   # Minibatch size - cardinality of distributions

d_output_size = 1    # Single dimension for 'real' vs. 'fake'


d_learning_rate = 2e-4  # 2e-4
g_learning_rate = 2e-4
optim_betas = (0.9, 0.999)

print_interval = 200

d_steps = 1  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator
g_steps = 1

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1_1 = nn.Linear(100, 256)
        self.fc1_1_bn = nn.BatchNorm1d(256)
        self.fc1_2 = nn.Linear(10, 256)
        self.fc1_2_bn = nn.BatchNorm1d(256)
        self.fc2 = nn.Linear(512, 512)
        self.fc2_bn = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc3_bn = nn.BatchNorm1d(1024)
        self.fc4 = nn.Linear(1024, 784)


        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
    def forward(self, input, label):
        x = F.relu(self.fc1_1_bn(self.fc1_1(input)))
        y = F.relu(self.fc1_2_bn(self.fc1_2(label)))
        x = torch.cat([x, y], 1)
        x = F.relu(self.fc2_bn(self.fc2(x)))
        x = F.relu(self.fc3_bn(self.fc3(x)))
        x = F.tanh(self.fc4(x))
        
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1_1 = nn.Linear(784, 1024)
        self.fc1_2 = nn.Linear(10, 1024)
        self.fc2 = nn.Linear(2048, 512)
        self.fc2_bn = nn.BatchNorm1d(512)
        self.fc3 = nn.Linear(512, 256)
        self.fc3_bn = nn.BatchNorm1d(256)
        self.fc4 = nn.Linear(256, 1)
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
                
    def forward(self, input, label):
        x = F.leaky_relu(self.fc1_1(input), 0.2)
        y = F.leaky_relu(self.fc1_2(label), 0.2)
        x = torch.cat([x, y], 1)
        x = F.leaky_relu(self.fc2_bn(self.fc2(x)), 0.2)
        x = F.leaky_relu(self.fc3_bn(self.fc3(x)), 0.2)
        x = F.sigmoid(self.fc4(x))
        
        return x

In [None]:

D  = Discriminator().to(device)
G = Generator().to(device)
print(D)
print(G)

In [None]:
criterion = nn.BCELoss()
lr = 0.0002
d_optimizer = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))

In [None]:
mini_batch = 5
y_=torch.LongTensor([3,2,4,1,5])
y_real_ = torch.ones(mini_batch)

y_fake_ = torch.zeros(mini_batch)

y_label_ = torch.zeros(mini_batch, 10)
y_label_.scatter_(1, y_.view(mini_batch, 1), 1)
y_label_

In [None]:
lb = LabelBinarizer()
lb.fit(list(range(0,n_classes)))
   #将标签进行one-hot编码
def to_categrical(y: torch.FloatTensor):
    y_n = y.numpy()
    
    y_one_hot = lb.transform(y_n)
    floatTensor = torch.FloatTensor(y_one_hot)
    return floatTensor

#样本和one-hot标签进行连接，以此作为条件生成
def concanate_data_label(x, y):
#     print(f'dimenion y: {y.shape} {x.shape}')
    y_one_hot = to_categrical(y)
#     print(y_one_hot.shape)
    con = torch.cat((x, y_one_hot), 1)
#     print(con.shape)
    
    return con.to(device)

D_losses = []
G_losses = []

def train(epoch):
    D.train()
    G.train()
    
     # learning rate decay
    if (epoch+1) == 30:
        g_optimizer.param_groups[0]['lr'] /= 10
        d_optimizer.param_groups[0]['lr'] /= 10
        print("learning rate change!")

    if (epoch+1) == 40:
        g_optimizer.param_groups[0]['lr'] /= 10
        d_optimizer.param_groups[0]['lr'] /= 10
        print("learning rate change!")
    
    D_losses.clear()
    G_losses.clear()
    for batch_idx, (data, label) in enumerate(train_loader):
        
        D.zero_grad()
        data_flatten = data.view(data.size(0), d_input_size)
        label_categorical = to_categrical(label)
        
        for d_index in range(d_steps):
            # 1. Train D on real+fake
            #  1A: Train D on real
            
            d_real_decision = D(Variable(data_flatten), label_categorical)
           
            d_real_error = criterion(d_real_decision, Variable(torch.ones(data.shape[0],1)).to(device))  # ones = true
            d_real_error.backward() # compute/store gradients, but don't change params
            
            #  1B: Train D on fake
            d_gen_input = Variable(torch.randn(data.shape[0], g_input_size))
            d_fake_data = G(d_gen_input.to(device), label_categorical).detach()  # detach to avoid training G on these labels
            
            d_fake_decision = D(d_fake_data.to(device), label_categorical)
            d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(data.shape[0], 1)).to(device))  # zeros = fake
            d_fake_error.backward()
            d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()
            
            d_error = d_real_error + d_fake_error
            D_losses.append(d_error.data)

        for g_index in range(g_steps):
        # 2. Train G on D's response (but DO NOT train D on these labels)
            G.zero_grad()

            gen_input = Variable(torch.randn(data.shape[0], g_input_size).to(device))
            g_fake_data = G(gen_input, label_categorical)
            dg_fake_decision = D(g_fake_data.to(device), label_categorical)
            g_error = criterion(dg_fake_decision, Variable(torch.ones(data.shape[0], 1)).to(device))  # we want to fool, so pretend it's all genuine

            g_error.backward()
            g_optimizer.step()  # Only optimizes G's parameters
            
            G_losses.append(g_error.data)
        
            
    print('[%d/%d]: D(x): %.3f, D(G(z)): %.3f' % (
        epoch , EPOCH,np.mean(D_losses), np.mean(G_losses)))

In [None]:
def test(epoch):
    G.eval()
    

In [None]:
def get_samples():
    '''10x10 matrix,row 0-9 represent label 0-9'''
    sample = torch.randn(100, g_input_size) #100x100  
    label = torch.zeros(100, 10)
    
    for i in range(100):
        label[i, i % n_classes] = i // n_classes
    
    return sample.to(device), label.to(device)

In [None]:
11//10

In [None]:
des_path = os.path.join(project_root, 'cgan_new_results/')
if not os.path.exists(des_path):
    os.makedirs(des_path, exist_ok=True)

import math,  itertools
from IPython import display
#for plot
size_figure_grid = 10
fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(6, 6))
for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
    ax[i,j].get_xaxis().set_visible(False)
    ax[i,j].get_yaxis().set_visible(False)

#train
for epoch in range(1, EPOCH + 1):
    display.clear_output(wait=True)
    train(epoch)
#     test(epoch)
    with torch.no_grad():
        sample, label = get_samples()
        sample = G(sample, label).cpu()
        save_image(sample.view(100, 1, 28, 28),
                   f'{des_path}epoch_{epoch}.png', nrow=10)
        
        for k in range(100):
            i = k//10
            j = k%10
            ax[i,j].cla()
            ax[i,j].imshow(sample[k,:].data.cpu().numpy().reshape(28, 28),cmap='Greys')
        
        display.display(plt.gcf())

In [None]:

images = []
for i in range(1, EPOCH+1):
    one_image = f'{des_path}epoch_{i}.png'
    images.append(imageio.imread(one_image))
imageio.mimsave(f'{des_path}cgan_new.gif', images, fps=5)
