In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision.utils import save_image

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('device = ', device)

device =  cuda


## Vanilla GAN

In [3]:
## learning hyperparameter

num_epochs = 1000
batch_size = 100
lr = 0.0002
img_size = 28*28
num_channel = 1
dir_name = "GAN_result"

## generator hyperparameter

noise_size = 100
hidden_size1 = 256
hidden_size2 = 512

if not os.path.exists(dir_name):
  os.makedirs(dir_name)

In [4]:
## MNIST

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

MNIST_data = datasets.MNIST(root='./',
                            train=True,
                            transform=transform,
                            download=True)

data_loader = DataLoader(dataset=MNIST_data,
                         batch_size=batch_size,
                         shuffle=True)

## Discriminator

In [5]:
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.linear1 = nn.Linear(img_size, hidden_size2)
    self.linear2 = nn.Linear(hidden_size2, hidden_size1)
    self.linear3 = nn.Linear(hidden_size1, 1)
    self.leaky_relu = nn.LeakyReLU(0.2)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x = self.leaky_relu(self.linear1(x))
    x = self.leaky_relu(self.linear2(x))
    x = self.linear3(x)
    x = self.sigmoid(x)
    return x

## Generator

In [6]:
class Generator(nn.Module):
  def __init__(self):
    super().__init__()

    self.linear1 = nn.Linear(noise_size, hidden_size1)
    self.linear2 = nn.Linear(hidden_size1, hidden_size2)
    self.linear3 = nn.Linear(hidden_size2, img_size)
    self.relu = nn.ReLU()
    self.tanh = nn.Tanh()

  def forward(self, x):
    x = self.relu(self.linear1(x))
    x = self.relu(self.linear2(x))
    x = self.linear3(x)
    x = self.tanh(x)
    return x

In [7]:
discriminator = Discriminator().to(device)
generator = Generator().to(device)

## Vanilla GAN 학습

In [8]:
criterion = nn.BCELoss()
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
g_optimizer = optim.Adam(generator.parameters(), lr=lr)

In [20]:
d_perf_list = []
g_perf_list = []
d_loss_list = []
g_loss_list = []

for epoch in range(num_epochs):
  for i, (images, labels) in enumerate(data_loader): # 100

    real_label = torch.ones((batch_size, 1), dtype=torch.float32).to(device)
    fake_label = torch.zeros((batch_size, 1), dtype=torch.float32).to(device)

    real_images = images.view(batch_size, -1).to(device)

    ## generator training
    z = torch.randn(batch_size, noise_size).to(device)
    fake_images = generator(z)
    g_loss = criterion(discriminator(fake_images), real_label)

    g_optimizer.zero_grad()
    g_loss.backward()
    g_optimizer.step()

    ## discriminator training
    z = torch.randn(batch_size, noise_size).to(device)
    fake_images = generator(z)

    fake_prediction = discriminator(fake_images)
    real_prediction = discriminator(real_images)

    fake_loss = criterion(fake_prediction, fake_label)
    real_loss = criterion(real_prediction, real_label)

    d_loss = (fake_loss + real_loss) / 2

    d_optimizer.zero_grad()
    d_loss.backward()
    d_optimizer.step()

    g_perf = discriminator(fake_images).mean()
    d_perf = discriminator(real_images).mean()

    if (i % 200) == 0:
      print(f"Epoch = [{epoch}/{num_epochs}] Batch = [{i}/{len(data_loader)}]")
      print(f"d_loss = {d_loss.item():.3f}, g_loss = {g_loss.item():.3f}, d_perf = {d_perf:.3f}, g_perf = {g_perf:.3f}")

      d_perf_list.append(d_perf)
      g_perf_list.append(g_perf)
      d_loss_list.append(d_loss.item())
      g_loss_list.append(g_loss.item())

      samples = fake_images.reshape(batch_size, 1, 28, 28)
      save_image(samples, os.path.join(dir_name, f"GAN_fake_{epoch}.png"))

