In [71]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.utils import save_image
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

# Loading data
train_dataset = torchvision.datasets.MNIST(root="/deep_takaya/self_study/deep_learning/data", train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root="/deep_takaya/self_study/deep_learning/data", train=False, transform=transforms.ToTensor(), download=True)
image, label = train_dataset[0]

# Difinition of Dataloader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=50, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=50, shuffle=False, num_workers=2)

In [57]:
# Difinition of Network
# discriminatorとgeneratorは別々でクラスを作る
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.z_dims = 100
        
        self.fc1 = nn.Linear(self.z_dims, 256)
        self.fc2 = nn.Linear(256, 512)
        self.fc3 = nn.Linear(512, 1024)
        self.fc4 = nn.Linear(1024, 784)
        self.bn1 = nn.BatchNorm1d(256)
        self.bn2 = nn.BatchNorm1d(512)
        self.bn3 = nn.BatchNorm1d(1024)
    
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), negative_slope=0.2)
        x = self.bn1(x)
        x = F.leaky_relu(self.fc2(x), negative_slope=0.2)
        x = self.bn2(x)
        x = F.leaky_relu(self.fc3(x), negative_slope=0.2)
        x = self.bn3(x)
        x = self.fc4(x)
        return x


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 1)
        
    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = F.leaky_relu(self.fc1(x), negative_slope=0.2)
        x = F.leaky_relu(self.fc2(x), negative_slope=0.2)
        x = self.fc3(x)
        return x

In [58]:
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

# setting networks
gen = Generator().to(device)
dis = Discriminator().to(device)

# setting optimizers
o_gen = optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=0.00001)
o_dis = optim.Adam(dis.parameters(), lr=0.0002, betas=(0.5, 0.999), weight_decay=0.00001)

# setting loss function
criterion = nn.MSELoss()

In [None]:
# training
num_epochs = 20
batch_size = 100
save_interval = 50

half_batch = int(batch_size / 2)

train_loss_list = []
train_acc_list = []
val_loss_list = []
val_acc_list = []

# labels
y_real = torch.ones(half_batch, 1)
y_fake = torch.zeros(half_batch, 1)
y_real = y_real.to(device)
y_fake = y_fake.to(device)

for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    val_loss = 0
    val_acc = 0
    
    D_running_loss = 0
    G_running_loss = 0
    
    
    #train discriminator
    gen.train()
    dis.train()
    
    for i, (images, labels) in enumerate(train_loader):
        # Don't be adapted view()
        real_images, labels = images.to(device), labels.to(device)
        z = torch.randn(half_batch, 100)
        z = z.to(device)
        
        o_gen.zero_grad()
        
        # train generator
        fake_images = gen(z)
        D_fake = dis(fake_images)
        G_loss = criterion(D_fake, y_real)
        G_loss.backward()
        o_gen.step()
        G_running_loss += G_loss.data
        #print("G_loss:" + str(G_loss.data))
        
        o_dis.zero_grad()
        
        # train discriminator
        D_real = dis(real_images)
        D_real_loss = criterion(D_real, y_real)
        
        fake_images = gen(z)
        D_fake = dis(fake_images.detach())
        D_fake_loss = criterion(D_fake, y_fake)
        
        D_loss = D_real_loss + D_fake_loss
        D_loss.backward()
        o_dis.step()
        D_running_loss += D_loss.data
        #print("D_loss:" + str(D_loss.data))
        
        
    # eval generator
    gen.eval()
        
    sample_z = torch.rand((64, 100))
    sample_z = sample_z.to(device)
    samples = gen(sample_z).data.cpu()
    samples = samples.view(-1, 28)
    save_image(samples, "gan_result" + str(epoch+1) + ".png")
        
    G_running_loss /= len(train_loader)
    D_running_loss /= len(train_loader)
    print("epoch" + str(epoch+1))
    print("G_loss:" + str(G_running_loss))
    print("D_loss:" + str(D_running_loss))

epoch1
G_loss:tensor(0.9638, device='cuda:0')
D_loss:tensor(0.1504, device='cuda:0')
epoch2
G_loss:tensor(0.9801, device='cuda:0')
D_loss:tensor(0.1444, device='cuda:0')
epoch3
G_loss:tensor(0.9799, device='cuda:0')
D_loss:tensor(0.1434, device='cuda:0')
epoch4
G_loss:tensor(0.9889, device='cuda:0')
D_loss:tensor(0.1377, device='cuda:0')
epoch5
G_loss:tensor(1.0002, device='cuda:0')
D_loss:tensor(0.1324, device='cuda:0')
epoch6
G_loss:tensor(1.0075, device='cuda:0')
D_loss:tensor(0.1245, device='cuda:0')
epoch7
G_loss:tensor(1.0117, device='cuda:0')
D_loss:tensor(0.1180, device='cuda:0')


In [79]:
save_image(fake_images[0].view(-1, 28), "gan_result.png")

In [78]:
fake_images[0].view(-1, 28).shape

torch.Size([28, 28])

In [21]:
z = torch.randn(64, 100)

In [23]:
z.shape

torch.Size([64, 100])