<a href="https://colab.research.google.com/github/bronya-y/GAN/blob/main/GAN_mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.utils import save_image
import torchvision
import os

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
batch_size = 1000

In [4]:

# Image processing
transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5], std=[0.5])])
                # transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]) # 3 for RGB channels

# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./',
                                   train=True,
                                   transform=transform,
                                   download=True)

# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,
                                          batch_size=batch_size, 
                                          shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [5]:
class Generator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Generator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.xfer = torch.nn.SELU()

    def forward(self, x):
        x = self.xfer( self.map1(x) )
        x = self.xfer( self.map2(x) )
        return self.xfer( self.map3( x ) )

In [6]:
class Discriminator(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(Discriminator, self).__init__()
        self.map1 = nn.Linear(input_size, hidden_size)
        self.map2 = nn.Linear(hidden_size, hidden_size)
        self.map3 = nn.Linear(hidden_size, output_size)
        self.elu = torch.nn.ELU()

    def forward(self, x):
        # print(x.shape)
        x = torch.flatten(x,1)
        # print(x.shape)
        x = self.elu(self.map1(x))
        x = self.elu(self.map2(x))
        return torch.sigmoid( self.map3(x) )

In [7]:
from torchsummary import summary
gg = Discriminator(784,360,784).cuda()
summary(gg,input_size=(28,28),device='cuda')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                  [-1, 360]         282,600
               ELU-2                  [-1, 360]               0
            Linear-3                  [-1, 360]         129,960
               ELU-4                  [-1, 360]               0
            Linear-5                  [-1, 784]         283,024
Total params: 695,584
Trainable params: 695,584
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.02
Params size (MB): 2.65
Estimated Total Size (MB): 2.67
----------------------------------------------------------------


In [8]:
epochs = 10
G_in = 100 #G的输入长度
G_hid = 520
G_out = 784
D_in = 784  #G_out与D_in应该一样
D_hid = 128
D_out = 1
d_learning_rate = 3e-3
g_learning_rate = 8e-3

d_batch_size = batch_size
g_batch_size = 1000

def get_noise_sampler():
    return lambda m, n: torch.rand(m, n).requires_grad_()  # m*n形状的随机分布数据采样
noise_data  = get_noise_sampler() #采样随机分布的函数

G = Generator(G_in,G_hid,G_out).cuda() #实例化G
D = Discriminator(D_in,D_hid,D_out).cuda() #实例化D

criterion = nn.BCELoss() #损失函数定义
d_optimizer = optim.SGD(D.parameters(), lr=d_learning_rate ) #, betas=optim_betas) #优化器
g_optimizer = optim.SGD(G.parameters(), lr=g_learning_rate ) #, betas=optim_betas) #优化器

In [10]:
for epoch in range(epochs):
  print(epoch)
  for r,y in data_loader:
    d_optimizer.zero_grad()
    r_target = torch.ones((d_batch_size,1)).cuda() #
    r = r.cuda() #
    d_r_pred = D(r)
    d_r_pred = d_r_pred.cuda()
    d_r_loss = criterion(d_r_pred,r_target)
    d_r_loss.backward()
    d_optimizer.step()

    d_optimizer.zero_grad()
    noise = noise_data( d_batch_size, G_in )
    noise = noise.cuda() #
    fake_data = G( noise ) 
    fake_target = torch.zeros( d_batch_size, 1 )
    fake_data = fake_data.cuda() #
    d_g_pred = D( fake_data )
    d_g_pred = d_g_pred.cuda() #
    d_g_loss = criterion( d_g_pred, fake_target.cuda())  # zeros = fake
    d_g_loss.backward()
    d_optimizer.step()

    g_optimizer.zero_grad()
    noise = noise_data( g_batch_size, G_in )
    noise = noise.cuda() #
    fake_data = G( noise )
    fake_data = fake_data.cuda() #
    fake_decision = D( fake_data )
    fake_decision = fake_decision.cuda()
    error = criterion( fake_decision, torch.ones( g_batch_size, 1 ).cuda() )  # we want to fool, so pretend it's all genuine
    error.backward()
    g_optimizer.step()

torch.save(D, "/content/model/D.pth") 
torch.save(G, "/content/model/G.pth")

0
1
2
3
4
5
6
7
8
9


In [11]:
GG =torch.load("/content/model/G.pth")
noisee = noise_data( 1, G_in )
fake_dataa = GG(noise)
nnp =fake_dataa[0]
print(nnp.shape)


torch.Size([784])


In [16]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

nnp = nnp.reshape((28,28))
fake_images = nnp.reshape(1, 1, 28, 28)
save_image(denorm(fake_images), os.path.join('./', 'fake_images-{}.png'.format(epoch+1)))