Epoch = [0/1000] Batch = [0/600]
d_loss = 0.687, g_loss = 0.691, d_perf = 0.611, g_perf = 0.505
Epoch = [0/1000] Batch = [200/600]
d_loss = 0.013, g_loss = 4.657, d_perf = 0.996, g_perf = 0.012
Epoch = [0/1000] Batch = [400/600]
d_loss = 0.041, g_loss = 7.510, d_perf = 0.956, g_perf = 0.001
Epoch = [1/1000] Batch = [0/600]
d_loss = 0.116, g_loss = 3.755, d_perf = 0.905, g_perf = 0.018
Epoch = [1/1000] Batch = [200/600]
d_loss = 0.061, g_loss = 7.039, d_perf = 0.926, g_perf = 0.002
Epoch = [1/1000] Batch = [400/600]
d_loss = 0.025, g_loss = 7.483, d_perf = 0.989, g_perf = 0.009
Epoch = [2/1000] Batch = [0/600]
d_loss = 0.088, g_loss = 4.989, d_perf = 0.921, g_perf = 0.012
Epoch = [2/1000] Batch = [200/600]
d_loss = 0.177, g_loss = 4.617, d_perf = 0.873, g_perf = 0.040
Epoch = [2/1000] Batch = [400/600]
d_loss = 0.234, g_loss = 3.939, d_perf = 0.833, g_perf = 0.042
Epoch = [3/1000] Batch = [0/600]
d_loss = 0.240, g_loss = 2.454, d_perf = 0.887, g_perf = 0.142
Epoch = [3/1000] Batch = [20

KeyboardInterrupt: 

# DCGAN

In [9]:
## DCGAN

num_epochs = 10
batch_size = 32
lr = 0.001
noise_size = 64
img_size = 64
num_channel = 1

## DCGAN 모델

In [10]:
class GANGenerator(nn.Module):
  def __init__(self):
    super().__init__()
    self.inp_sz = img_size // 4
    self.linear1 = nn.Linear(noise_size, 128*self.inp_sz**2)
    self.bn1 = nn.BatchNorm2d(128)
    self.up1 = nn.Upsample(scale_factor=2, mode='nearest')
    self.cn1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
    self.bn2 = nn.BatchNorm2d(128)
    self.rl = nn.LeakyReLU(0.2, inplace=True)
    self.up2 = nn.Upsample(scale_factor=2, mode='nearest')
    self.cn2 = nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1)
    self.bn3 = nn.BatchNorm2d(64)
    self.cn3 = nn.Conv2d(64, num_channel, kernel_size=3, stride=1, padding=1)
    self.tanh = nn.Tanh()

  def forward(self, x):
    x = self.linear1(x)
    x = x.view(x.shape[0], 128, self.inp_sz, self.inp_sz)
    x = self.bn1(x)
    x = self.up1(x)
    x = self.cn1(x)
    x = self.bn2(x)
    x = self.rl(x)
    x = self.up2(x)
    x = self.cn2(x)
    x = self.bn3(x)
    x = self.rl(x)
    x = self.cn3(x)
    x = self.tanh(x)
    return x

## DCGAN Discriminator

In [12]:
class GANDiscriminator(nn.Module):
  def __init__(self):
    super().__init__()

    def disc_module(ip_ch, op_ch, bnorm=True):
      mod = [nn.Conv2d(ip_ch, op_ch, kernel_size=3, stride=2, padding=1),
             nn.LeakyReLU(0.2, inplace=True),
             nn.Dropout(0.25)]
      if bnorm:
        mod += [nn.BatchNorm2d(op_ch)]
      return mod

    self.disc_model = nn.Sequential(
        *disc_module(num_channel, 16, bnorm=False),
        *disc_module(16, 32),  # [32, 16, 16]
        *disc_module(32, 64),  # [64, 8, 8]
        *disc_module(64, 128), # [128, 4, 4]
    )

    ds_size = img_size // 2**4  # 4번 절반으로 줄어들음
    self.fc = nn.Sequential(
        nn.Linear(128*(ds_size ** 2), 1), # 128 x (width x height)
        nn.Sigmoid()
    )

  def forward(self, x):
    x = self.disc_model(x)
    x = x.view(x.shape[0], -1)
    x = self.fc(x)
    return x

In [17]:
## DCGAN DATA

transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_data = datasets.MNIST(
    root='./',
    train=True,
    transform=transform,
    download=True
)

