# **데이터 가져오기**

In [None]:
from google.colab import files 

uploaded = files.upload()

for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn]))) 

In [None]:
! mkdir dataset_3000.zip                       # 마운트에 폴더 생성
! unzip dataset_3000.zip -d ./dataset_3000     # unzip 

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset

import cv2 as cv
import os 

import numpy as np
import pandas, random
import matplotlib.pyplot as plt

# **데이터 세팅**

In [None]:
src = '/content/dataset_3000/'


# 이미지 읽기
def img_read(src,file):
    img = cv.imread(src+file)  
    img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    img=img.reshape(28,48,3,1)
    img=np.transpose(img)
    return img


# 좌우반전 읽기
def img_leftright(src,file):
    origin = cv.imread(src+file)
    origin = cv.cvtColor(origin, cv.COLOR_BGR2RGB)
    img = cv.flip(origin,1)
    img=img.reshape(28,48,3,1)
    img=np.transpose(img)
    return img


# 블러 읽기
def img_blur(src,file):
    origin = cv.imread(src+file)
    origin = cv.cvtColor(origin, cv.COLOR_BGR2RGB)
    img = cv.medianBlur(origin,3)
    img=img.reshape(28,48,3,1)
    img=np.transpose(img)
    return img


#밝기 읽기
def img_bright(src,file):
    origin = cv.imread(src+file)
    origin = cv.cvtColor(origin, cv.COLOR_BGR2RGB)
    img = cv.add(origin,50)
    img=img.reshape(28,48,3,1)
 #   img=np.transpose(img)
    return img

# src 경로에 있는 파일 명을 저장합니다. 
files = os.listdir(src)

data_set = []


# 경로와 파일명을 입력으로 넣어 확인하고 
# 데이터를 255로 나눠서 0~1사이로 정규화 하여 X 리스트에 넣습니다. 

for file in files:  
    data_set.append(img_leftright(src,file)/255.)
 
for file in files: 
    data_set.append(img_read(src,file)/255.)
   
for file in files: 
    data_set.append(img_blur(src,file)/255.)
  
for file in files: 
    data_set.append(img_bright(src,file)/255.)
    
# array로 데이터 변환

data_set = np.array(data_set)

print('X_shape:',np.shape(data_set[0]))
print('X_list shape:',np.shape(data_set))


In [None]:
plt.imshow(np.transpose(data_set[500][0]))

In [None]:
f, axarr = plt.subplots(2,3, figsize=(16,8))
for i in range(2):
    for j in range(3):
        output = G.forward(generate_random_seed(54))
        img = output.detach().permute(0,2,3,1).reshape(3,48,28)
        axarr[i,j].imshow(np.transpose(img), interpolation='none',cmap='Blues')
        pass
    pass

# **helper functions**

In [None]:
def generate_random_image(size):
    random_data = torch.rand(size)
    return random_data


def generate_random_seed(size):
    random_data = torch.randn(size)
    return random_data

# **분류기**

In [None]:
class Discriminator(nn.Module):

  def __init__(self):
    
    super().__init__()

    self.model = nn.Sequential(
        nn.Conv2d(3,256,kernel_size=3,stride=2),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        nn.Conv2d(256,256,kernel_size=3,stride=2),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        nn.Conv2d(256,3,kernel_size=3,stride=2),
        nn.LeakyReLU(0.2),

        View(3*10),
        nn.Linear(3*10,1),
        nn.Sigmoid()
    )

    self.loss_function = nn.BCELoss()

    self.optimiser = torch.optim.Adam(self.parameters(),lr=0.0001)

    self.counter = 0;
    self.progress = []

    pass


  def forward(self,inputs):
    return self.model(inputs)

  def train(self,inputs,targets):
    outputs = self.forward(inputs)
    loss = self.loss_function(outputs,targets)

    self.counter +=1;
    if (self.counter % 10 == 0):
        self.progress.append(loss.item())
        pass
    if (self.counter % 1000 == 0):
        print("counter = ", self.counter)
        pass

    self.optimiser.zero_grad()
    loss.backward()
    self.optimiser.step()

    pass

  def plot_progress(self):
      df = pandas.DataFrame(self.progress, columns=['loss'])
      df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
      pass
    
  pass


# **분류기 테스트**

In [None]:
D = Discriminator()

for image_data in data_set:
  D.train(torch.Tensor(image_data), torch.FloatTensor([1.0]))
  D.train(generate_random_image((1,3,48,28)),torch.FloatTensor([0.0]))
  pass


In [None]:
D.plot_progress()

In [None]:
for i in range(4):
  image_data = data_set[random.randint(0,10000)]
  print( D.forward( torch.Tensor(image_data) ).item() )
  pass

for i in range(4):
  print( D.forward( generate_random_image((1,3,48,28))).item() )
  pass

# **생성기**

In [None]:
class Generator(nn.Module):

  def __init__(self):
    
    super().__init__()

    self.model = nn.Sequential(
        
        nn.Linear(54,3*13*23),
        nn.LeakyReLU(0.2),
        
        View((1,3,13,23)),

        nn.ConvTranspose2d(3,256,kernel_size=3,stride=2),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        nn.ConvTranspose2d(256,256,kernel_size=3,stride=2),
        nn.BatchNorm2d(256),
        nn.LeakyReLU(0.2),

        nn.Conv2d(256,3,kernel_size=3,stride=2,padding=1),
        nn.BatchNorm2d(3),

        nn.Sigmoid()
    )

    self.optimiser = torch.optim.Adam(self.parameters(),lr=0.0001)

    self.counter = 0;
    self.progress = []

    pass


  def forward(self,inputs):
    return self.model(inputs)

  def train(self,D,inputs,targets):
    g_output = self.forward(inputs)
    d_output = D.forward(g_output)
    loss = D.loss_function(d_output,targets)

    self.counter +=1;
    if (self.counter % 10 == 0):
        self.progress.append(loss.item())
        pass
    if (self.counter % 1000 == 0):
        print("counter = ", self.counter)
        pass

    self.optimiser.zero_grad()
    loss.backward()
    self.optimiser.step()

    pass

  def plot_progress(self):
      df = pandas.DataFrame(self.progress, columns=['loss'])
      df.plot(ylim=(0), figsize=(16,8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5, 1.0, 5.0))
      pass
    
  pass


# **생성기 테스트**

In [None]:
G = Generator()

output = G.forward(generate_random_seed(54))
img = output.detach().permute(0,2,3,1).reshape(3,48,28)
plt.imshow(np.transpose(img), interpolation='none', cmap='Blues')


# ***GAN 학습***

In [None]:
D = Discriminator()
G = Generator()

epochs = 1

for epoch in range(epochs):
  print ("epoch = ", epoch + 1)

  for image_data in data_set:
    
    D.train(torch.Tensor(image_data),torch.FloatTensor([1.0]))
    D.train(G.forward(generate_random_seed(54)).detach(),torch.FloatTensor([0.0]))
    
    G.train(D,generate_random_seed(54),torch.FloatTensor([1.0]))

    pass
    
  pass


In [None]:
D.plot_progress()

In [None]:
G.plot_progress()

# **GAN 실행**

In [None]:
f, axarr = plt.subplots(2,3, figsize=(16,8))
for i in range(2):
    for j in range(3):
        output = G.forward(generate_random_seed(54))
        img = output.detach().permute(0,2,3,1).reshape(3,48,28)
        axarr[i,j].imshow(np.transpose(img), interpolation='none',cmap='Blues')
        pass
    pass