In [None]:
import torch
torch.cuda.is_available()

True

In [None]:
!pip install pytorch-fid

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytorch-fid
  Downloading pytorch-fid-0.2.1.tar.gz (14 kB)
Building wheels for collected packages: pytorch-fid
  Building wheel for pytorch-fid (setup.py) ... [?25l[?25hdone
  Created wheel for pytorch-fid: filename=pytorch_fid-0.2.1-py3-none-any.whl size=14835 sha256=5c349fd92faf86cefb1ef78f83c931ca49dc873507e4e2f9666d33dd15bdecbb
  Stored in directory: /root/.cache/pip/wheels/24/ac/03/c5634775c8a64f702343ef5923278f8d3bb8c651debc4a6890
Successfully built pytorch-fid
Installing collected packages: pytorch-fid
Successfully installed pytorch-fid-0.2.1


In [None]:
import torch.nn as nn
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import numpy as np
from PIL import Image

from tqdm import tqdm
import os
import torch

from pytorch_fid.fid_score import *

os.environ['KMP_DUPLICATE_LIB_OK']='True'

torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)


class Generator(nn.Module):
    def __init__(self, ):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(100, 64 * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.ReLU(),
            nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.ReLU(),
            nn.ConvTranspose2d(64 * 4, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.ReLU(),
            nn.ConvTranspose2d(64 * 2, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        # input data는 [batch size, 100, 1, 1]의 형태로 주어야합니다.
        return self.main(input)


class Discriminator(nn.Module):
    # 모델의 코드는 여기서 작성해주세요

    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 64 * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 2, 64 * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 4, 64 * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64 * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64 * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        output = self.main(input)
        return output


if __name__ == "__main__":
    # 학습코드는 모두 여기서 작성해주세요
    
    ## hyperparameters
    workers = 2
    batch_size = 128
    img_size = 64
    epochs = 5
    lr = 0.0002
    beta1 = 0.5

    data_path = 'training_data/'
    real_img_path = 'training_data/celeba/'
    fake_img_path = 'fake_img/'

    dataset = datasets.ImageFolder(root=data_path,
                                   transform=transforms.Compose([
                                      transforms.Resize(img_size),
                                      transforms.CenterCrop(img_size),
                                      transforms.ToTensor(),
                                      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                  ]))
                              

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    generator = Generator().to(device)
    generator.apply(weights_init)

    discriminator = Discriminator().to(device)
    discriminator.apply(weights_init)

    criterion = nn.BCELoss()

    real_label = 1.
    fake_label = 0.

    d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
    g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
    
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    print("Starting Training Loop...")
    for epoch in range(epochs):
        for i, data in enumerate(dataloader, 0):

            # Update discriminator
            discriminator.zero_grad()
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
            output = discriminator(real_cpu).view(-1)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()

            # Generate fake images to train discriminator
            noise = torch.randn(b_size, 100, 1, 1, device=device)
            fake = generator(noise)
            label.fill_(fake_label)
            output = discriminator(fake.detach()).view(-1)
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            d_optimizer.step()

            # Update generator
            generator.zero_grad()
            label.fill_(real_label)  
            output = discriminator(fake).view(-1)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            g_optimizer.step()

            # Output training progress
            if i % 50 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, epochs, i, len(dataloader),
                        errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            G_losses.append(errG.item())
            D_losses.append(errD.item())

            iters += 1


Starting Training Loop...
[0/5][0/549]	Loss_D: 1.6913	Loss_G: 6.6971	D(x): 0.7556	D(G(z)): 0.6796 / 0.0021
[0/5][50/549]	Loss_D: 0.0346	Loss_G: 17.5047	D(x): 0.9788	D(G(z)): 0.0000 / 0.0000
[0/5][100/549]	Loss_D: 0.4230	Loss_G: 11.4945	D(x): 0.9058	D(G(z)): 0.1863 / 0.0001
[0/5][150/549]	Loss_D: 1.0821	Loss_G: 11.4517	D(x): 0.9898	D(G(z)): 0.6026 / 0.0001
[0/5][200/549]	Loss_D: 0.5956	Loss_G: 4.4065	D(x): 0.7956	D(G(z)): 0.2151 / 0.0306
[0/5][250/549]	Loss_D: 1.2524	Loss_G: 8.7042	D(x): 0.9011	D(G(z)): 0.6074 / 0.0008
[0/5][300/549]	Loss_D: 0.5753	Loss_G: 4.8802	D(x): 0.8110	D(G(z)): 0.2396 / 0.0113
[0/5][350/549]	Loss_D: 0.4429	Loss_G: 3.5473	D(x): 0.8374	D(G(z)): 0.1918 / 0.0498
[0/5][400/549]	Loss_D: 0.7926	Loss_G: 5.0182	D(x): 0.8420	D(G(z)): 0.3445 / 0.0140
[0/5][450/549]	Loss_D: 0.5673	Loss_G: 5.5884	D(x): 0.8389	D(G(z)): 0.2689 / 0.0065
[0/5][500/549]	Loss_D: 0.5383	Loss_G: 5.1179	D(x): 0.8432	D(G(z)): 0.2424 / 0.0105
[1/5][0/549]	Loss_D: 0.4652	Loss_G: 4.7964	D(x): 0.8446	D(G(z

In [None]:
# FID score 측정에 사용할 fake 이미지를 생성하는 코드 입니다.
# generator의 학습을 완료한 뒤 마지막에 실행하여 fake 이미지를 저장하시기 바랍니다.
test_noise = torch.randn(3000, 100, 1, 1, device=device)
with torch.no_grad():
    test_fake = generator(test_noise).detach().cpu()

    for index, img in enumerate(test_fake):
        fake = np.transpose(img.detach().cpu().numpy(), [1, 2, 0])
        fake = (fake * 127.5 + 127.5).astype(np.uint8)
        im = Image.fromarray(fake)
        im.save("./fake_img/fake_sample{}.jpeg".format(index))

In [None]:
import os
import torch

from pytorch_fid.fid_score import *

os.environ['KMP_DUPLICATE_LIB_OK']='True'

real_img_path = 'training_data/celeba/'
fake_img_path = 'fake_img/'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

if __name__ == "__main__":
    fid = calculate_fid_given_paths(
        paths=[real_img_path, fake_img_path],
        batch_size=128,
        device=device,
        dims=2048
    )

    print("fid score : {}".format(fid))

Downloading: "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" to /root/.cache/torch/hub/checkpoints/pt_inception-2015-12-05-6726825d.pth


  0%|          | 0.00/91.2M [00:00<?, ?B/s]

100%|██████████| 549/549 [04:44<00:00,  1.93it/s]
100%|██████████| 24/24 [00:12<00:00,  1.91it/s]


fid score : 68.76001566417895