data_loader = DataLoader(
    train_data,
    batch_size,
    shuffle=False
)

In [21]:
generator = GANGenerator().to(device)
discriminator = GANDiscriminator().to(device)

criterion = nn.BCELoss()

g_optimizer = optim.Adam(generator.parameters(), lr=lr)
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr)

In [25]:
from tqdm import tqdm

os.makedirs('./DCGAN_results', exist_ok=True)

for epoch in tqdm(range(num_epochs)):
  for idx, (images, labels) in enumerate(data_loader):
    real_label = torch.ones((images.shape[0], 1), dtype=torch.float32).to(device)
    fake_label = torch.zeros((images.shape[0], 1), dtype=torch.float32).to(device)

    ## generator training
    noise = torch.randn(images.shape[0], noise_size).to(device)
    fake_images = generator(noise)
    g_loss = criterion(discriminator(fake_images), real_label)

    g_optimizer.zero_grad()
    g_loss.backward()
    g_optimizer.step()

    ## discriminator training
    real_images = images.to(device)
    real_loss = criterion(discriminator(real_images), real_label)
    fake_loss = criterion(discriminator(fake_images.detach()), fake_label)
    d_loss = (real_loss + fake_loss) / 2

    d_optimizer.zero_grad()
    d_loss.backward()
    d_optimizer.step()

    g_perf = discriminator(fake_images).mean()
    d_perf = discriminator(real_images).mean()

    if idx % 200 == 0:
      print(f"Epoch = {epoch}, Batch = {idx}, g_loss = {g_loss:.3f}, d_loss = {d_loss:.3f}")
      print(f"d_perf = {d_perf:.3f}, g_perf = {g_perf:.3f}")
      save_image(fake_images.data[:25], f"DCGAN_results/{epoch}.png", nrow=5, normalize=True)

  0%|          | 0/10 [00:00<?, ?it/s]

Epoch = 0, Batch = 0, g_loss = 0.529, d_loss = 0.858
d_perf = 0.619, g_perf = 0.534
Epoch = 0, Batch = 200, g_loss = 1.962, d_loss = 0.274
d_perf = 0.755, g_perf = 0.187
Epoch = 0, Batch = 400, g_loss = 1.370, d_loss = 0.901
d_perf = 0.513, g_perf = 0.351
Epoch = 0, Batch = 600, g_loss = 1.208, d_loss = 0.503
d_perf = 0.563, g_perf = 0.330
Epoch = 0, Batch = 800, g_loss = 1.570, d_loss = 0.244
d_perf = 0.665, g_perf = 0.101
Epoch = 0, Batch = 1000, g_loss = 1.213, d_loss = 0.284
d_perf = 0.835, g_perf = 0.450
Epoch = 0, Batch = 1200, g_loss = 2.162, d_loss = 0.330
d_perf = 0.569, g_perf = 0.125
Epoch = 0, Batch = 1400, g_loss = 1.682, d_loss = 0.221
d_perf = 0.935, g_perf = 0.046
Epoch = 0, Batch = 1600, g_loss = 2.486, d_loss = 0.305
d_perf = 0.980, g_perf = 0.505
Epoch = 0, Batch = 1800, g_loss = 2.053, d_loss = 0.140
d_perf = 0.957, g_perf = 0.145


 10%|█         | 1/10 [00:51<07:44, 51.58s/it]

Epoch = 1, Batch = 0, g_loss = 3.438, d_loss = 0.240
d_perf = 0.911, g_perf = 0.149
Epoch = 1, Batch = 200, g_loss = 0.933, d_loss = 0.720
d_perf = 0.987, g_perf = 0.404
Epoch = 1, Batch = 400, g_loss = 5.978, d_loss = 0.618
d_perf = 0.787, g_perf = 0.095
Epoch = 1, Batch = 600, g_loss = 3.491, d_loss = 0.294
d_perf = 0.720, g_perf = 0.035
Epoch = 1, Batch = 800, g_loss = 3.317, d_loss = 0.103
d_perf = 0.830, g_perf = 0.021
Epoch = 1, Batch = 1000, g_loss = 4.166, d_loss = 0.128
d_perf = 0.727, g_perf = 0.044
Epoch = 1, Batch = 1200, g_loss = 4.970, d_loss = 0.276
d_perf = 0.819, g_perf = 0.009
Epoch = 1, Batch = 1400, g_loss = 3.644, d_loss = 0.030
d_perf = 0.975, g_perf = 0.088
Epoch = 1, Batch = 1600, g_loss = 1.290, d_loss = 0.578
d_perf = 0.827, g_perf = 0.184
Epoch = 1, Batch = 1800, g_loss = 2.020, d_loss = 0.323
d_perf = 0.987, g_perf = 0.165


 20%|██        | 2/10 [01:43<06:52, 51.54s/it]

