### Google Drive 마운트

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


### 라이브러리 Import 및 하이퍼파라미터 세팅

In [2]:
import torch.nn as nn
import torch
from tqdm import tqdm
from torchvision.utils import save_image

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LR = 0.0002
BATCH_SIZE = 128
IMAGE_SIZE = 32
IMG_CHANNELS = 1
LATENT_DIM = 100
NUM_EPOCHS = 20

### 모델 선언

In [4]:
class Generator(nn.Module):
  def __init__(self,IMAGE_SIZE=64, LATENT_DIM=100, IMG_CHANNELS=3):
    super(Generator, self).__init__()

    self.init_size = IMAGE_SIZE // 16
    # Fc layer
    # self.l1 = nn.Linear(latent_dim, 1024 * self.init_size ** 2)
    # Fc layer를 Convolution layer로 수정
    self.l1 = nn.ConvTranspose2d(in_channels=LATENT_DIM, out_channels=1024, kernel_size=self.init_size, stride=1, padding=0)

    self.blocks = nn.Sequential(
     self.block(1024, 512, 5, 2, 2, 1),
     self.block(512, 256, 5, 2, 2, 1),
     self.block(256, 128, 5, 2, 2, 1),
     nn.ConvTranspose2d(128, IMG_CHANNELS, 5, 2, 2, 1),
     nn.Tanh(),
    )

    self.initialize_weight()

  def initialize_weight(self):
    for m in self.modules():
      if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)

  def block(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, bias=False):
    result = nn.Sequential(
        nn.BatchNorm2d(in_channels),
        nn.ConvTranspose2d(in_channels,
                           out_channels,
                           kernel_size,
                           stride,
                           padding,
                           output_padding,
                           bias=bias,
        ),
        nn.ReLU(),
    )
    return result

  def forward(self, z):
    output = self.l1(z)
    # reshape (project and reshape을 convolution layer로 대체해서 이 부분도 주석 처리)
    # output = output.view(-1, 1024, 4, 4)
    output = self.blocks(output)
    return output

In [5]:
class Discriminator(nn.Module):
  def __init__(self, IMAGE_SIZE, IMG_CHANNELS):
    super(Discriminator, self).__init__()

    self.disc_size = IMAGE_SIZE // 8

    self.disc_blocks = nn.Sequential(
        nn.Conv2d(IMG_CHANNELS, 128, 5, 2, 2),
        nn.LeakyReLU(0.2),
        self.disc_block(128, 256, 5, 2, 2),
        self.disc_block(256, 512, 5, 2, 2),
        nn.Conv2d(512, 1, self.disc_size),
        nn.Flatten(),
        nn.Sigmoid(),
    )

    self.initialize_weight()

  def initialize_weight(self):
    for m in self.modules():
      if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
        nn.init.normal_(m.weight.data, 0.0, 0.02)

  def disc_block(self, in_channels, out_channels, kernel_size, stride, padding, bias=False):
    result = nn.Sequential(
        nn.BatchNorm2d(in_channels),
        nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride,
            padding,
            bias=bias
        ),
        nn.LeakyReLU(0.2),
    )
    return result

  def forward(self, x):
    return self.disc_blocks(x)


### 데이터 세팅 및 학습준비

In [6]:
import torchvision.transforms as T
from torchvision import datasets

In [7]:
transforms = T.Compose(
  [
    T.Resize(IMAGE_SIZE),
    T.ToTensor(),
    T.Normalize((0.5), (0.5)),
  ]
)

dataset = datasets.MNIST(root="./dataset/", train=True, transform=transforms,
                       download=True)

In [8]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE,
                                         shuffle=True, num_workers=4)

one_batch = next(iter(dataloader))

In [9]:
# 잘 데이터가 불러와졌는지 확인
for image,label in dataloader:
  break

In [11]:
image.shape

torch.Size([128, 1, 32, 32])

