<a href="https://colab.research.google.com/github/muyuuuu/colab/blob/main/GAN/minmaxGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
bs = 100

# MNIST Dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5))])

train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transform, download=True)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

In [3]:
class Generator(nn.Module):
    def __init__(self, g_input_dim, g_output_dim):
        super(Generator, self).__init__()       
        self.fc1 = nn.Linear(g_input_dim, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features*2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features*2)
        self.fc4 = nn.Linear(self.fc3.out_features, g_output_dim)

    # forward method
    def forward(self, x): 
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        return torch.tanh(self.fc4(x))
    
class Discriminator(nn.Module):
    def __init__(self, d_input_dim):
        super(Discriminator, self).__init__()
        self.fc1 = nn.Linear(d_input_dim, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, self.fc1.out_features//2)
        self.fc3 = nn.Linear(self.fc2.out_features, self.fc2.out_features//2)
        self.fc4 = nn.Linear(self.fc3.out_features, 1)
    
    # forward method
    def forward(self, x):
        x = F.leaky_relu(self.fc1(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        return torch.sigmoid(self.fc4(x))

In [4]:
z_dim = 100
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
D = Discriminator(mnist_dim).to(device)



In [5]:
# loss
criterion = nn.BCELoss() 

# optimizer
lr = 0.0002 
G_optimizer = optim.Adam(G.parameters(), lr = lr)
D_optimizer = optim.Adam(D.parameters(), lr = lr)

In [14]:
def D_train(x):
    #=======================Train the discriminator=======================#
    D.zero_grad()

    # train discriminator on real
    x_real, y_real = x.view(-1, mnist_dim).to(device), torch.ones(bs, 1).to(device)

    D_output = D(x_real)
    D_real_loss = criterion(D_output, y_real)
    D_real_score = D_output

    # train discriminator on facke
    z = torch.randn(bs, z_dim).to(device)
    x_fake, y_fake = G(z), torch.zeros(bs, 1).to(device)

    D_output = D(x_fake)
    D_fake_loss = criterion(D_output, y_fake)
    D_fake_score = D_output

    # gradient backprop & optimize ONLY D's parameters
    D_loss = D_real_loss + D_fake_loss
    D_loss.backward()
    D_optimizer.step()

    return  D_loss.data.item()

In [15]:
def G_train(x):
    #=======================Train the generator=======================#
    G.zero_grad()

    z = torch.randn(bs, z_dim).to(device)
    y = torch.ones(bs, 1).to(device)

    G_output = G(z)
    D_output = D(G_output)
    G_loss = criterion(D_output, y)

    # gradient backprop & optimize ONLY G's parameters
    G_loss.backward()
    G_optimizer.step()
        
    return G_loss.data.item()

In [19]:
n_epoch = 200
for epoch in range(1, n_epoch+1):           
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        for i in range(2):
            D_train(x)
        D_losses.append(D_train(x))
        G_losses.append(G_train(x))

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))
    
    if (epoch % 10 == 0):
        with torch.no_grad():
            test_z = torch.randn(bs, z_dim).to(device)
            generated = G(test_z)

            save_image(generated.view(generated.size(0), 1, 28, 28), 'sample_3_{}'.format(epoch) + '.png')

[1/200]: loss_d: 1.183, loss_g: 1.026
[2/200]: loss_d: 1.186, loss_g: 1.020
[3/200]: loss_d: 1.188, loss_g: 1.009
[4/200]: loss_d: 1.176, loss_g: 1.033
[5/200]: loss_d: 1.176, loss_g: 1.027
[6/200]: loss_d: 1.175, loss_g: 1.034
[7/200]: loss_d: 1.160, loss_g: 1.065
[8/200]: loss_d: 1.159, loss_g: 1.068
[9/200]: loss_d: 1.156, loss_g: 1.061
[10/200]: loss_d: 1.158, loss_g: 1.066
[11/200]: loss_d: 1.153, loss_g: 1.086
[12/200]: loss_d: 1.150, loss_g: 1.082
[13/200]: loss_d: 1.153, loss_g: 1.076
[14/200]: loss_d: 1.140, loss_g: 1.105
[15/200]: loss_d: 1.155, loss_g: 1.070
[16/200]: loss_d: 1.140, loss_g: 1.095
[17/200]: loss_d: 1.132, loss_g: 1.117
[18/200]: loss_d: 1.131, loss_g: 1.117
[19/200]: loss_d: 1.142, loss_g: 1.097
[20/200]: loss_d: 1.139, loss_g: 1.114
[21/200]: loss_d: 1.144, loss_g: 1.097
[22/200]: loss_d: 1.134, loss_g: 1.112
[23/200]: loss_d: 1.133, loss_g: 1.115
[24/200]: loss_d: 1.140, loss_g: 1.097
[25/200]: loss_d: 1.131, loss_g: 1.119
[26/200]: loss_d: 1.144, loss_g: 1

KeyboardInterrupt: ignored

In [20]:
!zip image.zip *.png

  adding: sample_100.png (deflated 3%)
  adding: sample_10.png (deflated 3%)
  adding: sample_20.png (deflated 3%)
  adding: sample_30.png (deflated 4%)
  adding: sample_3_100.png (deflated 4%)
  adding: sample_3_10.png (deflated 4%)
  adding: sample_3_110.png (deflated 4%)
  adding: sample_3_120.png (deflated 4%)
  adding: sample_3_130.png (deflated 3%)
  adding: sample_3_140.png (deflated 4%)
  adding: sample_3_20.png (deflated 3%)
  adding: sample_3_30.png (deflated 4%)
  adding: sample_3_40.png (deflated 4%)
  adding: sample_3_50.png (deflated 4%)
  adding: sample_3_60.png (deflated 4%)
  adding: sample_3_70.png (deflated 4%)
  adding: sample_3_80.png (deflated 4%)
  adding: sample_3_90.png (deflated 4%)
  adding: sample_40.png (deflated 3%)
  adding: sample_50.png (deflated 3%)
  adding: sample_60.png (deflated 3%)
  adding: sample_70.png (deflated 3%)
  adding: sample_80.png (deflated 3%)
  adding: sample_90.png (deflated 3%)


In [21]:
!ls

image.zip	  sample_3_10.png   sample_3_40.png  sample_50.png
mnist_data	  sample_3_110.png  sample_3_50.png  sample_60.png
sample_100.png	  sample_3_120.png  sample_3_60.png  sample_70.png
sample_10.png	  sample_3_130.png  sample_3_70.png  sample_80.png
sample_20.png	  sample_3_140.png  sample_3_80.png  sample_90.png
sample_30.png	  sample_3_20.png   sample_3_90.png  sample_data
sample_3_100.png  sample_3_30.png   sample_40.png