Epoch = 2, Batch = 0, g_loss = 3.685, d_loss = 0.120
d_perf = 0.875, g_perf = 0.020
Epoch = 2, Batch = 200, g_loss = 1.466, d_loss = 0.169
d_perf = 0.938, g_perf = 0.134
Epoch = 2, Batch = 400, g_loss = 1.731, d_loss = 0.173
d_perf = 0.993, g_perf = 0.371
Epoch = 2, Batch = 600, g_loss = 3.029, d_loss = 0.084
d_perf = 0.983, g_perf = 0.015
Epoch = 2, Batch = 800, g_loss = 2.064, d_loss = 0.031
d_perf = 0.911, g_perf = 0.015
Epoch = 2, Batch = 1000, g_loss = 2.947, d_loss = 0.159
d_perf = 0.950, g_perf = 0.059
Epoch = 2, Batch = 1200, g_loss = 4.754, d_loss = 0.076
d_perf = 0.838, g_perf = 0.006
Epoch = 2, Batch = 1400, g_loss = 3.764, d_loss = 0.039
d_perf = 0.921, g_perf = 0.045
Epoch = 2, Batch = 1600, g_loss = 5.732, d_loss = 0.553
d_perf = 0.742, g_perf = 0.015
Epoch = 2, Batch = 1800, g_loss = 3.597, d_loss = 0.084
d_perf = 0.912, g_perf = 0.047


 30%|███       | 3/10 [02:34<06:00, 51.46s/it]

Epoch = 3, Batch = 0, g_loss = 8.242, d_loss = 0.140
d_perf = 0.831, g_perf = 0.003
Epoch = 3, Batch = 200, g_loss = 8.784, d_loss = 0.548
d_perf = 0.596, g_perf = 0.025
Epoch = 3, Batch = 400, g_loss = 3.888, d_loss = 0.042
d_perf = 0.998, g_perf = 0.175
Epoch = 3, Batch = 600, g_loss = 4.473, d_loss = 0.033
d_perf = 0.928, g_perf = 0.096
Epoch = 3, Batch = 800, g_loss = 6.217, d_loss = 0.033
d_perf = 0.973, g_perf = 0.003
Epoch = 3, Batch = 1000, g_loss = 6.753, d_loss = 0.264
d_perf = 0.594, g_perf = 0.009
Epoch = 3, Batch = 1200, g_loss = 2.930, d_loss = 0.119
d_perf = 0.982, g_perf = 0.238
Epoch = 3, Batch = 1400, g_loss = 7.117, d_loss = 0.043
d_perf = 0.956, g_perf = 0.024
Epoch = 3, Batch = 1600, g_loss = 2.880, d_loss = 0.136
d_perf = 0.994, g_perf = 0.803
Epoch = 3, Batch = 1800, g_loss = 1.001, d_loss = 0.544
d_perf = 0.999, g_perf = 0.121


 40%|████      | 4/10 [03:26<05:09, 51.53s/it]

