In [1]:
from os import listdir
from os.path import join
import random
import matplotlib.pyplot as plt
%matplotlib inline

import os
import time
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torchvision.transforms.functional import to_pil_image
import pennylane as qml

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# Costum dataset 생성
class FacadeDataset(Dataset):
    def __init__(self, path2img, direction='b2a', transform=False):
        super().__init__()
        self.direction = direction
        self.path2a = join(path2img, 'a')
        self.path2b = join(path2img, 'b')
        self.img_filenames = [x for x in listdir(self.path2a)]
        self.transform = transform

    def __getitem__(self, index):
        a = Image.open(join(self.path2a, self.img_filenames[index])).convert('RGB')
        b = Image.open(join(self.path2b, self.img_filenames[index])).convert('RGB')
        
        if self.transform:
            a = self.transform(a)
            b = self.transform(b)
        a=(a+1)/2
        b=(b+1)/2
        

        if self.direction == 'b2a':
            return b,a
        else:
            return a,b

    def __len__(self):
        return len(self.img_filenames)

In [3]:
transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5]),
                    transforms.Resize((256,256))
])
path2img = 'data/facades/train'
train_ds = FacadeDataset(path2img, transform=transform)
train_dl = DataLoader(train_ds, batch_size=1, shuffle=True)

In [4]:
# UNet
class UNetDown(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
        super().__init__()

        layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1, bias=False)]

        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels)),

        layers.append(nn.LeakyReLU(0.2))

        if dropout:
            layers.append(nn.Dropout(dropout))

        self.down = nn.Sequential(*layers)

    def forward(self, x):
        x = self.down(x)
        return x
    
class UNetUp(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super().__init__()

        layers = [
            nn.ConvTranspose2d(in_channels, out_channels,4,2,1,bias=False),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU()
        ]

        if dropout:
            layers.append(nn.Dropout(dropout))

        self.up = nn.Sequential(*layers)

    def forward(self,x):
        x = self.up(x)
        return x
    
n_qubits = 6
dev = qml.device("lightning.qubit", wires=n_qubits)
@qml.qnode(dev, interface="torch")
def qc(inputs, weights):
    qml.AmplitudeEmbedding(features=inputs, wires=range(n_qubits),normalize=True)
    q_depth=3
    k=0


    # Repeated layer
    for i in range(q_depth):
        # Parameterised layer
        for y in range(n_qubits):
            qml.RY(weights[k], wires=y)
            k+=1

        # Control Z gates
        for y in range(n_qubits - 1):
            qml.CZ(wires=[y, y + 1])

    return qml.probs(wires=list(range(n_qubits)))

weight_shapes = {"weights": n_qubits*3}
qlayer = qml.qnn.TorchLayer(qc, weight_shapes)


In [5]:
class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3):
        super().__init__()

        self.down1 = UNetDown(in_channels, 8, normalize=False)
        self.down2 = UNetDown(8,16)                 
        self.down3 = UNetDown(16,32)               
        self.down4 = UNetDown(32,64,dropout=0.5) 
        self.down5 = UNetDown(64,64,dropout=0.5)      
        self.down6 = UNetDown(64,64,dropout=0.5)             
        self.down7 = UNetDown(64,64,dropout=0.5)              
        self.down8 = UNetDown(64,64,normalize=False,dropout=0.5)
        self.qlayer = qlayer
        self.up1 = UNetUp(64,64,dropout=0.5)
        self.up2 = UNetUp(64,64,dropout=0.5)
        self.up3 = UNetUp(64,64,dropout=0.5)
        self.up4 = UNetUp(64,64,dropout=0.5)
        self.up5 = UNetUp(64,32)
        self.up6 = UNetUp(32,16)
        self.up7 = UNetUp(16,8)
        self.up8 = nn.Sequential(
            UNetUp(8,3),
            nn.Tanh()
        )
        
    def forward(self, x):
        print(x)
        d1 = self.down1(x)
        
        d2 = self.down2(d1)
       
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        d8=d8.reshape(64)
        amp=d8.norm()
        d8=d8/amp
        c0=self.qlayer(d8)*amp
        c0=c0.reshape(1,64,1,1)
        u1 = self.up1(c0)
        u2 = self.up2(u1)
        u3 = self.up3(u2)
        u4 = self.up4(u3)
        u5 = self.up5(u4)
        u6 = self.up6(u5)
        u7 = self.up7(u6)
        u8 = self.up8(u7)
        return u8

# check
x = torch.randn(1,3,256,256,device=device)
model = GeneratorUNet().to(device)
out = model(x)
print(out.shape)



