# Static Code

In [1]:
from tensorboard_logger import configure, log_value
import tensorboard
import argparse
import os
import numpy as np
import math
import sys
import glob
import time
#import utils import Logger


import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch
import glob
import random
import os
import numpy as np

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class ImageDataset(Dataset):
    def __init__(self, folder_path, transforms_=None):
        self.transform = transforms.Compose(transforms_)
        self.files = sorted(glob.glob('%s/*.*' % folder_path))

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img = self.transform(img)
        return img

    def __len__(self):
        return len(self.files)
    
    
cuda = True if torch.cuda.is_available() else False
print(cuda)

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(100, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img
    
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity
    
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty


True


# Option settings 

In [2]:
n_epochs=100 #number of epochs of training
batch_size=5 #size of the batches
lr=0.0002 #adam: learning rate
b1=0.5  #"adam: decay of first order momentum of gradient")
b2=0.999 #adam: decay of first order momentum of gradient")
n_cpu=8
latent_dim=100
img_size=250
channels=3
n_critic=5
clip_value=0.01
sample_interval=1000
img_shape = (channels, img_size, img_size)
crop_size = 400
print(img_shape)

(3, 250, 250)


# Input data 

In [3]:


folder_path = "./zinc100k500pxRe250pxJPG/"
transforms_ = [ transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset(folder_path, transforms_=transforms_),
                        batch_size=batch_size, shuffle=True, num_workers=8)

# Initialize Models and Load Settings

In [4]:
# Loss weight for gradient penalty
lambda_gp = 10


# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

# If cuda, use cuda
if cuda:
    generator.cuda()
    discriminator.cuda()
    
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b2, b2))
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


# Load Previous model

In [5]:
state_g = torch.load("/home/jgmeyer2/vangan/gans/models/g250px_jpegs.model")
state_d = torch.load("/home/jgmeyer2/vangan/gans/models/d250px_jpegs.model")


generator.load_state_dict(state_g['state_dict'])
optimizer_G.load_state_dict(state_g['optimizer'])

discriminator.load_state_dict(state_d['state_dict'])
optimizer_D.load_state_dict(state_d['optimizer'])




# INitialize logger

In [5]:
configure("runs/test5",flush_secs=5)

# Training loop

In [6]:
# old logger
#from utils import Logger

#logger = Logger(model_name='wgangp1t1', data_name='200px')
#num_batches = len(dataloader)
n_epochs=1
batches_done = 0
### new logger
start = time.time()
batchtimes=[float()]

#batch_size=100
#print(latent_dim)

print(batch_size)
for epoch in range(n_epochs):
    for i, imgs in enumerate(dataloader):
        ## monitor time
       
        # Configure input
        real_imgs = Variable(imgs.type(Tensor))
        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
        # Generate a batch of images
        fake_imgs = generator(z)
        # Real images
        real_validity = discriminator(real_imgs)
        # Fake images
        fake_validity = discriminator(fake_imgs)
        # Gradient penalty
        gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
        # Adversarial loss
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
        d_loss.backward()
        optimizer_D.step()
        optimizer_G.zero_grad()
        # Train the generator every n_critic steps
        if i % n_critic == 0:

            # -----------------
            #  Train Generator
            # -----------------
            # Generate a batch of images
            fake_imgs = generator(z)
            # Loss measures generator's ability to fool the discriminator
            # Train on fake images
            fake_validity = discriminator(fake_imgs)
            g_loss = -torch.mean(fake_validity)
            g_loss.backward()
            optimizer_G.step()



            #logger.log(d_loss, g_loss, epoch, batches_done, num_batches)
            if batches_done % sample_interval/10 ==0:
                log_value('g_loss', g_loss, batches_done)
                log_value('d_loss', d_loss, batches_done)
            
            if batches_done % sample_interval == 0:
                save_image(fake_imgs.data[:25], "molpics500px/%d_d.png" % batches_done, nrow=5, normalize=True)

            end = time.time()
            batchtime = end - start

            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [5batch time: %f]"
                % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), batchtime)
            )
            batchtimes.append(batchtime*batch_size/5)
            start = time.time()
            
            batches_done += n_critic
print("average time per picture = " +str(np.mean(batchtimes)))
print("minutes per 100,000 pictures = "+str((np.mean(batchtimes)*100000)/60))

