In [None]:
# dataset. cuda 설정 필요?
## shuffle 부분이 GPU에서 작동하도록.

# celabA h5dy 어떻게 했는지 다시 알아보기
## h5py 형식으로 바꾼다. 이때 분류작업을 한다.
### zip을 풀면서 이름을 'ooo/ooo.jpg'처럼 만든다. 그럼 폴더를 사용하듯 indexing하며 분류할 수 있다.
## 폴더 별로 읽어서 훈련한다.
### 훈련하며 알아서 닫고 다음 폴더 읽음?

# GDL의 celabA는 small 인가?

In [None]:
# conv로 모델 만들기
## conv 

## 모델 내부에서 view처럼 작동하는 View class 만들기

### 왜 필요? 어떤 기능 필요? 
#### sequential 내부에서 forward로 진행됨. 
#### Conv2d()나 Linear() 같은 함수는 노드 사이 관계와 parameters()만 표현함.
#### 노드를 지목하고 싶음


In [4]:
a = 1,
b = (1,2),

In [7]:
print(type(a),a)
print(type(b),b)
print(*a)
print(*b)


<class 'tuple'> (1,)
<class 'tuple'> ((1, 2),)
1
(1, 2)


In [39]:
def crop_centre(img, new_width, new_height):
    height, width, _ = img.shape
    startx = width//2 - new_width//2
    starty = height//2 - new_height//2    
    return img[  starty:starty + new_height, startx:startx + new_width, :]


In [59]:
import torch
from torch.utils.data import Dataset,DataLoader 
import numpy as np
import h5py
import matplotlib.pyplot as plt

class CelebADataset(Dataset):
    
    def __init__(self, file):
        self.file_object = h5py.File(file, 'r')
        self.dataset = self.file_object['img_align_celeba']
        pass
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        if (index >= len(self.dataset)):
          raise IndexError()
        img = np.array(self.dataset[str(index)+'.jpg'])
        # 128x128 사각형으로 크롭
        img = crop_centre(img, 128, 128)
        return torch.FloatTensor(img).permute(2,0,1).view(3,128,128) / 255.0
    
    def plot_image(self, index):
        img = np.array(self.dataset[str(index)+'.jpg'])
        # 128x128 사각형으로 크롭
        img = crop_centre(img, 128, 128)
        plt.imshow(img, interpolation='nearest')
        pass
    

In [71]:
import torch
import torch.nn as nn

class View(nn.Module):
    def __init__(self,shape):
        super().__init__()
        self.shape = shape,
    def forward(self,x):
        return x.view(*self.shape)

In [67]:
class ConvVAE(nn.Module):
    def __init__(self,dim_latent=200):
        # 파이토치 부모 클래스 초기화
        super().__init__()
        self.dim_latent = dim_latent
       # 신경망 레이어 정의
        self.encoder = nn.Sequential(
            # (1,3,128,128) 형태를 의도
            nn.Conv2d(3, 32, kernel_size=2, stride=2),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            #(1,32,64,64)
            
            nn.Conv2d(32, 64, kernel_size=2, stride=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            #(1,64,32,32)

            nn.Conv2d(64, 64, kernel_size=2, stride=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            #(1,64,16,16)
            
            nn.Conv2d(64, 64, kernel_size=2, stride=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            #(1,64,8,8)

            View((-1,64*8*8)),
            nn.Linear(64*8*8, 2*self.dim_latent),
        )
        
        self.decoder = nn.Sequential(
            #(200)
            nn.Linear(self.dim_latent,64*8*8),
            View((-1,64,8,8)),
            #(64,8,8)

            nn.ConvTranspose2d(64,64,2,stride=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),            
            #(64,16,16)

            nn.ConvTranspose2d(64,64,2,stride=2),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            #(64,32,32)

            nn.ConvTranspose2d(64,32,2,stride=2),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            #(32,64,64)

            nn.ConvTranspose2d(32,3,2,stride=2),
            nn.Sigmoid()
            #(3,128,128)
        )

        # 손실 함수 생성
        self.loss_function = nn.BCELoss()

        # 옵티마이저 생성
        self.optimiser = torch.optim.Adam(self.parameters(), lr=0.0001)
    def forward(self,x):
        x = self.encoder(x)
        print(x.shape)
        x = x.view(-1,2,self.dim_latent)

        self.mu = x[:,0,:]
        self.log_var = x[:,1,:]
        print(self.mu.shape)
        print(self.log_var.shape)

        epsilon = torch.randn_like(self.mu)
        sigma = torch.exp(self.log_var/2)

        self.z = self.mu+sigma*epsilon
        de = self.decoder(self.z)
        return de

    def train(self,data_loader):
        for epoch in range(1, 101):
            for x_input in data_loader:
                self.optimiser.zero_grad()
                print(x_input.shape)
                y_pred = self.forward(x_input)
                
                mse_loss = ((y_pred- x_input)**2).sum()
                kl_loss = -0.5*torch.sum(1+self.log_var-2*self.mu**2 - 2*torch.exp(self.log_var))
                loss = mse_loss+kl_loss*10000
                loss.backward()
                self.optimiser.step()
            if epoch % 10 == 0:
                print(f"Epoch {epoch}: loss = {loss.item()}")


In [68]:
x_input = torch.FloatTensor(torch.randn((1,3,128,128)))
model = ConvVAE()
model(x_input).shape

torch.Size([1, 400])
torch.Size([1, 200])
torch.Size([1, 200])


torch.Size([1, 3, 128, 128])

In [69]:
dataset = CelebADataset('./data/celeba_aligned_small.h5py')
dataloader = DataLoader(dataset,32,True)


In [70]:
model.train(dataloader)

torch.Size([32, 3, 128, 128])
torch.Size([32, 400])
torch.Size([32, 200])
torch.Size([32, 200])
--hi
hi
torch.Size([32, 3, 128, 128])
torch.Size([32, 400])
torch.Size([32, 200])
torch.Size([32, 200])
--hi
hi
torch.Size([32, 3, 128, 128])
torch.Size([32, 400])
torch.Size([32, 200])
torch.Size([32, 200])
--hi
hi
torch.Size([32, 3, 128, 128])
torch.Size([32, 400])
torch.Size([32, 200])
torch.Size([32, 200])
--hi
hi
torch.Size([32, 3, 128, 128])
torch.Size([32, 400])
torch.Size([32, 200])
torch.Size([32, 200])
--hi
hi
torch.Size([32, 3, 128, 128])
torch.Size([32, 400])
torch.Size([32, 200])
torch.Size([32, 200])
--hi
hi
torch.Size([32, 3, 128, 128])
torch.Size([32, 400])
torch.Size([32, 200])
torch.Size([32, 200])
--hi
hi
torch.Size([32, 3, 128, 128])
torch.Size([32, 400])
torch.Size([32, 200])
torch.Size([32, 200])
--hi
hi
torch.Size([32, 3, 128, 128])
torch.Size([32, 400])
torch.Size([32, 200])
torch.Size([32, 200])
--hi
hi
torch.Size([32, 3, 128, 128])
torch.Size([32, 400])
torch.Size([

KeyboardInterrupt: 