tensor([[[[ 3.1692e-01,  2.1175e+00, -4.4934e-01,  ...,  2.1827e-01,
           -4.1943e-01, -7.7055e-01],
          [-2.4606e-01, -5.6883e-01,  7.5542e-02,  ...,  6.3938e-01,
            1.0312e-01,  7.6918e-01],
          [-1.9574e-01, -3.0436e-01,  1.2157e-01,  ...,  2.7604e-01,
           -3.4550e-02,  1.0803e-01],
          ...,
          [-6.4386e-01,  3.1231e-01, -7.5478e-01,  ..., -2.1602e-03,
           -5.3222e-01, -8.1635e-02],
          [ 7.3150e-01,  9.0138e-03,  2.7234e-01,  ...,  1.8248e-01,
            9.5824e-01,  3.3488e-02],
          [-8.1353e-01,  3.4577e-01, -8.1151e-02,  ..., -6.1934e-01,
           -9.4608e-01,  1.9028e+00]],

         [[-9.4013e-01, -1.3689e+00, -9.2546e-01,  ..., -1.3012e+00,
            2.8168e-01, -1.4023e+00],
          [ 4.2357e-02, -1.1090e+00,  4.6232e-01,  ..., -1.4432e+00,
            1.2009e+00, -2.9127e-01],
          [ 1.5367e-01,  5.8763e-01,  4.5354e-01,  ..., -9.0218e-02,
            1.7929e-01, -1.8051e+00],
          ...,
     

In [6]:
class Dis_block(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True):
        super().__init__()

        layers = [nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
    
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        x = self.block(x)
        return x

# check
x = torch.randn(16,64,128,128,device=device)
model = Dis_block(64,128).to(device)
out = model(x)
print(out.shape)

torch.Size([16, 128, 64, 64])


In [7]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3):
        super().__init__()

        self.stage_1 = Dis_block(in_channels*2,64,normalize=False)
        self.stage_2 = Dis_block(64,128)
        self.stage_3 = Dis_block(128,256)
        self.stage_4 = Dis_block(256,512)

        self.patch = nn.Conv2d(512,1,3,padding=1) # 16x16 패치 생성

    def forward(self,a,b):
        x = torch.cat((a,b),1)
        x = self.stage_1(x)
        x = self.stage_2(x)
        x = self.stage_3(x)
        x = self.stage_4(x)
        x = self.patch(x)
        x = torch.sigmoid(x)
        return x
# check
x = torch.randn(16,3,256,256,device=device)
model = Discriminator().to(device)
out = model(x,x)
print(out.shape)

torch.Size([16, 1, 16, 16])


In [8]:
def initialize_weights(model):
    class_name = model.__class__.__name__
    if class_name.find('Conv') != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)




In [9]:
loss_func_gan = nn.BCELoss()
loss_func_pix = nn.L1Loss()

# loss_func_pix 가중치
lambda_pixel = 100

# patch 수
patch = (1,256//2**4,256//2**4)

# 최적화 파라미터
from torch import optim
lr = 2e-4
beta1 = 0.5
beta2 = 0.999
model_dis=Discriminator().cuda()
model_gen=GeneratorUNet().cuda()
opt_dis = optim.Adam(model_dis.parameters(),lr=lr,betas=(beta1,beta2))
opt_gen = optim.Adam(model_gen.parameters(),lr=lr,betas=(beta1,beta2))

In [10]:
# 학습
model_gen.train()
model_dis.train()

batch_count = 0
num_epochs = 20
start_time = time.time()

loss_hist = {'gen':[],
             'dis':[]}

for epoch in range(num_epochs):
    for a, b in train_dl:
        ba_si = a.size(0)

        # real image
        real_a = a.to(device)
        real_b = b.to(device)

        # patch label
        real_label = torch.ones(ba_si, *patch, requires_grad=False).to(device)
        fake_label = torch.zeros(ba_si, *patch, requires_grad=False).to(device)

        # generator
        model_gen.zero_grad()

        fake_b = model_gen(real_a) # 가짜 이미지 생성
        out_dis = model_dis(fake_b, real_b) # 가짜 이미지 식별
        

        gen_loss = loss_func_gan(out_dis, real_label)
        pixel_loss = loss_func_pix(fake_b, real_b)

        g_loss = gen_loss + lambda_pixel * pixel_loss
        g_loss.backward()
        opt_gen.step()

        # discriminator
        model_dis.zero_grad()

        out_dis = model_dis(real_b, real_a) # 진짜 이미지 식별
        real_loss = loss_func_gan(out_dis,real_label)
        
        out_dis = model_dis(fake_b.detach(), real_a) # 가짜 이미지 식별
        fake_loss = loss_func_gan(out_dis,fake_label)

        d_loss = (real_loss + fake_loss) / 2.
        d_loss.backward()
        opt_dis.step()

        loss_hist['gen'].append(g_loss.item())
        loss_hist['dis'].append(d_loss.item())

        batch_count += 1
        if batch_count % 10 == 0:
            print('Epoch: %.0f, G_Loss: %.6f, D_Loss: %.6f, time: %.2f min' %(epoch, g_loss.item(), d_loss.item(), (time.time()-start_time)/60) )

tensor([[[[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          ...,
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

         [[0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667],
          [0.6667, 0.6667, 0.6667,  ..., 0.6667, 0.6667, 0.6667],
          [0.6667, 0.6667, 0.6667,  ..., 0



ValueError: State vectors have to be of norm 1.0, vector 0 has norm nan