Skip to content

Commit

Permalink
load trained models
Browse files Browse the repository at this point in the history
  • Loading branch information
grohith327 committed Jun 12, 2020
1 parent c83d560 commit 776d8d1
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 5 deletions.
12 changes: 12 additions & 0 deletions simplegan/gan/cgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(
activation="relu",
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
gen_path=None,
disc_path=None,
):

self.image_size = None
Expand Down Expand Up @@ -278,6 +280,13 @@ def __load_model(self):

self.gen_model, self.disc_model = self.generator(), self.discriminator()

if self.config["gen_path"] is not None:
self.gen_model.load_weights(self.config["gen_path"])
print("Generator checkpoint restored")
if self.config["disc_path"] is not None:
self.disc_model.load_weights(self.config["disc_path"])
print("Discriminator checkpoint restored")

def fit(
self,
train_ds=None,
Expand Down Expand Up @@ -432,6 +441,9 @@ def generate_samples(self, n_samples=1, labels_list=None, save_dir=None):
len(labels_list) == n_samples
), "Number of samples does not match length of labels list"

if self.gen_model is None:
self.__load_model()

Z = np.random.uniform(-1, 1, (n_samples, self.noise_dim))
labels_list = np.array(labels_list)
generated_samples = self.gen_model([Z, labels_list]).numpy()
Expand Down
32 changes: 32 additions & 0 deletions simplegan/gan/cyclegan.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def __init__(
disc_channels=[64, 128, 256, 512],
kernel_size=(4, 4),
kernel_initializer=tf.random_normal_initializer(0.0, 0.02),
gen_g_path=None,
gen_f_path=None,
disc_x_path=None,
disc_y_path=None,
):

