# This notebooks encodes light curve segments into the Shape and Intensity Features of lightcurve Segments. 

Requirements: Segmented light curves and trained VAE-LSTM network, see notebook 01

## Edit the paths below as required and run all cells to encode the segments

In [1]:
weights_dir = "../models/VAE_weights/model_2021-09-12_17-45-06.h5" # path to VAE-LSTM weights
segment_data_dir = "../data/segments/" # path to segmented data


### OUTPUTS
encodings_dir = "../data/encodings/" # where to save the encoded data

In [2]:
%load_ext autoreload
%autoreload 2
from IPython.display import clear_output
import os
import fnmatch
import numpy as np
import pickle
import matplotlib.pyplot as plt
import umap
from sklearn.mixture import GaussianMixture
from scipy import stats
# from sklearn.cluster import OPTICS

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.backend import mean
from tensorflow.keras.backend import square
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import CuDNNLSTM #CuDNNLSTM
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import RepeatVector
from tensorflow.keras.layers import TimeDistributed
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.layers import Flatten

from tensorflow.keras.utils import Sequence
from tensorflow.keras import Input
from tensorflow.keras import Model
# from tensorflow.keras.layers import BatchNormalization
# from tensorflow.keras.layers import Conv1D
from scipy.stats import zscore

plt.rcParams['figure.figsize'] = (5.0, 5.0)
plt.rcParams.update({'font.size': 12})
plt.rcParams.update(plt.rcParamsDefault)

np.random.seed(seed=11)

cwd = os.getcwd()

examples.directory is deprecated; in the future, examples will be found relative to the 'datapath' directory.
  "found relative to the 'datapath' directory.".format(key))
examples.directory is deprecated; in the future, examples will be found relative to the 'datapath' directory.
  "found relative to the 'datapath' directory.".format(key))
The text.latex.unicode rcparam was deprecated in Matplotlib 2.2 and will be removed in 3.1.
  "2.2", name=key, obj_type="rcparam", addendum=addendum)


# Load the VAE-LSTM model

In [3]:
class Sampling(layers.Layer):
    """Uses (z_mean, z_log_var) to sample z.
    https://www.tensorflow.org/guide/keras/custom_layers_and_models#putting_it_all_together_an_end-to-end_example"""
    
    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon


original_dim = 128
intermediate_dim = 1024
latent_dim = 20

# Define encoder model.
original_inputs = tf.keras.Input(shape=(original_dim,1), name='encoder_input')
input_err = Input(shape=(original_dim,1))
x = layers.CuDNNLSTM(intermediate_dim, return_sequences=False)(original_inputs)
z_mean = layers.Dense(latent_dim, name='z_mean')(x)
z_log_var = layers.Dense(latent_dim, name='z_log_var')(x)
z = Sampling()((z_mean, z_log_var))
encoder = tf.keras.Model(inputs=original_inputs, outputs=z, name='encoder')

# Define decoder model.
latent_inputs = tf.keras.Input(shape=(latent_dim,), name='z_sampling')
x = layers.RepeatVector(original_dim)(latent_inputs)
x = layers.CuDNNLSTM(intermediate_dim, return_sequences=True)(x)
outputs = layers.TimeDistributed(layers.Dense(1))(x)
decoder = tf.keras.Model(inputs=latent_inputs, outputs=outputs, name='decoder')

# Define VAE model.
outputs = decoder(z)
vae = tf.keras.Model(inputs=[original_inputs, input_err], outputs=outputs, name='vae')

vae.load_weights(weights_dir)

# use the encoder of VAE-LSTM to extract data features from L=128 light curve segments

In [3]:
# load segments

with open('{}/segments_1024s_256stride_0125cad_segmented_to64_train.pkl'.format(segment_data_dir), 'rb') as f:
    train_data = pickle.load(f)
