In [1]:
# Initialize required random seeds before importing anything else.
import random

import numpy as np
from tensorflow import set_random_seed

np.random.seed(42)
set_random_seed(42)
random.seed(42)

  from ._conv import register_converters as _register_converters


In [2]:
%matplotlib inline

import json

from keras.utils import to_categorical
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.model_selection import train_test_split

import sgantech as sgt

Using TensorFlow backend.


In [3]:
sns.set_style('whitegrid')
color_palette = sns.color_palette()

In [4]:
(x_train, y_train), (x_test, y_test) = sgt.utils.load_preprocessed_mnist()
X = np.concatenate((x_train, x_test))
y = np.concatenate((y_train, y_test))

In [5]:
def train_test_gan(gan_factory, x_train, y_train, x_test, y_test, epochs=40000, batch_size=32, test_steps=500, num_labels=None, verbose=True):
    classes = np.unique(y)
    num_classes = len(classes)
    
    quarter_batch = batch_size // 4
    cw1 = {0: 1, 1: 1}
    cw2 = {i: 10 / quarter_batch for i in range(num_classes)}
    cw2[num_classes] = 1 / quarter_batch
    
    if num_labels is None:
        x_train_labeled = x_train
        y_train_labeled = y_train
    else:
        x_train_labeled, _, y_train_labeled, _ = train_test_split(x_train, y_train, train_size=num_labels, stratify=y_train)
    
    disc, gen, combined = gan_factory()
    
    histories = []
    for epoch in range(epochs):
        # Select a random quater batch of images for real/fake
        real_fake_idx = np.random.randint(0, x_train.shape[0], quarter_batch)
        real_fake_imgs = x_train[real_fake_idx]
        # Select a random quarter batch of labeled images
        labeled_idx = np.random.randint(0, x_train_labeled.shape[0], quarter_batch)
        labeled_imgs = x_train_labeled[labeled_idx]
        # Sample noise and generate a half batch of new images
        noise = np.random.normal(0, 1, (quarter_batch * 2, 100))
        gen_imgs = gen.predict(noise)

        valid = np.ones((quarter_batch, 1))
        fake = np.zeros((quarter_batch, 1))

        labels = to_categorical(y_train_labeled[labeled_idx], num_classes=num_classes+1)
        fake_labels = to_categorical(np.full((quarter_batch, 1), num_classes), num_classes=num_classes+1)

        # Train the discriminator
        d_loss_real = disc.train_on_batch([real_fake_imgs, labeled_imgs], [valid, labels], class_weight=[cw1, cw2])
        d_loss_fake = disc.train_on_batch([gen_imgs[:quarter_batch], gen_imgs[quarter_batch:]], [fake, fake_labels], class_weight=[cw1, cw2])
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        noise = np.random.normal(0, 1, (batch_size, 100))
        validity = np.ones((batch_size, 1))

        # Train the generator
        g_loss = combined.train_on_batch(noise, validity, class_weight=[cw1, cw2])
        
        if epoch % test_steps == 0:
            test_loss = disc.evaluate(
                [x_test, x_test],
                [np.ones((len(x_test), 1)), to_categorical(y_test, num_classes=num_classes+1)]
            )
            print(test_loss)
            histories.append((d_loss, g_loss, test_loss))

        # Plot the progress
        if verbose:
            print("%d [D loss: %f, acc: %.2f%%, op_acc: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[3], 100*d_loss[4], g_loss))
    
    test_loss = disc.evaluate(
        [x_test, x_test],
        [np.ones((len(x_test), 1)), to_categorical(y_test, num_classes=num_classes+1)]
    )
    histories.append(test_loss)
    
    return histories

## MLP

In [6]:
def dense_gan_factory():
    dense_d = sgt.discriminators.dense((28, 28, 1), 10)
    dense_g = sgt.generators.dense((100,), (28, 28, 1))
    dense_d, dense_g, dense_gd = sgt.gan.SGANBuilder(dense_g, dense_d, (28, 28, 1), (100,)).build()
    return dense_d, dense_g, dense_gd

In [7]:
# Baseline to compare to Keras-GAN SGAN implementation
dense_histories = train_test_gan(dense_gan_factory, x_train, y_train, x_test, y_test, epochs=50000, verbose=False)

  'Discrepancy between trainable weights and collected trainable'


[2.2246240142822264, 0.0006687384232878685, 4.448579292297364, 1.0, 0.101]
[0.2999799731850624, 0.05961843948364258, 0.5403415083408356, 0.9839, 0.8297]
[0.2161311832770705, 0.07307303901836276, 0.3591893277183175, 0.9752, 0.8946]
[0.3200689273267984, 0.24706843256652355, 0.3930694226115942, 0.8947, 0.8753]
[0.2444122080899775, 0.03218081219019368, 0.4566436036318541, 0.9909, 0.868]
[0.21635690970569849, 0.10674314368776977, 0.32597067624628545, 0.948, 0.8981]
[0.2035786152422428, 0.12498221894800662, 0.2821750132501125, 0.9424, 0.9171]
[0.29513024272918703, 0.23387618803977966, 0.3563842994108796, 0.8898, 0.8985]
[0.4029760581731796, 0.43840651924610136, 0.36754559778273105, 0.8295, 0.8845]
[0.3627657168507576, 0.35865169382095335, 0.36687974074482915, 0.8288, 0.8972]
[0.26464787408113477, 0.2700628353357315, 0.2592329122364521, 0.8698, 0.9225]
[0.3824726029396057, 0.46178493514060975, 0.3031602714553475, 0.7635, 0.9156]
[0.3514248103141785, 0.43868491754531863, 0.2641647025167942, 0.

[0.31287341623306275, 0.41129388732910155, 0.21445294565558434, 0.8473, 0.9436]
[0.38085583572387693, 0.5189483971118927, 0.2427632746785879, 0.6997, 0.9357]
[0.35770034642219545, 0.5400064956665039, 0.17539419580288232, 0.7011, 0.956]
[0.43414911193847655, 0.5865520609855652, 0.2817461644232273, 0.6645, 0.9053]
[0.5087850812911987, 0.6412317261695862, 0.37633843283653257, 0.6172, 0.8719]
[0.46013143553733826, 0.5857758693695069, 0.3344870013535023, 0.6602, 0.9008]
[0.4222005461215973, 0.5848980908393859, 0.2595030020684004, 0.6487, 0.9323]
[0.4374427368164063, 0.6597972102165223, 0.2150882605969906, 0.5805, 0.9429]
[0.47409060537815095, 0.6900483713150024, 0.25813284121751784, 0.5852, 0.9261]
[0.6279630009174347, 0.7132008902549744, 0.542725111579895, 0.5933, 0.8451]
[0.30854683721065523, 0.3777119121551514, 0.23938176089525223, 0.8813, 0.935]
[0.2923755549669266, 0.3884540763378143, 0.1962970328733325, 0.8519, 0.9481]
[0.32609912209510805, 0.37978089637756346, 0.2724173473358154, 0.8

## CNN

In [8]:
def cnn_gan_factory():
    cnn_d = sgt.discriminators.cnn((28, 28, 1), 10)
    cnn_g = sgt.generators.cnn((100,), (7, 7))
    return sgt.gan.SGANBuilder(cnn_g, cnn_d, (28, 28, 1), (100,)).build()

In [9]:
cnn_histories = train_test_gan(cnn_gan_factory, x_train, y_train, x_test, y_test, epochs=50000, verbose=False)

  'Discrepancy between trainable weights and collected trainable'


[1.5266181573867799, 0.6581736038208008, 2.3950627113342287, 0.9993, 0.155]
[0.8894659698486328, 0.46102712755203246, 1.3179048151016235, 0.9985, 0.8395]
[0.8684154569625855, 0.2218155288219452, 1.515015379524231, 1.0, 0.5102]
[0.8795646850585938, 0.11662167875766755, 1.6425076969146728, 1.0, 0.3823]
[1.130880242729187, 0.16031430435180663, 2.1014461736679078, 1.0, 0.2078]
[1.1170890327453613, 0.06855704011917114, 2.1656210342407225, 1.0, 0.2203]
[0.9929895433425904, 0.08758923246860505, 1.8983898563385009, 1.0, 0.3092]
[0.9637416146278381, 0.09734471324682235, 1.8301385147094726, 1.0, 0.1787]
[0.8467981915473938, 0.014840499311685563, 1.6787558828353881, 1.0, 0.4105]
[0.911191126537323, 0.01503033764064312, 1.8073519193649292, 1.0, 0.1923]
[0.7248379307746887, 0.0031875285699963568, 1.446488332748413, 1.0, 0.3835]
[0.9534582098007203, 0.004687890223413706, 1.9022285318374634, 1.0, 0.4519]
[0.7845143791198731, 0.006523920612037182, 1.5625048370361327, 1.0, 0.3523]
[0.8691721320152282, 

[0.38382933056354523, 0.5197834910869599, 0.24787516993358732, 0.6871, 0.9434]
[0.32149990377426146, 0.49456463785171506, 0.148435169737041, 0.8001, 0.9654]
[0.3964980451107025, 0.6520879076004028, 0.14090818375991657, 0.6185, 0.9661]
[0.33857399230003354, 0.3980466520547867, 0.27910133403390647, 0.8444, 0.9272]
[0.3495643633365631, 0.36394208245277404, 0.3351866435021162, 0.8546, 0.9103]
[0.49586045055389405, 0.6330274152755737, 0.3586934882432222, 0.6534, 0.9086]
[0.3612495900154114, 0.4429230386734009, 0.27957614121586083, 0.8155, 0.925]
[0.38935787692070006, 0.5421806503295898, 0.23653510444220155, 0.6998, 0.9487]
[0.5136109098911286, 0.7282948203086853, 0.2989269975721836, 0.5635, 0.9252]
[0.5769280168533325, 0.7794682544708252, 0.37438777875751256, 0.5367, 0.9144]
[0.6461996832370758, 0.882741287612915, 0.4096580808363855, 0.511, 0.909]
[0.5988951130867004, 0.7763278427124023, 0.42146238180696965, 0.4576, 0.8984]
[0.6203972496032715, 0.8592284657001495, 0.381566033603251, 0.4313,

## Semi-Supervised

In [10]:
dense_ss_histories = {}
for num_examples in [100, 300, 500, 1000]:
    print(num_examples)
    dense_ss_histories[num_examples] = train_test_gan(
        dense_gan_factory, x_train, y_train, x_test, y_test, epochs=50000, num_labels=num_examples, verbose=False
    )

100


  'Discrepancy between trainable weights and collected trainable'


[2.515849605178833, 0.0017914900243282319, 5.029907718658447, 1.0, 0.1001]
[0.7649905866146087, 0.04282216841429472, 1.4871590037822724, 0.9947, 0.6915]
[1.4135153090953827, 0.3277073676288128, 2.4993232554912566, 0.8585, 0.6444]
[1.0723413139104843, 0.21002634667642414, 1.934656285238266, 0.9077, 0.6786]
[2.042448923873901, 0.6546838081359864, 3.4302140325546264, 0.6972, 0.5851]
[1.8237023423194885, 0.2081240452557802, 3.439280640888214, 0.9179, 0.576]
[1.5538107577323914, 0.09709620481431484, 3.010525311756134, 0.9676, 0.6079]
[2.3414799247741698, 0.1811539857506752, 4.5018058658599855, 0.9199, 0.4964]
[2.638203807067871, 0.37351196658611296, 4.902895632171631, 0.8127, 0.4819]
[2.995414539909363, 0.2567243714809418, 5.734104692840576, 0.8839, 0.4068]
[2.8628258432388307, 0.3291000228047371, 5.396551666641235, 0.8416, 0.4474]
[4.365072267913819, 0.9411244416236877, 7.789020093536377, 0.5517, 0.3183]
[2.249270009231567, 0.21430132138729097, 4.284238696861267, 0.8891, 0.5529]
[3.8758810

[8.032024713134765, 0.6378113030433655, 15.426238107299804, 0.6123, 0.026]
[8.066327494812011, 0.7186646424293518, 15.413990353393554, 0.5277, 0.028]
[8.169986408996582, 0.6324118275642395, 15.70756099243164, 0.6387, 0.0152]
[7.923958697509765, 0.4054349669456482, 15.442482421875, 0.8378, 0.025]
[7.319776390075684, 0.530298872089386, 14.10925389099121, 0.7274, 0.0921]
[7.745778936767578, 0.772349182510376, 14.719208685302734, 0.4728, 0.0591]
[7.874770028686523, 0.7530525532722473, 14.996487506103515, 0.4882, 0.0486]
[6.378655696868896, 0.7177016582489014, 12.039609730529785, 0.5545, 0.2125]
[8.01507415008545, 0.7138112245559692, 15.316337054443359, 0.5533, 0.0352]
[8.182550216674805, 0.6892609746932984, 15.675839511108398, 0.5669, 0.0178]
[7.32320975112915, 0.5491743075847626, 14.097245199584961, 0.7306, 0.09]
[7.517759736633301, 0.9583263027191162, 14.07719316558838, 0.395, 0.0942]
[7.859720950317382, 0.6796562471389771, 15.039785675048828, 0.5757, 0.0502]
[6.956099906921387, 0.755991

[1.9072146780014039, 0.37628837089538575, 3.43814098777771, 0.8367, 0.4804]
[1.3250133082389832, 0.2906609682559967, 2.359365653514862, 0.9091, 0.6054]
[2.200373101711273, 0.431927108669281, 3.9688191028594972, 0.8007, 0.466]
[2.4359854539871217, 0.4676918351173401, 4.40427907409668, 0.7719, 0.4072]
[2.563080751609802, 0.79978935546875, 4.326372143936157, 0.5425, 0.4337]
[2.0118855949401855, 0.6919116821289063, 3.331859508228302, 0.5793, 0.5406]
[2.9019658069610594, 0.7160561367034912, 5.087875482940674, 0.6137, 0.394]
[2.5219101522445677, 0.5536780158996581, 4.490142295455932, 0.7205, 0.4332]
[2.3372345809936523, 0.4900675008773804, 4.184401674461364, 0.7416, 0.4286]
[2.4142487722396853, 0.7281907217025757, 4.100306821250916, 0.5389, 0.4731]
[2.9840893787384033, 0.5379694083213806, 5.430209349060059, 0.7272, 0.3747]
[2.7238996980667114, 0.7130310663223267, 4.734768333435059, 0.5473, 0.3988]
[3.4470830768585206, 0.7738821088790894, 6.1202840431213374, 0.5413, 0.3306]
[2.207043991470337

[7.240650826263428, 0.7093758821487427, 13.771925750732422, 0.547, 0.0655]
[5.307984998321533, 0.5742014317512513, 10.041768562316895, 0.6926, 0.2428]
[5.630534813690185, 0.7009467061996459, 10.560122927856446, 0.5723, 0.2374]
[6.160115161895752, 0.810018799495697, 11.510211517333984, 0.4759, 0.1708]
[5.263171966552735, 0.5965653539657593, 9.92977859802246, 0.6677, 0.2628]
[6.432643675231934, 0.5829336260795593, 12.282353727722167, 0.6766, 0.1324]
[5.925945153808594, 0.4206330905437469, 11.431257217407227, 0.8247, 0.1844]
[7.226842660522461, 0.7562947498321533, 13.697390560913085, 0.5162, 0.0812]
[6.470574243164062, 0.6888536898612976, 12.252294766235352, 0.5738, 0.1251]
[5.639593807220459, 0.7215758147239685, 10.557611787414551, 0.5218, 0.2234]
[4.6655971042633055, 0.6275778980255127, 8.70361633758545, 0.6398, 0.3427]
[5.991794513702392, 0.6728104883193969, 11.310778541564941, 0.5868, 0.1864]
[6.608305249786377, 0.8622998319625854, 12.354310641479492, 0.4378, 0.1361]
[6.56276377105712

[1.7744902782440186, 0.5860330963134766, 2.9629474660873414, 0.6556, 0.5677]
[2.112378572273254, 0.6739541276931763, 3.5508030086517333, 0.6167, 0.4608]
[2.3005557231903078, 0.8658748126983643, 3.7352366451263426, 0.4802, 0.464]
[2.092632611656189, 0.8534998009681701, 3.3317654156684875, 0.4689, 0.549]
[2.1114181774139404, 0.5701677483558655, 3.652668605041504, 0.6839, 0.5097]
[1.8324682443618774, 0.6496458885192871, 3.0152905961990357, 0.6227, 0.6138]
[2.4891701164245608, 0.611570820236206, 4.366769414520264, 0.6202, 0.3988]
[2.3965236637115477, 0.6331191157341003, 4.159928211021423, 0.643, 0.5238]
[3.4161727054595947, 0.7775340481758117, 6.054811364746094, 0.503, 0.3345]
[2.473207534790039, 0.7539063844680786, 4.192508689117432, 0.5638, 0.4568]
[2.7185387947082518, 0.8629475640296936, 4.5741300308227535, 0.4788, 0.4063]
[2.5236544952392577, 0.7975617291450501, 4.249747270965576, 0.4747, 0.4366]
[2.6438565757751467, 0.6189908425331115, 4.66872230834961, 0.6288, 0.4268]
[2.450205332565

[0.9422626517295838, 1.164494453239441, 0.7200308458924294, 0.5268, 0.7919]
[0.33395760291665794, 0.04760699761789292, 0.6203082084655762, 0.9829, 0.842]
[0.3362341431826353, 0.07619663880616427, 0.5962716463536024, 0.9733, 0.8632]
[0.42933369869589805, 0.219424117384851, 0.6392432789737359, 0.9054, 0.87]
[0.48681901820898055, 0.40958174135684966, 0.564056295183301, 0.8178, 0.8717]
[0.32408938060812653, 0.037077117112092675, 0.6111016445791349, 0.9918, 0.876]
[0.5101758904516697, 0.3408222050189972, 0.6795295798503794, 0.8226, 0.8727]
[0.9608943796634674, 0.794165391778946, 1.1276233641505242, 0.6834, 0.755]
[0.5425233539342881, 0.21825981410741807, 0.866786894511804, 0.9095, 0.8326]
[0.46346196860671046, 0.21312224682569503, 0.7138016910135746, 0.9048, 0.8534]
[0.5856716544151306, 0.42508199331760405, 0.7462613170400262, 0.7656, 0.8394]
[0.6158110651493073, 0.3403493225812912, 0.8912728080809116, 0.8355, 0.8302]
[0.5100063983798027, 0.2882492258310318, 0.7317635701924562, 0.8815, 0.84

[1.3959858779907226, 0.707103707408905, 2.084868050670624, 0.6009, 0.6329]
[1.673159322166443, 0.7419014142036437, 2.6044172277450564, 0.5331, 0.5389]
[1.3737207359313965, 0.7183171666145325, 2.029124291038513, 0.5364, 0.6325]
[1.1016588409423829, 0.49313592228889463, 1.7101817632198333, 0.7282, 0.6627]
[1.6238709049224853, 0.4878020308494568, 2.7599397815704347, 0.7637, 0.5474]
[1.4178063041687012, 0.4519706088066101, 2.38364200630188, 0.794, 0.5678]
[1.3490315809249878, 0.5769094190597535, 2.121153745365143, 0.6873, 0.6404]
[1.4509976029396057, 0.5619345174789429, 2.34006068983078, 0.6819, 0.6147]
[1.4685921564102173, 0.6468318097114563, 2.290352499771118, 0.624, 0.5578]
[1.4822949854850769, 0.5582281790733338, 2.4063617911338806, 0.6861, 0.5507]
[1.180291506576538, 0.5576263501167298, 1.802956660079956, 0.694, 0.6907]
[1.4438987805366517, 0.5534797205448151, 2.334317840862274, 0.7096, 0.573]
[1.5383431617736816, 0.6124210678100586, 2.464265255355835, 0.6233, 0.5024]
[1.5163115855216

In [11]:
cnn_ss_histories = {}
for num_examples in [100, 300, 500, 1000]:
    print(num_examples)
    cnn_ss_histories[num_examples] = train_test_gan(
        cnn_gan_factory, x_train, y_train, x_test, y_test, epochs=50000, num_labels=num_examples, verbose=False
    )

100


  'Discrepancy between trainable weights and collected trainable'


[1.4985946659088134, 0.6073497560501099, 2.3898395767211915, 1.0, 0.0982]
[1.2441082008361817, 0.5166885471343994, 1.9715278526306153, 1.0, 0.6793]
[0.9565770183563232, 0.1810712418794632, 1.732082797241211, 1.0, 0.4182]
[1.179902975654602, 0.24691066746711732, 2.112895280838013, 1.0, 0.2873]
[1.2106107084274291, 0.1988023513555527, 2.222419067764282, 1.0, 0.3451]
[1.4543880603790282, 0.35323369922637937, 2.555542420196533, 1.0, 0.2802]
[2.3444155162811278, 0.06679221549034119, 4.62203882522583, 1.0, 0.1963]
[1.7076219764709473, 0.04701990394592285, 3.3682240493774414, 1.0, 0.1751]
[1.877545306968689, 0.012696791546046735, 3.7423938285827636, 1.0, 0.1714]
[2.1517030979156493, 0.010064041477441788, 4.29334214630127, 1.0, 0.1384]
[2.335878858947754, 0.0393317468225956, 4.632425974273682, 1.0, 0.1329]
[2.3067614032745363, 0.024973936840891837, 4.5885488845825195, 1.0, 0.1121]
[1.1127429695129394, 0.014883412313461304, 2.210602536010742, 1.0, 0.4919]
[1.9304768577575684, 0.0108791349887847

[1.2688737897872924, 0.0638268291413784, 2.4739207523345947, 1.0, 0.5367]
[0.8779583947658539, 0.09455883738994598, 1.6613579594612122, 1.0, 0.6331]
[0.7702453996658325, 0.039910281187295914, 1.5005805151939393, 1.0, 0.6908]
[0.6661845278263092, 0.12240806167125702, 1.2099609944820404, 1.0, 0.7264]
[0.9853360743522644, 0.05637223070263862, 1.9142999151229858, 1.0, 0.6492]
[0.7592845076084137, 0.06540474535226821, 1.453164266872406, 1.0, 0.7154]
[1.6259967693328858, 0.0489689478635788, 3.203024576950073, 1.0, 0.4728]
[1.0608111229896546, 0.06744588174819946, 2.054176362991333, 1.0, 0.6189]
[1.1905700150489806, 0.1120780104637146, 2.2690620166778563, 1.0, 0.6188]
[1.5736238021850586, 0.058039085906744003, 3.089208522415161, 1.0, 0.5186]
[1.5371327152252197, 0.07255722458362579, 3.0017082048416137, 1.0, 0.5316]
[2.7427111679077147, 0.2093959886074066, 5.276026351928711, 0.9943, 0.3349]
[4.02992430267334, 0.07428400175571441, 7.985564608764649, 1.0, 0.1787]
[6.9639351997375485, 0.127867965

[1.2871027811050415, 0.001415241764485836, 2.572790320587158, 1.0, 0.2517]
[1.8164931203842163, 0.0007495769061148167, 3.632236665344238, 1.0, 0.2431]
[1.293234081840515, 0.002590798008069396, 2.5838773635864256, 1.0, 0.2274]
[1.023189059638977, 0.0008712671702727676, 2.045506849479675, 1.0, 0.3461]
[0.939436173248291, 0.0019733432162553073, 1.8768990087509154, 1.0, 0.4994]
[0.6958663699150085, 0.0030996551595628263, 1.3886330846786499, 1.0, 0.5558]
[1.072591451358795, 0.002138513045758009, 2.143044388961792, 1.0, 0.394]
[1.250443324661255, 0.016779351446032525, 2.4841072994232176, 1.0, 0.2892]
[1.2728209365844727, 0.0034901155434548855, 2.5421517574310304, 1.0, 0.3731]
[1.0103725501060485, 0.0038624424524605276, 2.016882661628723, 1.0, 0.5015]
[0.8361039336204529, 0.002254210114106536, 1.66995365524292, 1.0, 0.4896]
[1.0346546023368834, 0.007231213159114122, 2.062077992630005, 1.0, 0.4209]
[0.9803707631111145, 0.010250457563996314, 1.9504910682678223, 1.0, 0.4959]
[0.7826128762245178,

[6.812694115447998, 0.1212199536204338, 13.50416828918457, 0.999, 0.0219]
[6.503511608886718, 0.07559082233905792, 12.931432405090332, 0.9999, 0.0306]
[4.783632330322265, 0.13518796906471253, 9.432076692199708, 0.9989, 0.1158]
[4.363242832946777, 0.1153192017197609, 8.611166456604003, 0.9989, 0.133]
[1.6785003946304322, 0.11429309030771255, 3.242707699584961, 0.9995, 0.5229]
[4.674165459442139, 0.06323242827653885, 9.285098510742188, 1.0, 0.1582]
[6.933655207824707, 0.07009594838619232, 13.797214468383789, 1.0, 0.0301]
[6.77259642791748, 0.14597169264554977, 13.399221122741698, 0.9929, 0.0414]
[4.36242292175293, 0.0905540940284729, 8.634291764068603, 1.0, 0.1623]
[2.8544394149780272, 0.062316152334213254, 5.646562687683105, 0.9999, 0.3227]
[2.6890940113067625, 0.16479268453121185, 5.21339532623291, 0.9977, 0.4017]
[2.3766890441894533, 0.11033392279148102, 4.643044163131714, 1.0, 0.3926]
[7.002189405822754, 0.15128800625801087, 13.853090827941895, 0.9774, 0.0261]
[7.549118685150146, 0.2

[0.547460053062439, 0.09338602545261383, 1.001534078025818, 1.0, 0.7023]
[0.6244592350959778, 0.1846057774066925, 1.0643126929283142, 1.0, 0.648]
[0.4785594481945038, 0.043086139678955075, 0.9140327557563782, 1.0, 0.7146]
[0.7185782565116883, 0.04012835041284561, 1.3970281642913818, 1.0, 0.6128]
[0.42006851572990417, 0.08523921325206757, 0.7548978179454804, 1.0, 0.764]
[0.4492713203430176, 0.04853493835926056, 0.8500077037334443, 1.0, 0.731]
[0.6387534721374512, 0.10296841201782227, 1.1745385330200195, 1.0, 0.6613]
[0.629540514087677, 0.0445586770772934, 1.214522350883484, 1.0, 0.6645]
[1.1324786958694457, 0.0992227605342865, 2.16573462600708, 1.0, 0.4938]
[0.43769925515651703, 0.08736013892889023, 0.7880383702039718, 1.0, 0.7556]
[0.2218289255619049, 0.04676227584481239, 0.39689557352662086, 1.0, 0.8858]
[0.3447379789471626, 0.03680056080818176, 0.6526753993034363, 1.0, 0.8201]
[0.5929706004142761, 0.04277396156787872, 1.1431672391891479, 1.0, 0.7058]
[0.49501846766471863, 0.092802620

1000
[1.5140812139511108, 0.6372059030532837, 2.3909565280914307, 1.0, 0.1162]
[1.0164955419540405, 0.5880085718154907, 1.4449825128555298, 0.9165, 0.8345]
[0.8740831069946289, 0.375041801738739, 1.3731244142532348, 1.0, 0.652]
[1.2146071392059326, 0.38408376750946044, 2.0451305084228517, 1.0, 0.2531]
[0.9918223151206971, 0.24289398896694184, 1.7407506462097169, 1.0, 0.324]
[1.0608794764518739, 0.2888744164943695, 1.8328845401763916, 1.0, 0.2664]
[0.9910440117835998, 0.2640333550453186, 1.718054666519165, 1.0, 0.3099]
[1.2784117855072021, 0.15555117945671082, 2.4012723918914793, 1.0, 0.2321]
[1.2474618350982667, 0.12767769827842712, 2.367245969390869, 1.0, 0.2364]
[1.2132732847213745, 0.0906486330628395, 2.3358979331970215, 1.0, 0.2564]
[0.7865627235412598, 0.09003275294303895, 1.48309269657135, 1.0, 0.4746]
[0.8269599536895752, 0.14154992685317994, 1.5123699779510498, 1.0, 0.4409]
[0.6408173104286193, 0.03669128397107124, 1.2449433336257933, 1.0, 0.4192]
[0.485282773065567, 0.03415838

[0.23521900000572205, 0.21849003794193267, 0.25194796276101844, 0.9967, 0.9362]
[0.274181675362587, 0.21726124222278595, 0.33110210888162256, 0.9954, 0.9158]
[0.2416612304329872, 0.15555986285209655, 0.3277625973193906, 0.9997, 0.9251]
[0.20258528941869736, 0.14788392207622528, 0.25728665644340215, 1.0, 0.9398]
[0.2446005278944969, 0.19606945247650145, 0.2931316035006195, 0.999, 0.9335]
[0.2941709589242935, 0.24582966387271882, 0.34251225348077713, 0.994, 0.9265]
[0.21853250337839128, 0.18144098439216613, 0.25562402309179305, 1.0, 0.935]
[0.29446388614177704, 0.2878195671081543, 0.3011082047978416, 0.9891, 0.9302]
[0.3060439649820328, 0.35993627099990844, 0.25215165859324623, 0.9691, 0.944]
[0.272543869805336, 0.21867931916713715, 0.326408419981692, 0.9956, 0.931]
[0.305760640501976, 0.32819320816993713, 0.28332807217927186, 0.9649, 0.9397]
[0.2754109127521515, 0.2374866891860962, 0.31333513546930625, 0.9948, 0.9358]
[0.38382756958007813, 0.45446909828186033, 0.3131860410379013, 0.8763

In [12]:
import pickle

with open('vanilla_gan_results.pkl', 'wb') as f:
    pickle.dump({
        'dense_histories': dense_histories,
        'cnn_histories': cnn_histories,
        'dense_ss_histories': dense_ss_histories,
        'cnn_ss_histories': cnn_ss_histories,
    }, f)