Pix2Pix.__init__(
Expand All @@ -75,6 +79,8 @@ def __init__(
self.disc_model_x = None
self.disc_model_y = None

self.config = locals()

def load_data(
self,
data_dir=None,
Expand Down Expand Up @@ -280,6 +286,19 @@ def __load_model(self):
self.discriminator(),
)

if self.config["gen_g_path"] is not None:
self.gen_model_g.load_weights(self.config["gen_g_path"])
print("Generator-G checkpoint restored")
if self.config["gen_f_path"] is not None:
self.gen_model_f.load_weights(self.config["gen_f_path"])
print("Generator-F checkpoint restored")
if self.config["disc_x_path"] is not None:
self.disc_model_x.load_weights(self.config["disc_x_path"])
print("Discriminator-X checkpoint restored")
if self.config["disc_y_path"] is not None:
self.disc_model_y.load_weights(self.config["disc_y_path"])
print("Discriminator-Y checkpoint restored")

def _save_samples(self, model, image, count):

assert os.path.exists(self.save_img_dir), "sample directory does not exist"
Expand Down Expand Up @@ -552,8 +571,18 @@ def fit(
assert isinstance(save_model, str), "Not a valid directory"
if save_model[-1] != "/":
self.gen_model_g.save_weights(save_model + "/generator_g_checkpoint")
self.gen_model_f.save_weights(save_model + "/generator_f_checkpoint")
self.disc_model_x.save_weights(
save_model + "/discrimnator_x_checkpoint"
)
self.disc_model_y.save_weights(
save_model + "/discrimnator_y_checkpoint"
)
else:
self.gen_model_g.save_weights(save_model + "generator_g_checkpoint")
self.gen_model_f.save_weights(save_model + "generator_f_checkpoint")
self.disc_model_x.save_weights(save_model + "discrimnator_x_checkpoint")
self.disc_model_y.save_weights(save_model + "discrimnator_y_checkpoint")

def generate_samples(self, test_ds=None, save_dir=None):

Expand All @@ -569,6 +598,9 @@ def generate_samples(self, test_ds=None, save_dir=None):

assert test_ds is not None, "Enter input test dataset"

if self.gen_model_g is None:
self.__load_model()

generated_samples = []
for image in test_ds:
gen_image = self.gen_model_g(image, training=False).numpy()
Expand Down
14 changes: 13 additions & 1 deletion simplegan/gan/dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def __init__(
activation="relu",
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
gen_path=None,
disc_path=None,
):

self.image_size = None
Expand Down Expand Up @@ -109,7 +111,7 @@ def load_data(

else:

train_data = load_custom_data(data_dir, img_size)
train_data = load_custom_data(data_dir, img_shape)

self.image_size = train_data.shape[1:]

Expand Down Expand Up @@ -321,6 +323,13 @@ def __load_model(self):

self.gen_model, self.disc_model = self.generator(), self.discriminator()

if self.config["gen_path"] is not None:
self.gen_model.load_weights(self.config["gen_path"])
print("Generator checkpoint restored")
if self.config["disc_path"] is not None:
self.disc_model.load_weights(self.config["disc_path"])
print("Discriminator checkpoint restored")

def fit(
self,
train_ds=None,
Expand Down Expand Up @@ -466,6 +475,9 @@ def generate_samples(self, n_samples=1, save_dir=None):
returns ``None`` if save_dir is ``not None``, otherwise returns a numpy array with generated samples
"""

if self.gen_model is None:
self.__load_model()

Z = tf.random.normal([n_samples, self.noise_dim])
generated_samples = self.gen_model(Z).numpy()

Expand Down
14 changes: 13 additions & 1 deletion simplegan/gan/infogan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from tensorflow.keras import layers
from ..datasets.load_mnist import load_mnist
from ..datasets.load_cifar10 import load_cifar10
from ..datasets.load_custom_data import load_custom_data
from ..datasets.load_custom_data import load_custom_data_with_labels
from ..losses.minmax_loss import gan_discriminator_loss, gan_generator_loss
from ..losses.infogan_loss import auxillary_loss
import datetime
Expand Down Expand Up @@ -57,6 +57,8 @@ def __init__(
activation="leaky_relu",
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
gen_path=None,
disc_path=None,
):

self.image_size = None
Expand Down Expand Up @@ -304,6 +306,13 @@ def __load_model(self):

self.gen_model, self.disc_model = self.generator(), self.discriminator()

if self.config["gen_path"] is not None:
self.gen_model.load_weights(self.config["gen_path"])
print("Generator checkpoint restored")
if self.config["disc_path"] is not None:
self.disc_model.load_weights(self.config["disc_path"])
print("Discriminator checkpoint restored")

def fit(
self,
train_ds=None,
Expand Down Expand Up @@ -457,6 +466,9 @@ def generate_samples(self, n_samples=1, save_dir=None):
returns ``None`` if save_dir is ``not None``, otherwise returns a numpy array with generated samples
"""

if self.gen_model is None:
self.__load_model()

Z = np.random.randn(n_samples, self.noise_dim)
label_input = tf.keras.utils.to_categorical(
(np.random.randint(0, self.n_classes, n_samples)), self.n_classes
Expand Down
12 changes: 12 additions & 0 deletions simplegan/gan/pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def __init__(
disc_channels=[64, 128, 256, 512],
kernel_size=(4, 4),
kernel_initializer=tf.random_normal_initializer(0.0, 0.02),
gen_path=None,
disc_path=None,
):

self.gen_model = None
Expand Down Expand Up @@ -365,6 +367,13 @@ def __load_model(self):

self.gen_model, self.disc_model = self.generator(), self.discriminator()

if self.config["gen_path"] is not None:
self.gen_model.load_weights(self.config["gen_path"])
print("Generator checkpoint restored")
if self.config["disc_path"] is not None:
self.disc_model.load_weights(self.config["disc_path"])
print("Discriminator checkpoint restored")

def _save_samples(self, model, ex_input, ex_target, count):

assert os.path.exists(self.save_img_dir), "sample directory does not exist"
Expand Down Expand Up @@ -582,6 +591,9 @@ def generate_samples(self, test_ds=None, save_dir=None):

assert test_ds is not None, "Enter input test dataset"

if self.gen_model is None:
self.__load_model()

generated_samples = []
for image in test_ds:
gen_image = self.gen_model(image, training=False).numpy()
Expand Down
16 changes: 13 additions & 3 deletions simplegan/gan/sagan.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,13 @@ class SAGAN:
noise_dim (int, optional): represents the dimension of the prior to sample values. Defaults to ``128``
"""

def __init__(
self, noise_dim=128,
):
def __init__(self, noise_dim=128, gen_path=None, disc_path=None):

self.image_size = None
self.noise_dim = noise_dim
self.n_classes = None
self.gen_model = None
self.disc_model = None
self.config = locals()

def load_data(
Expand Down Expand Up @@ -211,6 +211,13 @@ def __load_model(self):
Discriminator(self.n_classes),
)

if self.config["gen_path"] is not None:
self.gen_model.load_weights(self.config["gen_path"])
print("Generator checkpoint restored")
if self.config["disc_path"] is not None:
self.disc_model.load_weights(self.config["disc_path"])
print("Discriminator checkpoint restored")

@tf.function
def train_step(self, images, labels):

Expand Down Expand Up @@ -381,6 +388,9 @@ def generate_samples(self, n_samples=1, labels_list=None, save_dir=None):
len(labels_list) == n_samples
), "Number of samples does not match length of labels list"

if self.gen_model is None:
self.__load_model()

Z = np.random.uniform(-1, 1, (n_samples, self.noise_dim))
labels_list = np.array(labels_list)
generated_samples = self.gen_model([Z, labels_list]).numpy()
Expand Down
12 changes: 12 additions & 0 deletions simplegan/gan/vanilla_gan.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(
activation="relu",
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
gen_path=None,
disc_path=None,
):

self.image_size = None
Expand Down Expand Up @@ -241,6 +243,13 @@ def __load_model(self):

self.gen_model, self.disc_model = self.generator(), self.discriminator()

if self.config["gen_path"] is not None:
self.gen_model.load_weights(self.config["gen_path"])
print("Generator checkpoint restored")
if self.config["disc_path"] is not None:
self.disc_model.load_weights(self.config["disc_path"])
print("Discriminator checkpoint restored")

def fit(
self,
train_ds=None,
Expand Down Expand Up @@ -380,6 +389,9 @@ def generate_samples(self, n_samples=1, save_dir=None):
returns ``None`` if save_dir is ``not None``, otherwise returns a numpy array with generated samples
"""

if self.gen_model is None:
self.__load_model()

Z = np.random.uniform(-1, 1, (n_samples, self.noise_dim))
generated_samples = self.gen_model(Z)
generated_samples = tf.reshape(
Expand Down
12 changes: 12 additions & 0 deletions simplegan/gan/voxelgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def __init__(
activation="relu",
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
gen_path=None,
disc_path=None,
):

self.noise_dim = noise_dim
Expand Down Expand Up @@ -256,6 +258,13 @@ def __load_model(self):

self.gen_model, self.disc_model = self.generator(), self.discriminator()

if self.config["gen_path"] is not None:
self.gen_model.load_weights(self.config["gen_path"])
print("Generator checkpoint restored")
if self.config["disc_path"] is not None:
self.disc_model.load_weights(self.config["disc_path"])
print("Discriminator checkpoint restored")

def fit(
self,
train_ds=None,
Expand Down Expand Up @@ -407,6 +416,9 @@ def generate_sample(self, n_samples=1, plot=False):
``None`` if ``plot`` is ``True`` else a numpy array of samples of shape ``(n_samples, side_length, side_length, side_length, 1)``
"""

if self.gen_model is None:
self.__load_model()

Z = np.random.uniform(0, 1, (n_samples, self.noise_dim))
generated_samples = self.gen_model(Z).numpy()

Expand Down
11 changes: 11 additions & 0 deletions simplegan/gan/wgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __init__(
activation="relu",
kernel_initializer="glorot_uniform",
kernel_regularizer=None,
gen_path=None,
disc_path=None,
):

DCGAN.__init__(
Expand All @@ -70,12 +72,21 @@ def __init__(
activation,
kernel_initializer,
kernel_regularizer,
gen_path,
disc_path,
)

def __load_model(self):

self.gen_model, self.disc_model = self.generator(), self.discriminator()

if self.config["gen_path"] is not None:
self.gen_model.load_weights(self.config["gen_path"])
print("Generator checkpoint restored")
if self.config["disc_path"] is not None:
self.disc_model.load_weights(self.config["disc_path"])
print("Discriminator checkpoint restored")

def fit(
self,
train_ds=None,
Expand Down

0 comments on commit 776d8d1

Please sign in to comment.