In [None]:
#SRGAN(Super Resolution GAN)

# /content/drive/MyDrive/data/GAN/img_align_celeba



In [None]:
import glob
import torchvision.transforms as tf
from torch.utils.data.dataset import Dataset
from PIL import Image
import torch

In [None]:
class CelebA(Dataset):
    def __init__(self):
        self.imgs = glob.glob('/content/drive/MyDrive/data/GAN/img_align_celeba/*.jpg')

        #해상도 낮은 이미지
        self.low_res_tf = tf.Compose([
            tf.Resize((32,32)),
            tf.ToTensor(),
            tf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        #원본 이미지
        self.high_res_tf = tf.Compose([
            tf.Resize((64,64)),
            tf.ToTensor(),
            tf.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, i):

        img = Image.open(self.imgs[i])

        img_low_res = self.low_res_tf(img)
        img_high_res = self.high_res_tf(img)

        return [img_low_res, img_high_res]



In [None]:
from torch.nn.modules.activation import PReLU
from torch.nn.modules import BatchNorm2d
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.PReLU(),# 0이하의 값에서는 기울기를
            nn.Conv2d(out_channels, out_channels, kernel_size = 3, stride = 1, padding=1),
            nn.BatchNorm2d(out_channels)
        )
    def forward(self,x):
        x_ = x
        x = self.layers(x)

        #합성곱 층을거친 후에 원래의 입력 텐서와 더해준다.
        x = x_ + x
        return x
    


In [None]:
class UpSample(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(UpSample, self).__init__(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
            nn.PixelShuffle(upscale_factor=2),
            nn.PReLU()
        )
    #생성자


In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, stride=1, padding=4),
            nn.PReLU()
        )
        self.res_block = nn.Sequential(
            ResidualBlock(64,64),
            ResidualBlock(64,64),
            ResidualBlock(64,64)
        )

        self.conv2 = nn.Conv2d(64,64,kernel_size=3, stride=1, padding=1)

        self.bn2 = nn.BatchNorm2d(64)

        self.upsample_blocks = nn.Sequential(
            UpSample(64,256)
        )

        self.conv3 = nn.Conv2d(64,3,kernel_size=9, stride=1, padding=4)
    
    def forward(self, x):

        x = self.conv1

        #나중에 가서 사용하므로 저장해둠
        x_ = x

        x = self.res_block(x)
        x = self.conv2(x)
        x = self.bn2(x)
        
        x = x + x_

        x = self.upsample_blocks(x)
        x = self.conv3(x)
        
        return x

In [None]:
#판별자의 기본블록

class DisBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DisBlock, self).__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU()
        )
    def forward(self, x):
        return self.layers(x)

In [None]:
# 감별자 블록

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU()
        )

        self.blocks = DisBlock(64,64)
        self.fc1 = nn.Linear(65536, 1024)
        self.activation = nn.LeakyReLU()
        self.fc2 = nn.Linear(1024, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv1(x)
        x = self.blocks(x)

        x = torch.flatten(x, start_dim=1)

        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        x = self.sigmoid(x)

        return x

    




In [None]:
# 특징 추출에는 vgg19를 사용

from torchvision.models.vgg import vgg19

class FeatureExtactor(nn.Module):
    def __init__(self):
        super(FeatureExtactor, self).__init__()

        vgg19_model = vgg19(pretrained=True)

        self.feature_extractor = nn.Sequential(
            *list(vgg19_model.features.children())[:9]
        )
    
    def forward(self, img):
        return self.feature_extractor(img)

In [None]:
import tqdm
from torch.utils.data.dataloader import DataLoader
from torch.optim.adam import Adam

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
dataset = CelebA()
batch_size = 8
loader = DataLoader(dataset, batch_size= batch_size, shuffle=True)

G = Generator().to(device)
D = Discriminator().to(device)

feature_extractor = FeatureExtactor().to(device)
feature_extractor.eval()

G_optim = Adam(G.parameters(), lr = 0.0001, betas =(0.5, 0.999))
D_optim = Adam(D.parameters(), lr = 0.0001, betas =(0.5, 0.999))
