In [1]:
import os
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torch.autograd import Variable
from tqdm import tqdm
from sklearn.preprocessing import OneHotEncoder

In [2]:
# GPU Device
gpu_id = '2'
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
use_cuda = torch.cuda.is_available()
print("GPU device %s:" %(gpu_id), use_cuda)

GPU device 2: True


In [3]:
epochs = 200

In [4]:
# Normalize data with mean=0.5, std=1.0
mnist_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.
    transforms.Normalize(mean=(0.5,), std=(0.5,)),
])

In [5]:
download_root = './MNIST_DATASET'

train_dataset = MNIST(download_root, transform=mnist_transform, train=True, download=True)
valid_dataset = MNIST(download_root, transform=mnist_transform, train=False, download=True)
test_dataset = MNIST(download_root, transform=mnist_transform, train=False, download=True)

In [6]:
batch_size = 64

train_loader = DataLoader(dataset=train_dataset, 
                         batch_size=batch_size,
                         shuffle=True,
                         drop_last=True)

valid_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size,
                         shuffle=True)

test_loader = DataLoader(dataset=test_dataset, 
                         batch_size=batch_size,
                         shuffle=True)

In [7]:
onehot = OneHotEncoder(10)

In [8]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.G_encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=2, padding=2, bias=False),
            nn.InstanceNorm2d(32),
            nn.ReLU(inplace=False),
            
            nn.Conv2d(32, 64, kernel_size=5, stride=2, padding=2, bias=False),
            nn.InstanceNorm2d(64),
            nn.ReLU(inplace=False),
            
            nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2, bias=False),
            nn.InstanceNorm2d(128),
            nn.ReLU(inplace=False)
        )
        
        
        self.G_decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=5, stride=2, padding=2, bias=False),
            nn.InstanceNorm2d(64),
            nn.LeakyReLU(0.2, inplace=False),
            
            nn.ConvTranspose2d(64, 32, kernel_size=5, stride=2, padding=2, bias=False),
            nn.InstanceNorm2d(32),
            nn.LeakyReLU(0.2, inplace=False),
            
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=0, bias=False),
            nn.InstanceNorm2d(1),
            nn.Tanh()
        )
        
        self.label = nn.Sequential(
            nn.Linear(10, 2048),
            nn.LeakyReLU(0.2),
        )
    
    def forward(self, x, y):
        encode = self.G_encoder(x)
        bottleneck = encode.view(encode.size(0), -1)
        label_encode = self.label(y)
        bottleneck += label_encode
        bottleneck = bottleneck.view(bottleneck.size(0), 128, 4, 4)
        out = self.G_decoder(bottleneck)
        return out

In [9]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.D = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(4),
            nn.LeakyReLU(0.2, inplace=False),
            
            nn.Conv2d(4, 8, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(8),
            nn.LeakyReLU(0.2, inplace=False),
            
            nn.Conv2d(8, 16, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, inplace=False),
            
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
        
            nn.AdaptiveAvgPool2d(1)
        )
        
        self.label = nn.Sequential(
            nn.Linear(10, 32),
            nn.LeakyReLU(0.2)
        )
        
        self.sigmoid = nn.Sigmoid()
        
        self.top = nn.Linear(32, 1)
        
    def forward(self, x, y):
        out = self.D(x)
        out = out.view(out.size(0), -1)
        
        label = self.label(y)
        out += label
        out = self.top(out)
        out = self.sigmoid(out)
        return out

In [10]:
D = Discriminator().cuda()
G = Generator().cuda()

criterion = nn.BCELoss()
lr = 0.0002
D_optim = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))
G_optim = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))

In [11]:
def train_D(D, x, real_label, fake_images, fake_labels, y):
    D.zero_grad()
    outputs = D(x, y)
    real_loss = criterion(outputs, real_labels)
    real_score = outputs
    
    outputs = D(fake_images, y)
    fake_loss = criterion(outputs, fake_labels)
    fake_score = fake_loss
    
    d_loss = real_loss + fake_loss
    d_loss.backward(retain_graph=True)
    D_optim.step()
    return d_loss, real_score, fake_score

def train_G(G, D_outputs, real_labels, y):
    G.zero_grad()
    g_loss = criterion(D_outputs, real_labels)
    g_loss.backward()
    G_optim.step()
    return g_loss

In [None]:

for epoch in range(epochs):
    bar = tqdm(total=len(train_loader), leave=False)
    D_loss, G_loss = 0.0, 0.0

    for batch_idx, (images, target) in enumerate(train_loader):
        images = images.cuda()
        target = torch.tensor(onehot.fit_transform(target.reshape([-1, 1]), y=10).toarray(), dtype=torch.float32)
        target = target.cuda()
        
        fake_images = G(images, target)

        real_labels = torch.ones(batch_size).cuda()
        fake_labels = torch.zeros(batch_size).cuda()

        d_loss, real_score, fake_score = train_D(D, images, real_labels, fake_images, fake_labels, target)
        D_loss += d_loss
        outputs = D(fake_images, target)
        g_loss = train_G(G, outputs, real_labels, target)
        G_loss += g_loss
        bar.set_description("Epochs {:d}/{:d}, D_loss {:f}, G_loss {:f}".format(epoch, epochs, D_loss/batch_idx, G_loss/batch_idx),refresh=True)
        bar.update()
    bar.close()

    
        
        
    
    

  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
Epochs0/200, D_loss1.414556, G_loss0.669459:  86%|████████▋ | 809/937 [00:27<00:03, 35.27it/s]

In [None]:
outputs = D(images, target)

In [None]:
images.shape