In [1]:
import os
from tqdm import tnrange, tqdm_notebook, tqdm
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.utils import save_image, make_grid
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from tensorboardX import SummaryWriter

In [2]:
from matplotlib import rcParams
rcParams['figure.figsize'] = (12, 8)

%matplotlib inline

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
batch_size = 128
num_epochs = 1000

z_dimension = 100
num_feature_x1 = 192
num_feature_x2 = 192

In [5]:
device_ids = [0, 1]

In [6]:
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5, .5, .5), (.5, .5, .5))
])

dataset = datasets.ImageFolder('./datas/faces', transform=img_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

In [7]:
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

In [8]:
class Discriminator(nn.Module): # b 3 96 96
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.LeakyReLU(.2, True),
            nn.AvgPool2d(2, 2), 
        ) # b 32 48 48
        
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 3, padding=1),
            nn.LeakyReLU(.2, True),
            nn.AvgPool2d(2, 2),
        ) # b 64 24 24
        
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 64, 3, padding=1),
            nn.LeakyReLU(.2, True),
            nn.AvgPool2d(4, 4),
        ) # b 64 6 6
        
        self.fc = nn.Sequential(
            nn.Linear(64 * 6 * 6, 1024),
            nn.LeakyReLU(.2, True),
            nn.Linear(1024, 1),
            nn.Sigmoid(),
        ) # b 1
        
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
    
    def forward(self, x): # b 1 28 28
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.conv3(out)
        
        out = out.view(x.size(0), -1)
        return self.fc(out)

