In [25]:
"""Generative Adversarial Networks."""

from deepchem.models import TensorGraph
from deepchem.models.tensorgraph import layers
from collections import Sequence
import numpy as np
import tensorflow as tf
import time
import deepchem as dc
from deepchem.data.datasets import NumpyDataset # import NumpyDataset

import numpy as np
import tensorflow as tf
import collections

from deepchem.metrics import to_one_hot

from deepchem.models.tensorgraph.tensor_graph import TensorGraph, TFWrapper
from deepchem.models.tensorgraph.layers import Feature, Label, Weights, \
    WeightedError, Dense, Dropout, WeightDecay, Reshape, SoftMax, SoftMaxCrossEntropy, \
    L2Loss, ReduceSum, Concat, Stack
import rdkit

In [26]:
#!/usr/bin/env python2
# -*- coding: utf-8 -*-
from __future__ import division
from __future__ import unicode_literals

from deepchem.feat import Featurizer


class dRawFeaturizer(Featurizer):

  def __init__(self, smiles=False):
    self.smiles = smiles

  def _featurize(self, mol):
    from rdkit import Chem
    if self.smiles:
      return smiles
    else:
      return smiles

In [27]:
"""
hiv dataset loader.
"""
from __future__ import division
from __future__ import unicode_literals

import os
import logging
import deepchem

logger = logging.getLogger(__name__)


def dload_hiv(featurizer='ECFP', split='index', reload=True):
  """Load hiv datasets. Does not do train/test split"""
  # Featurize hiv dataset
  logger.info("About to featurize hiv dataset.")
  data_dir = deepchem.utils.get_data_dir()
  if reload:
    save_dir = os.path.join(data_dir, "hiv/" + featurizer + "/" + str(split))

  dataset_file = os.path.join(data_dir, "HIV.csv")
  if not os.path.exists(dataset_file):
    deepchem.utils.download_url(
        'http://deepchem.io.s3-website-us-west-1.amazonaws.com/datasets/HIV.csv'
    )

  hiv_tasks = ["HIV_active"]

  if reload:
    loaded, all_dataset, transformers = deepchem.utils.save.load_dataset_from_disk(
        save_dir)
    if loaded:
      return hiv_tasks, all_dataset, transformers

  if featurizer == 'ECFPsmall':
    featurizer = deepchem.feat.CircularFingerprint(size=1024)
  elif featurizer == 'ECFPmedium':
    featurizer = deepchem.feat.CircularFingerprint(size=4096)
  elif featurizer == 'ECFPlarge':
    featurizer = deepchem.feat.CircularFingerprint(size=16384)
  elif featurizer == 'ECFP4':
    featurizer = deepchem.feat.CircularFingerprint(radius=4, size = 1024)
  elif featurizer == 'ECFP6':
    featurizer = deepchem.feat.CircularFingerprint(radius=6, size = 1024)
  elif featurizer == 'Raw':
    featurizer = dRawFeaturizer()

  loader = deepchem.data.CSVLoader(
      tasks=hiv_tasks, smiles_field="smiles", featurizer=featurizer)
  dataset = loader.featurize(dataset_file, shard_size=8192)
  # Initialize transformers
  transformers = [
      deepchem.trans.BalancingTransformer(transform_w=True, dataset=dataset)
  ]

  logger.info("About to transform data")
  for transformer in transformers:
    dataset = transformer.transform(dataset)

  if split == None:
    return hiv_tasks, (dataset, None, None), transformers

  splitters = {
      'index': deepchem.splits.IndexSplitter(),
      'random': deepchem.splits.RandomSplitter(),
      'scaffold': deepchem.splits.ScaffoldSplitter(),
      'butina': deepchem.splits.ButinaSplitter()
  }
  splitter = splitters[split]
  train, valid, test = splitter.train_valid_test_split(dataset)

  if reload:
    deepchem.utils.save.save_dataset_to_disk(save_dir, train, valid, test,
                                             transformers)
  return hiv_tasks, (train, valid, test), transformers

In [28]:
#MODELS TO USE
#LOGREG, RANDOM FOREST, INFLUENCE RELEVANCE VECTOR, MULTITASK NETWORK 

#FEATURIZATIONS
#RAW, ECFP, 

In [29]:
n_features = 1024
eshiv_tasks, eshiv_datasets, eshiv_transformers = dload_hiv(featurizer='ECFPsmall', split='random', reload=True)
eshiv_train_dataset, eshiv_valid_dataset, eshiv_test_dataset = eshiv_datasets

metric = dc.metrics.Metric(dc.metrics.roc_auc_score, np.mean)

