<a href="https://colab.research.google.com/github/gomdoori/AI/blob/main/SRGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import math
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim

In [None]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.datasets as datasets

In [None]:
#디바이스 할당
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
data_root = '/content/drive/MyDrive/Colab Notebooks/SRGAN'

In [None]:
# 입력용 이미지
image_low = 32

#출력용 이미지
image_high = 64

In [None]:
trans_low = transforms.Compose([transforms.Resize((image_low,image_low)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

In [None]:
trans_high = transforms.Compose([transforms.Resize((image_high,image_high)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

In [None]:
trainlow = torchvision.datasets.ImageFolder(root=data_root,
                                            transform=trans_low)

In [None]:
trainhigh = torchvision.datasets.ImageFolder(root=data_root,
                                            transform=trans_high)

In [None]:
batch_size =8

In [None]:
low_loader = DataLoader(
    trainlow, batch_size=batch_size,
    shuffle=True
  )

In [None]:
high_loader = DataLoader(
    trainhigh, batch_size=batch_size,
    shuffle=True
  )

In [None]:
# 몇 개의 데이터 그룹으로 데이터를 가져올 수 있는가
print(len(low_loader))

# 데이터로더로부터 가장 처음 한 세트를 가져옴
for images, labels in low_loader:
  break

print(images.shape)
print(labels.shape)

443
torch.Size([8, 3, 32, 32])
torch.Size([8])


In [None]:
import torchvision.utils as vutils

In [None]:
from torchvision.utils import save_image

In [None]:
save_image(images, 'image_32.png', nrow=4,normalize=True)

In [None]:
# 몇 개의 데이터 그룹으로 데이터를 가져올 수 있는가
print(len(high_loader))

# 데이터로더로부터 가장 처음 한 세트를 가져옴
for images, labels in high_loader:
  break

print(images.shape)
print(labels.shape)

443
torch.Size([8, 3, 64, 64])
torch.Size([8])


In [None]:
save_image(images, 'image_64.png', nrow=4,normalize=True)

In [None]:
trans_lowhigh = transforms.Compose([transforms.Resize((image_low,image_high)),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
                                transforms.Resize((image_high,image_high))
])

In [None]:
trainlowhigh = torchvision.datasets.ImageFolder(root=data_root,
                                            transform=trans_lowhigh)

In [None]:
lowhigh_loader = DataLoader(
    trainlowhigh, batch_size=batch_size,
    shuffle=True
  )

In [None]:
# 몇 개의 데이터 그룹으로 데이터를 가져올 수 있는가
print(len(lowhigh_loader))

# 데이터로더로부터 가장 처음 한 세트를 가져옴
for images, labels in high_loader:
  break

print(images.shape)
print(labels.shape)

443
torch.Size([8, 3, 64, 64])
torch.Size([8])


In [None]:
save_image(images, 'image_3264.png', nrow=4,normalize=True)

In [None]:
import glob
from PIL import Image

In [None]:
class CelebA(Dataset):
  def __init__(self):
    self.imgs = glob.glob(data_root + '/CNN/*.jpg')

    self.low_res_tf = transforms.Compose([transforms.Resize((image_low,image_low)),
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])

    self.high_res_tf = transforms.Compose([transforms.Resize((image_high,image_high)),
                                          transforms.ToTensor(),
                                          transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    ])

  def __len__(self):
    return len(self.imgs)

  def __getitem__(self, i):

    img = Image.open(self.imgs[i])

    #저화질 이미지를 입력으로
    img_low_res = self.low_res_tf(img)
    #고화질 이미지를 정답으로
    img_high_res = self.high_res_tf(img)

    return [img_low_res, img_high_res]

In [None]:
dataset = CelebA()
batch_size = 8
loader = DataLoader(dataset, batch_size = batch_size, shuffle = True)

In [None]:
class ResidualBlock(nn.Module):
  def __init__(self):
    super(ResidualBlock, self).__init__()

    # 생산자의 구성 요소 정의
    self.layers = nn.Sequential(
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64),
        nn.PReLU(),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(64)
    )

  def forward(self, x):
    x_old = x
    x = self.layers(x)

    # 합성곱층을 거친 후 원래의 입력 텐서와 더해줌
    x = x_old + x

    return x

In [None]:
class UpSample(nn.Sequential):
  def __init__(self):
    super(UpSample, self).__init__(
        nn.Conv2d(in_channels=64, out_channels=256,
                  kernel_size=3, stride=1, padding=1),
        nn.PixelShuffle(upscale_factor=2),
        nn.PReLU()
    )

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()

    # 첫 번째 합성곱층
    self.conv1 = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=9, stride=1, padding=4),
        nn.PReLU()
    )

    # 합성곱 블록
    self.res_blocks = nn.Sequential(
        ResidualBlock(),
        ResidualBlock(),
        ResidualBlock()
    )

    self.conv2 = nn.Conv2d(64,64,kernel_size=3, stride=1, padding=1)
    self.bn2 = nn.BatchNorm2d(64)
    # 업샘플링층
    self.upsample_blocks = nn.Sequential(UpSample())
    # 마지막 합성곱층
    self.conv3 = nn.Conv2d(64, 3, kernel_size=9, stride=1, padding=4)

  def forward(self, x):
    x = self.conv1(x)

    x_old = x

    x = self.res_blocks(x)
    x = self.conv2(x)
    x = self.bn2(x)
    x = x_old + x

    x = self.upsample_blocks(x)

    x = self.conv3(x)

    return x

In [None]:
class DiscBlock(nn.Module):
  def __init__(self):
    super(DiscBlock, self).__init__()

    #생산자의 구성요소 정의
    self.layers = nn.Sequential(                              #여기 이상함. 스트라이드로 사이즈 줄어야함
        nn.Conv2d(in_channels=64, out_channels=64,
                  kernel_size=3, stride=2, padding=1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU()
    )

  def forward(self, x):

    return self.layers(x)

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()

    # 첫 번째 합성곱층
    self.conv1 = nn.Sequential(
        nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
        nn.LeakyReLU()
    )
    self.blocks = DiscBlock()

    self.fc1 = nn.Linear(64*16*16, 1024)
    self.activation = nn.LeakyReLU()
    self.fc2 = nn.Linear(1024, 1)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x = self.conv1(x)
    x = self.blocks(x)

    x = torch.flatten(x, start_dim=1)

    x = self.fc1(x)               # 여기서 걸렸다. 사이즈가 안 줄었다 64*64*64 -> 32*32*64 안됨
    x = self.activation(x)
    x = self.fc2(x)
    x = self.sigmoid(x)

    return x

In [None]:
from torchvision.models.vgg import vgg19

class FeatureExtractor(nn.Module):
  def __init__(self):
    super(FeatureExtractor, self).__init__()

    self.features = vgg19(pretrained=True).features

    vgg19_model = vgg19(pretrained=True)

    self.feature_extractor = nn.Sequential(
        *list(vgg19_model.features.children())[:9])

  def forward(self, img):
    return self.feature_extractor(img)

In [None]:
G = Generator().to(device)
D = Discriminator().to(device)
feature_extractor = FeatureExtractor().to(device)
feature_extractor.eval()

FeatureExtractor(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride

In [None]:
G_optim = optim.Adam(G.parameters(), lr=0.0001, betas=(0.5, 0.999))
D_optim = optim.Adam(D.parameters(), lr=0.0001, betas=(0.5, 0.999))

In [None]:
epochs = 25
for epoch in range(epochs):
  for i, (low_res, high_res) in enumerate(lowhigh_loader):
    # 기울기 초기화
    G_optim.zero_grad()
    D_optim.zero_grad()

    b_size = len(high_res)

    # 진짜 이미지와 가짜 이미지의 정답
    label_true = torch.ones(b_size, 1).to(device)
    label_false = torch.zeros(b_size, 1).to(device)

    # 생성자 학습
    fake_hr = G(low_res.to(device))
    GAN_loss = nn.MSELoss()(D(fake_hr), label_true)
    # 가짜 이미지 특징 추출
    fake_features = feature_extractor(fake_hr)
    # 진짜 이미지 특징 추출
    real_features = feature_extractor(high_res.to(device))
    # 둘의 차이 비교
    content_loss = nn.L1Loss()(fake_features, real_features)

    loss_G = content_loss + 0.001*GAN_loss
    loss_G.backward()
    G_optim.step()

    # 판별자 학습
    # 진짜 이미지 손실
    real_loss = nn.MSELoss()(D(high_res.to(device)), label_true)
    # 가짜 이미지 손실
    fake_loss = nn.MSELoss()(D(fake_hr.detach()), label_false)

    loss_D = (real_loss + fake_loss)/2
    loss_D.backward()
    D_optim.step()

  print(f"Epoch {epoch} of {epochs}")
  print(f"Generator loss: {loss_G:.8f}")

RuntimeError: mat1 and mat2 shapes cannot be multiplied (8x262144 and 65536x1024)

In [None]:
D(fake_hr)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (8x262144 and 65536x1024)

In [None]:
import torchvision.utils as vutils
with torch.no_grad():
  low_res, high_res = dataset[0]

  #생성자의 입력
  input_tensor = torch.unsqueeze(low_res, 0).to(device)
  pred = G(input_tensor)
  pred = pred.squeeze()
  pred = pred.permute(1,2,0).cpu()

  low_res = low_res.permute(1,2,0)
  high_res = high_res.permute(1,2,0)

  plt.subplot(3,1,1)
  plt.title('low resolution image')
  plt.axis("off")
  plt.imshow(vutils.make_grid(low_res, padding=2, normalize=True))

  plt.subplot(3,1,2)
  plt.title('predicted high resolution image')
  plt.axis("off")
  plt.imshow(vutils.make_grid(pred, padding=2, normalize=True))

  plt.subplot(3,1,3)
  plt.title('generated image')
  plt.axis("off")
  plt.imshow(vutils.make_grid(high_res, padding=2, normalize=True))

  plt.show()

In [None]:
class Celeb_test(Dataset):
  def __init__(self):
    self.imgs = glob.glob(data_root + '/ex2/*.jpg')