In [1]:
#import sys
#!conda install --yes --prefix {sys.prefix} torchvision

In [2]:
#import sys
#!conda install --yes --prefix {sys.prefix} -c rdkit rdkit

In [1]:
from tensorboard_logger import configure, log_value
import tensorboard
import argparse
import os
import numpy as np
import math
import sys
import glob
#import utils import Logger


import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch

### Part to use my own images

In [2]:
import glob
import random
import os
import numpy as np

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

class ImageDataset(Dataset):
    def __init__(self, folder_path, transforms_=None):
        self.transform = transforms.Compose(transforms_)
        self.files = sorted(glob.glob('%s/*.*' % folder_path))

    def __getitem__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img = self.transform(img)
        return img

    def __len__(self):
        return len(self.files)

In [3]:
img_shape = (3, 100, 100)

cuda = True if torch.cuda.is_available() else False
print(cuda)

True


## Option Settings

In [56]:
n_epochs=5000 #number of epochs of training
batch_size=400 #size of the batches
lr=0.0002 #adam: learning rate
b1=0.5  #"adam: decay of first order momentum of gradient")
b2=0.999 #adam: decay of first order momentum of gradient")
n_cpu=8
latent_dim=100
img_size=200
channels=3
n_critic=5
clip_value=0.01
sample_interval=1000
img_shape = (channels, img_size, img_size)
crop_size = 400
print(img_shape)

(3, 200, 200)


### nn.modules

In [57]:

folder_path = "./zinc100k500px"
transforms_ = [ transforms.Resize(img_size),
                transforms.CenterCrop(img_size),
                transforms.ToTensor(),
                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
dataloader = DataLoader(ImageDataset(folder_path, transforms_=transforms_),
                        batch_size=batch_size, shuffle=True, num_workers=8)

In [58]:
#img_shape = (3, 100, 100)

#folder_path = "./molecules100"
#transforms_ = [ transforms.Resize(100),
#                transforms.CenterCrop(100),
#                transforms.ToTensor(),
#                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]
#dataloader = DataLoader(ImageDataset(folder_path, transforms_=transforms_),
#                        batch_size=100, shuffle=True, num_workers=8)

In [59]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(100, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], *img_shape)
        return img

In [60]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
        )

    def forward(self, img):
        img_flat = img.view(img.shape[0], -1)
        validity = self.model(img_flat)
        return validity

In [61]:
# Loss weight for gradient penalty
lambda_gp = 10

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()


In [62]:
if cuda:
    generator.cuda()
    discriminator.cuda()

TO DO: Determine optimal learning rate

# Actual training now

In [63]:
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b2, b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 1).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty



In [64]:
imgs=next(iter(dataloader))

In [65]:
# Sample noise as generator input

imgs
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
print(z)
print(latent_dim)
print(imgs.shape)
print(imgs)

# Generate a batch of images
fake_imgs = generator(z)


tensor([[ 0.6221,  0.9872,  0.2508,  ..., -0.4239,  0.2535,  0.6143],
        [ 0.4697,  1.0298,  0.1345,  ..., -0.1841, -1.0988,  1.6070],
        [ 1.1106, -0.5274, -0.1055,  ...,  0.2113,  0.5366, -0.5801],
        ...,
        [-1.0739,  0.6992, -0.1014,  ..., -0.3990,  1.0972, -0.4691],
        [-0.7951,  0.4914, -0.8391,  ..., -0.4018, -0.5528,  0.2405],
        [ 0.1755, -0.2522,  0.2104,  ...,  1.2836,  0.4065, -0.2441]],
       device='cuda:0')