Loading raw samples now.
shard_size: 8192
About to start loading CSV from /var/folders/ph/d61mj9js3hvf9z7n92kk8y6c0000gn/T/HIV.csv
Loading shard 1 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
Featurizing sample 8000
TIMING: featurizing shard 0 took 14.438 s
Loading shard 2 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
Featurizing sample 8000
TIMING: featurizing shard 1 took 14.378 s
Loading shard 3 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
Featurizing sample 8000
TIMING: featurizing shard 2 took 15.088 s
Loading shard 

In [30]:
emhiv_tasks, emhiv_datasets, emhiv_transformers = dload_hiv(featurizer='ECFPmedium', split='random', reload=True)
emhiv_train_dataset, emhiv_valid_dataset, emhiv_test_dataset = emhiv_datasets

Loading raw samples now.
shard_size: 8192
About to start loading CSV from /var/folders/ph/d61mj9js3hvf9z7n92kk8y6c0000gn/T/HIV.csv
Loading shard 1 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
Featurizing sample 8000
TIMING: featurizing shard 0 took 35.405 s
Loading shard 2 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
Featurizing sample 8000
TIMING: featurizing shard 1 took 35.204 s
Loading shard 3 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
Featurizing sample 8000
TIMING: featurizing shard 2 took 38.385 s
Loading shard 

In [31]:
n_features = 1024
elhiv_tasks, elhiv_datasets, elhiv_transformers = dload_hiv(featurizer='ECFPlarge', split='random', reload=True)
elhiv_train_dataset, elhiv_valid_dataset, elhiv_test_dataset = elhiv_datasets

metric = dc.metrics.Metric(dc.metrics.roc_auc_score, np.mean)

Loading raw samples now.
shard_size: 8192
About to start loading CSV from /var/folders/ph/d61mj9js3hvf9z7n92kk8y6c0000gn/T/HIV.csv
Loading shard 1 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
Featurizing sample 8000
TIMING: featurizing shard 0 took 119.645 s
Loading shard 2 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
Featurizing sample 8000
TIMING: featurizing shard 1 took 113.834 s
Loading shard 3 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
Featurizing sample 8000
TIMING: featurizing shard 2 took 110.962 s
Loading sha

In [32]:
n_features = 1024
e4hiv_tasks, e4hiv_datasets, e4hiv_transformers = dload_hiv(featurizer='ECFP4', split='random', reload=True)
e4hiv_train_dataset, e4hiv_valid_dataset, e4hiv_test_dataset = e4hiv_datasets

metric = dc.metrics.Metric(dc.metrics.roc_auc_score, np.mean)

Loading raw samples now.
shard_size: 8192
About to start loading CSV from /var/folders/ph/d61mj9js3hvf9z7n92kk8y6c0000gn/T/HIV.csv
Loading shard 1 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
Featurizing sample 8000
TIMING: featurizing shard 0 took 13.376 s
Loading shard 2 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
Featurizing sample 8000
TIMING: featurizing shard 1 took 13.528 s
Loading shard 3 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
Featurizing sample 8000
TIMING: featurizing shard 2 took 13.788 s
Loading shard 

In [33]:
n_features = 1024
e6hiv_tasks, e6hiv_datasets, e6hiv_transformers = dload_hiv(featurizer='ECFP6', split='random', reload=True)
e6hiv_train_dataset, e6hiv_valid_dataset, e6hiv_test_dataset = e6hiv_datasets

metric = dc.metrics.Metric(dc.metrics.roc_auc_score, np.mean)

Loading raw samples now.
shard_size: 8192
About to start loading CSV from /var/folders/ph/d61mj9js3hvf9z7n92kk8y6c0000gn/T/HIV.csv
Loading shard 1 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
Featurizing sample 8000
TIMING: featurizing shard 0 took 14.214 s
Loading shard 2 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
Featurizing sample 8000
TIMING: featurizing shard 1 took 14.752 s
Loading shard 3 of size 8192.
Featurizing sample 0
Featurizing sample 1000
Featurizing sample 2000
Featurizing sample 3000
Featurizing sample 4000
Featurizing sample 5000
Featurizing sample 6000
Featurizing sample 7000
Featurizing sample 8000
TIMING: featurizing shard 2 took 14.732 s
Loading shard 

In [38]:
n_features = 1024
rahiv_tasks, rahiv_datasets, rahiv_transformers = dload_hiv(featurizer='Raw', split='random', reload=True)
rahiv_train_dataset, rahiv_valid_dataset, rahiv_test_dataset = rahiv_datasets

metric = dc.metrics.Metric(dc.metrics.roc_auc_score, np.mean)