Epoch = 4, Batch = 0, g_loss = 3.953, d_loss = 0.080
d_perf = 0.956, g_perf = 0.006
Epoch = 4, Batch = 200, g_loss = 2.946, d_loss = 0.213
d_perf = 0.968, g_perf = 0.149
Epoch = 4, Batch = 400, g_loss = 6.219, d_loss = 0.030
d_perf = 0.875, g_perf = 0.111
Epoch = 4, Batch = 600, g_loss = 6.504, d_loss = 0.160
d_perf = 0.809, g_perf = 0.001
Epoch = 4, Batch = 800, g_loss = 7.625, d_loss = 0.024
d_perf = 0.986, g_perf = 0.000
Epoch = 4, Batch = 1000, g_loss = 1.621, d_loss = 0.103
d_perf = 0.992, g_perf = 0.065
Epoch = 4, Batch = 1200, g_loss = 9.103, d_loss = 0.291
d_perf = 0.869, g_perf = 0.001
Epoch = 4, Batch = 1400, g_loss = 4.938, d_loss = 0.055
d_perf = 0.951, g_perf = 0.047
Epoch = 4, Batch = 1600, g_loss = 4.167, d_loss = 0.292
d_perf = 0.980, g_perf = 0.013
Epoch = 4, Batch = 1800, g_loss = 6.154, d_loss = 0.007
d_perf = 0.991, g_perf = 0.013


 50%|█████     | 5/10 [04:17<04:17, 51.60s/it]

Epoch = 5, Batch = 0, g_loss = 4.489, d_loss = 0.090
d_perf = 0.993, g_perf = 0.054
Epoch = 5, Batch = 200, g_loss = 9.275, d_loss = 1.020
d_perf = 0.870, g_perf = 0.024
Epoch = 5, Batch = 400, g_loss = 3.797, d_loss = 0.235
d_perf = 0.996, g_perf = 0.089
Epoch = 5, Batch = 600, g_loss = 5.466, d_loss = 0.135
d_perf = 0.783, g_perf = 0.023
Epoch = 5, Batch = 800, g_loss = 7.192, d_loss = 0.033
d_perf = 0.903, g_perf = 0.001
Epoch = 5, Batch = 1000, g_loss = 4.887, d_loss = 0.019
d_perf = 0.970, g_perf = 0.096
Epoch = 5, Batch = 1200, g_loss = 10.808, d_loss = 0.379
d_perf = 0.726, g_perf = 0.008
Epoch = 5, Batch = 1400, g_loss = 3.386, d_loss = 0.065
d_perf = 0.998, g_perf = 0.132
Epoch = 5, Batch = 1600, g_loss = 6.719, d_loss = 0.031
d_perf = 0.823, g_perf = 0.008
Epoch = 5, Batch = 1800, g_loss = 6.455, d_loss = 0.015
d_perf = 0.999, g_perf = 0.043


 60%|██████    | 6/10 [05:09<03:26, 51.63s/it]

Epoch = 6, Batch = 0, g_loss = 5.618, d_loss = 0.002
d_perf = 0.998, g_perf = 0.083
Epoch = 6, Batch = 200, g_loss = 4.728, d_loss = 0.217
d_perf = 0.999, g_perf = 0.107
Epoch = 6, Batch = 400, g_loss = 2.926, d_loss = 0.015
d_perf = 0.995, g_perf = 0.108
Epoch = 6, Batch = 600, g_loss = 4.555, d_loss = 0.401
d_perf = 0.968, g_perf = 0.004
Epoch = 6, Batch = 800, g_loss = 7.892, d_loss = 0.013
d_perf = 0.987, g_perf = 0.000
Epoch = 6, Batch = 1000, g_loss = 7.699, d_loss = 0.750
d_perf = 0.987, g_perf = 0.005
Epoch = 6, Batch = 1200, g_loss = 3.839, d_loss = 0.016
d_perf = 0.999, g_perf = 0.094
Epoch = 6, Batch = 1400, g_loss = 6.291, d_loss = 0.147
d_perf = 1.000, g_perf = 0.150
Epoch = 6, Batch = 1600, g_loss = 5.248, d_loss = 0.011
d_perf = 0.973, g_perf = 0.004
Epoch = 6, Batch = 1800, g_loss = 8.307, d_loss = 0.003
d_perf = 0.987, g_perf = 0.005


 70%|███████   | 7/10 [06:01<02:35, 51.72s/it]