In [9]:
class Generator(nn.Module):
    def __init__(self, inp_dim, num_feature_x1, num_feature_x2):
        super(Generator, self).__init__()
        
        self.num_feature_x1 = num_feature_x1
        self.num_feature_x2 = num_feature_x2
        
        self.fc = nn.Sequential(
            nn.Linear(inp_dim, num_feature_x1 * num_feature_x2),
            # todo
            #nn.Sig
        ) # b h*w
        self.br = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.LeakyReLU(.2, True),
        ) # b 1 192 192
        
        self.downsample1 = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(.2, True),
        ) # b 64 192 192
        
        self.downsample2 = nn.Sequential(
            nn.Conv2d(64, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(.2, True),
        ) # b 32 192 192
        
        self.downsample3 = nn.Sequential(
            nn.Conv2d(32, 3, 3, padding=1, stride=2),
            nn.Tanh(),
        ) # b 3 96 96
        
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
            
    def forward(self, x):
        out = self.fc(x)
        out = out.view(x.size(0), 1, self.num_feature_x1, self.num_feature_x2)
        out = self.br(out)

        out = self.downsample1(out)
        out = self.downsample2(out)
        out = self.downsample3(out)
        return out

In [10]:
d = Discriminator()#.cuda(device_ids[0])
g = Generator(z_dimension, num_feature_x1, num_feature_x2)#.cuda(device_ids[0])

d.weight_init(0.0, 0.02)
g.weight_init(0.0, 0.02)

d = nn.DataParallel(d, device_ids=device_ids).to(device)
g = nn.DataParallel(g, device_ids=device_ids).to(device)


criterion = nn.BCELoss()

d_optimezer = optim.Adam(d.parameters(), lr=2e-4)
# d_optimezer = nn.DataParallel(d_optimezer, device_ids=device_ids)
g_optimezer = optim.Adam(g.parameters(), lr=2e-4)
# g_optimezer = nn.DataParallel(g_optimezer, device_ids=device_ids)

    There is an imbalance between your GPUs. You may want to exclude GPU 0 which
    has less than 75% of the memory or cores of GPU 1. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.


In [None]:
writer = SummaryWriter('./log/cnn_gan_faces_multip_gpu')

In [None]:
total_count = len(dataloader)
for epoch in tqdm_notebook(xrange(num_epochs)):
    
    d_loss_total = .0
    g_loss_total = .0
    _step = epoch * total_count
    for i, (img, _) in enumerate(dataloader):
        
        real_img = img.cuda()
        real_labels = torch.ones(img.size(0), 1).cuda()
        fake_labels = torch.zeros(img.size(0), 1).cuda()
        
        real_out = d(real_img)
        d_loss_real = criterion(real_out, real_labels)
        real_scores = real_out
        
        z = torch.randn(img.size(0), z_dimension).cuda()
        fake_img = g(z)
        fake_out = d(fake_img)
        d_loss_fake = criterion(fake_out, fake_labels)
        fake_scores = fake_out
        
        d_loss = d_loss_real + d_loss_fake
        d_optimezer.zero_grad()
        d_loss.backward()
        d_optimezer.step()
        
        z = torch.randn(img.size(0), z_dimension).cuda()
        fake_img = g(z)
        fake_out = d(fake_img)
        g_loss = criterion(fake_out, real_labels)
        
        g_optimezer.zero_grad()
        g_loss.backward()
        g_optimezer.step()
        
        d_loss_total += d_loss.item() * img.size(0)
        g_loss_total += g_loss.item() * img.size(0)
        
        step = _step + i + 1
        
        if (i + 1) % 100 == 0:
            writer.add_scalar('Discriminator Real Loss', d_loss_real.item(), step)
            writer.add_scalar('Discriminator Fake Loss', d_loss_fake.item(), step)
            writer.add_scalar('Discriminator Loss', d_loss.item(), step)
            writer.add_scalar('Generator Loss', g_loss.item(), step)
        
        
        if (i + 1) % 200 == 0:
            tqdm.write('Epoch [{}/{}], Step: {:6d}, d_loss: {:.6f}, g_loss: {:.6f}, real_scores: {:.6f}' \
', fake_scores: {:.6f}'.format(epoch+1, num_epochs, (i+1) * batch_size, d_loss, g_loss, real_scores.mean(), fake_scores.mean()))
    
    _d_loss_total = d_loss_total / (total_count * (epoch + 1))
    _g_loss_total = g_loss_total / (total_count * (epoch + 1))
    
    setp = (epoch + 1) * total_count
    writer.add_scalar('Discriminator Total Loss', _d_loss_total, step)
    writer.add_scalar('Generator Total Loss', _g_loss_total, step)
    tqdm.write("Finish Epoch [{}/{}], D Loss: {:.6f}, G Loss: {:.6f}".format(epoch+1, 
                                                                             num_epochs, 
                                                                             _d_loss_total,
                                                                             _g_loss_total, ))
    if epoch == 0:
        real_images = real_img.view(-1, 3, 96, 96).cpu().data
        save_image(real_images, './cnn_gan_faces_multip_gpu/real_images.png')

    fake_images = fake_img.view(-1, 3, 96, 96).cpu().data
    save_image(fake_images, './cnn_gan_faces_multip_gpu/fake_images-{}.png'.format(epoch+1))

Epoch [1/1000], Step:  25600, d_loss: 0.239176, g_loss: 4.410129, real_scores: 0.907127, fake_scores: 0.049808
Epoch [1/1000], Step:  51200, d_loss: 1.136935, g_loss: 1.091598, real_scores: 0.541926, fake_scores: 0.350372
Finish Epoch [1/1000], D Loss: 53.956662, G Loss: 483.101966
Epoch [2/1000], Step:  25600, d_loss: 0.165761, g_loss: 3.087519, real_scores: 0.964212, fake_scores: 0.066160
Epoch [2/1000], Step:  51200, d_loss: 0.201467, g_loss: 10.828316, real_scores: 0.978207, fake_scores: 0.037622
Finish Epoch [2/1000], D Loss: 24.980865, G Loss: 219.678303
Epoch [3/1000], Step:  25600, d_loss: 0.383713, g_loss: 3.331735, real_scores: 0.844418, fake_scores: 0.132665
Epoch [3/1000], Step:  51200, d_loss: 0.588216, g_loss: 2.137198, real_scores: 0.803894, fake_scores: 0.199470
Finish Epoch [3/1000], D Loss: 18.686692, G Loss: 151.140767
Epoch [4/1000], Step:  25600, d_loss: 1.256347, g_loss: 2.006586, real_scores: 0.565093, fake_scores: 0.212506
Epoch [4/1000], Step:  51200, d_loss: 1

In [None]:
writer.close()

In [None]:
torch.save(d.state_dict(), './ser/faces_discriminator_m_gpu.pkl')
torch.save(g.state_dict(), './ser/faces_generator_m_gpu.pkl')

In [None]:
z = torch.randn(4, z_dimension).to(device)
images = g(z)
# save_image(images, 'xx.png')
plt.imshow(Image.fromarray(make_grid(images).mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()))
plt.show()