Loading dataset from disk.
Loading dataset from disk.
Loading dataset from disk.


In [None]:
#TRAIN TWO GANS, ONE ON RAW, ONE ON ECFP. MAYBE EVEN TRY MULTIPLE ECFPS...

In [54]:
class ESMALLHIVGAN(dc.models.WGAN):

  def get_noise_input_shape(self):
    return (None, 1024)

  def get_data_input_shapes(self):
    return [(None, 1024)]

  def create_generator(self, noise_input, conditional_inputs):
    dense1 = layers.Dense(1024, in_layers=noise_input, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense2 = layers.Dense(1024, in_layers=dense1, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    dense3 = layers.Dense(1024, in_layers=dense2, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    dense4 = layers.Dense(1024, in_layers=dense3, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    dense5 = layers.Dense(1024, in_layers=dense4, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    reshaped = layers.Reshape((None, 1024), in_layers=dense5)
    return [reshaped]

  def create_discriminator(self, data_inputs, conditional_inputs):
    dense1 = layers.Dense(1024, in_layers=data_inputs, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense2 = layers.Dense(1024, in_layers=dense1, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense3 = layers.Dense(1024, in_layers=dense2, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense4 = layers.Dense(1024, in_layers=dense3, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense5 = layers.Dense(1024, in_layers=dense4, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense6 = layers.Dense(1024, in_layers=dense5, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense7 = layers.Dense(1, in_layers=dense6, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    return dense7
#DISCRIMINATOR NEEDS TO BE BETTER!
esmallgan = ESMALLHIVGAN(learning_rate=0.0005)

In [57]:
class EMEDIUMHIVGAN(dc.models.WGAN):

  def get_noise_input_shape(self):
    return (None, 4096)

  def get_data_input_shapes(self):
    return [(None, 4096)]

  def create_generator(self, noise_input, conditional_inputs):
    dense1 = layers.Dense(4096, in_layers=noise_input, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense2 = layers.Dense(4096, in_layers=dense1, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    dense3 = layers.Dense(4096, in_layers=dense2, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    dense4 = layers.Dense(4096, in_layers=dense3, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    dense5 = layers.Dense(4096, in_layers=dense4, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    reshaped = layers.Reshape((None, 4096), in_layers=dense5)
    return [reshaped]

  def create_discriminator(self, data_inputs, conditional_inputs):
    dense1 = layers.Dense(4096, in_layers=data_inputs, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense2 = layers.Dense(4096, in_layers=dense1, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense3 = layers.Dense(4096, in_layers=dense2, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense4 = layers.Dense(4096, in_layers=dense3, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense5 = layers.Dense(4096, in_layers=dense4, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense6 = layers.Dense(4096, in_layers=dense5, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense7 = layers.Dense(1, in_layers=dense6, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    return dense7
#DISCRIMINATOR NEEDS TO BE BETTER!
emediumgan = EMEDIUMHIVGAN(learning_rate=0.0005)

In [58]:
class ELARGEHIVGAN(dc.models.WGAN):

  def get_noise_input_shape(self):
    return (None, 16384)

  def get_data_input_shapes(self):
    return [(None, 16384)]

  def create_generator(self, noise_input, conditional_inputs):
    dense1 = layers.Dense(16384, in_layers=noise_input, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense2 = layers.Dense(16384, in_layers=dense1, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    dense3 = layers.Dense(16384, in_layers=dense2, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    dense4 = layers.Dense(16384, in_layers=dense3, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    dense5 = layers.Dense(16384, in_layers=dense4, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    reshaped = layers.Reshape((None, 16384), in_layers=dense5)
    return [reshaped]

  def create_discriminator(self, data_inputs, conditional_inputs):
    dense1 = layers.Dense(16384, in_layers=data_inputs, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense2 = layers.Dense(16384, in_layers=dense1, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense3 = layers.Dense(16384, in_layers=dense2, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense4 = layers.Dense(16384, in_layers=dense3, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense5 = layers.Dense(16384, in_layers=dense4, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense6 = layers.Dense(16384, in_layers=dense5, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense7 = layers.Dense(1, in_layers=dense6, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    return dense7
#DISCRIMINATOR NEEDS TO BE BETTER!
elargegan = ELARGEHIVGAN(learning_rate=0.0005)

In [59]:
class E4HIVGAN(dc.models.WGAN):

  def get_noise_input_shape(self):
    return (None, 1024)

  def get_data_input_shapes(self):
    return [(None, 1024)]

  def create_generator(self, noise_input, conditional_inputs):
    dense1 = layers.Dense(1024, in_layers=noise_input, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense2 = layers.Dense(1024, in_layers=dense1, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    dense3 = layers.Dense(1024, in_layers=dense2, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    dense4 = layers.Dense(1024, in_layers=dense3, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    dense5 = layers.Dense(1024, in_layers=dense4, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    reshaped = layers.Reshape((None, 1024), in_layers=dense5)
    return [reshaped]

  def create_discriminator(self, data_inputs, conditional_inputs):
    dense1 = layers.Dense(1024, in_layers=data_inputs, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense2 = layers.Dense(1024, in_layers=dense1, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense3 = layers.Dense(1024, in_layers=dense2, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense4 = layers.Dense(1024, in_layers=dense3, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense5 = layers.Dense(1024, in_layers=dense4, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense6 = layers.Dense(1024, in_layers=dense5, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense7 = layers.Dense(1, in_layers=dense6, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    return dense7
#DISCRIMINATOR NEEDS TO BE BETTER!
e4gan = E4HIVGAN(learning_rate=0.0005)

In [60]:
class E6HIVGAN(dc.models.WGAN):

  def get_noise_input_shape(self):
    return (None, 1024)

  def get_data_input_shapes(self):
    return [(None, 1024)]

  def create_generator(self, noise_input, conditional_inputs):
    dense1 = layers.Dense(1024, in_layers=noise_input, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense2 = layers.Dense(1024, in_layers=dense1, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    dense3 = layers.Dense(1024, in_layers=dense2, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    dense4 = layers.Dense(1024, in_layers=dense3, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    dense5 = layers.Dense(1024, in_layers=dense4, activation_fn=tf.sigmoid, normalizer_fn=tf.layers.batch_normalization)
    reshaped = layers.Reshape((None, 1024), in_layers=dense5)
    return [reshaped]

  def create_discriminator(self, data_inputs, conditional_inputs):
    dense1 = layers.Dense(1024, in_layers=data_inputs, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense2 = layers.Dense(1024, in_layers=dense1, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense3 = layers.Dense(1024, in_layers=dense2, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense4 = layers.Dense(1024, in_layers=dense3, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense5 = layers.Dense(1024, in_layers=dense4, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense6 = layers.Dense(1024, in_layers=dense5, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    dense7 = layers.Dense(1, in_layers=dense6, activation_fn=tf.nn.relu, normalizer_fn=tf.layers.batch_normalization)
    return dense7
#DISCRIMINATOR NEEDS TO BE BETTER!
e6gan = E6HIVGAN(learning_rate=0.0005)

In [50]:
def esmalliterbatches(epochs):
  for i in range(epochs):
    print(esmallgan.batch_size)
    for batch in eshiv_train_dataset.iterbatches(batch_size=esmallgan.batch_size):
      yield {esmallgan.data_inputs[0]: batch[0]}

esmallgan.fit_gan(esmalliterbatches(10), generator_steps=1.5, checkpoint_interval=5000)

100
100


KeyboardInterrupt: 

In [61]:
def emediumiterbatches(epochs):
  for i in range(epochs):
    print(emediumgan.batch_size)
    for batch in emhiv_train_dataset.iterbatches(batch_size=emediumgan.batch_size):
      yield {emediumgan.data_inputs[0]: batch[0]}

emediumgan.fit_gan(emediumiterbatches(10), generator_steps=1.5, checkpoint_interval=5000)

100
100


KeyboardInterrupt: 

In [40]:
def elargeiterbatches(epochs):
  for i in range(epochs):
    print(elargegan.batch_size)
    for batch in elhiv_train_dataset.iterbatches(batch_size=elargegan.batch_size):
      yield {elargegan.data_inputs[0]: batch[0]}

elargegan.fit_gan(elargeiterbatches(10), generator_steps=1.5, checkpoint_interval=5000)

100


KeyboardInterrupt: 

In [None]:
def e4iterbatches(epochs):
  for i in range(epochs):
    print(e4gan.batch_size)
    for batch in e4hiv_train_dataset.iterbatches(batch_size=e4gan.batch_size):
      yield {e4gan.data_inputs[0]: batch[0]}

e4gan.fit_gan(e4iterbatches(10), generator_steps=1.5, checkpoint_interval=5000)

In [None]:
def e6iterbatches(epochs):
  for i in range(epochs):
    print(e6gan.batch_size)
    for batch in e6hiv_train_dataset.iterbatches(batch_size=e6gan.batch_size):
      yield {e6gan.data_inputs[0]: batch[0]}

e6gan.fit_gan(e6iterbatches(10), generator_steps=1.5, checkpoint_interval=5000)