In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import os
import glob
import matplotlib.pyplot as plt
import torchvision

In [2]:
print(torch.__version__)

1.13.1


In [3]:
print(torch.cuda.is_available())

True


In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

### Layer 정의

In [5]:
class Block(nn.Module):
  def __init__(self, in_channels, out_channels, stride):
    super(Block, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 4, stride, 1, bias = True, padding_mode = 'reflect'),
        nn.InstanceNorm2d(out_channels),
        nn.LeakyReLU(0.2)
    )
  def forward(self, x):
    return self.conv(x)

In [6]:
class Discriminator(nn.Module):
  def __init__(self, in_channels = 3, features = [64, 128, 256, 512]):
    super(Discriminator, self).__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(in_channels, features[0], kernel_size = 4, stride = 2, padding = 1, padding_mode ='reflect'), # Conv(3, 64, 4, 2, 1)
    nn.LeakyReLU(0.2))

    layers = []
    in_channels = features[0]
    for feature in features[1:]:
      layers.append(Block(in_channels, feature, stride = 1 if feature == features[-1] else 2)) #Conv(64,128, 4, 2, 1),Conv(128,256, 4, 2, 1), Conv(256, 512, 4, 1, 1)  담는다
      in_channels = feature # 
    layers.append(nn.Conv2d(in_channels,1, kernel_size = 4, stride = 1, padding = 1, padding_mode = 'reflect')) #Conv(512, 1, 4, 1, 1)
    self.model = nn.Sequential(*layers)

  def forward(self, x):
    x = self.initial(x)
    return torch.sigmoid(self.model(x)) # 30 x 30 pathgan

In [7]:
x = torch.randn((5, 3, 256, 256))
model = Discriminator(in_channels = 3)
preds = model(x)
print(preds.shape)
print(preds[0][0])

