In [None]:
# import related libraries
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
import torchvision.datasets as dataset
import torchvision.transforms as transforms
import torchvision.utils as vision_utils

from models.srgan_discriminator import Discriminator
from models.srgan_generator import Generator

import os
import tqdm
import numpy

import matplotlib.pyplot as plot
import cv2

# constants
learning_rate = 0.0002
beta1_for_adam = 0.5
beta2_for_adam = 0.999
real_label = 1.0
fake_label = 0.0
length_of_z_input_vector = 100

In [None]:
# model and dataset initializing step
def initial_weights(input_module):
    class_name = input_module.__class__.__name__
    if class_name.find('Conv') != -1:
        torch.nn.init.normal_(input_module.weight, 0.0, 0.02)
    elif class_name.find('BatchNorm') != -1:
        torch.nn.init.normal_(input_module.weight, 1.0, 0.02)
        torch.nn.init.zeros_(input_module.bias)

generator_discriminator_interface_size = 64

device = torch.device('cuda:0')

discriminator_object = Discriminator(3).to(device)
generator_object = Generator(length_of_z_input_vector, 3).to(device)

discriminator_object.apply(initial_weights)
generator_object.apply(initial_weights)

optimizer_for_discriminator = Adam(discriminator_object.parameters(), lr=learning_rate, betas=(beta1_for_adam, beta2_for_adam))
optimizer_for_generator = Adam(generator_object.parameters(), lr=learning_rate, betas=(beta1_for_adam, beta2_for_adam))

loss_for_discriminator = torch.nn.BCELoss()

tensored_data_set = dataset.ImageFolder(root='/workspace', transform=transforms.Compose([transforms.Resize(generator_discriminator_interface_size), transforms.CenterCrop(generator_discriminator_interface_size), transforms.ToTensor(),
# transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
]))

loaded_data_set = DataLoader(tensored_data_set, shuffle=True, num_workers=10)

fixed_noise = torch.randn(generator_discriminator_interface_size, length_of_z_input_vector, 1, 1, device=device)



In [None]:
# train both models
for epoch in range(501):
    for index, data in enumerate(loaded_data_set, 0):
        # update discriminator network
        ## training with real data
        discriminator_object.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label, dtype=real_cpu.dtype, device=device)

        output = discriminator_object(real_cpu).view(-1)
        loss = loss_for_discriminator(output, label)

        loss.backward()
        
        ## training with fake data
        noise = torch.randn(batch_size, length_of_z_input_vector, 1, 1, device=device)
        fake = generator_object(noise)
        label.fill_(fake_label)
        output = discriminator_object(fake.detach())
        
        loss2 = loss_for_discriminator(output, label)
        loss2.backward()

        optimizer_for_discriminator.step()


        # update generator network
        generator_object.zero_grad()
        label.fill_(real_label)
        output = discriminator_object(fake).view(-1) # use previous data
        
        loss3 = loss_for_discriminator(output, label)
        loss3.backward()

        optimizer_for_generator.step()

        if index % 100 == 0:
            vision_utils.save_image(real_cpu, '/data/home/taeho/pytorch_tutorials/samples_epoch%03d.png' % (epoch), normalize=True)
            fake = generator_object(fixed_noise)
            plot.figure(figsize=(8,8))
            plot.axis("off")
            plot.title("Training Images")
            plot.imshow(numpy.transpose(vision_utils.make_grid(fake.detach(), padding=2, normalize=True).cpu(),(1,2,0)))
            vision_utils.save_image(fake.detach(), '/data/home/taeho/pytorch_tutorials/fake_samples_epoch%03d.png' % (epoch), normalize=True)

In [None]:
# save weight models
torch.save(discriminator_object.state_dict(), '/data/home/taeho/pytorch_tutorials/srgan_discriminator_weight')
torch.save(generator_object.state_dict(), '/data/home/taeho/pytorch_tutorials/srgan_generator_weight')