In [None]:
import matplotlib.pyplot as plt
import os,glob
import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from  torchvision import transforms
from torchsummary import summary

In [None]:
# 하이퍼 파라메터
EPOCHS = 5
BATCH_SIZE = 500
NOISE_SIZE = 100

In [None]:
# 출력이미지 폴더
output_foler = './output'
os.makedirs(output_foler, exist_ok=True)
for f in glob.glob(output_foler+"/*"):
  os.remove(f)

In [None]:
# 데이터준비
# 변환방식 지정
trans =  transforms.Compose([
    transforms.ToTensor()
])

In [None]:
from matplotlib.colors import BASE_COLORS
# 학습용데이터 로드
train_loader= DataLoader(
    MNIST('mnist', train=True, download=True, transform=trans),
    batch_size = BATCH_SIZE,
    shuffle=True
)
val_loader= DataLoader(
    MNIST('mnist', train=False, download=True, transform=trans),
    batch_size = BATCH_SIZE
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting mnist/MNIST/raw/train-images-idx3-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting mnist/MNIST/raw/train-labels-idx1-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw



In [7]:
# 감별자 클래스
class discriminator_model(nn.Module):
  def __init__(self):
    super(discriminator_model,self).__init__()
    # mnist는 흑백이고 채널정보가 1차원
    self.conv2d1 = nn.Conv2d(1, out_channels=64, kernel_size=3,bias=False)
    self.conv2d2 = nn.Conv2d(64, out_channels=64, kernel_size=3,bias=False)
    self.conv2d3 = nn.Conv2d(64, out_channels=64, kernel_size=3,bias=False)
    self.conv2d4 = nn.Conv2d(64, out_channels=1, kernel_size=3,bias=False)
    
    self.bn = nn.BatchNorm2d(64)        
    self.sigmoid = nn.Sigmoid()
    self.leakyRelu = nn.LeakyReLU()
  def forward(self, x):
    x = self.leakyRelu(self.conv2d1(x))
    x = self.leakyRelu(self.bn(self.conv2d2(x)))
    x = self.leakyRelu(self.bn(self.conv2d3(x)))
    x = self.sigmoid(self.conv2d4(x))
    return x   

In [8]:
#감별자 생성 함수
def build_discriminator():
  discriminator = discriminator_model()
  summary(discriminator, (1,28,28))
  return discriminator

In [12]:
# 생성자 클래스
class Generator_model(nn.Module):
  def __init__(self):
    super(Generator_model,self).__init__()
    self.deconv1 = nn.ConvTranspose2d(NOISE_SIZE, 32,3)
    self.deconv2 = nn.ConvTranspose2d(32, 16,3)
    self.deconv3 = nn.ConvTranspose2d(16, 32,3)
    self.deconv4 = nn.ConvTranspose2d(32, 1,3)
    
    self.bn1 = nn.BatchNorm2d(32)
    self.bn2 = nn.BatchNorm2d(16)
    self.bn3 = nn.BatchNorm2d(1)

    self.Relu = nn.ReLU()
  def forward(self, x):
    x = self.Relu(self.bn1(self.deconv1(x)))
    x = self.Relu(self.bn2(self.deconv2(x)))
    x = self.Relu(self.bn1(self.deconv3(x)))
    x = self.Relu(self.bn3(self.deconv4(x)))
    return x


In [13]:
#생성자 생성 함수
def build_generator():
  generator = Generator_model()
  summary(generator, (NOISE_SIZE,1,1))
  return generator

In [14]:
discriminator = build_discriminator()
generator = build_generator()
d_optimizer = optim.Adam(discriminator.parameters(),lr = 1e-3)
g_optimizer = optim.Adam(generator.parameters(),lr = 1e-3)
criterion = nn.BCELoss()

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 26, 26]             576
         LeakyReLU-2           [-1, 64, 26, 26]               0
            Conv2d-3           [-1, 64, 24, 24]          36,864
       BatchNorm2d-4           [-1, 64, 24, 24]             128
         LeakyReLU-5           [-1, 64, 24, 24]               0
            Conv2d-6           [-1, 64, 22, 22]          36,864
       BatchNorm2d-7           [-1, 64, 22, 22]             128
         LeakyReLU-8           [-1, 64, 22, 22]               0
            Conv2d-9            [-1, 1, 20, 20]             576
          Sigmoid-10            [-1, 1, 20, 20]               0
Total params: 75,136
Trainable params: 75,136
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 2.22
Params size (MB): 0.29
Estimated Tot

In [15]:
# 학습
# 감별자
def train_discriminator(images):
  # 진짜 이미지로 첫번재 학습
  decide = discriminator(images)
  all_1 = torch.ones(BATCH_SIZE)
  decide = decide.squeeze()
  real_loss = criterion(decide,all_1)

  # 가짜 이미지 생성 - 생성자
  noise = torch.randn((BATCH_SIZE,NOISE_SIZE,1,1))
  fake = generator(noise)

  # 가짜 이미지로 두번재 학습
  decide = discriminator(fake).squeeze()
  all_0 = torch.zeros(BATCH_SIZE)
  fake_loss =  criterion(decide, all_0)

  d_loss = real_loss + fake_loss

  d_optimizer.zero_grad()
  d_loss.backword()
  d_optimizer.step()
  return d_loss


In [16]:
# 생성자
def train_generator():
  # 가짜 이미지 생성
  noise = torch.randn((BATCH_SIZE,NOISE_SIZE,1,1))
  fake = generator(noise)

  # 가짜 이미지로 감별자 속임
  decide = discriminator(fake).squeeze()
  all_1 = torch.ones(BATCH_SIZE)
  g_loss = criterion(decide,all_1)

  g_optimizer.zero_grad()
  g_loss.backword()
  g_optimizer.step()
  return g_loss


In [23]:
# 생성자를 이용해서 샘플 이미지
# 노이즈 벡터를 이용해서 가짜 이미지 생성
def sample(epoch):
  grid = 5
  # 노이즈로 가짜 이미지 생성
  noise = torch.randn((grid*grid, NOISE_SIZE,1,1))

  # 가중치 고정
  with torch.no_grad():
    fake = generator(noise) 
  # plt.imshow(fake,cmap='gray')
  # plt.axis('off')

  for i in range(grid*grid):
    plt.subplot(5,5,i+1)
    plt.imshow(fake[i,0],cmap='gray')
    plt.axis('off')
  
  # 생성된 이미지를 PNG 파일로 저장
  path = f"{output_foler}/img-{epoch}"
  plt.savefig(path)
  plt.close()

In [25]:
sample(100)

In [35]:
def train_GAN():
  for epoch in range(EPOCHS):
    print(f"epoch : {epoch}")
    for images,_ in train_loader:      
      # print(images.shape)  # 500,1,28,28
      d_loss = train_discriminator(images)
      # g_loss = train_generator()
      # print(f"생성자  손실 : {g_loss.item():.4f}")
      # print(f"감별자  손실 : {d_loss.item():.4f}")

In [36]:
train_GAN()

epoch : 0


ValueError: ignored