<a href="https://colab.research.google.com/github/chikara-n-ellipse/reconst3d/blob/feature%2Fsimple-gan/imsup3d.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import shutil
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as tfs
import os
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline
# %matplotlib notebook

In [None]:
extract_dir = "/content/img_cropped_celeba"
shutil.unpack_archive(
    "/content/drive/MyDrive/Projects/celebA/img_cropped_celeba.zip", 
    extract_dir=extract_dir,
    )
data_dir = "/content/img_cropped_celeba/"

In [None]:
# !ls /content/img_cropped_celeba/val

In [None]:
# Setup
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

In [None]:
class Generator(nn.Module):
  """
    Vanilla GAN Generator
  """
  def __init__(self,):
    super().__init__()

    # First upsampling
    self.ct0_0 = nn.ConvTranspose2d(64, 48, 3)
    self.ct0_1 = nn.ConvTranspose2d(48, 32, 3)
    self.lrelu0 = nn.LeakyReLU(0.25)

    # Second upsampling
    self.upsample1 = nn.Upsample((8, 8))
    self.c1_0 = nn.Conv2d(32, 24, 3)
    self.c1_1 = nn.Conv2d(24, 16, 3)
    self.lrelu1 = nn.LeakyReLU(0.25)

    # Third upsampling
    self.upsample2 = nn.Upsample((16, 16))
    self.c2_0 = nn.Conv2d(16, 8, 3, padding='same')
    self.c2_1 = nn.Conv2d(8, 3, 3, padding='same')
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):

    x = self.ct0_0(x)
    x = self.ct0_1(x)
    x =  self.lrelu0(x)

    x = self.upsample1(x)
    x = self.c1_0(x)
    x = self.c1_1(x)
    x =  self.lrelu1(x)

    x = self.upsample2(x)
    x = self.c2_0(x)
    x = self.c2_1(x)
    x = self.sigmoid(x)

    return x


class Discriminator(nn.Module):
  """
    Vanilla GAN Discriminator
  """
  def __init__(self):
    super().__init__()
    
    # First downsampling
    self.c0_0 = nn.Conv2d(3, 8, 3)
    self.c0_1 = nn.Conv2d(8, 16, 3)
    self.lrelu0 = nn.LeakyReLU(0.25)
    self.downsample0 = nn.Upsample((8, 8))

    # Second downsampling
    self.c1_0 = nn.Conv2d(16, 24, 3)
    self.c1_1 = nn.Conv2d(24, 32, 3)
    self.lrelu1 = nn.LeakyReLU(0.25)
    self.downsample1 = nn.Upsample((6, 6))

    # Third downsampling
    self.c2_0 = nn.Conv2d(32, 48, 4)
    self.c2_1 = nn.Conv2d(48, 64, 3)
    self.linear = nn.Linear(64, 1)
    self.sigmoid = nn.Sigmoid()

  def forward(self, x):
    x = self.c0_0(x)
    x = self.c0_1(x)
    x = self.lrelu0(x)
    x =  self.downsample0(x)

    x = self.c1_0(x)
    x = self.c1_1(x)
    x = self.lrelu1(x)
    x =  self.downsample1(x)

    x = self.c2_0(x)
    x = self.c2_1(x)
    x = self.linear(x.squeeze())
    x = self.sigmoid(x)
    return x

In [None]:
def fix_seed(seed):
    # random
    # random.seed(seed)
    # Numpy
    np.random.seed(seed)
    # Pytorch
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

class ImageDataset(Dataset):
    def __init__(self, image_paths, transform=None):
        self.image_paths = image_paths
        self.transform = transform
        
    def get_class_label(self, image_name):
        # your method here
        y = ...
        return y
        
    def __getitem__(self, index):
        image_path = self.image_paths[index]
        x = torch.tensor(plt.imread(image_path))/255.0
        if self.transform is not None:
            x = self.transform(x.permute(2, 0, 1))
        return x
    
    def __len__(self):
        return len(self.image_paths)

fix_seed(0)
batch_size = 512

paths = [data_dir+'train'+'/'+fname for fname in os.listdir(data_dir+'train')]
p_depth = 4
transform = tfs.Resize(2**p_depth)
dataset = ImageDataset(paths, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size,
                        shuffle=True, num_workers=0)



generator = Generator().to(device)
discriminator = Discriminator().to(device)

optimizer_gen = torch.optim.Adam(generator.parameters(), lr=0.001)
optimizer_dis = torch.optim.Adam(discriminator.parameters(), lr=0.0001)


def train_one_epoch(
        turn_steps = {'dis':10, 'gen':10}, 
        max_step = None, 
        is_debug = False, 
        show_interval = 100
        ):
    for i_batch, sample_batched in enumerate(dataloader):

        if i_batch % (turn_steps['dis']+turn_steps['gen']) < turn_steps['dis']:
            step_turn = 'dis'
        else:
            step_turn = 'gen'

        if max_step is not None and i_batch >= max_step:
            break
        if is_debug:
            print(i_batch, sample_batched.size())
            plt.imshow(torch.cat(list(sample_batched.permute(0, 2, 3, 1)[:3]), 1))
            plt.show()
            # input()
        
        generator.train()
        discriminator.train()
        optimizer_gen.zero_grad()
        optimizer_dis.zero_grad()

        image_real = sample_batched.to(device)
        z = torch.randn((batch_size, 64, 1, 1)).to(device)

        image_fake = generator(z)

        # print(image_fake[0])
        # input()
        
        if step_turn == 'dis':

            p_real = discriminator(image_real)
            p_fake = discriminator(image_fake)

            loss_dis_real = - torch.log(p_real).mean()
            loss_dis_fake = - torch.log(1 - p_fake).mean()
            loss_dis = loss_dis_real + loss_dis_fake

            loss_dis.backward()
            optimizer_dis.step()

        elif step_turn == 'gen':

            p_fake = discriminator(image_fake)

            loss_gen = - torch.log(p_fake).mean()
            
            loss_gen.backward()
            optimizer_gen.step()

        if i_batch % show_interval == 0:
            if step_turn == 'dis':
                print(f"p_fake:{p_fake.detach().cpu().mean().item():.3f}, p_real:{p_real.detach().cpu().mean().item():.3f}")
            elif step_turn == 'gen':
                print(f"p_fake:{p_fake.detach().mean().item():.3f}")
            print(image_fake.size())
            plt.imshow(torch.cat(list(image_fake.detach().cpu().permute(0, 2, 3, 1)[:3]), 1))
            plt.show()

for epoch in range(100000):
    print(f"epoch: {epoch:05}")
    train_one_epoch(turn_steps={'dis':10, 'gen':10}, show_interval=100000000)


In [None]:
config = {
    'batchsize': 2,
    'num_workers':4,
    'image_size': 64,
}