100
torch.Size([400, 3, 200, 200])
tensor([[[[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          ...,
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000]],

         [[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.00

In [None]:
# old logger
#from utils import Logger
#logger = Logger(model_name='wGANGP', data_name='mol224')
#num_batches = len(dataloader)
n_epochs=500
batches_done = 0
### new logger
#configure("runs/test1",flush_secs=5)

#batch_size=100
#print(latent_dim)
for epoch in range(n_epochs):
    for i, imgs in enumerate(dataloader):
        # Configure input
        real_imgs = Variable(imgs.type(Tensor))
        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], latent_dim))))
        # Generate a batch of images
        fake_imgs = generator(z)
        # Real images
        real_validity = discriminator(real_imgs)
        # Fake images
        fake_validity = discriminator(fake_imgs)
        # Gradient penalty
        gradient_penalty = compute_gradient_penalty(discriminator, real_imgs.data, fake_imgs.data)
        # Adversarial loss
        d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty
        d_loss.backward()
        optimizer_D.step()
        optimizer_G.zero_grad()
        # Train the generator every n_critic steps
        if i % n_critic == 0:
            # -----------------
            #  Train Generator
            # -----------------
            # Generate a batch of images
            fake_imgs = generator(z)
            # Loss measures generator's ability to fool the discriminator
            # Train on fake images
            fake_validity = discriminator(fake_imgs)
            g_loss = -torch.mean(fake_validity)
            g_loss.backward()
            optimizer_G.step()
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
            )

            #logger.log(d_loss, g_loss, epoch, batches_done, num_batches)
            if batches_done % sample_interval/10 ==0:
                log_value('g_loss', g_loss, batches_done)
                log_value('d_loss', d_loss, batches_done)
            
            if batches_done % sample_interval == 0:
                save_image(fake_imgs.data[:25], "molpics500px/%d_b.png" % batches_done, nrow=5, normalize=True)

            batches_done += n_critic