5
[Epoch 0/1] [Batch 0/2000] [D loss: 8.087671] [G loss: 0.065366] [5batch time: 0.931179]
[Epoch 0/1] [Batch 5/2000] [D loss: -567.965088] [G loss: -0.170554] [5batch time: 0.898888]
[Epoch 0/1] [Batch 10/2000] [D loss: -1505.326172] [G loss: -1.529939] [5batch time: 0.907479]
[Epoch 0/1] [Batch 15/2000] [D loss: -2695.595703] [G loss: -5.341569] [5batch time: 0.882582]
[Epoch 0/1] [Batch 20/2000] [D loss: -3933.389893] [G loss: -13.375514] [5batch time: 0.905233]
[Epoch 0/1] [Batch 25/2000] [D loss: -4838.215820] [G loss: -31.952581] [5batch time: 0.907908]
[Epoch 0/1] [Batch 30/2000] [D loss: -4806.338867] [G loss: -53.683540] [5batch time: 0.897358]
[Epoch 0/1] [Batch 35/2000] [D loss: -3855.025391] [G loss: -93.181923] [5batch time: 0.898592]
[Epoch 0/1] [Batch 40/2000] [D loss: -3222.952148] [G loss: -126.084206] [5batch time: 0.910570]
[Epoch 0/1] [Batch 45/2000] [D loss: -3320.944336] [G loss: -140.084290] [5batch time: 0.908895]
[Epoch 0/1] [Batch 50/2000] [D loss: -3752.23339

[Epoch 0/1] [Batch 420/2000] [D loss: -347.932861] [G loss: -3203.039795] [5batch time: 0.896833]
[Epoch 0/1] [Batch 425/2000] [D loss: -612.530945] [G loss: -2956.669189] [5batch time: 0.896049]
[Epoch 0/1] [Batch 430/2000] [D loss: -102.558044] [G loss: -3495.974365] [5batch time: 0.892218]
[Epoch 0/1] [Batch 435/2000] [D loss: -461.666077] [G loss: -3138.668945] [5batch time: 0.874080]
[Epoch 0/1] [Batch 440/2000] [D loss: -312.091492] [G loss: -3286.340332] [5batch time: 0.898537]
[Epoch 0/1] [Batch 445/2000] [D loss: -599.583862] [G loss: -2986.123047] [5batch time: 0.901865]
[Epoch 0/1] [Batch 450/2000] [D loss: 425.260193] [G loss: -4002.648438] [5batch time: 0.907014]
[Epoch 0/1] [Batch 455/2000] [D loss: 432.650757] [G loss: -3973.274658] [5batch time: 0.909390]
[Epoch 0/1] [Batch 460/2000] [D loss: -605.986938] [G loss: -2915.540039] [5batch time: 0.909810]
[Epoch 0/1] [Batch 465/2000] [D loss: -368.296997] [G loss: -3119.831543] [5batch time: 0.892771]
[Epoch 0/1] [Batch 470

[Epoch 0/1] [Batch 850/2000] [D loss: -50.385334] [G loss: -718.950806] [5batch time: 0.907542]
[Epoch 0/1] [Batch 855/2000] [D loss: -103.970139] [G loss: -712.428040] [5batch time: 0.887082]
[Epoch 0/1] [Batch 860/2000] [D loss: -20.397224] [G loss: -842.537048] [5batch time: 0.893337]
[Epoch 0/1] [Batch 865/2000] [D loss: -37.454174] [G loss: -872.766785] [5batch time: 0.893759]
[Epoch 0/1] [Batch 870/2000] [D loss: -70.321899] [G loss: -885.544434] [5batch time: 0.902352]
[Epoch 0/1] [Batch 875/2000] [D loss: -41.195652] [G loss: -960.776550] [5batch time: 0.908836]
[Epoch 0/1] [Batch 880/2000] [D loss: -131.632477] [G loss: -916.584106] [5batch time: 0.889803]
[Epoch 0/1] [Batch 885/2000] [D loss: -94.097504] [G loss: -1003.889832] [5batch time: 0.904941]
[Epoch 0/1] [Batch 890/2000] [D loss: -21.544722] [G loss: -1121.779297] [5batch time: 0.895882]
[Epoch 0/1] [Batch 895/2000] [D loss: -13.370022] [G loss: -1176.402710] [5batch time: 0.902519]
[Epoch 0/1] [Batch 900/2000] [D los

[Epoch 0/1] [Batch 1280/2000] [D loss: -9.199570] [G loss: -153.148483] [5batch time: 0.904542]
[Epoch 0/1] [Batch 1285/2000] [D loss: -15.171153] [G loss: -195.885452] [5batch time: 0.902480]
[Epoch 0/1] [Batch 1290/2000] [D loss: -16.968481] [G loss: -243.492844] [5batch time: 0.880398]
[Epoch 0/1] [Batch 1295/2000] [D loss: -39.509613] [G loss: -271.964020] [5batch time: 0.892226]
[Epoch 0/1] [Batch 1300/2000] [D loss: -7.695563] [G loss: -358.232880] [5batch time: 0.900847]
[Epoch 0/1] [Batch 1305/2000] [D loss: -12.231634] [G loss: -406.617004] [5batch time: 0.929706]
[Epoch 0/1] [Batch 1310/2000] [D loss: -15.365938] [G loss: -459.221039] [5batch time: 0.878980]
[Epoch 0/1] [Batch 1315/2000] [D loss: -9.917288] [G loss: -517.554504] [5batch time: 0.904765]
[Epoch 0/1] [Batch 1320/2000] [D loss: -21.670959] [G loss: -557.828674] [5batch time: 0.883117]
[Epoch 0/1] [Batch 1325/2000] [D loss: -15.388021] [G loss: -614.793152] [5batch time: 0.906604]
[Epoch 0/1] [Batch 1330/2000] [D 

[Epoch 0/1] [Batch 1710/2000] [D loss: 19.121912] [G loss: 232.357529] [5batch time: 0.898061]
[Epoch 0/1] [Batch 1715/2000] [D loss: 6.124456] [G loss: 259.982849] [5batch time: 0.872641]
[Epoch 0/1] [Batch 1720/2000] [D loss: 1.529225] [G loss: 280.949707] [5batch time: 0.903325]
[Epoch 0/1] [Batch 1725/2000] [D loss: 0.978690] [G loss: 297.282379] [5batch time: 0.915385]
[Epoch 0/1] [Batch 1730/2000] [D loss: -0.912380] [G loss: 312.043549] [5batch time: 0.904616]
[Epoch 0/1] [Batch 1735/2000] [D loss: -0.413338] [G loss: 325.781799] [5batch time: 0.906282]
[Epoch 0/1] [Batch 1740/2000] [D loss: 12.197227] [G loss: 324.784760] [5batch time: 0.900841]
[Epoch 0/1] [Batch 1745/2000] [D loss: 2.199239] [G loss: 344.365631] [5batch time: 0.878794]
[Epoch 0/1] [Batch 1750/2000] [D loss: 9.074334] [G loss: 345.891785] [5batch time: 0.911001]
[Epoch 0/1] [Batch 1755/2000] [D loss: -2.444327] [G loss: 363.500916] [5batch time: 0.899207]
[Epoch 0/1] [Batch 1760/2000] [D loss: 11.595289] [G lo

## Try pre-resized images

In [None]:
from PIL import Image
import os, sys

path = "./zinc100k500px/"
dirs = os.listdir( path )
dirs = dirs[0:10000]
#print(dirs)[1]
newpath = "./zinc100k500pxRe250px/"
final_size = 250

def resize(final_size, newpath):
    for item in dirs:
        if os.path.isfile(path+item):
            im = Image.open(path+item)
            f, e = os.path.splitext(item)
            
            basename= (item.split("/")[-1]).split(".")[0]
            
            imResize = im.resize((final_size,final_size), Image.ANTIALIAS)
            new_im = Image.new("RGB", (final_size, final_size))
            new_im.paste(imResize)
            #new_im.paste(im, ((final_size-new_image_size[0])//2, (final_size-new_image_size[1])//2))
            #new_im.save(f + 'resized.jpg', 'JPEG', quality=90)

            new_im.save(newpath + basename + 'resized.png', 'PNG', quality=95)

resize(250,newpath)

In [None]:
from PIL import Image
import os, sys

path = "./zinc100k500pxRe250px/"

dirs = os.listdir( path )
dirs = dirs[0:10000]
#print(dirs)[1]
newpath = "./zinc100k500pxRe250pxJPG/"
os.makedirs(newpath, exist_ok=True)

def png2jpg(path, newpath):
    for item in dirs:
        if os.path.isfile(path+item):
            im = Image.open(path+item)
            f, e = os.path.splitext(item)
            
            basename= (item.split("/")[-1]).split(".")[0]
            new_im = Image.new("RGB", (250, 250))
            new_im.paste(im)
            #new_im.paste(im, ((final_size-new_image_size[0])//2, (final_size-new_image_size[1])//2))
            #new_im.save(f + 'resized.jpg', 'JPEG', quality=90)

            new_im.save(newpath + basename + 'resized.jpeg', 'JPEG', quality=95)

png2jpg(path, newpath)

# Train one epoch for testing speeds 

In [11]:


def train(dataloader, batch_size, n_epochs):
    start = time.time()
    batches_done=0
    batchtimes=[float()]
    for epoch in range(n_epochs):
        for i, imgs in enumerate(dataloader):
            ## monitor time

            # Configure input
            real_imgs = Variable(imgs.type(Tensor))
            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()
            # Sample noise as generator input
            z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
            # Generate a batch of images
            fake_imgs = generator(z)
            # Real images
            real_validity = discriminator(real_imgs)
            # Fake images
            fake_validity = discriminator(fake_imgs)
            # Gradient penalty
            gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
            # Adversarial loss
            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
            d_loss.backward()
            optimizer_D.step()
            optimizer_G.zero_grad()
            # Train the generator every n_critic steps
            if i % n_critic == 0:

                # -----------------
                #  Train Generator
                # -----------------
                # Generate a batch of images
                fake_imgs = generator(z)
                # Loss measures generator's ability to fool the discriminator
                # Train on fake images
                fake_validity = discriminator(fake_imgs)
                g_loss = -torch.mean(fake_validity)
                g_loss.backward()
                optimizer_G.step()



                #logger.log(d_loss, g_loss, epoch, batches_done, num_batches)
                if batches_done % sample_interval/10 ==0:
                    log_value('g_loss', g_loss, batches_done)
                    log_value('d_loss', d_loss, batches_done)

                if batches_done % sample_interval/100 == 0:
                    save_image(fake_imgs.data[:25], "molpics500px/%d_d.png" % batches_done, nrow=5, normalize=True)

                end = time.time()
                batchtime = (end - start)/5

                print(
                    "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [batch time: %f]"
                    % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), batchtime)
                )
                batchtimes.append(batchtime/batch_size)
                start = time.time()

                batches_done += n_critic
    print("average time per picture = " +str(np.mean(batchtimes)))
    print("minutes per 100,000 pictures = "+str((np.mean(batchtimes)*100000)/60))

In [12]:
folder_path = "./zinc100k500pxRe250pxJPG/"
batch_size = 60
transforms_ = [ transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset(folder_path, transforms_=transforms_),
                        batch_size=batch_size, shuffle=True, num_workers=8)
train(dataloader, batch_size, 1)

[Epoch 0/1] [Batch 0/167] [D loss: 15.302143] [G loss: -908.854858] [batch time: 0.247740]
[Epoch 0/1] [Batch 5/167] [D loss: 17.402771] [G loss: -921.105347] [batch time: 0.189697]
[Epoch 0/1] [Batch 10/167] [D loss: 21.201168] [G loss: -931.035400] [batch time: 0.189470]
[Epoch 0/1] [Batch 15/167] [D loss: 16.733734] [G loss: -929.514221] [batch time: 0.193110]
[Epoch 0/1] [Batch 20/167] [D loss: 15.244930] [G loss: -927.200256] [batch time: 0.190559]
[Epoch 0/1] [Batch 25/167] [D loss: 12.640652] [G loss: -918.174927] [batch time: 0.189335]
[Epoch 0/1] [Batch 30/167] [D loss: 14.900623] [G loss: -912.032715] [batch time: 0.188965]
[Epoch 0/1] [Batch 35/167] [D loss: 18.028090] [G loss: -904.062317] [batch time: 0.188347]
[Epoch 0/1] [Batch 40/167] [D loss: 18.058678] [G loss: -889.497314] [batch time: 0.187090]
[Epoch 0/1] [Batch 45/167] [D loss: 13.704739] [G loss: -868.281372] [batch time: 0.186964]
[Epoch 0/1] [Batch 50/167] [D loss: 10.296509] [G loss: -844.734985] [batch time: 

# Save models

In [None]:
os.makedirs("/home/jgmeyer2/vangan/gans/models",exist_ok=True)
PATH = "/home/jgmeyer2/vangan/gans/models/"
modelid="250px_jpegs"


state_g = {
    'epoch': epoch,
    'state_dict': generator.state_dict(),
    'optimizer': optimizer_G.state_dict()
    }
torch.save(state_g, PATH+"g"+modelid+".model")

state_d = {
    'epoch': epoch,
    'state_dict': discriminator.state_dict(),
    'optimizer': optimizer_D.state_dict()
    }
torch.save(state_d, PATH+"d"+modelid+".model")
print("saved models @")
print(epoch)


#def save_model(net, optim, ckpt_fname):
#    state_dict = net.module.state_dict()
#    for key in state_dict.keys():
#        state_dict[key] = state_dict[key].cpu()
#        torch.save({
#            'epoch': epoch,                                                                                                                                                                                     
#            'state_dict': state_dict,                                                                                                                                                                                
#            'optimizer': optim},                                                                                                                                                                                     
#            ckpt_fname)