In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter




In [2]:
class Discriminator(nn.Module):
    def __init__(self,img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim,128),
            nn.LeakyReLU(0.1),
            nn.Linear(128,1),
            nn.Sigmoid(),           
        )
    def forward(self,x):
        x = self.disc(x)
        return x

In [3]:
class Generator(nn.Module):
    def __init__(self,z_dim,img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim,256),
            nn.LeakyReLU(0.1),
            nn.Linear(256,img_dim),
            nn.Tanh(),            
        )
        
    def forward(self,x):
        x = self.gen(x)
        return x
    

In [4]:
device ="cuda:0" if torch.cuda.is_available() else "cpu"
lr = 3e-4 # 0.0003
z_dim = 64
image_dim = 28*28*1
batch_size = 512
num_epochs = 50

disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim,image_dim).to(device)
fixed_noise = torch.randn((batch_size,z_dim)).to(device)
t_transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,),(0.5,))]
    )

dataset = datasets.MNIST('./data',download=True,train=True,transform=t_transforms)
loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)


opt_disc = optim.Adam(disc.parameters(),lr=lr)
opt_gen = optim.Adam(gen.parameters(),lr=lr)
criterion = nn.BCELoss()

writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0




In [5]:
for epoch in range(0,num_epochs):
    for i,(real,_) in enumerate(loader):
        real = real.view(-1,image_dim).to(device)
        batch_size = real.shape[0]
        
        
        ## 판별기 학습
        noise = torch.randn(batch_size,z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        # disc_real = torch.Size([32]) 판별 결과, 0~1사이 값
        lossD_real = criterion(disc_real,torch.ones_like(disc_real))


        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake,torch.zeros_like(disc_fake))
        lossD = (lossD_fake+lossD_real) /2
        
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()
        
        
        # 제너레이터 학습
        output = disc(fake).view(-1)
        lossG = criterion(output,torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        if i == 0:
            print(f"epoch [{epoch}/{num_epochs}] \ ",f"loss d : {lossD:.3f}, lossg: {lossG:.3f}" )
        
        with torch.no_grad():
            fake = gen(fixed_noise).reshape(-1,1,28,28)
            data = real.reshape(-1,1,28,28)
            img_grid_fake = torchvision.utils.make_grid(fake,normalize=True)
            img_grid_real = torchvision.utils.make_grid(data,normalize=True)
            
            writer_fake.add_image("엠니스트 fake",img_grid_fake,global_step=step)
            
            writer_real.add_image("엠니스트 fake",img_grid_real,global_step=step)
                    
            step +=1

epoch [0/50] \  loss d : 0.649, lossg: 0.677
epoch [1/50] \  loss d : 0.400, lossg: 1.059
epoch [2/50] \  loss d : 0.075, lossg: 2.505
epoch [3/50] \  loss d : 0.276, lossg: 1.162
epoch [4/50] \  loss d : 0.458, lossg: 0.830
epoch [5/50] \  loss d : 0.473, lossg: 0.945
epoch [6/50] \  loss d : 0.432, lossg: 1.138
epoch [7/50] \  loss d : 0.577, lossg: 0.813
epoch [8/50] \  loss d : 0.366, lossg: 1.377
epoch [9/50] \  loss d : 0.604, lossg: 0.986
epoch [10/50] \  loss d : 0.495, lossg: 1.117
epoch [11/50] \  loss d : 0.916, lossg: 0.608
epoch [12/50] \  loss d : 0.898, lossg: 0.597
epoch [13/50] \  loss d : 0.387, lossg: 1.387
epoch [14/50] \  loss d : 0.404, lossg: 1.317
epoch [15/50] \  loss d : 0.725, lossg: 0.772
epoch [16/50] \  loss d : 0.906, lossg: 0.555
epoch [17/50] \  loss d : 0.804, lossg: 0.700
epoch [18/50] \  loss d : 0.368, lossg: 1.310
epoch [19/50] \  loss d : 0.206, lossg: 1.952
epoch [20/50] \  loss d : 0.642, lossg: 0.784
epoch [21/50] \  loss d : 0.655, lossg: 0.84