with open('{}/segments_1024s_256stride_0125cad_segmented_to64_valid.pkl'.format(segment_data_dir), 'rb') as f:
    valid_data = pickle.load(f)
with open('{}/segments_1024s_256stride_0125cad_segmented_to64_test.pkl'.format(segment_data_dir), 'rb') as f:
    test_data = pickle.load(f)

In [5]:

def encode_data(segment_data, model, save_encoding=False, save_file_dir=None):
    """
    Custom function for the encoding of segmented light curves
    """
    
    cadence = np.min(np.diff([seg[2][0] for seg in segment_data]))
    
    # get rid of the meta-data and make numpy arrays with required dimensions
    train_data_counts = [seg[2][1] for seg in segment_data]
    
    #divide by cadence to turn counts to count rates
    train_data_counts = np.vstack(train_data_counts) /cadence
    train_data_counts = np.expand_dims(train_data_counts, axis=-1)
    train_data_errors = [seg[2][2] for seg in segment_data]
    train_data_errors = np.vstack(train_data_errors)
    
    #error values must be non-zero. replace zeros with a small value
    min_nonzero_train = np.min(train_data_errors[train_data_errors!=0])/10
    train_data_errors[train_data_errors==0] = min_nonzero_train
    train_data_errors = np.expand_dims(train_data_errors, axis=-1)
    
    # standardize data per segment
    train_data_errors = ((train_data_errors)/np.expand_dims(np.std(train_data_counts, axis=1), axis=1)).astype(np.float32)
    train_data_counts = zscore(train_data_counts, axis=1).astype(np.float32)  
    
    #     get lists of metadata
    seg_ids = [seg[0] for seg in segment_data]
    seg_labels = [seg[1] for seg in segment_data]
    
    segments = train_data_counts
    errors = train_data_errors
    
    trained_encoder = tf.keras.Model(inputs=model.input, outputs=[model.get_layer("z_mean").output, model.get_layer("z_log_var").output])
    segment_encoding = np.zeros((segments.shape[0], 2, 20))
    for seg_ind, seg in enumerate(segments):
        prediction = trained_encoder.predict([np.expand_dims(seg, axis=0), np.expand_dims(errors[seg_ind], axis=0)])
        segment_encoding[seg_ind][0] = prediction[0].flatten()
        segment_encoding[seg_ind][1] = prediction[1].flatten()
        clear_output(wait=True)
        print("Encoded {}/{} segments".format(seg_ind+1, segments.shape[0]))
        
    
    if save_encoding==True and save_file_dir != None:
        with open(save_file_dir, 'wb') as f:
            pickle.dump((seg_ids, seg_labels, segment_encoding), f)
        print("Encodings and metadata saved to: ", save_file_dir)
        
    else:
        return (seg_ids, seg_labels, segment_encoding)

In [6]:
encode_data(train_data, vae, save_encoding=True, save_file_dir='{}/segments_1024s_256stride_0125cad_segmented_to64_encoded_train.pkl'.format(encodings_dir))

Encoded 483584/483584 segments
Encodings and metadata saved to:  ../../../data_GRS1915/segments_1024s_256stride_0125cad_segmented_to64_encoded_train.pkl


In [7]:
encode_data(valid_data, vae, save_encoding=True, save_file_dir='{}/segments_1024s_256stride_0125cad_segmented_to64_encoded_valid.pkl'.format(encodings_dir))

Encoded 185984/185984 segments
Encodings and metadata saved to:  ../../../data_GRS1915/segments_1024s_256stride_0125cad_segmented_to64_encoded_valid.pkl


In [8]:
encode_data(test_data, vae, save_encoding=True, save_file_dir='{}/segments_1024s_256stride_0125cad_segmented_to64_encoded_test.pkl'.format(encodings_dir))

Encoded 36224/36224 segments
Encodings and metadata saved to:  ../../../data_GRS1915/segments_1024s_256stride_0125cad_segmented_to64_encoded_test.pkl
