In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import sys
sys.path.append('/content/drive/MyDrive/GWPAC/MSG-cGAN/Ours/classes')
!pip install tensorflow-addons



In [3]:
import tensorflow as tf 
import dataset as ds
import msggan as gan
from pathlib import Path
import zipfile
import csv

In [4]:
# define some variables/argumens to be used elsewhere 
# (some of these might be put inside the class rather than here)
loadweights = True
usebias     = False
nrns        = 128
glr         = 1e-2
dlr         = 1e-2
r1g         = 1e-1
epsln       = 1e-3
res         = 8192
bsize       = 32
ksize       = 9
zippath     = 'drive/MyDrive/GWPAC/MSG-cGAN/Ours/ins/time_series/has_postmergers/sims.zip'
metadata    = 'drive/MyDrive/GWPAC/MSG-cGAN/Ours/ins/time_series/has_postmergers/METADATA.csv'

In [5]:
# get all 16Gb of sims into the runtime before we start having to do everything else
# this should make the program run much faster than having to load each file just-in-time
print('unzipping files in runtime environment. this will take ~10 minutes.')
sims = zipfile.ZipFile(zippath, 'r')
sims.extractall()
sims.close()
print('...done.')

unzipping files in runtime environment. this will take ~10 minutes.
...done.


In [None]:
pathway = 'sims'
outpath = f'drive/MyDrive/GWPAC/MSG-cGAN/Ours/outs/tseq_labels_{res}d_{nrns}n_{glr}_{dlr}_{r1g}r_{epsln}e_{ksize}k_c'
weights = f'drive/MyDrive/GWPAC/MSG-cGAN/Ours/outs/tseq_labels_{res}d_{nrns}n_{glr}_{dlr}_{r1g}r_{epsln}e_{ksize}k_b/weights'

# save some of the data about the run to the attrs file, in case we want to look this info up later
Path(outpath).mkdir(exist_ok=True)
with open(outpath + '/attrs.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['loadweights', loadweights])
    writer.writerow(['neurons', nrns])
    writer.writerow(['g learning rate', glr])
    writer.writerow(['d learning rate', dlr])
    writer.writerow(['r1gamma', r1g])
    writer.writerow(['epsilon', epsln])
    writer.writerow(['resolution', res])
    writer.writerow(['batch size', bsize])

# load the dataset
dataset = ds.Dataset(batch_size=bsize, pathway=pathway, mdpath=metadata, outpath=outpath, nlabels=4, endres=res)

# call the GAN class
msgg = gan.MSG_CcGAN_ts(neurs=nrns, endres=res, g_lr=glr, d_lr=dlr, r1_gamma=r1g, epsilon=epsln, 
                        outpath=outpath, epochs=500000, nchannels=2, ksize=ksize, usebias=usebias)

# continue from previous save point?
if loadweights == True:
    msgg.D.load_weights(weights + '/discriminator34000.h5')
    msgg.G.load_weights(weights + '/generator34000.h5')

# start the training loop...
msgg.train(dataset)

# ... or make sample waveforms 
# msgg.makewaves(dataset)

Collecting objects for use...
Total objects found:  50141
Model: "Generator"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 noise (InputLayer)             [(None, 124)]        0           []                               
                                                                                                  
 labels (InputLayer)            [(None, 4)]          0           []                               
                                                                                                  
 concatenate (Concatenate)      (None, 128)          0           ['noise[0][0]',                  
                                                                  'labels[0][0]']                 
                                                                                                  
 reshape (Reshape)              