Epoch = 7, Batch = 0, g_loss = 10.857, d_loss = 0.008
d_perf = 0.946, g_perf = 0.000
Epoch = 7, Batch = 200, g_loss = 3.505, d_loss = 0.021
d_perf = 0.999, g_perf = 0.316
Epoch = 7, Batch = 400, g_loss = 4.504, d_loss = 0.031
d_perf = 0.981, g_perf = 0.046
Epoch = 7, Batch = 600, g_loss = 4.537, d_loss = 0.039
d_perf = 0.996, g_perf = 0.040
Epoch = 7, Batch = 800, g_loss = 7.384, d_loss = 0.518
d_perf = 0.791, g_perf = 0.000
Epoch = 7, Batch = 1000, g_loss = 1.527, d_loss = 0.112
d_perf = 0.989, g_perf = 0.005
Epoch = 7, Batch = 1200, g_loss = 4.769, d_loss = 0.105
d_perf = 0.994, g_perf = 0.008
Epoch = 7, Batch = 1400, g_loss = 3.609, d_loss = 0.072
d_perf = 1.000, g_perf = 0.036
Epoch = 7, Batch = 1600, g_loss = 4.692, d_loss = 0.127
d_perf = 0.936, g_perf = 0.053
Epoch = 7, Batch = 1800, g_loss = 7.908, d_loss = 0.024
d_perf = 0.986, g_perf = 0.009


 80%|████████  | 8/10 [06:53<01:43, 51.71s/it]

Epoch = 8, Batch = 0, g_loss = 6.869, d_loss = 0.003
d_perf = 0.956, g_perf = 0.003
Epoch = 8, Batch = 200, g_loss = 6.493, d_loss = 0.033
d_perf = 0.953, g_perf = 0.013
Epoch = 8, Batch = 400, g_loss = 8.300, d_loss = 0.031
d_perf = 0.916, g_perf = 0.012
Epoch = 8, Batch = 600, g_loss = 5.777, d_loss = 0.052
d_perf = 0.998, g_perf = 0.004
Epoch = 8, Batch = 800, g_loss = 6.682, d_loss = 0.022
d_perf = 0.978, g_perf = 0.019
Epoch = 8, Batch = 1000, g_loss = 5.512, d_loss = 0.009
d_perf = 0.993, g_perf = 0.003
Epoch = 8, Batch = 1200, g_loss = 6.199, d_loss = 0.044
d_perf = 0.984, g_perf = 0.081
Epoch = 8, Batch = 1400, g_loss = 11.067, d_loss = 0.106
d_perf = 0.838, g_perf = 0.001
Epoch = 8, Batch = 1600, g_loss = 4.962, d_loss = 0.006
d_perf = 0.994, g_perf = 0.002
Epoch = 8, Batch = 1800, g_loss = 7.796, d_loss = 0.001
d_perf = 1.000, g_perf = 0.075


 90%|█████████ | 9/10 [07:44<00:51, 51.71s/it]

Epoch = 9, Batch = 0, g_loss = 8.960, d_loss = 0.016
d_perf = 0.958, g_perf = 0.001
Epoch = 9, Batch = 200, g_loss = 4.796, d_loss = 0.180
d_perf = 0.998, g_perf = 0.035
Epoch = 9, Batch = 400, g_loss = 9.456, d_loss = 0.068
d_perf = 0.999, g_perf = 0.007
Epoch = 9, Batch = 600, g_loss = 7.750, d_loss = 0.041
d_perf = 0.954, g_perf = 0.001
Epoch = 9, Batch = 800, g_loss = 4.940, d_loss = 0.003
d_perf = 0.999, g_perf = 0.009
Epoch = 9, Batch = 1000, g_loss = 7.659, d_loss = 1.597
d_perf = 0.888, g_perf = 0.005
Epoch = 9, Batch = 1200, g_loss = 5.102, d_loss = 0.011
d_perf = 0.999, g_perf = 0.005
Epoch = 9, Batch = 1400, g_loss = 3.507, d_loss = 0.046
d_perf = 0.998, g_perf = 0.044
Epoch = 9, Batch = 1600, g_loss = 9.557, d_loss = 0.012
d_perf = 0.818, g_perf = 0.000
Epoch = 9, Batch = 1800, g_loss = 8.568, d_loss = 0.003
d_perf = 0.999, g_perf = 0.000


100%|██████████| 10/10 [08:36<00:00, 51.65s/it]


In [26]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn, optim
import torch.nn.functional as F

