<a href="https://colab.research.google.com/github/karlmaji/pytorch_learning/blob/master/%E5%9F%BA%E4%BA%8EMNIST%E5%AE%9E%E7%8E%B0GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!mkdir img

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision

batch_size = 64
img_size = torch.as_tensor([1,28,28])
noise_dim = 128


class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(
        nn.Linear(noise_dim,128),
        nn.BatchNorm1d(128),
        nn.GELU(),
        nn.Linear(128,256),
        nn.BatchNorm1d(256),
        nn.GELU(),
        nn.Linear(256,512),
        nn.BatchNorm1d(512),
        nn.GELU(),
        nn.Linear(512,img_size.prod()),
        nn.Tanh()
    )
  def forward(self,x):
    x = self.model(x)
    return x.reshape(-1,*img_size)
class Discrimnator(nn.Module):
  def __init__(self):
    super().__init__()
    self.model = nn.Sequential(
        nn.Flatten(start_dim=1),
        nn.Linear(img_size.prod(),512),
        nn.GELU(),
        nn.Linear(512,256),
        nn.GELU(),
        nn.Linear(256,128),
        nn.GELU(),
        nn.Linear(128,64),
        nn.GELU(),
        nn.Linear(64,1),
        nn.Sigmoid()
    )
  def forward(self,x):
    x = self.model(x)
    return x


# Training
dataset = torchvision.datasets.MNIST("mnist_data", train=True, download=True,
                                     transform=torchvision.transforms.Compose(
                                         [
                                             torchvision.transforms.Resize(28),
                                             torchvision.transforms.ToTensor(),
                                             torchvision.transforms.Normalize([0.5], [0.5]),
                                         ]
                                                                             )
                                     )
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

generator = Generator()
discrimnator = Discrimnator()

g_opt = torch.optim.Adam(generator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)
d_opt = torch.optim.Adam(discrimnator.parameters(), lr=0.0003, betas=(0.4, 0.8), weight_decay=0.0001)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
epoch = 200
one = torch.ones(batch_size,1,device=device)
zero = torch.zeros(batch_size,1,device=device)
loss_fn = torch.nn.BCELoss().to(device)
generator = generator.to(device)
discrimnator = discrimnator.to(device)

real_loss_log =[]
fake_loss_log =[]
g_loss_log=[]
d_loss_log=[]

for i in range(epoch):
  for idx,(real_img,_) in enumerate(dataloader):
    noise = torch.randn(batch_size,noise_dim,device=device)
    fake_img = generator(noise)
    # train generator
    g_loss = loss_fn(discrimnator(fake_img),one)
    g_opt.zero_grad()
    g_loss.backward()
    g_opt.step()

    # train discrimnator
    real_img = real_img.to(device)
    fake_loss = loss_fn(discrimnator(fake_img.detach()),zero)
    real_loss = loss_fn(discrimnator(real_img),one)
    d_loss = fake_loss + real_loss
    d_opt.zero_grad()
    d_loss.backward()
    d_opt.step()


    real_loss_log.append(real_loss.item())
    fake_loss_log.append(fake_loss.item())
    d_loss_log.append(d_loss.item())
    g_loss_log.append(g_loss.item())

    # 观察real_loss与fake_loss，同时下降同时达到最小值，并且差不多大，说明D已经稳定了

    if idx % 50 == 0:
        print(f"step:{len(dataloader)*i+idx}, g_loss:{g_loss.item()}, d_loss:{d_loss.item()},real_loss:{real_loss.item()},fake_loss:{fake_loss.item()}")

    if idx % 5000 == 0:
        image = fake_img[:16].data
        torchvision.utils.save_image(image, f"./img/image_{len(dataloader)*i+i}.png", nrow=4)


