In [105]:
import torch
import torchvision
import torch.nn as nn
import numpy as np
from torch.utils.tensorboard import SummaryWriter
log = 0

In [None]:
class Generator(nn.Module):
    def __init__(self, latent, n_class):
        super(Generator, self).__init__()
        self.label = nn.Embedding(n_class, n_class)
        
        self.fc1 = nn.Sequential(
            nn.Linear(n_class, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2, inplace=True),
        
        )
        self.fc2 = nn.Sequential(
            nn.Linear(latent, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2, inplace=True),
        
        )
        self.fc3 = nn.Sequential(
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 784),
            nn.Tanh()
        )
    def forward(self, x, label):
        out = torch.cat((self.fc1(self.label(label)), self.fc2(x)), -1)
        out = self.fc3(out)
        return out

class Discrimitor(nn.Module):
    def __init__(self, n_class):
        super(Discrimitor, self).__init__()
        self.label = nn.Embedding(n_class, n_class)
        
        self.fc1 = nn.Sequential(
            nn.Linear(n_class, 128),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.fc2 = nn.Sequential(
            nn.Linear(784, 128),
            nn.LeakyReLU(0.2, inplace=True),
        )
        
        self.fc3 = nn.Sequential(
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x, label):
        out = torch.cat((self.fc1(self.label(label)), self.fc2(x)), -1)
        out = self.fc3(out)
        return out

In [138]:
train_data = torchvision.datasets.MNIST(root='../dataset', download=False, train=True, transform=torchvision.transforms.ToTensor())
train_iter = torch.utils.data.DataLoader(train_data, batch_size=128, shuffle=True)

In [139]:
latent = 64
lr = 0.002
epochs = 100
n_class = 10

In [140]:
G = Generator(latent, n_class)
D = Discrimitor(n_class)
criterion = nn.BCELoss()
optimizer_G = torch.optim.RMSprop(G.parameters(), lr=lr)
optimizer_D = torch.optim.RMSprop(D.parameters(), lr=lr)

In [None]:
test = torch.normal(0, 1, (32, latent))
test_label = torch.from_numpy(np.random.randint(10, size=(32,)))
print(test_label.reshape(4,8))

In [None]:
log += 1
writer = SummaryWriter(log_dir='cgan/'+str(log))
num = 0
for epoch in range(epochs):
    avg_loss_g = 0.0
    avg_loss_d = 0.0
    for i, (real_img, label) in enumerate(train_iter):
        batch_size = real_img.shape[0]
        real_img = real_img.view(batch_size, -1)
        
        real_label = torch.ones(batch_size, 1)
        fake_label = torch.zeros(batch_size, 1)
        
        z = torch.normal(0, 1, (batch_size, latent))
        
        # 训练判别器
        d_real = D(real_img, label)
        d_real_loss = criterion(d_real, real_label)
        
        fake_img = G(z, label)
        d_fake = D(fake_img, label)
        d_fake_loss = criterion(d_fake, fake_label)
        
        optimizer_D.zero_grad()
        d_loss = (d_real_loss + d_fake_loss)
        d_loss.backward()
        optimizer_D.step()
        
        # 训练生成器
        fake_img = G(z, label)
        d_fake = D(fake_img, label)
        g_loss = criterion(d_fake, real_label)
        
        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()
        
        avg_loss_g += g_loss.item()
        avg_loss_d += d_loss.item()
        
        if (i+1) % 100 == 0:
            print("Epoch:{}, Loss_D:{}, Loss_G:{}".format(epoch, avg_loss_d/100, avg_loss_g/100))
            writer.add_scalar('Loss_D', avg_loss_d/100, num)
            writer.add_scalar('Loss_G', avg_loss_g/100, num)
            avg_loss_g = 0.0
            avg_loss_d = 0.0
            num += 1
    with torch.no_grad():
        s = "Epoch-"+str(epoch)
        show = torch.clamp(G(test, test_label), 0, 1).reshape(32, 1, 28, 28)
        writer.add_images(s, show, 0)
writer.close()

In [None]:
import matplotlib.pyplot as plt

G.eval()
lab = torch.tensor([4])
z = torch.normal(0, 1, (1, latent))
image = G(z, lab)
image = torch.clamp(image, 0, 1)
img = torchvision.transforms.ToPILImage()(image.reshape(28,28))
plt.imshow(img,cmap='gray')