In [None]:
from __future__ import absolute_import, division, print_function, unicode_literals

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, models

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import PercentFormatter

In [None]:
from ipynb.fs.full.LIGO_data_simulation_functions import noise, wave, dataset, norm

In [None]:
fs = 4096
dt = 1/fs
data_size = 5*fs
sample_size = 10000

In [None]:
train_data, train_labels= dataset(sample_size,data_size)
test_data, test_labels = dataset(sample_size,data_size)
train_data = norm(train_data)
test_data = norm(test_data)

In [None]:
train_chirp_m = train_labels[:,0]
train_d = train_labels[:,1]
train_snr = train_labels[:,2]
test_chirp_m = test_labels[:,0]
test_d = test_labels[:,1]
test_snr = test_labels[:,2]

In [None]:
train_pm = train_labels[:,:2]
test_pm = test_labels[:,:2]

In [None]:
np.savez("/Users/jkliao117/Desktop/simulated data/pm_new/train data.npz", train_data = train_data)
np.savez("/Users/jkliao117/Desktop/simulated data/pm_new/train labels.npz", train_labels = train_labels)
np.savez("/Users/jkliao117/Desktop/simulated data/pm_new/test data.npz", test_data = test_data)
np.savez("/Users/jkliao117/Desktop/simulated data/pm_new/test labels.npz", test_labels = test_labels)

In [None]:
train_data = np.load("/Users/jkliao117/Desktop/simulated data/pm_new/train data.npz")['train_data']
train_labels = np.load("/Users/jkliao117/Desktop/simulated data/pm_new/train labels.npz")['train_labels']
test_data = np.load("/Users/jkliao117/Desktop/simulated data/pm_new/test data.npz")['test_data']
test_labels = np.load("/Users/jkliao117/Desktop/simulated data/pm_new/test labels.npz")['test_labels']

In [None]:
model = models.Sequential([
    layers.Reshape((train_data.shape[-1],1),input_shape=[train_data.shape[-1]]),
    layers.Conv1D(16,1, activation='relu'),
    layers.MaxPooling1D(4),
    layers.Dense(16, activation='relu'),
    layers.Conv1D(8,1, activation='relu'),
    layers.MaxPooling1D(4),
    layers.Dense(8, activation='relu'),
    layers.Flatten(),
    layers.Dense(256, activation='relu'),
    layers.Dense(256, activation='linear'),
    layers.Dense(128, activation='relu'),
    layers.Dense(128, activation='linear'),
    layers.Dense(64, activation='relu'),
    layers.Dense(64, activation='linear'),
    layers.Dense(32, activation='relu'),
    layers.Dense(32, activation='linear'),
    layers.Dense(1, activation='relu')
  ])
model.compile(optimizer='adam',
              loss='mean_squared_error', 
              metrics=['mean_squared_error'])
model.summary()

In [None]:
history = model.fit(train_data, train_chirp_m, epochs=30, # 30 epochs
                    validation_data=(test_data, test_chirp_m))

In [None]:
model.save_weights("/Users/jkliao117/Desktop/simulated data/pm_new/new_mass_model_weights.ckpt".format(epoch=30))

In [None]:
model.load_weights("/Users/jkliao117/Desktop/simulated data/pm_new/new_mass_model_weights.ckpt")

In [None]:
test_loss, test_acc = model.evaluate(test_data, test_chirp_m, verbose=2)

In [None]:
predictions = np.reshape(model.predict(test_data),(sample_size,))
difference = test_chirp_m-predictions

In [None]:
plt.figure(figsize=(10,8))
x = np.arange(0,100,1)
plt.scatter(test_chirp_m,predictions,alpha=0.25)
plt.plot(x,x,'r',alpha=0.5)
plt.xlim(0,100)
plt.ylim(0,100)
plt.grid()
plt.xlabel("true chirp mass ($M_{\odot}$)",fontsize=35)
plt.ylabel("predicted chirp mass ($M_{\odot}$)",fontsize=35)
plt.xticks(fontsize=30)
plt.yticks(fontsize=30)
plt.tight_layout()
#plt.savefig('/Users/jkliao117/Desktop/prediction vs label', transparent=True)

In [None]:
plt.figure(figsize=(12.5,8))
plt.hist(difference,bins=50,density=True)
plt.xlabel("difference between predcition and label",fontsize=35)
plt.ylabel("number of samples",fontsize=35)
plt.xticks(fontsize=30)
plt.yticks(fontsize=30)
plt.locator_params(axis='x',nbins=8)
plt.locator_params(axis='y',nbins=8)
plt.gca().yaxis.set_major_formatter(PercentFormatter(xmax=1))

In [None]:
chirp_m_error = np.mean(np.abs(difference)/test_chirp_m)
print(chirp_m_error)

In [None]:
plt.figure(figsize=(10,8))
mask = (test_snr<10)
plt.scatter(test_snr[mask],np.abs(difference)[mask],alpha=0.1)