torch.Size([5, 1, 30, 30])
tensor([[0.3997, 0.4635, 0.3825, 0.5762, 0.4164, 0.5671, 0.5631, 0.4284, 0.5155,
         0.4465, 0.3008, 0.4125, 0.6461, 0.3669, 0.5287, 0.3254, 0.3260, 0.5202,
         0.5438, 0.5735, 0.4327, 0.5066, 0.4545, 0.4391, 0.4005, 0.5319, 0.4182,
         0.3856, 0.4044, 0.3371],
        [0.4296, 0.3853, 0.3748, 0.4903, 0.4241, 0.5474, 0.2968, 0.6216, 0.4105,
         0.4608, 0.4702, 0.3213, 0.4719, 0.3711, 0.2722, 0.4946, 0.5077, 0.4899,
         0.3581, 0.3996, 0.6745, 0.5588, 0.5144, 0.5302, 0.5815, 0.4528, 0.3726,
         0.3200, 0.4719, 0.4872],
        [0.5392, 0.5690, 0.3816, 0.4496, 0.5237, 0.6391, 0.4891, 0.5576, 0.3496,
         0.5032, 0.4880, 0.3674, 0.6315, 0.5784, 0.5499, 0.5230, 0.3261, 0.4750,
         0.3534, 0.5026, 0.6153, 0.5265, 0.4416, 0.6719, 0.5822, 0.5177, 0.5683,
         0.4058, 0.2487, 0.4515],
        [0.5151, 0.6575, 0.4452, 0.6741, 0.3999, 0.4591, 0.4490, 0.4454, 0.5137,
         0.4790, 0.5723, 0.6446, 0.5547, 0.4174, 0.5352, 0.39

In [8]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels, down = True, use_act = True, **kwargs):
    super(ConvBlock, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels, out_channels, padding_mode = 'reflect', **kwargs)
        if down else 
        nn.ConvTranspose2d(in_channels, out_channels, **kwargs),
        nn.InstanceNorm2d(out_channels),
        nn.ReLU(inplace = True) if use_act else nn.Identity()
    )
  def forward(self, x):
    return self.conv(x)

In [9]:
class ResidualBlock(nn.Module):
  def __init__(self, channels):
    super(ResidualBlock, self).__init__()
    self.block = nn.Sequential(
      ConvBlock(channels, channels, kernel_size = 3, padding = 1),
      ConvBlock(channels, channels, use_act = False, kernel_size = 3, padding = 1)
  )
  def forward(self, x):
    return x + self.block(x)

### 생성기 정의

In [10]:
class Generator(nn.Module):
  def __init__(self, img_channels,num_features = 64, num_resblock = 9):
    super(Generator, self).__init__()
    self.initial = nn.Sequential(
        nn.Conv2d(img_channels, num_features, kernel_size = 7, stride = 1, padding = 3, padding_mode = 'reflect'),
        nn.ReLU(inplace = True)
    )
    self.down_blocks = nn.ModuleList(
      [
          ConvBlock(num_features, num_features * 2, kernel_size = 3, stride = 2, padding = 1),
       ConvBlock(num_features * 2, num_features * 4, kernel_size = 3, stride = 2, padding = 1),]
    )
    self.residual_block = nn.Sequential(
        *[ResidualBlock(num_features * 4) for _ in range(num_resblock)]
    )

    self.up_blocks = nn.ModuleList(
        [
        ConvBlock(num_features * 4, num_features * 2, down = False, kernel_size = 3, stride = 2,padding = 1, output_padding = 1),
        ConvBlock(num_features * 2, num_features, down = False, kernel_size = 3, stride = 2,padding = 1, output_padding = 1)
    ])
    
    self.last = nn.Conv2d(num_features * 1, img_channels, kernel_size = 7, stride = 1, padding = 3, padding_mode = 'reflect')
  
  def forward(self, x):
    x = self.initial(x)
    for layer in self.down_blocks:
      x = layer(x)
    x= self.residual_block(x)
    for layer in self.up_blocks:
      x = layer(x)
    
    return torch.tanh(self.last(x))

In [12]:
test = torch.randn(2, 3, 256, 256)
gen = Generator(3)
print(gen(test).shape)

torch.Size([2, 3, 256, 256])


In [13]:
transforms_ = [transforms.Resize(256),
                                  transforms.RandomHorizontalFlip(),
                                  transforms.ToTensor(),
                                  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
test_trans = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

### 데이터셋은 celeb attribute

In [14]:
class Custom_Dataset(Dataset):
  def __init__(self, root, transforms_ = None, mode = 'train'):
    self.transforms = transforms.Compose(transforms_)
    
    self.file_A = sorted(glob.glob(os.path.join(root, 'trainA') + '/*')) 
    self.file_B = sorted(glob.glob(os.path.join(root, 'trainB') + '/*'))
  
  def __getitem__(self, index):
    itemA = self.transforms(Image.open(self.file_A[index % len(self.file_A)]))

    itemB = self.transforms(Image.open(self.file_B[index % len(self.file_B)]))

    return {'A' : itemA, 'B' : itemB}

  def __len__(self):
    return max(len(self.file_A), len(self.file_B))  

In [15]:
#root= "C:\\Users\\User\\Desktop\\GAN proj\\archive\\img_align_celeba\\img_align_celeba_AB\\"
root = "D:\\다운로드\\archive\\dataAB\\"
batch_size = 32
learning_rate = 2e-4
lambda_identity = 0
lambda_cycle = 10
num_epochs = 200

In [16]:
disc_H = Discriminator(in_channels = 3).to(device)
disc_Z = Discriminator(in_channels = 3).to(device)
gen_Z = Generator(img_channels = 3, num_resblock = 9).to(device)
gen_H = Generator(img_channels = 3, num_resblock = 9).to(device)

optim_disc = optim.Adam(list(disc_H.parameters()) + list(disc_Z.parameters()),
                        lr = learning_rate, betas = (0.5, 0.999))
optim_gen = optim.Adam(list(gen_Z.parameters()) + list(gen_H.parameters()),
                       lr = learning_rate, betas = (0.5, 0.999))
L1 = nn.L1Loss()
mse = nn.MSELoss()

In [17]:
dataset = Custom_Dataset(
     root, transforms_ = transforms_
)
dataloader = DataLoader(dataset, batch_size = 2, shuffle = True)

In [20]:
imagea = next(iter(dataloader))['A'][0]
imageb = next(iter(dataloader))['B'][0]

In [21]:
imagea

tensor([[[ 0.0353,  0.0275,  0.0275,  ...,  0.3020,  0.3020,  0.3020],
         [ 0.0353,  0.0275,  0.0275,  ...,  0.3020,  0.3020,  0.3020],
         [ 0.0353,  0.0275,  0.0275,  ...,  0.3020,  0.3020,  0.3020],
         ...,
         [ 0.2078,  0.2078,  0.2078,  ...,  0.5529,  0.5529,  0.5529],
         [ 0.2078,  0.2078,  0.2078,  ...,  0.5529,  0.5529,  0.5529],
         [ 0.2078,  0.2078,  0.2078,  ...,  0.5529,  0.5529,  0.5529]],

        [[-0.0118, -0.0196, -0.0196,  ...,  0.2784,  0.2784,  0.2784],
         [-0.0118, -0.0196, -0.0196,  ...,  0.2784,  0.2784,  0.2784],
         [-0.0118, -0.0196, -0.0196,  ...,  0.2784,  0.2784,  0.2784],
         ...,
         [-0.5373, -0.5373, -0.5373,  ...,  0.5922,  0.5922,  0.5922],
         [-0.5373, -0.5373, -0.5373,  ...,  0.5922,  0.5922,  0.5922],
         [-0.5373, -0.5373, -0.5373,  ...,  0.5922,  0.5922,  0.5922]],

        [[-0.1843, -0.1922, -0.1922,  ..., -0.1373, -0.1373, -0.1373],
         [-0.1843, -0.1922, -0.1922,  ..., -0

In [22]:
print(imagea.size())
print(imageb.size())

torch.Size([3, 313, 256])
torch.Size([3, 313, 256])


In [23]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [24]:
# disc_H = Discriminator(in_channels = 3).to(device) = disc_A
# disc_Z = Discriminator(in_channels = 3).to(device) = disc_O
# gen_Z = Generator(img_channels = 3, num_resblock = 9).to(device) = gen_O
# gen_H = Generator(img_channels = 3, num_resblock = 9).to(device) = gen_A
# zebra = orange, horse = apple

for epoch in range(num_epochs):
  for idx, batch in enumerate(dataloader):
  
    orange = batch['B'].to(device)
    apple = batch['A'].to(device)
    optim_disc.zero_grad()
    optim_gen.zero_grad()

    fake_apple = gen_H(orange) 
    D_A_real = disc_H(apple) 
    D_A_fake = disc_H(fake_apple.detach()) 
    D_A_real_loss = mse(D_A_real, torch.ones_like(D_A_real))
    D_A_fake_loss = mse(D_A_fake, torch.zeros_like(D_A_fake)) 
    D_A_loss = D_A_real_loss + D_A_fake_loss 

    fake_orange = gen_Z(apple)
    D_O_real = disc_Z(orange)
    D_O_fake = disc_Z(fake_orange.detach())
    D_O_real_loss = mse(D_O_real, torch.ones_like(D_O_real))
    D_O_fake_loss = mse(D_O_fake, torch.zeros_like(D_O_fake))
    D_O_loss = D_O_real_loss + D_O_fake_loss

    D_loss = (D_A_loss + D_O_loss) / 2

    D_loss.backward()
    optim_disc.step()
    #generator 학습
    D_A_fake = disc_H(fake_apple)
    D_O_fake = disc_Z(fake_orange)
    loss_G_A = mse(D_O_fake, torch.ones_like(D_A_fake)) #CE 를 쓰지않고 여기서 정의된 새로운 손실함수 
    loss_G_O = mse(D_O_fake, torch.ones_like(D_O_fake))

    cycle_orange = gen_Z(fake_apple)
    cycle_apple = gen_H(fake_orange)
    cycle_orange_loss = L1(orange, cycle_orange)
    cycle_apple_loss = L1(apple, cycle_apple)

    identity_orange = gen_Z(orange)
    identity_apple = gen_H(apple)
    identity_orange_loss = L1(orange, identity_orange)
    identity_apple_loss = L1(apple, identity_apple)

    G_loss = (loss_G_O + loss_G_A + cycle_orange_loss * lambda_cycle + cycle_apple_loss * lambda_cycle
              +identity_apple_loss * lambda_identity + identity_orange_loss * lambda_identity)
    G_loss.backward()
    optim_gen.step()
    if idx == 50:
      print('EPOCH : ', epoch, 'G_loss : ', G_loss, 'D_loss :', D_loss)
      break

NameError: name 'gen_A' is not defined

### 모델 저장

In [None]:
torch.save(gen_H, 'gen_H.pt')
torch.save(gen_Z, 'gen_Z.pt')

In [None]:
apple_image = iter(dataloader).next()['A'][1]
orange_image = iter(dataloader).next()['B'][1]

In [None]:
apple_image = apple_image.to(device)
orange_image = orange_image.to(device)

모델에 이미지 넣어서 결과 확인.
나름 잘 변환된 것을 확인할 수 있다.

In [None]:
plt.figure(figsize = (10, 10))
plt.subplot(1, 2, 1)
plt.imshow(torchvision.utils.make_grid(apple_image.cpu(), normalize = True).permute(1, 2, 0)) 
plt.subplot(1,2,2)
plt.imshow(torchvision.utils.make_grid(orange_image.cpu(), normalize = True).permute(1, 2, 0)) 

In [None]:
trans_apple = gen_H(orange_image).cpu()
plt.imshow(torchvision.utils.make_grid(trans_apple.cpu(), normalize = True).permute(1, 2, 0)) 

In [None]:
trans_orange = gen_Z(apple_image).cpu()
plt.imshow(torchvision.utils.make_grid(trans_orange.cpu(), normalize = True).permute(1, 2, 0)) 