In [1]:
# Initialize required random seeds before importing anything else.
import random

import numpy as np
from tensorflow import set_random_seed

np.random.seed(42)
set_random_seed(42)
random.seed(42)

  from ._conv import register_converters as _register_converters


In [2]:
%matplotlib inline

import json

from keras.optimizers import RMSprop
from keras.utils import to_categorical
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

import sgantech as sgt

Using TensorFlow backend.


In [3]:
sns.set_style('whitegrid')
color_palette = sns.color_palette()

In [4]:
(x_train, y_train), (x_test, y_test) = sgt.utils.load_preprocessed_mnist()
X = np.concatenate((x_train, x_test))
y = np.concatenate((y_train, y_test))

In [5]:
n_critic = 5
clip_value = 0.01

In [6]:
def train_test_gan(gan_factory, x_train, y_train, x_test, y_test, epochs=40000, batch_size=32, test_steps=500, num_labels=None, verbose=True):
    classes = np.unique(y)
    num_classes = len(classes)
    
    quarter_batch = batch_size // 4
    cw1 = {-1: 1, 1: 1}
    cw2 = {i: 10 / quarter_batch for i in range(num_classes)}
    cw2[num_classes] = 1 / quarter_batch
    
    if num_labels is None:
        x_train_labeled = x_train
        y_train_labeled = y_train
    else:
        x_train_labeled, _, y_train_labeled, _ = train_test_split(x_train, y_train, train_size=num_labels, stratify=y_train)
    
    disc, gen, combined = gan_factory()
    
    histories = []
    for epoch in range(epochs):
        for _ in range(n_critic):
            # Select a random quater batch of images for real/fake
            real_fake_idx = np.random.randint(0, x_train.shape[0], quarter_batch)
            real_fake_imgs = x_train[real_fake_idx]
            # Select a random quarter batch of labeled images
            labeled_idx = np.random.randint(0, x_train_labeled.shape[0], quarter_batch)
            labeled_imgs = x_train_labeled[labeled_idx]
            # Sample noise and generate a half batch of new images
            noise = np.random.normal(0, 1, (quarter_batch * 2, 100))
            gen_imgs = gen.predict(noise)

            valid = -np.ones((quarter_batch, 1))
            fake = np.ones((quarter_batch, 1))

            labels = to_categorical(y_train_labeled[labeled_idx], num_classes=num_classes+1)
            fake_labels = to_categorical(np.full((quarter_batch, 1), num_classes), num_classes=num_classes+1)

            # Train the discriminator
            d_loss_real = disc.train_on_batch([real_fake_imgs, labeled_imgs], [valid, labels], class_weight=[cw1, cw2])
            d_loss_fake = disc.train_on_batch([gen_imgs[:quarter_batch], gen_imgs[quarter_batch:]], [fake, fake_labels], class_weight=[cw1, cw2])
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            
            for l in disc.layers:
                weights = l.get_weights()
                weights = [np.clip(w, -clip_value, clip_value) for w in weights]
                l.set_weights(weights)

        noise = np.random.normal(0, 1, (batch_size, 100))
        validity = -np.ones((batch_size, 1))

        # Train the generator
        g_loss = combined.train_on_batch(noise, validity, class_weight=[cw1, cw2])
        
        if epoch % test_steps == 0:
            test_loss = disc.evaluate(
                [x_test, x_test],
                [np.ones((len(x_test), 1)), to_categorical(y_test, num_classes=num_classes+1)]
            )
            print(test_loss)
            histories.append((d_loss, g_loss, test_loss))

        # Plot the progress
        if verbose:
            print("%d [D loss: %f, acc: %.2f%%, op_acc: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[3], 100*d_loss[4], g_loss))
    
    test_loss = disc.evaluate(
        [x_test, x_test],
        [np.ones((len(x_test), 1)), to_categorical(y_test, num_classes=num_classes+1)]
    )
    histories.append(test_loss)
    
    return histories

In [7]:
def dense_wgan_factory():
    dense_d = sgt.discriminators.dense((28, 28, 1), 10, unsupervised_activation='linear')
    dense_g = sgt.generators.dense((100,), (28, 28, 1))
    return sgt.gan.SGANBuilder(dense_g, dense_d, (28, 28, 1), (100,)) \
        .set_discriminator_loss([sgt.losses.wasserstein_loss, 'categorical_crossentropy']) \
        .set_generator_loss(sgt.losses.wasserstein_loss) \
        .set_discriminator_optimizer(RMSprop(lr=0.00005)) \
        .set_generator_optimizer(RMSprop(lr=0.00005)) \
        .build()

In [8]:
dense_histories = train_test_gan(dense_wgan_factory, x_train, y_train, x_test, y_test, epochs=12500, verbose=False)

  'Discrepancy between trainable weights and collected trainable'


[1.197403161239624, -0.0019956754026934503, 2.3968020027160644, 0.0, 0.103]
[14.05979585723877, 25.764546923828124, 2.35504483795166, 0.0, 0.1135]
[13.275085575866699, 24.134525231933594, 2.415645907974243, 0.0, 0.1028]
[12.388497633361816, 22.343237283325195, 2.433757963180542, 0.0, 0.098]
[9.716625354003906, 17.01142036743164, 2.421830376815796, 0.0, 0.1009]
[4.770108943176269, 7.135394436645508, 2.404823447418213, 0.0, 0.1009]
[1.9379847164154054, 1.4741973217010498, 2.4017721138000487, 0.9228, 0.1135]
[1.3065621587753296, 0.6690315710067749, 1.944092751121521, 0.4121, 0.2044]
[1.346818564605713, 0.5410958429336548, 2.1525412826538086, 0.4749, 0.1876]
[0.9440448625564575, 0.12305027339458466, 1.7650394529342652, 0.0, 0.2928]
[0.7638637909889221, -0.07504746795296668, 1.60277504901886, 0.0, 0.3948]
[0.7790240445137024, 0.06251202437877655, 1.4955360664367676, 0.0, 0.4237]
[0.7429135074615478, 0.013741545180231332, 1.4720854698181152, 0.0, 0.549]
[0.67816492395401, -0.0278030903927981

In [9]:
def cnn_wgan_factory():
    cnn_d = sgt.discriminators.cnn((28, 28, 1), 10, unsupervised_activation='linear')
    cnn_g = sgt.generators.cnn((100,), (7, 7))
    return sgt.gan.SGANBuilder(cnn_g, cnn_d, (28, 28, 1), (100,)) \
        .set_discriminator_loss([sgt.losses.wasserstein_loss, 'categorical_crossentropy']) \
        .set_generator_loss(sgt.losses.wasserstein_loss) \
        .set_discriminator_optimizer(RMSprop(lr=0.00005)) \
        .set_generator_optimizer(RMSprop(lr=0.00005)) \
        .build()

In [10]:
cnn_histories = train_test_gan(cnn_wgan_factory, x_train, y_train, x_test, y_test, epochs=12500, verbose=False)

  'Discrepancy between trainable weights and collected trainable'


[1.1990313787460327, 0.0001704237805912271, 2.397892288208008, 0.0, 0.101]
[1.1996344678878783, -0.008763673540949821, 2.4080326103210448, 0.0, 0.0]
[1.2075552717208862, 0.0247183290630579, 2.3903922092437746, 0.0, 0.204]
[1.1926410427093506, -0.0023240239698439835, 2.387606108093262, 0.0, 0.2056]
[1.1855406366348267, -0.012687110218405724, 2.3837683769226072, 0.0, 0.1938]
[1.1844341131210328, -0.015254679092764855, 2.3841229038238527, 0.0, 0.098]
[1.1839132236480714, -0.014246545204520225, 2.382072996902466, 0.0, 0.1767]
[1.183460143661499, -0.014997118881344796, 2.3819174087524413, 0.0, 0.1923]
[1.184146026992798, -0.012973507159948348, 2.381265560531616, 0.0, 0.199]
[1.1825757215499877, -0.015531020629405975, 2.3806824584960937, 0.0, 0.1439]
[1.1833165950775146, -0.013018601006269455, 2.379651792907715, 0.0, 0.1495]
[1.1860692331314087, -0.008982692870497704, 2.3811211574554445, 0.0, 0.1137]
[1.1865172359466554, -0.0037652269564568998, 2.3767997016906737, 0.0, 0.1961]
[1.18462406845

In [13]:
from sklearn.model_selection import train_test_split

dense_ss_histories = {}
for num_examples in [100, 300, 500, 1000]:
    print(num_examples)
    dense_ss_histories[num_examples] = train_test_gan(
        dense_wgan_factory, x_train, y_train, x_test, y_test, epochs=50000, num_labels=num_examples, verbose=False
    )

100


  'Discrepancy between trainable weights and collected trainable'


[1.204011713409424, 0.01119626025557518, 2.3968271759033204, 0.0, 0.1199]
[14.620000622558594, 26.868520629882813, 2.3714805931091307, 0.0, 0.1032]
[14.04086590270996, 25.670133981323243, 2.411597789764404, 0.0, 0.0958]
[13.270549453735352, 24.146482904052736, 2.3946159759521484, 0.0, 0.1135]
[11.830445603942872, 21.2529085357666, 2.407982676315308, 0.0, 0.1135]
[6.83843620300293, 11.269701022338868, 2.4071713973999023, 0.0, 0.0982]
[2.838068974685669, 3.272713101196289, 2.403424839019775, 0.0, 0.1032]
[1.211424487876892, 0.022819855827093123, 2.4000291145324706, 0.0, 0.0]
[1.0659988498687745, 0.10298681316375732, 2.0290108837127687, 0.0104, 0.3046]
[0.8837500127792358, 0.020311928924918173, 1.7471880922317504, 0.0, 0.4233]
[0.8310049837112427, 0.0457795806825161, 1.6162303857803344, 0.0, 0.4391]
[0.7912660607337951, -0.014227733683586121, 1.5967598556518554, 0.0, 0.5182]
[0.8111813258171081, 0.029437933930009604, 1.5929247215270996, 0.0, 0.5135]
[0.7917610613822937, 0.0071258541077375

[0.7828291113853455, 0.01385477720350027, 1.5518034448623657, 0.0, 0.5597]
[0.7624292971611023, -0.02800359134078026, 1.5528621864318848, 0.0, 0.5728]
[0.7772472512245179, -0.011806803105771542, 1.5663013063430786, 0.0, 0.5386]
[0.7939010718345642, 0.015701901811361312, 1.5721002403259277, 0.0, 0.5378]
[0.7279104077339172, -0.09065493549108505, 1.5464757480621338, 0.0, 0.5707]
[0.7706083152770996, 0.0023804808925837278, 1.5388361476898194, 0.0, 0.5643]
[0.7722240639686584, -0.010966439636424184, 1.5554145656585694, 0.0, 0.5399]
[0.7838748317718506, 0.007361217431351543, 1.5603884477615357, 0.0, 0.5235]
[0.7748920244216919, 0.011276629947684705, 1.5385074186325074, 0.0, 0.5558]
[0.7909342675209046, 0.02773600270152092, 1.5541325271606445, 0.0, 0.5625]
[0.7646217092514038, -0.0013619243394583464, 1.5306053409576417, 0.0, 0.5644]
[0.7568169487953186, -0.044090224523842335, 1.5577241214752198, 0.0, 0.554]
[0.7856790115356446, 0.03400556723475456, 1.5373524576187134, 0.0, 0.5653]
[0.7865097

[0.8731090294837952, -0.016179607404768467, 1.762397668838501, 0.0, 0.3715]
[0.8719294109344482, -0.020654971302300693, 1.7645137935638429, 0.0, 0.3595]
[0.8771961116790772, 0.0004820874644443393, 1.7539101341247558, 0.0, 0.3704]
[0.8737041703224182, -0.012114058296382428, 1.7595224029541017, 0.0, 0.3651]
[0.8712271175384522, -0.0022786013967357574, 1.7447328397750854, 0.0, 0.3728]
[0.8791037318229675, 0.015874538972228767, 1.7423329299926757, 0.0, 0.3768]
[0.8706862325668335, -0.0018409744903445243, 1.743213441467285, 0.0, 0.3415]
[0.8712145224571228, -0.009365437648445368, 1.751794482421875, 0.0, 0.3902]
[0.8673848667144776, -0.009140996130928397, 1.7439107339859008, 0.0, 0.3715]
[0.8507487863540649, -0.03268920039460063, 1.73418677444458, 0.0, 0.3713]
[0.8598974678993225, -0.01691172415614128, 1.7367066612243653, 0.0, 0.3474]
[0.8637723167419433, -0.004946555956080556, 1.7324911890029908, 0.0, 0.3731]
[0.857549636554718, -0.021522141972184182, 1.7366214149475097, 0.0, 0.38]
[0.87240

[0.6203196290016174, 0.004544651648402214, 1.2360946062088012, 0.0, 0.6681]
[0.611190832901001, -0.020121484452486037, 1.2425031517028808, 0.0, 0.6533]
[0.6441631107330322, 0.04185585605949164, 1.2464703666687011, 0.0, 0.6546]
[0.6504481995582581, 0.0034814962428063156, 1.297414902496338, 0.0, 0.611]
[0.6112362073898315, -0.016045028972625733, 1.2385174409866333, 0.0, 0.6654]
[0.6317313121795655, 0.03284350828528404, 1.2306191144943237, 0.0, 0.6912]
[0.6187536979675293, -0.0025943635475821794, 1.2401017576217652, 0.0, 0.6507]
[0.6151827103614808, -0.007486924047023058, 1.237852346611023, 0.0, 0.6392]
[0.630444870185852, 0.02015184145644307, 1.2407378995895386, 0.0, 0.6396]
[0.6220655570983886, -0.0050889582075178625, 1.2492200716018678, 0.0, 0.6868]
[0.6270541994094848, 0.01788199287503958, 1.2362264080047607, 0.0, 0.6479]
[0.6320866761207581, 0.034235763972997664, 1.2299375890731812, 0.0, 0.6793]
[0.6152019745826721, -0.023466007217764855, 1.2538699562072755, 0.0, 0.6419]
[0.593357557

[14.209712370300293, 26.067600134277345, 2.351824598312378, 0.0, 0.098]
[13.37063542327881, 24.364450158691405, 2.376820669174194, 0.0, 0.1135]
[12.267789981079101, 22.114933251953126, 2.420646695327759, 0.0, 0.0892]
[9.102778619384766, 15.790830072021484, 2.414727174758911, 0.0, 0.1028]
[4.186281909942627, 5.966947753143311, 2.4056160469055174, 0.0, 0.1135]
[1.6706795528411866, 0.9403078709602356, 2.4010512355804443, 1.0, 0.0974]
[1.3764957748413087, 0.4505564215183258, 2.3024351322174073, 0.633, 0.1965]
[1.3948724403381347, 0.7582965222358704, 2.031448358154297, 0.5905, 0.2898]
[1.4747647556304933, 1.2650795024871826, 1.6844499990463258, 0.8971, 0.4026]
[1.5029309673309326, 1.483465685081482, 1.5223962467193604, 0.5637, 0.4438]
[1.0508916600227356, 0.6485877737045288, 1.4531955486297607, 0.999, 0.5326]
[0.6652593037605286, -0.033102806317806245, 1.3636214149475097, 0.0, 0.5442]
[0.6546518270492554, -0.058440335869789126, 1.3677439907073974, 0.0, 0.5312]
[0.6727251004219055, -0.000708

[0.6162960761070252, 0.0020470367873087524, 1.230545116043091, 0.0, 0.6802]
[0.6302319295883179, 0.03197123099416494, 1.2284926267623901, 0.0, 0.6707]
[0.6208618525505066, -0.007942432315275073, 1.2496661367416382, 0.0, 0.6757]
[0.603865029335022, -0.040335889983177185, 1.2480659450531006, 0.0, 0.6682]
[0.6491188742637635, 0.06290317951440812, 1.2353345691680908, 0.0, 0.6435]
[0.6252757669448853, -0.0008564422968775034, 1.251407976913452, 0.0, 0.63]
[0.6209393746376037, 0.00919427226204425, 1.2326844758987427, 0.0, 0.689]
[0.6168567053794861, 0.007517949637211859, 1.226195460128784, 0.0, 0.6677]
[0.598876414489746, -0.02819542874544859, 1.225948257446289, 0.0, 0.6699]
[0.6161189891815185, 0.012165480642393232, 1.2200724967956542, 0.0, 0.6902]
[0.6302777379989624, 0.043136423057317734, 1.217419055366516, 0.0, 0.6732]
[0.6281582933425903, 0.011577266132459045, 1.244739317512512, 0.0, 0.6458]
[0.5970558793067932, -0.04751953784227371, 1.241631294441223, 0.0, 0.6352]
[0.5856250723838806, -

In [None]:
cnn_ss_histories = {}
for num_examples in [100, 300, 500, 1000]:
    print(num_examples)
    cnn_ss_histories[num_examples] = train_test_gan(
        cnn_wgan_factory, x_train, y_train, x_test, y_test, epochs=50000, num_labels=num_examples, verbose=False
    )

100


  'Discrepancy between trainable weights and collected trainable'


[1.199020803833008, 0.0001444129419978708, 2.3978971458435057, 0.0, 0.0892]
[1.2034447160720825, -0.00455792110785842, 2.4114473541259764, 0.0, 0.0868]
[1.2173508432388305, 0.042682951325178145, 2.3920187324523927, 0.0, 0.098]
[1.1863086349487304, -0.010958937323093415, 2.383576203918457, 0.0, 0.1822]
[1.1819973251342772, -0.020649126720428467, 2.3846437767028807, 0.0, 0.1749]
[1.1821890487670899, -0.02132643413543701, 2.385704529953003, 0.0, 0.157]
[1.1808946201324464, -0.023594550281763077, 2.3853837936401368, 0.0, 0.098]
[1.1807229125976562, -0.023016786503791808, 2.384462607955933, 0.0, 0.1954]
[1.1836269836425781, -0.017711034417152403, 2.384964999008179, 0.0, 0.098]
[1.1821100498199464, -0.017773621013760565, 2.381993716430664, 0.0, 0.1795]
[1.1839344789505004, -0.018774266615509988, 2.386643218231201, 0.0, 0.098]
[1.184864582633972, -0.01578654153048992, 2.3855157051086424, 0.0, 0.098]
[1.1848929361343383, -0.016464816415309905, 2.386250684738159, 0.0, 0.098]
[1.1856680513381959

In [None]:
def dense_wgan_factory_test():
    from keras.optimizers import Adam
    
    dense_d = sgt.discriminators.dense((28, 28, 1), 10, unsupervised_activation='linear')
    dense_g = sgt.generators.dense((100,), (28, 28, 1))
    return sgt.gan.SGANBuilder(dense_g, dense_d, (28, 28, 1), (100,)) \
        .set_discriminator_loss([sgt.losses.wasserstein_loss, 'categorical_crossentropy']) \
        .set_generator_loss(sgt.losses.wasserstein_loss) \
        .set_discriminator_optimizer(Adam()) \
        .set_generator_optimizer(Adam()) \
        .build()

In [None]:
dense_histories_adam = train_test_gan(dense_wgan_factory_test, x_train, y_train, x_test, y_test, epochs=12500, verbose=False)

In [None]:
clip_value = 0.1

In [None]:
dense_histories_large_clip = train_test_gan(dense_wgan_factory, x_train, y_train, x_test, y_test, epochs=12500, verbose=False)