[Epoch 0/500] [Batch 0/250] [D loss: 8.083148] [G loss: 0.003303]
[Epoch 0/500] [Batch 5/250] [D loss: -351.765991] [G loss: -0.149655]
[Epoch 0/500] [Batch 10/250] [D loss: -930.247498] [G loss: -1.183127]
[Epoch 0/500] [Batch 15/250] [D loss: -1678.090942] [G loss: -4.038185]
[Epoch 0/500] [Batch 20/250] [D loss: -2472.244141] [G loss: -10.242505]
[Epoch 0/500] [Batch 25/250] [D loss: -3096.625732] [G loss: -22.473108]
[Epoch 0/500] [Batch 30/250] [D loss: -3221.287109] [G loss: -42.861942]
[Epoch 0/500] [Batch 35/250] [D loss: -2691.095215] [G loss: -70.594467]
[Epoch 0/500] [Batch 40/250] [D loss: -2159.705078] [G loss: -98.534950]
[Epoch 0/500] [Batch 45/250] [D loss: -2086.342285] [G loss: -124.022202]
[Epoch 0/500] [Batch 50/250] [D loss: -2327.064453] [G loss: -142.505280]
[Epoch 0/500] [Batch 55/250] [D loss: -2662.669922] [G loss: -156.052536]
[Epoch 0/500] [Batch 60/250] [D loss: -2936.402344] [G loss: -167.244370]
[Epoch 0/500] [Batch 65/250] [D loss: -3089.248047] [G loss:

[Epoch 2/500] [Batch 65/250] [D loss: 20.854404] [G loss: 383.124207]
[Epoch 2/500] [Batch 70/250] [D loss: 24.007080] [G loss: 461.651398]
[Epoch 2/500] [Batch 75/250] [D loss: 31.401312] [G loss: 530.231934]
[Epoch 2/500] [Batch 80/250] [D loss: 36.679817] [G loss: 594.485779]
[Epoch 2/500] [Batch 85/250] [D loss: 40.748421] [G loss: 652.752930]
[Epoch 2/500] [Batch 90/250] [D loss: 53.014137] [G loss: 694.534973]
[Epoch 2/500] [Batch 95/250] [D loss: 52.534916] [G loss: 739.802002]
[Epoch 2/500] [Batch 100/250] [D loss: 59.745953] [G loss: 767.630005]
[Epoch 2/500] [Batch 105/250] [D loss: 63.451279] [G loss: 788.779053]
[Epoch 2/500] [Batch 110/250] [D loss: 67.600464] [G loss: 799.071472]
[Epoch 2/500] [Batch 115/250] [D loss: 74.160706] [G loss: 796.354126]
[Epoch 2/500] [Batch 120/250] [D loss: 74.420433] [G loss: 789.849731]
[Epoch 2/500] [Batch 125/250] [D loss: 73.220802] [G loss: 775.397095]
[Epoch 2/500] [Batch 130/250] [D loss: 68.945419] [G loss: 755.137390]
[Epoch 2/500]

[Epoch 4/500] [Batch 145/250] [D loss: 19.082449] [G loss: 419.848694]
[Epoch 4/500] [Batch 150/250] [D loss: 18.133350] [G loss: 364.137024]
[Epoch 4/500] [Batch 155/250] [D loss: 14.364824] [G loss: 309.080353]
[Epoch 4/500] [Batch 160/250] [D loss: 12.306402] [G loss: 249.864532]
[Epoch 4/500] [Batch 165/250] [D loss: 10.522099] [G loss: 188.781174]
[Epoch 4/500] [Batch 170/250] [D loss: 10.077324] [G loss: 125.855629]
[Epoch 4/500] [Batch 175/250] [D loss: 8.115013] [G loss: 62.740025]
[Epoch 4/500] [Batch 180/250] [D loss: 6.805653] [G loss: -2.477991]
[Epoch 4/500] [Batch 185/250] [D loss: 2.669103] [G loss: -70.089981]
[Epoch 4/500] [Batch 190/250] [D loss: -2.311835] [G loss: -143.936127]
[Epoch 4/500] [Batch 195/250] [D loss: -8.295969] [G loss: -220.384109]
[Epoch 4/500] [Batch 200/250] [D loss: -12.813758] [G loss: -300.992462]
[Epoch 4/500] [Batch 205/250] [D loss: -18.558140] [G loss: -382.746918]
[Epoch 4/500] [Batch 210/250] [D loss: -20.028784] [G loss: -470.269897]
[Ep

[Epoch 6/500] [Batch 220/250] [D loss: 21.910301] [G loss: -929.359497]
[Epoch 6/500] [Batch 225/250] [D loss: 15.576054] [G loss: -871.729065]
[Epoch 6/500] [Batch 230/250] [D loss: 12.649599] [G loss: -812.616943]
[Epoch 6/500] [Batch 235/250] [D loss: 8.584141] [G loss: -748.364990]
[Epoch 6/500] [Batch 240/250] [D loss: 2.105625] [G loss: -677.723083]
[Epoch 6/500] [Batch 245/250] [D loss: 0.819658] [G loss: -608.660156]
[Epoch 7/500] [Batch 0/250] [D loss: -1.066672] [G loss: -535.843750]
[Epoch 7/500] [Batch 5/250] [D loss: -3.028653] [G loss: -459.507751]
[Epoch 7/500] [Batch 10/250] [D loss: -5.558800] [G loss: -380.282013]
[Epoch 7/500] [Batch 15/250] [D loss: -5.809355] [G loss: -301.790558]
[Epoch 7/500] [Batch 20/250] [D loss: -5.896058] [G loss: -222.479584]
[Epoch 7/500] [Batch 25/250] [D loss: -5.654815] [G loss: -143.147659]
[Epoch 7/500] [Batch 30/250] [D loss: -4.727274] [G loss: -64.659889]
[Epoch 7/500] [Batch 35/250] [D loss: -3.946465] [G loss: 13.474526]
[Epoch 7

[Epoch 9/500] [Batch 50/250] [D loss: 38.139580] [G loss: 909.389954]
[Epoch 9/500] [Batch 55/250] [D loss: 40.542965] [G loss: 919.643982]
[Epoch 9/500] [Batch 60/250] [D loss: 39.539295] [G loss: 927.985779]
[Epoch 9/500] [Batch 65/250] [D loss: 40.673172] [G loss: 930.260437]
[Epoch 9/500] [Batch 70/250] [D loss: 40.828182] [G loss: 928.670654]
[Epoch 9/500] [Batch 75/250] [D loss: 39.641285] [G loss: 923.212097]
[Epoch 9/500] [Batch 80/250] [D loss: 41.834957] [G loss: 909.520386]
[Epoch 9/500] [Batch 85/250] [D loss: 39.992886] [G loss: 895.057373]
[Epoch 9/500] [Batch 90/250] [D loss: 35.978111] [G loss: 878.481079]
[Epoch 9/500] [Batch 95/250] [D loss: 34.466225] [G loss: 855.026001]
[Epoch 9/500] [Batch 100/250] [D loss: 33.395618] [G loss: 827.003418]
[Epoch 9/500] [Batch 105/250] [D loss: 30.880005] [G loss: 797.131226]
[Epoch 9/500] [Batch 110/250] [D loss: 30.221182] [G loss: 762.091309]
[Epoch 9/500] [Batch 115/250] [D loss: 24.913099] [G loss: 728.685547]
[Epoch 9/500] [B

[Epoch 11/500] [Batch 125/250] [D loss: 16.300587] [G loss: 618.929260]
[Epoch 11/500] [Batch 130/250] [D loss: 15.530472] [G loss: 593.288513]
[Epoch 11/500] [Batch 135/250] [D loss: 14.072131] [G loss: 566.715637]
[Epoch 11/500] [Batch 140/250] [D loss: 12.888462] [G loss: 538.655823]
[Epoch 11/500] [Batch 145/250] [D loss: 12.648497] [G loss: 508.426636]
[Epoch 11/500] [Batch 150/250] [D loss: 11.198539] [G loss: 479.230255]
[Epoch 11/500] [Batch 155/250] [D loss: 9.852489] [G loss: 449.630981]
[Epoch 11/500] [Batch 160/250] [D loss: 10.128193] [G loss: 418.182770]
[Epoch 11/500] [Batch 165/250] [D loss: 9.752403] [G loss: 387.595367]
[Epoch 11/500] [Batch 170/250] [D loss: 9.362565] [G loss: 357.155884]
[Epoch 11/500] [Batch 175/250] [D loss: 9.047735] [G loss: 326.827484]
[Epoch 11/500] [Batch 180/250] [D loss: 8.467027] [G loss: 297.619049]
[Epoch 11/500] [Batch 185/250] [D loss: 7.445475] [G loss: 269.324768]
[Epoch 11/500] [Batch 190/250] [D loss: 8.422608] [G loss: 239.526016]

[Epoch 13/500] [Batch 200/250] [D loss: 7.757856] [G loss: 551.642212]
[Epoch 13/500] [Batch 205/250] [D loss: 7.277392] [G loss: 527.211548]
[Epoch 13/500] [Batch 210/250] [D loss: 5.930305] [G loss: 502.242218]
[Epoch 13/500] [Batch 215/250] [D loss: 6.473877] [G loss: 474.390808]
[Epoch 13/500] [Batch 220/250] [D loss: 6.131089] [G loss: 446.053070]
[Epoch 13/500] [Batch 225/250] [D loss: 4.132030] [G loss: 418.388184]
[Epoch 13/500] [Batch 230/250] [D loss: 4.006530] [G loss: 388.234100]
[Epoch 13/500] [Batch 235/250] [D loss: 3.130096] [G loss: 358.050079]
[Epoch 13/500] [Batch 240/250] [D loss: 3.274602] [G loss: 326.382416]
[Epoch 13/500] [Batch 245/250] [D loss: 3.286685] [G loss: 294.445251]
[Epoch 14/500] [Batch 0/250] [D loss: 2.172831] [G loss: 263.526978]
[Epoch 14/500] [Batch 5/250] [D loss: 2.164679] [G loss: 231.603134]
[Epoch 14/500] [Batch 10/250] [D loss: 2.838851] [G loss: 199.146423]
[Epoch 14/500] [Batch 15/250] [D loss: 3.041368] [G loss: 167.657501]
[Epoch 14/50

[Epoch 16/500] [Batch 25/250] [D loss: 8.425225] [G loss: 108.629829]
[Epoch 16/500] [Batch 30/250] [D loss: 8.316772] [G loss: 72.145523]
[Epoch 16/500] [Batch 35/250] [D loss: 8.545118] [G loss: 34.557957]
[Epoch 16/500] [Batch 40/250] [D loss: 8.027904] [G loss: -3.426916]
[Epoch 16/500] [Batch 45/250] [D loss: 7.984835] [G loss: -42.885723]
[Epoch 16/500] [Batch 50/250] [D loss: 7.437178] [G loss: -83.299530]
[Epoch 16/500] [Batch 55/250] [D loss: 6.485192] [G loss: -124.869156]
[Epoch 16/500] [Batch 60/250] [D loss: 5.559310] [G loss: -168.069977]
[Epoch 16/500] [Batch 65/250] [D loss: 4.402359] [G loss: -212.516464]
[Epoch 16/500] [Batch 70/250] [D loss: 3.473150] [G loss: -258.524475]
[Epoch 16/500] [Batch 75/250] [D loss: 2.052891] [G loss: -305.108948]
[Epoch 16/500] [Batch 80/250] [D loss: 1.468907] [G loss: -353.466522]
[Epoch 16/500] [Batch 85/250] [D loss: -0.607627] [G loss: -400.764221]
[Epoch 16/500] [Batch 90/250] [D loss: -0.776999] [G loss: -450.244141]
[Epoch 16/500

[Epoch 18/500] [Batch 100/250] [D loss: 8.402599] [G loss: -654.252747]
[Epoch 18/500] [Batch 105/250] [D loss: 12.070439] [G loss: -689.552002]
[Epoch 18/500] [Batch 110/250] [D loss: 12.855333] [G loss: -720.071106]
[Epoch 18/500] [Batch 115/250] [D loss: 14.565430] [G loss: -746.852478]
[Epoch 18/500] [Batch 120/250] [D loss: 17.935595] [G loss: -772.189758]
[Epoch 18/500] [Batch 125/250] [D loss: 18.931620] [G loss: -791.007568]
[Epoch 18/500] [Batch 130/250] [D loss: 21.642246] [G loss: -807.355530]
[Epoch 18/500] [Batch 135/250] [D loss: 21.715052] [G loss: -817.146057]
[Epoch 18/500] [Batch 140/250] [D loss: 23.222876] [G loss: -824.848877]
[Epoch 18/500] [Batch 145/250] [D loss: 24.063042] [G loss: -827.468872]
[Epoch 18/500] [Batch 150/250] [D loss: 24.382217] [G loss: -826.062378]
[Epoch 18/500] [Batch 155/250] [D loss: 22.885365] [G loss: -818.577637]
[Epoch 18/500] [Batch 160/250] [D loss: 23.501842] [G loss: -809.846619]
[Epoch 18/500] [Batch 165/250] [D loss: 22.982716] [

[Epoch 20/500] [Batch 175/250] [D loss: 11.265419] [G loss: -637.566528]
[Epoch 20/500] [Batch 180/250] [D loss: 11.479801] [G loss: -620.890198]
[Epoch 20/500] [Batch 185/250] [D loss: 10.498545] [G loss: -600.880615]
[Epoch 20/500] [Batch 190/250] [D loss: 8.319132] [G loss: -577.946411]
[Epoch 20/500] [Batch 195/250] [D loss: 7.801656] [G loss: -555.111877]
[Epoch 20/500] [Batch 200/250] [D loss: 6.669035] [G loss: -529.820923]
[Epoch 20/500] [Batch 205/250] [D loss: 5.184715] [G loss: -503.159241]
[Epoch 20/500] [Batch 210/250] [D loss: 3.982026] [G loss: -475.661865]
[Epoch 20/500] [Batch 215/250] [D loss: 3.055005] [G loss: -447.409821]
[Epoch 20/500] [Batch 220/250] [D loss: 1.477794] [G loss: -417.716675]
[Epoch 20/500] [Batch 225/250] [D loss: 1.302143] [G loss: -388.961914]
[Epoch 20/500] [Batch 230/250] [D loss: 0.357796] [G loss: -358.832642]
[Epoch 20/500] [Batch 235/250] [D loss: -0.284384] [G loss: -328.899292]
[Epoch 20/500] [Batch 240/250] [D loss: -0.600223] [G loss: 

[Epoch 23/500] [Batch 0/250] [D loss: 18.696548] [G loss: 407.163544]
[Epoch 23/500] [Batch 5/250] [D loss: 20.365963] [G loss: 437.375641]
[Epoch 23/500] [Batch 10/250] [D loss: 22.217279] [G loss: 463.437225]
[Epoch 23/500] [Batch 15/250] [D loss: 22.630314] [G loss: 486.535614]
[Epoch 23/500] [Batch 20/250] [D loss: 23.449387] [G loss: 504.901672]
[Epoch 23/500] [Batch 25/250] [D loss: 24.403479] [G loss: 518.580811]
[Epoch 23/500] [Batch 30/250] [D loss: 25.785126] [G loss: 527.975220]
[Epoch 23/500] [Batch 35/250] [D loss: 26.231459] [G loss: 532.945923]
[Epoch 23/500] [Batch 40/250] [D loss: 25.404852] [G loss: 534.744019]
[Epoch 23/500] [Batch 45/250] [D loss: 25.101208] [G loss: 531.361877]
[Epoch 23/500] [Batch 50/250] [D loss: 24.254957] [G loss: 524.408020]
[Epoch 23/500] [Batch 55/250] [D loss: 24.306072] [G loss: 512.293762]
[Epoch 23/500] [Batch 60/250] [D loss: 22.241339] [G loss: 498.259338]
[Epoch 23/500] [Batch 65/250] [D loss: 21.594561] [G loss: 479.189636]
[Epoch 2

[Epoch 25/500] [Batch 75/250] [D loss: 11.643518] [G loss: -51.917294]
[Epoch 25/500] [Batch 80/250] [D loss: 10.512949] [G loss: -30.802246]
[Epoch 25/500] [Batch 85/250] [D loss: 8.284092] [G loss: -10.403492]
[Epoch 25/500] [Batch 90/250] [D loss: 8.175785] [G loss: 9.316833]
[Epoch 25/500] [Batch 95/250] [D loss: 7.739910] [G loss: 28.345058]
[Epoch 25/500] [Batch 100/250] [D loss: 7.325931] [G loss: 46.407200]
[Epoch 25/500] [Batch 105/250] [D loss: 6.793611] [G loss: 63.648388]
[Epoch 25/500] [Batch 110/250] [D loss: 6.378135] [G loss: 79.563286]
[Epoch 25/500] [Batch 115/250] [D loss: 5.953534] [G loss: 94.229858]
[Epoch 25/500] [Batch 120/250] [D loss: 5.420074] [G loss: 107.510834]
[Epoch 25/500] [Batch 125/250] [D loss: 5.112176] [G loss: 119.528992]
[Epoch 25/500] [Batch 130/250] [D loss: 4.454015] [G loss: 130.299438]
[Epoch 25/500] [Batch 135/250] [D loss: 4.198717] [G loss: 139.562256]
[Epoch 25/500] [Batch 140/250] [D loss: 3.838957] [G loss: 147.487427]
[Epoch 25/500] [

In [22]:

#tensorboard
!python -m tensorboard.main --logdir runs
#tensorboard --help
#tensorboard --logdir=runs

#logger.d_loss.data.cpu().numpy()
#logger.display_status(epoch, n_epochs, batches_done, num_batches, d_loss, g_loss, d_pred_real, d_pred_fake)

TensorBoard 1.10.0 at http://minilearn:6006 (Press CTRL+C to quit)
[33mW0913 15:35:36.738302 Thread-1 application.py:276] path /[[_dataImageSrc]] not found, sending 404
[0mW0913 15:35:36.738301 139841509275392 application.py:276] path /[[_dataImageSrc]] not found, sending 404
[33mW0913 15:35:36.786560 Thread-1 application.py:276] path /[[_imageURL]] not found, sending 404
[0mW0913 15:35:36.786560 139841509275392 application.py:276] path /[[_imageURL]] not found, sending 404
^C


### Save the models

In [None]:
os.makedirs("/home/jgmeyer2/vangan/gans/models",exist_ok=True)
PATH = "/home/jgmeyer2/vangan/gans/models/g5k.model"
modelid="5k"


state_g = {
    'epoch': epoch,
    'state_dict': generator.state_dict(),
    'optimizer': optimizer_G.state_dict()
    }
torch.save(state_g, PATH+"g"+modelid+".model")

state_d = {
    'epoch': epoch,
    'state_dict': discriminator.state_dict(),
    'optimizer': optimizer_D.state_dict()
    }
torch.save(state_d, PATH+"d"+modelid+".model")
print("saved models @")
print(epoch)


#def save_model(net, optim, ckpt_fname):
#    state_dict = net.module.state_dict()
#    for key in state_dict.keys():
#        state_dict[key] = state_dict[key].cpu()
#        torch.save({
#            'epoch': epoch,                                                                                                                                                                                     
#            'state_dict': state_dict,                                                                                                                                                                                
#            'optimizer': optim},                                                                                                                                                                                     
#            ckpt_fname)

# Load models

In [None]:



model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

state = torch.load(filepath)