import os
from tqdm import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [27]:
if not os.path.exists("tiny_nerf_data.npz"):
  !wget http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz

--2025-03-31 07:11:23--  http://cseweb.ucsd.edu/~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz
Resolving cseweb.ucsd.edu (cseweb.ucsd.edu)... 132.239.8.30
Connecting to cseweb.ucsd.edu (cseweb.ucsd.edu)|132.239.8.30|:80... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cseweb.ucsd.edu//~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz [following]
--2025-03-31 07:11:23--  https://cseweb.ucsd.edu//~viscomp/projects/LF/papers/ECCV20/nerf/tiny_nerf_data.npz
Connecting to cseweb.ucsd.edu (cseweb.ucsd.edu)|132.239.8.30|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 12727482 (12M)
Saving to: ‘tiny_nerf_data.npz’


2025-03-31 07:11:24 (114 MB/s) - ‘tiny_nerf_data.npz’ saved [12727482/12727482]



## Positional encoding

$$\gamma(p) = (\sin(2^0\pi p), \cos(2^0\pi p), \dots, \sin(2^{L-1}\pi p), \cos(2^{L-1}\pi p))$$

In [28]:
def encoding(x, L=10):
  res = [x]
  for i in range(L):
    for fn in [torch.sin, torch.cos]:
      res.append(fn(2**i*torch.pi*x))
  return torch.cat(res, dim=-1)

In [29]:
x = torch.tensor([3.1, 5.6, 7.3]) # x, y, z
y = encoding(x, L=4)
print(y) # 총 27개 (3개+3*4*2개)

tensor([ 3.1000,  5.6000,  7.3000, -0.3090, -0.9511, -0.8090, -0.9511,  0.3090,
        -0.5878,  0.5878, -0.5878,  0.9511,  0.8090, -0.8090, -0.3090,  0.9511,
         0.9511, -0.5878,  0.3090,  0.3090, -0.8090,  0.5878,  0.5878,  0.9511,
        -0.8090, -0.8090,  0.3090])


# NeRF 클래스