In [12]:
gen = Generator(IMAGE_SIZE, LATENT_DIM, IMG_CHANNELS).to(device)
disc = Discriminator(IMAGE_SIZE, IMG_CHANNELS).to(device)

In [13]:
gen_optim = torch.optim.Adam(gen.parameters(), lr=LR, betas=(0.5, 0.999))
disc_optim = torch.optim.Adam(disc.parameters(), lr=LR, betas=(0.5, 0.999))
criterion = nn.BCELoss()

In [14]:
# train모드
gen.train()
disc.train()

Discriminator(
  (disc_blocks): Sequential(
    (0): Conv2d(1, 128, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Conv2d(128, 256, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Conv2d(256, 512, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), bias=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
    (5): Flatten(start_dim=1, end_dim=-1)
    (6): Sigmoid()
  )
)

In [15]:
! mkdir images

### 학습

In [16]:
save_interval = 100
sample_interval = len(dataloader)
save_path_gen_template = '/content/drive/MyDrive/generator_epoch_%d.pth'
save_path_disc_template = '/content/drive/MyDrive/discriminator_epoch_%d.pth'

for epoch in tqdm(range(NUM_EPOCHS)):
  for idx, (real_img, label) in enumerate(dataloader):
    """
    real : 1
    fake : 0
    """
    real_img = real_img.to(device)
    latent_z_batch = torch.rand(BATCH_SIZE, LATENT_DIM, 1, 1).to(device)

    fake_img = gen(latent_z_batch)

    real_disc_pred = disc(real_img)
    real_disc_loss = criterion(real_disc_pred, torch.ones_like(real_disc_pred))

    fake_img_detach = fake_img.detach()
    fake_disc_pred = disc(fake_img_detach)
    fake_disc_loss = criterion(fake_disc_pred, torch.zeros_like(fake_disc_pred))

    disc_loss = (real_disc_loss + fake_disc_loss) / 2

    disc.zero_grad()
    disc_loss.backward()
    disc_optim.step()

    #---------------------------------

    fake_disc_pred_for_gen = disc(fake_img)
    gen_loss = criterion(fake_disc_pred_for_gen, torch.ones_like(fake_disc_pred_for_gen))
    gen.zero_grad()
    gen_loss.backward()
    gen_optim.step()

    #---------------------------------
    
    if idx%100 == 0:
      print(f"Batch {idx}/{len(dataloader)}, Disc_loss: {disc_loss.item()}, Gen_loss: {gen_loss.item()}")

    #---------------------------------

    # GAN의 Inference : Random Latent z에서 실제같은 이미지 만드는 것
    batches_done = epoch * len(dataloader) + idx
    if batches_done % sample_interval == 0:
      save_image(fake_img.data[:25], "images/%d.png" % batches_done, nrow=5, normalize=True)

  if (epoch + 1) % save_interval == 0:
    gen_save_path = save_path_gen_template % (epoch + 1)
    disc_save_path = save_path_disc_template % (epoch + 1)
    torch.save(gen.state_dict(), gen_save_path)
    torch.save(disc.state_dict(), disc_save_path)
    print(f"Epoch {epoch+1} : model saved")


  0%|          | 0/20 [00:00<?, ?it/s]

Batch 0/469, Disc_loss: 0.6994179487228394, Gen_loss: 0.8490980863571167
Batch 100/469, Disc_loss: 0.06756505370140076, Gen_loss: 4.005115509033203
Batch 200/469, Disc_loss: 0.17384669184684753, Gen_loss: 3.1917619705200195
Batch 300/469, Disc_loss: 0.39388054609298706, Gen_loss: 2.8187153339385986
Batch 400/469, Disc_loss: 0.14883139729499817, Gen_loss: 3.281489610671997


  5%|▌         | 1/20 [00:07<02:16,  7.19s/it]

Batch 0/469, Disc_loss: 0.12972311675548553, Gen_loss: 3.0669589042663574
Batch 100/469, Disc_loss: 0.2862199544906616, Gen_loss: 1.2224438190460205
Batch 200/469, Disc_loss: 0.06847476959228516, Gen_loss: 3.8000006675720215
Batch 300/469, Disc_loss: 0.19573067128658295, Gen_loss: 3.616098403930664
Batch 400/469, Disc_loss: 0.08981189131736755, Gen_loss: 3.4183297157287598


 10%|█         | 2/20 [00:13<01:59,  6.63s/it]

Batch 0/469, Disc_loss: 0.10412613302469254, Gen_loss: 4.296972274780273
Batch 100/469, Disc_loss: 0.12356717139482498, Gen_loss: 4.23677921295166
Batch 200/469, Disc_loss: 0.05944022536277771, Gen_loss: 4.885835647583008
Batch 300/469, Disc_loss: 0.04510992765426636, Gen_loss: 4.371728897094727
Batch 400/469, Disc_loss: 0.05656847357749939, Gen_loss: 5.886724948883057


 15%|█▌        | 3/20 [00:19<01:50,  6.49s/it]

Batch 0/469, Disc_loss: 0.24275171756744385, Gen_loss: 4.116342544555664
Batch 100/469, Disc_loss: 0.9987753629684448, Gen_loss: 0.9030971527099609
Batch 200/469, Disc_loss: 0.2797122299671173, Gen_loss: 2.202960968017578
Batch 300/469, Disc_loss: 0.5432805418968201, Gen_loss: 3.112147808074951
Batch 400/469, Disc_loss: 0.38256317377090454, Gen_loss: 3.445456027984619


 20%|██        | 4/20 [00:25<01:41,  6.32s/it]

Batch 0/469, Disc_loss: 0.3171638548374176, Gen_loss: 1.7489773035049438
Batch 100/469, Disc_loss: 0.31689465045928955, Gen_loss: 2.1401424407958984
Batch 200/469, Disc_loss: 0.4648464322090149, Gen_loss: 2.6843981742858887
Batch 300/469, Disc_loss: 0.6510007381439209, Gen_loss: 1.192718744277954
Batch 400/469, Disc_loss: 0.5304203033447266, Gen_loss: 1.4327306747436523


 25%|██▌       | 5/20 [00:32<01:34,  6.31s/it]

Batch 0/469, Disc_loss: 0.2602173089981079, Gen_loss: 2.8634660243988037
Batch 100/469, Disc_loss: 0.3898589015007019, Gen_loss: 1.2956808805465698
Batch 200/469, Disc_loss: 0.41663244366645813, Gen_loss: 1.130537986755371
Batch 300/469, Disc_loss: 0.3341796100139618, Gen_loss: 1.2610831260681152
Batch 400/469, Disc_loss: 0.2136826068162918, Gen_loss: 3.4629456996917725


 30%|███       | 6/20 [00:38<01:26,  6.20s/it]

Batch 0/469, Disc_loss: 0.4282756447792053, Gen_loss: 1.3832038640975952
Batch 100/469, Disc_loss: 0.19297753274440765, Gen_loss: 2.289973497390747
Batch 200/469, Disc_loss: 0.22946426272392273, Gen_loss: 1.6220595836639404
Batch 300/469, Disc_loss: 0.2369879186153412, Gen_loss: 3.3872509002685547
Batch 400/469, Disc_loss: 0.16890579462051392, Gen_loss: 2.7938101291656494


 35%|███▌      | 7/20 [00:44<01:20,  6.22s/it]

Batch 0/469, Disc_loss: 0.15936288237571716, Gen_loss: 1.4269262552261353
Batch 100/469, Disc_loss: 0.26746129989624023, Gen_loss: 2.156959056854248
Batch 200/469, Disc_loss: 0.20583347976207733, Gen_loss: 2.8312883377075195
Batch 300/469, Disc_loss: 0.4907049834728241, Gen_loss: 2.7548866271972656
Batch 400/469, Disc_loss: 0.20907840132713318, Gen_loss: 2.787937641143799


 40%|████      | 8/20 [00:50<01:14,  6.19s/it]

Batch 0/469, Disc_loss: 0.3057018518447876, Gen_loss: 1.8382794857025146
Batch 100/469, Disc_loss: 0.20217958092689514, Gen_loss: 1.6619361639022827
Batch 200/469, Disc_loss: 0.7446690797805786, Gen_loss: 0.9311144351959229
Batch 300/469, Disc_loss: 0.2124226987361908, Gen_loss: 5.096282958984375
Batch 400/469, Disc_loss: 0.8087772130966187, Gen_loss: 6.3036088943481445


 45%|████▌     | 9/20 [00:56<01:08,  6.23s/it]

Batch 0/469, Disc_loss: 0.09174183011054993, Gen_loss: 3.28977370262146
Batch 100/469, Disc_loss: 0.31917303800582886, Gen_loss: 8.224142074584961
Batch 200/469, Disc_loss: 0.2190232127904892, Gen_loss: 4.178989410400391
Batch 300/469, Disc_loss: 0.11744855344295502, Gen_loss: 3.5141072273254395
Batch 400/469, Disc_loss: 0.12562435865402222, Gen_loss: 4.631982803344727


 50%|█████     | 10/20 [01:02<01:01,  6.15s/it]

Batch 0/469, Disc_loss: 0.09373334050178528, Gen_loss: 4.329319953918457
Batch 100/469, Disc_loss: 0.037006013095378876, Gen_loss: 5.170007705688477
Batch 200/469, Disc_loss: 0.06686758995056152, Gen_loss: 3.8150782585144043
Batch 300/469, Disc_loss: 0.06539321690797806, Gen_loss: 4.289289474487305
Batch 400/469, Disc_loss: 0.05909344553947449, Gen_loss: 3.2036094665527344


 55%|█████▌    | 11/20 [01:09<00:55,  6.18s/it]

Batch 0/469, Disc_loss: 0.06427718698978424, Gen_loss: 3.7931809425354004
Batch 100/469, Disc_loss: 0.08220963180065155, Gen_loss: 3.598168134689331
Batch 200/469, Disc_loss: 0.07540608942508698, Gen_loss: 3.48939847946167
Batch 300/469, Disc_loss: 0.09565527737140656, Gen_loss: 2.5125067234039307
Batch 400/469, Disc_loss: 0.05437032878398895, Gen_loss: 4.154994964599609


 60%|██████    | 12/20 [01:15<00:49,  6.16s/it]

Batch 0/469, Disc_loss: 0.08984596282243729, Gen_loss: 3.5852041244506836
Batch 100/469, Disc_loss: 0.050648972392082214, Gen_loss: 5.742857933044434
Batch 200/469, Disc_loss: 0.09427456557750702, Gen_loss: 3.9992218017578125
Batch 300/469, Disc_loss: 0.09822800010442734, Gen_loss: 5.534101486206055
Batch 400/469, Disc_loss: 0.05815998837351799, Gen_loss: 3.3723502159118652


 65%|██████▌   | 13/20 [01:21<00:43,  6.18s/it]

Batch 0/469, Disc_loss: 0.10117679089307785, Gen_loss: 3.1897144317626953
Batch 100/469, Disc_loss: 0.027712387964129448, Gen_loss: 4.35353946685791
Batch 200/469, Disc_loss: 0.017354171723127365, Gen_loss: 4.780834197998047
Batch 300/469, Disc_loss: 0.1028783768415451, Gen_loss: 2.9963183403015137
Batch 400/469, Disc_loss: 0.03973688185214996, Gen_loss: 4.931893348693848


 70%|███████   | 14/20 [01:27<00:36,  6.14s/it]

Batch 0/469, Disc_loss: 0.02958538383245468, Gen_loss: 5.097036361694336
Batch 100/469, Disc_loss: 0.06792496889829636, Gen_loss: 5.148562908172607
Batch 200/469, Disc_loss: 0.01414414867758751, Gen_loss: 5.293487548828125
Batch 300/469, Disc_loss: 0.09367452561855316, Gen_loss: 4.773387432098389
Batch 400/469, Disc_loss: 0.055371593683958054, Gen_loss: 5.9789276123046875


 75%|███████▌  | 15/20 [01:33<00:30,  6.19s/it]

Batch 0/469, Disc_loss: 0.02032606303691864, Gen_loss: 6.184516906738281
Batch 100/469, Disc_loss: 0.05616174638271332, Gen_loss: 5.978659629821777
Batch 200/469, Disc_loss: 0.012250425294041634, Gen_loss: 5.17723274230957
Batch 300/469, Disc_loss: 0.06044817343354225, Gen_loss: 3.993567943572998
Batch 400/469, Disc_loss: 0.06429685652256012, Gen_loss: 3.851451873779297


 80%|████████  | 16/20 [01:39<00:24,  6.16s/it]

Batch 0/469, Disc_loss: 0.031288906931877136, Gen_loss: 4.620327949523926
Batch 100/469, Disc_loss: 0.03279920667409897, Gen_loss: 3.896106719970703
Batch 200/469, Disc_loss: 0.02369607798755169, Gen_loss: 4.644407272338867
Batch 300/469, Disc_loss: 0.03380569815635681, Gen_loss: 6.000030994415283
Batch 400/469, Disc_loss: 0.013360006734728813, Gen_loss: 5.372875213623047


 85%|████████▌ | 17/20 [01:45<00:18,  6.17s/it]

Batch 0/469, Disc_loss: 0.011145888827741146, Gen_loss: 5.095773696899414
Batch 100/469, Disc_loss: 0.05702926963567734, Gen_loss: 5.051337718963623
Batch 200/469, Disc_loss: 0.04109695553779602, Gen_loss: 5.800652027130127
Batch 300/469, Disc_loss: 0.021539580076932907, Gen_loss: 4.917031288146973
Batch 400/469, Disc_loss: 0.01149340532720089, Gen_loss: 5.172658920288086


 90%|█████████ | 18/20 [01:52<00:12,  6.16s/it]

Batch 0/469, Disc_loss: 0.11134392768144608, Gen_loss: 7.264865875244141
Batch 100/469, Disc_loss: 0.01354639045894146, Gen_loss: 5.044279098510742
Batch 200/469, Disc_loss: 0.021752823144197464, Gen_loss: 4.854220867156982
Batch 300/469, Disc_loss: 1.2665455341339111, Gen_loss: 24.23818588256836
Batch 400/469, Disc_loss: 0.02634812705218792, Gen_loss: 4.975741386413574


 95%|█████████▌| 19/20 [01:58<00:06,  6.14s/it]

Batch 0/469, Disc_loss: 0.04852693900465965, Gen_loss: 4.621677398681641
Batch 100/469, Disc_loss: 0.02612009085714817, Gen_loss: 4.485849857330322
Batch 200/469, Disc_loss: 0.06342855095863342, Gen_loss: 7.510012626647949
Batch 300/469, Disc_loss: 0.025521544739603996, Gen_loss: 4.9880523681640625
Batch 400/469, Disc_loss: 0.012526042759418488, Gen_loss: 5.47385311126709


100%|██████████| 20/20 [02:04<00:00,  6.22s/it]


In [17]:
! tar -cvf images.tar ./images

./images/
./images/1407.png
./images/8442.png
./images/7504.png
./images/0.png
./images/469.png
./images/5628.png
./images/3752.png
./images/4221.png
./images/8911.png
./images/7035.png
./images/3283.png
./images/7973.png
./images/2345.png
./images/6566.png
./images/6097.png
./images/5159.png
./images/2814.png
./images/938.png
./images/4690.png
./images/1876.png