In [31]:
class NeRF(nn.Module):
  def __init__(self, pos_enc_dim=63, view_enc_dim=27, hidden=256):
    super().__init__()

    self.linear1 = nn.Sequential(
        nn.Linear(pos_enc_dim, hidden),
        nn.ReLU()
    )

    self.pre_skip_linear = nn.Sequential()
    for _ in range(4):
      self.pre_skip_linear.append(nn.Linear(hidden, hidden))
      self.pre_skip_linear.append(nn.ReLU())

    self.linear_skip = nn.Sequential(
        nn.Linear(pos_enc_dim+hidden, hidden),
        nn.ReLU()
    )

    self.post_skip_linear = nn.Sequential()
    for _ in range(2):
      self.post_skip_linear.append(nn.Linear(hidden, hidden))
      self.post_skip_linear.append(nn.ReLU())

    self.density_layer = nn.Sequential(
        nn.Linear(hidden, 1),
        nn.ReLU()
    )

    self.linear2 = nn.Linear(hidden, hidden)

    self.color_linear1 = nn.Sequential(
        nn.Linear(hidden+view_enc_dim, hidden//2),
        nn.ReLU()
    )

    self.color_linear2 = nn.Sequential(
        nn.Linear(hidden//2, 3),
        nn.Sigmoid()
    )

  def forward(self, input):
    position = input[..., :3]  # x, y, z
    view_dirs = input[..., 3:] # direction

    # Encode
    pos_enc = encoding(position, L=10)
    view_enc = encoding(view_dirs, L=4)

    x = self.linear(pos_enc)
    x = self.pre_skip_linear(x)

    # Skip connection
    x = torch.cat([x, pos_enc], dim=-1)
    x = self.linear_skip(x)

    x = self.post_skip_linear(x)

    # Density
    sigma = self.density_layer(x)

    x = self.linear2(x)

    x = torch.cat([x, view_enc], dim=-1)
    x = self.color_linear1(x)
    rgb = self.color_linear2(x)

    return torch.cat([sigma, rgb], dim=-1)

## Get rays

In [None]:
def get_rays(H, W, focal, c2w):
  """
  Generate rays for a given camera configuration.

  Args:
    H: Image height.
    W: Image width.
    focal: Focal length.
    c2w: Camera-to-world transformation matrix (4x4).

  Returns:
    rays_o: Ray origins (H*W, 3).
    rays_d: Ray directions (H*W, 3).
  """

  device = c2w.device
  focal = torch.from_numpy(focal).to(device)

  i, j = torch.meshgrid(
      torch.arange(W, dtype=torch.float32, device=device),
      torch.arange(H, dtype=torch.float32, device=device),
      index = 'xy'
  )

  dirs = torch.stack([(i-W*0.5)/focal, -(j-H*0.5)/focal, -torch.ones_like(i, device=device)], -1)

  rays_d = torch.sum(dirs[..., None, :]*c2w[:3, :3], -1)
  rays_d = rays_d.view(-1, 3)

  rays_o = c2w[:3, -1].expand(rays_d.shape)

  return rays_o, rays_d

def render_rays(network_fn, rays_o, rays_d, near, far, N_samples, device, rand = False, embed_fn = None, chunk=1024*4):

  def batchify(fn, chunk):
    return lambda inputs: torch.cat([fn(inputs[i:i+chunk]) for i in range(0, inputs.shape[0], chunk)], 0)

  # sampling
  z_vals = torch.linspace(near, far, steps=N_samples, device=device)
  if rand:
    z_vals += torch.rand(*z_vals.shape[:-1], N_samples, device=device) * (far-near) / N_samples

  # 포인트 = 기준점 + 방향*샘플링
  pts = rays_o[..., None, :] + rays_d[..., None, :]*z_vals # [10000, 64, 3]

  ## Normalize view directions

  # dir:크기와 방향 나타냄. 크기 제거하고 방향만 남기기 위해 나눔
  view_dirs = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
  view_dirs = view_dirs[..., None, :].expand(pts.shape)

  # input = point + direction
  input_pts = torch.cat([pts, view_dirs], dim=-1)
  raw = batchify(network_fn, chunk)(input_pts)

  sigma_a = raw[..., 0] # batch demension, 1번째는 sigma
  rgb = raw[..., 1:]    # 나머지는 rgb

  ## distance (2번째~마지막 - 첫번쩨~(마지막-1))
  dists = z_vals[..., 1:] - z_vals[..., :-1]
  # 마지막 거리는 무한대로(연산에 참여하지 않도록)
  dists = torch.cat([dists, torch.tensor([1e10], device=device)], -1)

  alpha = 1. - torch.exp(-sigma_a*dists)
  alpha = alpha.unsqueeze(-1) # 차원 맞추기, Shape: [batch, N_samples, 1]

  ## Computing transmittance
  ones_shape = (alpha.shape[0], 1, 1) # 연산을 위해 생성

  T = torch.cumprod(
        torch.cat([
            torch.ones(ones_shape, device=device),  # (10000, 1, 1)
            1. - alpha + 1e-10                      # (10000, 64, 1)
        ], dim=1),
        dim=1
    )[:, :-1]  # Shape: [batch, N_samples, 1]

  weights = alpha * T  # Shape: [batch, N_samples, 1]

  # Compute final colors and depths, accumulation
  rgb_map = torch.sum(weights * rgb, dim=1)  # Sum along sample dimension
  # rgb_map.shape = torch.Size([10000, 3])
  depth_map = torch.sum(weights.squeeze(-1) * z_vals, dim=-1)  # Shape: [batch]
  acc_map = torch.sum(weights.squeeze(-1), dim=-1)  # Shape: [batch]

  return rgb_map, depth_map, acc_map


In [37]:
W = 3; H = 4; focal = 1

i, j = torch.meshgrid(
    torch.arange(W, dtype=torch.float32, device=device),
    torch.arange(H, dtype=torch.float32, device=device),
    indexing='xy'
)

print("i = \n", i)
print("j = \n", j)

i = 
 tensor([[0., 1., 2.],
        [0., 1., 2.],
        [0., 1., 2.],
        [0., 1., 2.]], device='cuda:0')
j = 
 tensor([[0., 0., 0.],
        [1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.]], device='cuda:0')
