In [None]:
%load_ext autoreload
%autoreload 2
from IPython.display import clear_output
import os
import fnmatch
import numpy as np
import pickle
import matplotlib.pyplot as plt
import umap
from sklearn.mixture import GaussianMixture
from scipy import stats
# from sklearn.cluster import OPTICS

# import tensorflow as tf
# from tensorflow.keras import layers
# from tensorflow.keras.backend import mean
# from tensorflow.keras.backend import square
# from tensorflow.keras.models import Sequential
# from tensorflow.keras.layers import CuDNNLSTM
# from tensorflow.keras.layers import Dense
# from tensorflow.keras.layers import RepeatVector
# from tensorflow.keras.layers import TimeDistributed
# from tensorflow.keras.callbacks import EarlyStopping
# from tensorflow.keras.callbacks import ModelCheckpoint
# from tensorflow.keras.layers import Flatten

# from tensorflow.keras.utils import Sequence
# from tensorflow.keras import Input
# from tensorflow.keras import Model
# from tensorflow.keras.layers import BatchNormalization
# from tensorflow.keras.layers import Conv1D
from scipy.stats import zscore

plt.rcParams['figure.figsize'] = (5.0, 5.0)
plt.rcParams.update({'font.size': 6})
plt.rcParams.update(plt.rcParamsDefault)

np.random.seed(seed=11)


cwd = os.getcwd()

if cwd.split("/")[1] == "export":
    data_dir = "../../../files_from_snuffy"
else:
    data_dir = "../../../data_GRS1915"


In [None]:
import pandas as pd
with open("{}/468202_len128_s2_4cad_start_times_errorfix.pkl".format(data_dir), 'rb') as f: # JMD(?) of the start of each segment
    seg_times = pickle.load(f)

In [None]:
with open('{}/468202_len128_s2_4cad_ids_errorfix.pkl'.format(data_dir), 'rb') as f:
    seg_ids = pickle.load(f)

In [None]:
with open("{}/shape_moments_GMM122_labels.pkl".format(data_dir), 'rb') as f: # output of LSTM autoencoder's decoder
    shape_moments_GMM122_labels = pickle.load(f)

In [None]:
# load (shuffled) list of unique IDs for the 468202 light curve segments. Format: observationID_segmentIndex,
# i.e. ['96701-01-48-00_3','20402-01-02-02_122','70703-01-01-000_1420', ...]
with open('{}/468202_len128_s2_4cad_ids_errorfix.pkl'.format(data_dir), 'rb') as f:
    seg_ids = pickle.load(f)

# make a list of observation ids; for each of the 468202 segments, get the ID of the observation that it comes from
#i.e. ['96701-01-48-00','20402-01-02-02', '70703-01-01-000',
seg_ObIDs = [seg.split("_")[0] for seg in seg_ids] # gets rid of the within-observation segment indices and creates a degenerate list of observation IDs

In [None]:
# make a dict that groups indices of segments of the same observation 
# i.e. where each observation id can be found in seg_ObIDs
#i.e. ObID_SegIndices_dict == {'10258-01-01-00': [916, 949, 1046...467528, 467578], ....}
ObID_SegIndices_dict = {key:[] for key in np.unique(seg_ObIDs)}
for ID_index, ObID in enumerate(seg_ObIDs):
    ObID_SegIndices_dict.setdefault(ObID, []).append(ID_index)

In [None]:
# make a dictionary of Gaussian component labels instead of segment indices  
#i.e. ObID_GaussComps_dict_500 == {'10258-01-01-00': [401, 433, 382...101, 152], ....}
ObID_GaussComps_dict_122 = {}
for ObID, Indices in ObID_SegIndices_dict.items():
    ObID_GaussComps_dict_122[ObID] = [shape_moments_GMM122_labels[ind] for ind in Indices]

In [None]:
# make a data frame containing the counts of light curve segments in each of the Gaussian components, for each observation
obs_component_counts_df_122 = pd.DataFrame(np.zeros((len(ObID_GaussComps_dict_122),len(np.unique(shape_moments_GMM122_labels)))), index=np.unique(seg_ObIDs), columns=list(range(122)), dtype=int)

In [None]:
ObID_seg_times_dict = {}
for ObID, Indices in ObID_SegIndices_dict.items():
    ObID_seg_times_dict[ObID] = [seg_times_ids_order.iloc[ind,0] for ind in Indices]

In [None]:
ObID_GaussComps_dict_122_chrono = {}
for ObID, times in ObID_seg_times_dict.items():
    ObID_GaussComps_dict_122_chrono[ObID] = np.take(ObID_GaussComps_dict_122[ObID], np.argsort(times))

In [None]:
with open("{}/ObID_GaussComps_dict_122_chrono.pkl".format(data_dir), 'wb') as f:
    pickle.dump(ObID_GaussComps_dict_122_chrono, f)

In [None]:
ObID_GaussComps_dict_122['10258-01-01-00']

In [None]:
ObID_GaussComps_dict_122_chrono

In [None]:
plt.plot(np.take(ObID_seg_times_dict['10258-01-01-00'], np.argsort(ObID_seg_times_dict['10258-01-01-00']) ))
plt.show()

In [None]:
seg_times_ids_order = seg_times.loc[seg_ids]

In [None]:
seg_times_ids_order

In [None]:
segments_dir = '{}/468202_len128_s2_4cad_counts_errorfix.pkl'.format(data_dir)
errors_dir = '{}/468202_len128_s2_4cad_errors_errorfix.pkl'.format(data_dir)
recos_dir = "{}/reconstructions_from_model_2020-08-30_11-42-38.pkl".format(data_dir)

with open(segments_dir, 'rb') as f:
    segments = pickle.load(f)
with open(errors_dir, 'rb') as f:
    errors = pickle.load(f)
with open(recos_dir, 'rb') as f:
    recos = pickle.load(f)

recos= recos*np.std(segments, axis=1) + np.mean(segments, axis=1)

# errors = ((errors)/np.expand_dims(np.std(segments, axis=1), axis=1)).astype(np.float32)
# segments = zscore(segments, axis=1).astype(np.float32)  # standardize per segment

In [None]:
# load observation classifications from Huppenkothen 2017
clean_belloni = open('{}/1915Belloniclass_updated.dat'.format(data_dir))
lines = clean_belloni.readlines()
states = lines[0].split()
belloni_clean = {}
for h,l in zip(states, lines[1:]):
    belloni_clean[h] = l.split()
    #state: obsID1, obsID2...
ob_state = {}
for state, obs in belloni_clean.items():
    if state == "chi1" or state == "chi2" or state == "chi3" or state == "chi4": state = "chi"
    for ob in obs:
        ob_state[ob] = state
        
# load IDs of segmented light curves: observationsID_segmentIndex
with open('{}/468202_len128_s2_4cad_ids_errorfix.pkl'.format(data_dir), 'rb') as f:
    seg_ids = pickle.load(f)

        
seg_ObIDs = [seg.split("_")[0] for seg in seg_ids] # get rid of the within-observation segment indices and create a degenerate list of observation IDs

classes = np.array(["alpha", "beta", "gamma", "delta", "theta", "kappa", "lambda", "mu", "nu", "rho", "phi", "chi", "eta", "omega"])
scales = []
segment_class = []
for ob in seg_ObIDs:
    if ob in ob_state:
        segment_class.append(ob_state[ob])
    else:
        segment_class.append("Unknown")

In [None]:
seg_ind=18

In [None]:
seg_ind+=1
pylab.rcParams['figure.figsize'] = (14, 3)
plt.errorbar(np.linspace(0,512, num=128), segments[seg_ind], yerr=errors[seg_ind], ecolor="black")

plt.plot(np.linspace(0,512, num=128), recos[seg_ind], c="magenta", zorder=5)
plt.show()

In [None]:
np.where(np.array(segment_class) == "alpha")

In [None]:
# with open('../../../data_GRS1915/1776_light_curves_1s_bin_errorfix.pkl', 'rb') as f:
#     lcs = pickle.load(f)
# with open('../../../data_GRS1915/1776_light_curves_1s_bin_ids_errorfix.pkl', 'rb') as f:
#     ids = pickle.load(f)
    
    
# clean_belloni = open('../../../data_GRS1915/1915Belloniclass_updated.dat')
# lines = clean_belloni.readlines()
# states = lines[0].split()
# belloni_clean = {}
# for h,l in zip(states, lines[1:]):
#     belloni_clean[h] = l.split()
#     #state: obsID1, obsID2...
# ob_state = {}
# for state, obs in belloni_clean.items():
#     if state == "chi1" or state == "chi2" or state == "chi3" or state == "chi4": state = "chi"
#     for ob in obs:
#         ob_state[ob] = state


        

# inv_ob_state = {}
# for k, v in ob_state.items():
#     inv_ob_state[v] = inv_ob_state.get(v, [])
#     inv_ob_state[v].append(k)


import matplotlib.pylab as pylab
import matplotlib.ticker as ticker
#https://stackoverflow.com/questions/8389636/creating-over-20-unique-legend-colors-using-matplotlib


pylab.rcParams['figure.figsize'] = (3.15, 8.4) # A4 size 210mm x 297mm

# colors = matplotlib.colors.CSS4_COLORS.keys()
# colors = np.array(list(colors))


# NUM_COLORS = 14
# cm = plt.get_cmap("jet")#('gist_rainbow')
# colors = [cm(1.*i/NUM_COLORS) for i in range(NUM_COLORS)]

ids_ar = np.array(segment_class)

# class_names = list(inv_ob_state.keys())


alpha = np.where(ids_ar == "alpha")[0][1]#1
beta= np.where(ids_ar == "beta")[0][7]#7
gamma=np.where(ids_ar == "gamma")[0][4]
delta=np.where(ids_ar == "delta")[0][5]
theta=np.where(ids_ar == "theta")[0][4]#4
kappa=np.where(ids_ar == "kappa")[0][3]
lambda1=np.where(ids_ar == "lambda")[0][2]
mu=np.where(ids_ar == "mu")[0][1]
nu=np.where(ids_ar == "nu")[0][1]
rho=np.where(ids_ar == "rho")[0][1]
phi=np.where(ids_ar == "phi")[0][1]
chi=np.where(ids_ar == "chi")[0][1]
eta=np.where(ids_ar == "eta")[0][4]
omega=np.where(ids_ar == "omega")[0][3]


selected_lcs = [alpha,beta,gamma,delta,theta,kappa,lambda1,mu,nu,rho,phi,chi,eta,omega]


fig, axes = plt.subplots(nrows=14, ncols=1)
axes = axes.flatten()

plt.subplots_adjust(hspace=0.05)
plt.subplots_adjust(wspace=0.01)

# good_classes = ["delta", "mu", "rho", "phi"]
intervals = {}

class_names = np.array(["alpha", "beta", "gamma", "delta", "theta", "kappa", "lambda", "mu", "nu", "rho", "phi", "chi", "eta", "omega"])


for plot_ind in range(14):
    light_c = selected_lcs[plot_ind]
    class_name = class_names[plot_ind]    
    axes[plot_ind].plot(np.linspace(0,512, num=128), segments[light_c], c="black", linewidth=0.5, zorder=1, label="Input")
    axes[plot_ind].plot(np.linspace(0,512, num=128), recos[light_c], c="magenta", linewidth=0.5, zorder=2, label="Reconstruction")
    axes[plot_ind].set_ylim([0, 15000])
    axes[plot_ind].set_xlim([0, 512])
    axes[plot_ind].text(0.99,0.99,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=8)
        
    axes[plot_ind].tick_params(axis="y", which="major", length=2, width=0.75, labelsize=6, direction="in")
#     axes[plot_ind].tick_params(axis="y", which="minor", length=2, width=0.75, labelsize=8, direction="in")


#     if plot_ind%2 == 0:
#         axes[plot_ind].tick_params(axis="y", which="major", length=5, width=1, labelsize=8, direction="in")
#     else:
#         axes[plot_ind].tick_params(axis="y", which="major", length=5, width=1, labelsize=0, direction="in")
#         plt.setp(axes[plot_ind].get_yticklabels(), visible=False)

    if plot_ind == 6:
        axes[plot_ind].set_ylabel("Rate (kcts/s)", size=6)
    if plot_ind == 13:#plot_ind == 12 or plot_ind == 13
        axes[plot_ind].tick_params(axis="x", which="major", length=2, width=0.75, labelsize=6, direction="in")
        axes[plot_ind].set_xlabel("Time (s)", size=6)

    else:
        axes[plot_ind].tick_params(axis="x", which="major", length=2, width=0.75, labelsize=0, direction="in")
        plt.setp(axes[plot_ind].get_xticklabels(), visible=False)
    
#     axes[plot_ind].set_yticks([0, 2000, 4000, 6000, 8000, 10000, 12000, 14000])
#     axes[plot_ind].set_yticklabels([0, "", "", 6, "", "", 12, ""])
    axes[plot_ind].set_yticks([0, 2500, 5000, 7500, 10000, 12500])
    axes[plot_ind].set_yticklabels([0, "", 5, "", 10, "",])
    axes[plot_ind].set_xticks([0, 100, 200,300, 400, 500])
    axes[plot_ind].set_xticklabels([0, 100, 200,300, 400, 500])

    
axes.reshape((14,1))


# axes[0][0].tick_params(axis="x", which="major", length=5, width=1, labelsize=5, direction="in")
# axes[0][1].tick_params(axis="x", which="major", length=5, width=1, labelsize=5, direction="in")
# axes[1][0].tick_params(axis="x", which="major", length=5, width=1, labelsize=25, direction="in")
# axes[1][1].tick_params(axis="x", which="major", length=5, width=1, labelsize=25, direction="in")

# axes[0][0].tick_params(axis="y", which="major", length=5, width=1, labelsize=25, direction="in")
# axes[1][0].tick_params(axis="y", which="major", length=5, width=1, labelsize=25, direction="in")
# axes[0][1].tick_params(axis="y", which="major", length=5, width=1, labelsize=5, direction="in")
# axes[1][1].tick_params(axis="y", which="major", length=5, width=1, labelsize=5, direction="in")

# plt.setp(axes[0][1].get_yticklabels(), visible=False)
# plt.setp(axes[1][1].get_yticklabels(), visible=False)
# plt.setp(axes[0][1].get_xticklabels(), visible=False)
# plt.setp(axes[0][0].get_xticklabels(), visible=False)

# axes[1][0].set_xticks([500,1000,1500,2000])
# axes[1][1].set_xticks([500,1000,1500,2000])


# axes[0][0].yaxis.set_label_coords(-0.2, 0)
# axes[1][0].xaxis.set_label_coords(1, -0.1)

# plt.suptitle("14 classes of activity of x-ray black hole binary GRS1915+105", fontsize=40, y=0.92)

# plt.gca().set_axis_off()
# plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
#             hspace = 0, wspace = 0)
# plt.margins(0,0)
# plt.gca().xaxis.set_major_locator(plt.NullLocator())
# plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.savefig('all_classes_of_GRS1915_segments_with_reconstructions.png', dpi=300, bbox_inches = 'tight',pad_inches = 0)
# plt.legend()
fig.show()

In [None]:
import matplotlib.pylab as pylab
import matplotlib.ticker as ticker
#https://stackoverflow.com/questions/8389636/creating-over-20-unique-legend-colors-using-matplotlib


# pylab.rcParams['figure.figsize'] = (3.15, 8.4) # A4 size 210mm x 297mm
plt.rcParams['figure.figsize'] = (6.97, 8.4)
plt.rcParams.update({'font.size': 6})


# colors = matplotlib.colors.CSS4_COLORS.keys()
# colors = np.array(list(colors))


# NUM_COLORS = 14
# cm = plt.get_cmap("jet")#('gist_rainbow')
# colors = [cm(1.*i/NUM_COLORS) for i in range(NUM_COLORS)]

ids_ar = np.array(segment_class)

# class_names = list(inv_ob_state.keys())


alpha = np.where(ids_ar == "alpha")[0][1]#1
beta= np.where(ids_ar == "beta")[0][7]#7
gamma=np.where(ids_ar == "gamma")[0][4]
delta=np.where(ids_ar == "delta")[0][5]
theta=np.where(ids_ar == "theta")[0][4]#4
kappa=np.where(ids_ar == "kappa")[0][3]
lambda1=np.where(ids_ar == "lambda")[0][2]
mu=np.where(ids_ar == "mu")[0][1]
nu=np.where(ids_ar == "nu")[0][1]
rho=np.where(ids_ar == "rho")[0][1]
phi=np.where(ids_ar == "phi")[0][1]
chi=np.where(ids_ar == "chi")[0][1]
eta=np.where(ids_ar == "eta")[0][4]
omega=np.where(ids_ar == "omega")[0][3]


# selected_lcs = [alpha,beta,gamma,delta,theta,kappa,lambda1,mu,nu,rho,phi,chi,eta,omega]
selected_lcs = [alpha,beta,chi,delta,eta,gamma,kappa,lambda1,mu,nu,omega,phi,rho,theta]


fig, axes = plt.subplots(nrows=14, ncols=2)
# axes = axes.flatten()

plt.subplots_adjust(hspace=0.05)
plt.subplots_adjust(wspace=0.01)

# good_classes = ["delta", "mu", "rho", "phi"]
intervals = {}

class_names = np.array(["alpha", "beta", "gamma", "delta", "theta", "kappa", "lambda", "mu", "nu", "rho", "phi", "chi", "eta", "omega"])
class_names = np.sort(class_names)

for plot_ind in range(14):
    light_c = selected_lcs[plot_ind]
    class_name = class_names[plot_ind]    
    axes[plot_ind, 0].plot(np.linspace(0,512, num=128), segments[light_c], c="black", linewidth=0.5, zorder=1, label="Input")
    axes[plot_ind, 0].plot(np.linspace(0,512, num=128), recos[light_c], c="magenta", linewidth=0.5, zorder=2, label="Reconstruction")
    axes[plot_ind, 1].plot(np.linspace(0,512, num=128), segments[light_c]/1000, c="black", linewidth=0.5, zorder=1, label="Input")
    axes[plot_ind, 1].plot(np.linspace(0,512, num=128), recos[light_c]/1000, c="magenta", linewidth=0.5, zorder=2, label="Reconstruction")
    axes[plot_ind, 1].yaxis.tick_right()
    axes[plot_ind, 0].set_ylim([0, 15000])
    axes[plot_ind, 0].set_xlim([0, 512])
    axes[plot_ind, 1].set_xlim([0, 512])

    axes[plot_ind, 0].text(0.99,0.99,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind, 0].transAxes, size=10)
        
    axes[plot_ind, 0].tick_params(axis="y", which="major", length=2, width=0.75, labelsize=6, direction="in")
    axes[plot_ind, 1].tick_params(axis="y", which="major", length=2, width=0.75, labelsize=6, direction="in")
#     axes[plot_ind].tick_params(axis="y", which="minor", length=2, width=0.75, labelsize=8, direction="in")


#     if plot_ind%2 == 0:
#         axes[plot_ind].tick_params(axis="y", which="major", length=5, width=1, labelsize=8, direction="in")
#     else:
#         axes[plot_ind].tick_params(axis="y", which="major", length=5, width=1, labelsize=0, direction="in")
#         plt.setp(axes[plot_ind].get_yticklabels(), visible=False)

    if plot_ind == 6:
        axes[plot_ind, 0].set_ylabel("Rate (kcts/s)", size=6)
    if plot_ind == 13:#plot_ind == 12 or plot_ind == 13:
        axes[plot_ind, 0].tick_params(axis="x", which="major", length=2, width=0.75, labelsize=6, direction="in")
        axes[plot_ind, 1].tick_params(axis="x", which="major", length=2, width=0.75, labelsize=6, direction="in")
        axes[plot_ind, 0].set_xlabel("Time (s)", size=6, x=1)

    else:
        axes[plot_ind, 0].tick_params(axis="x", which="major", length=2, width=0.75, labelsize=0, direction="in")
        plt.setp(axes[plot_ind, 0].get_xticklabels(), visible=False)
        axes[plot_ind, 1].tick_params(axis="x", which="major", length=2, width=0.75, labelsize=0, direction="in")
        plt.setp(axes[plot_ind, 1].get_xticklabels(), visible=False)
    
#     axes[plot_ind].set_yticks([0, 2000, 4000, 6000, 8000, 10000, 12000, 14000])
#     axes[plot_ind].set_yticklabels([0, "", "", 6, "", "", 12, ""])
    axes[plot_ind, 0].set_yticks([0, 2500, 5000, 7500, 10000, 12500])
    axes[plot_ind, 0].set_yticklabels([0, "", 5, "", 10, "",])
    axes[plot_ind, 0].set_xticks([0, 100, 200,300, 400, 500])
    axes[plot_ind, 0].set_xticklabels([0, "", 200,"", 400, ""])
    axes[plot_ind, 1].set_xticks([0, 100, 200,300, 400, 500])
    axes[plot_ind, 1].set_xticklabels([0, "", 200,"", 400, ""])
    
# axes.reshape((14,1))


# axes[0][0].tick_params(axis="x", which="major", length=5, width=1, labelsize=5, direction="in")
# axes[0][1].tick_params(axis="x", which="major", length=5, width=1, labelsize=5, direction="in")
# axes[1][0].tick_params(axis="x", which="major", length=5, width=1, labelsize=25, direction="in")
# axes[1][1].tick_params(axis="x", which="major", length=5, width=1, labelsize=25, direction="in")

# axes[0][0].tick_params(axis="y", which="major", length=5, width=1, labelsize=25, direction="in")
# axes[1][0].tick_params(axis="y", which="major", length=5, width=1, labelsize=25, direction="in")
# axes[0][1].tick_params(axis="y", which="major", length=5, width=1, labelsize=5, direction="in")
# axes[1][1].tick_params(axis="y", which="major", length=5, width=1, labelsize=5, direction="in")

# plt.setp(axes[0][1].get_yticklabels(), visible=False)
# plt.setp(axes[1][1].get_yticklabels(), visible=False)
# plt.setp(axes[0][1].get_xticklabels(), visible=False)
# plt.setp(axes[0][0].get_xticklabels(), visible=False)

# axes[1][0].set_xticks([500,1000,1500,2000])
# axes[1][1].set_xticks([500,1000,1500,2000])


# axes[0][0].yaxis.set_label_coords(-0.2, 0)
# axes[1][0].xaxis.set_label_coords(1, -0.1)

# plt.suptitle("14 classes of activity of x-ray black hole binary GRS1915+105", fontsize=40, y=0.92)

# plt.gca().set_axis_off()
# plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
#             hspace = 0, wspace = 0)
# plt.margins(0,0)
# plt.gca().xaxis.set_major_locator(plt.NullLocator())
# plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.savefig('figures/segments_fit_alphabetical.png', dpi=300, bbox_inches = 'tight',pad_inches = 0)
# plt.legend()
plt.show()

In [None]:
plt.rcParams['figure.figsize'] = (3.15, 2.15)
plt.rcParams.update({'font.size': 6})
subset_size = 50e3
no_subsets = 5
data_subsets = [(int(start*subset_size), int((1+start)*subset_size)) for start in range(no_subsets)][3:]
# gaussian_mixture_bics = np.zeros((len(data_subsets),len(component_no_list)))
component_no_list = np.concatenate(([10, 25, 50, 70, 80, 90, 100], np.arange(110,142, 2), [150]))
bics = np.array([
[1.853893282449822500e+07, 1.722017647717545927e+07, 1.658319324933229759e+07, 1.630327464005833492e+07, 1.622106282425292768e+07, 1.616687566882342100e+07, 1.615730210213817097e+07, 1.613055099964568764e+07, 1.611427361123152077e+07, 1.614833229564627074e+07, 1.612438706897565909e+07, 1.610314177309925482e+07, 1.611599805476822704e+07, 1.611190407219867036e+07, 1.613187932035312243e+07, 1.612686166295694374e+07, 1.610543674698158540e+07, 1.610464430899482593e+07, 1.612001376647870429e+07, 1.613567941551903449e+07, 1.613709874137604795e+07, 1.613797303043186106e+07, 1.613711653342354484e+07, 1.616225371089820936e+07],
[18649056.479567  , 17292560.83637504, 16573116.12680366,
16322533.02137523, 16216074.19606247, 16154734.81151236,
16129106.07263587, 16131854.50699835, 16119122.21325973,
16135573.44851911, 16158603.67399981, 16151476.94026508,
16127494.15584713, 16097320.76032251, 16140875.91346822,
16115866.55727115, 16118995.33300484, 16105372.49759903,
16112728.50581094, 16115875.38019666, 16153486.95589538,
16167876.91470784, 16139959.11138932, 1.616255141896034777e+07],
[1.861993137938806042e+07, 1.732917515701735020e+07, 1.657063235443255119e+07, 1.635879176514221355e+07, 1.626360311877287738e+07, 1.618802062452851236e+07, 1.614967957935284078e+07, 1.615186423880326003e+07, 1.616757082466319762e+07, 1.619746861296005547e+07, 1.617370181996636279e+07, 1.616590259545178898e+07, 1.616194254895075411e+07, 1.613518411618221551e+07, 1.617350772094361670e+07, 1.614869119315046445e+07, 1.612203835914292000e+07, 1.615331379861885123e+07, 1.613294921764167212e+07, 1.615344635884070210e+07, 1.613590505756535381e+07, 1.614922787741473131e+07, 1.615719204464513995e+07, 1.616830989636509307e+07],
[1.853940395327756181e+07, 1.730897341657006368e+07, 1.658365185461111180e+07, 1.630294958357473835e+07, 1.621625770949508809e+07, 1.618283259364969283e+07, 1.613909268951667286e+07, 1.611507191162823141e+07, 1.611494689585074782e+07, 1.610472038442745246e+07, 1.612916576028856635e+07, 1.608498335807089880e+07, 1.610001351276045479e+07, 1.612772030996491760e+07, 1.609651759386808611e+07, 1.608189611665542610e+07, 1.608605904953365773e+07, 1.611580276102753356e+07, 1.613783609712784737e+07, 1.611093687816699408e+07, 1.610780632114464603e+07, 1.613467261296702735e+07, 1.611441260656008683e+07, 1.616245731585987844e+07],
[18503327.3887386 , 17273712.70135051, 16551224.30526865,
        16276096.82541498, 16195889.17302874, 16159136.8378545 ,
        16140232.2930156 , 16153301.53835457, 16142290.54757776,
        16104116.02340264, 16116909.31318822, 16089586.5741044 ,
        16079214.70663652, 16110927.1401799 , 16111921.31548247,
        16150935.62304156, 16161332.18917462, 16142814.68089103,
        16148160.39985222, 16165847.30070143, 16136459.53606088,
        16178912.17868229, 16116594.94776744, 16172298.90584182]
])
minimums = []
fig, axes = plt.subplots(nrows=1, ncols=1)

for iteration, dataset_bics in enumerate(bics):
    minimums.append(component_no_list[np.argmin(dataset_bics)])
    axes.plot(component_no_list,dataset_bics/1e7)


axes.plot(component_no_list,np.mean(bics, axis=0)/1e7, marker = ".", c="cyan", label="mean")
axes.set_xlim([90,150])
axes.set_ylim([1.607, 1.62])
plt.legend(fontsize=6)
# plt.title("BIC values for Gaussian mixture models fit to 5 data subsets of 50k segments")
plt.xlabel("No. Gaussian components", size=6)
plt.ylabel("Bayesian information criterion ($10^7$)", size=6)

axes.tick_params(axis="x", which="major", length=2, width=0.75, labelsize=6, direction="in")
axes.tick_params(axis="y", which="major", length=2, width=0.75, labelsize=6, direction="in")

# y=np.linspace(1.26, 1.28, 5)
# scale_y =1e7
# ticks_y = ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
# axes.yaxis.set_major_formatter(ticks_y)

plt.savefig('figures/bic_grid_search_1024lstm.png', dpi=300, bbox_inches = 'tight',pad_inches = 0.01)



plt.show()

In [None]:
plt.rcParams['figure.figsize'] = (3.15, 2.15)
plt.rcParams.update({'font.size': 6})
subset_size = 50e3
no_subsets = 5
data_subsets = [(int(start*subset_size), int((1+start)*subset_size)) for start in range(no_subsets)][3:]
# gaussian_mixture_bics = np.zeros((len(data_subsets),len(component_no_list)))
component_no_list = np.concatenate(([10, 25, 50, 70, 80, 90, 100], np.arange(110,142, 2), [150]))
bics = np.array([
[1.853893282449822500e+07, 1.722017647717545927e+07, 1.658319324933229759e+07, 1.630327464005833492e+07, 1.622106282425292768e+07, 1.616687566882342100e+07, 1.615730210213817097e+07, 1.613055099964568764e+07, 1.611427361123152077e+07, 1.614833229564627074e+07, 1.612438706897565909e+07, 1.610314177309925482e+07, 1.611599805476822704e+07, 1.611190407219867036e+07, 1.613187932035312243e+07, 1.612686166295694374e+07, 1.610543674698158540e+07, 1.610464430899482593e+07, 1.612001376647870429e+07, 1.613567941551903449e+07, 1.613709874137604795e+07, 1.613797303043186106e+07, 1.613711653342354484e+07, 1.616225371089820936e+07],
[18649056.479567  , 17292560.83637504, 16573116.12680366,
16322533.02137523, 16216074.19606247, 16154734.81151236,
16129106.07263587, 16131854.50699835, 16119122.21325973,
16135573.44851911, 16158603.67399981, 16151476.94026508,
16127494.15584713, 16097320.76032251, 16140875.91346822,
16115866.55727115, 16118995.33300484, 16105372.49759903,
16112728.50581094, 16115875.38019666, 16153486.95589538,
16167876.91470784, 16139959.11138932, 1.616255141896034777e+07],
[1.861993137938806042e+07, 1.732917515701735020e+07, 1.657063235443255119e+07, 1.635879176514221355e+07, 1.626360311877287738e+07, 1.618802062452851236e+07, 1.614967957935284078e+07, 1.615186423880326003e+07, 1.616757082466319762e+07, 1.619746861296005547e+07, 1.617370181996636279e+07, 1.616590259545178898e+07, 1.616194254895075411e+07, 1.613518411618221551e+07, 1.617350772094361670e+07, 1.614869119315046445e+07, 1.612203835914292000e+07, 1.615331379861885123e+07, 1.613294921764167212e+07, 1.615344635884070210e+07, 1.613590505756535381e+07, 1.614922787741473131e+07, 1.615719204464513995e+07, 1.616830989636509307e+07],
[1.853940395327756181e+07, 1.730897341657006368e+07, 1.658365185461111180e+07, 1.630294958357473835e+07, 1.621625770949508809e+07, 1.618283259364969283e+07, 1.613909268951667286e+07, 1.611507191162823141e+07, 1.611494689585074782e+07, 1.610472038442745246e+07, 1.612916576028856635e+07, 1.608498335807089880e+07, 1.610001351276045479e+07, 1.612772030996491760e+07, 1.609651759386808611e+07, 1.608189611665542610e+07, 1.608605904953365773e+07, 1.611580276102753356e+07, 1.613783609712784737e+07, 1.611093687816699408e+07, 1.610780632114464603e+07, 1.613467261296702735e+07, 1.611441260656008683e+07, 1.616245731585987844e+07],
[18503327.3887386 , 17273712.70135051, 16551224.30526865,
        16276096.82541498, 16195889.17302874, 16159136.8378545 ,
        16140232.2930156 , 16153301.53835457, 16142290.54757776,
        16104116.02340264, 16116909.31318822, 16089586.5741044 ,
        16079214.70663652, 16110927.1401799 , 16111921.31548247,
        16150935.62304156, 16161332.18917462, 16142814.68089103,
        16148160.39985222, 16165847.30070143, 16136459.53606088,
        16178912.17868229, 16116594.94776744, 16172298.90584182]
])
minimums = []
fig, axes = plt.subplots(nrows=1, ncols=1)

for iteration, dataset_bics in enumerate(bics):
    minimums.append(component_no_list[np.argmin(dataset_bics)])
    axes.plot(component_no_list,dataset_bics/1e7)


axes.plot(component_no_list,np.mean(bics, axis=0)/1e7, marker = ".", c="cyan", label="mean")
axes.set_xlim([90,150])
axes.set_ylim([1.607, 1.62])
plt.legend(fontsize=6)
# plt.title("BIC values for Gaussian mixture models fit to 5 data subsets of 50k segments")
plt.xlabel("No. Gaussian components", size=6)
plt.ylabel("Bayesian information criterion ($10^7$)", size=6)

axes.tick_params(axis="x", which="major", length=2, width=0.75, labelsize=6, direction="in")
axes.tick_params(axis="y", which="major", length=2, width=0.75, labelsize=6, direction="in")

# y=np.linspace(1.26, 1.28, 5)
# scale_y =1e7
# ticks_y = ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
# axes.yaxis.set_major_formatter(ticks_y)

# plt.savefig('figures/bic_grid_search_1024lstm.png', dpi=300, bbox_inches = 'tight',pad_inches = 0.01)



plt.show()

In [None]:
plt.rcParams['figure.figsize'] = (3.15, 2.15)
plt.rcParams.update({'font.size': 6})
subset_size = 50e3
no_subsets = 5
data_subsets = [(int(start*subset_size), int((1+start)*subset_size)) for start in range(no_subsets)][3:]
# gaussian_mixture_bics = np.zeros((len(data_subsets),len(component_no_list)))
[[16186733.20254564, 16241434.8049347 ],
       [16182579.3219899 , 16240088.8081091 ],
       [16174865.74778001, 16221455.58526794],
       [16199528.61543708, 16231765.41996747],
       [16154733.1189743 , 16193377.90690356]])
minimums = []
fig, axes = plt.subplots(nrows=1, ncols=1)

for iteration, dataset_bics in enumerate(bics):
    minimums.append(component_no_list[np.argmin(dataset_bics)])
    axes.plot(component_no_list,dataset_bics/1e7)


# axes.plot(component_no_list,np.mean(bics, axis=0)/1e7, marker = ".", c="cyan", label="mean")
# axes.set_xlim([90,150])
# axes.set_ylim([1.607, 1.62])
plt.legend(fontsize=6)
# plt.title("BIC values for Gaussian mixture models fit to 5 data subsets of 50k segments")
plt.xlabel("No. Gaussian components", size=6)
plt.ylabel("Bayesian information criterion ($10^7$)", size=6)

axes.tick_params(axis="x", which="major", length=2, width=0.75, labelsize=6, direction="in")
axes.tick_params(axis="y", which="major", length=2, width=0.75, labelsize=6, direction="in")

# y=np.linspace(1.26, 1.28, 5)
# scale_y =1e7
# ticks_y = ticker.FuncFormatter(lambda y, pos: '{0:g}'.format(y/scale_y))
# axes.yaxis.set_major_formatter(ticks_y)

# plt.savefig('figures/bic_grid_search_1024lstm.png', dpi=300, bbox_inches = 'tight',pad_inches = 0.01)



plt.show()

In [None]:
#160, 170 bics
[[16186733.20254564, 16241434.8049347 ],
       [16182579.3219899 , 16240088.8081091 ],
       [16174865.74778001, 16221455.58526794],
       [16199528.61543708, 16231765.41996747],
       [16154733.1189743 , 16193377.90690356]]
#180, 190, 200 bics

[[16237322.11813296, 16275376.95864639, 16296302.0147642 ],
       [16243712.80555515, 16299079.98722455, 16300187.13137991],
       [16210517.52764545, 16231785.91035465, 16326195.40168083],
       [16258597.89322176, 16287211.07189597, 16294112.22156437],
       [16228693.61702962, 16266838.69559363, 16286374.98240227]]

In [None]:
component_no_list = np.concatenate(([10, 25, 50, 70, 80, 90, 100], np.arange(110,142, 2), [150, 160, 170, 180, 190, 200]))
bics = np.array([
[1.853893282449822500e+07, 1.722017647717545927e+07, 1.658319324933229759e+07, 1.630327464005833492e+07, 1.622106282425292768e+07, 1.616687566882342100e+07, 1.615730210213817097e+07, 1.613055099964568764e+07, 1.611427361123152077e+07, 1.614833229564627074e+07, 1.612438706897565909e+07, 1.610314177309925482e+07, 1.611599805476822704e+07, 1.611190407219867036e+07, 1.613187932035312243e+07, 1.612686166295694374e+07, 1.610543674698158540e+07, 1.610464430899482593e+07, 1.612001376647870429e+07, 1.613567941551903449e+07, 1.613709874137604795e+07, 1.613797303043186106e+07, 1.613711653342354484e+07, 1.616225371089820936e+07, 16186733.20254564, 16241434.8049347, 16237322.11813296, 16275376.95864639, 16296302.0147642],
[18649056.479567  , 17292560.83637504, 16573116.12680366,
16322533.02137523, 16216074.19606247, 16154734.81151236,
16129106.07263587, 16131854.50699835, 16119122.21325973,
16135573.44851911, 16158603.67399981, 16151476.94026508,
16127494.15584713, 16097320.76032251, 16140875.91346822,
16115866.55727115, 16118995.33300484, 16105372.49759903,
16112728.50581094, 16115875.38019666, 16153486.95589538,
16167876.91470784, 16139959.11138932, 1.616255141896034777e+07, 16182579.3219899 , 16240088.8081091, 16243712.80555515, 16299079.98722455, 16300187.13137991],
[1.861993137938806042e+07, 1.732917515701735020e+07, 1.657063235443255119e+07, 1.635879176514221355e+07, 1.626360311877287738e+07, 1.618802062452851236e+07, 1.614967957935284078e+07, 1.615186423880326003e+07, 1.616757082466319762e+07, 1.619746861296005547e+07, 1.617370181996636279e+07, 1.616590259545178898e+07, 1.616194254895075411e+07, 1.613518411618221551e+07, 1.617350772094361670e+07, 1.614869119315046445e+07, 1.612203835914292000e+07, 1.615331379861885123e+07, 1.613294921764167212e+07, 1.615344635884070210e+07, 1.613590505756535381e+07, 1.614922787741473131e+07, 1.615719204464513995e+07, 1.616830989636509307e+07, 16174865.74778001, 16221455.58526794, 16210517.52764545, 16231785.91035465, 16326195.40168083],
[1.853940395327756181e+07, 1.730897341657006368e+07, 1.658365185461111180e+07, 1.630294958357473835e+07, 1.621625770949508809e+07, 1.618283259364969283e+07, 1.613909268951667286e+07, 1.611507191162823141e+07, 1.611494689585074782e+07, 1.610472038442745246e+07, 1.612916576028856635e+07, 1.608498335807089880e+07, 1.610001351276045479e+07, 1.612772030996491760e+07, 1.609651759386808611e+07, 1.608189611665542610e+07, 1.608605904953365773e+07, 1.611580276102753356e+07, 1.613783609712784737e+07, 1.611093687816699408e+07, 1.610780632114464603e+07, 1.613467261296702735e+07, 1.611441260656008683e+07, 1.616245731585987844e+07, 16199528.61543708, 16231765.41996747, 16258597.89322176, 16287211.07189597, 16294112.22156437],
[18503327.3887386 , 17273712.70135051, 16551224.30526865,
        16276096.82541498, 16195889.17302874, 16159136.8378545 ,
        16140232.2930156 , 16153301.53835457, 16142290.54757776,
        16104116.02340264, 16116909.31318822, 16089586.5741044 ,
        16079214.70663652, 16110927.1401799 , 16111921.31548247,
        16150935.62304156, 16161332.18917462, 16142814.68089103,
        16148160.39985222, 16165847.30070143, 16136459.53606088,
        16178912.17868229, 16116594.94776744, 16172298.90584182, 16154733.1189743 , 16193377.90690356, 16228693.61702962, 16266838.69559363, 16286374.98240227]
])

In [None]:
import matplotlib.patches as patches
plt.rcParams['figure.figsize'] = (3.32,3.32)

fig, ax1 = plt.subplots()

# These are in unitless percentages of the figure size. (0,0 is bottom left)
left, bottom, width, height = [0.40, 0.35, 0.5, 0.53]# 0.45, 0.35, 0.45, 0.53
ax2 = fig.add_axes([left, bottom, width, height])

minimums = []
# fig, axes = plt.subplots(nrows=1, ncols=1)

for iteration, dataset_bics in enumerate(bics):
    minimums.append(component_no_list[np.argmin(dataset_bics)])
    ax1.plot(component_no_list,dataset_bics/1e7)
    ax2.plot(component_no_list,dataset_bics/1e7)

rect = patches.Rectangle((90, 1.607),60,0.013,linewidth=1,edgecolor='black',facecolor='none', linestyle ="--")

# Add the patch to the Axes
ax1.add_patch(rect)

ax2.plot(component_no_list,np.mean(bics, axis=0)/1e7, marker = ".", c="cyan", label="mean")
ax2.set_xlim([90,150])
ax2.set_ylim([1.607, 1.62])
plt.legend(fontsize=6)
# plt.title("BIC values for Gaussian mixture models fit to 5 data subsets of 50k segments")
ax1.set_xlabel("No. Gaussian components", size=6)
ax1.set_ylabel("Bayesian information criterion ($10^7$)", size=6)

ax1.tick_params(axis="x", which="major", length=2, width=0.75, labelsize=6, direction="in")
ax1.tick_params(axis="y", which="major", length=2, width=0.75, labelsize=6, direction="in")

ax2.set_yticks(np.arange(1.608, 1.618, 0.002))
ax2.set_xticks(np.arange(100, 150, 10))

# axes[plot_ind, 0].set_yticklabels([0, "", 5, "", 10, "",])

# ax1.plot(range(10), color='red')
# ax2.plot(range(6)[::-1], color='green')

bottom_y,top_y=ax1.get_ylim()
left_x, right_x = ax1.get_xlim()

# ax1.plot([90, left_x+(0.40*(right_x-left_x))], [1.62, bottom_y+(0.35*(top_y-bottom_y))], c="black", linestyle ="--",linewidth=1)
# ax1.plot([150, right_x], [1.62, bottom_y+(0.35*(top_y-bottom_y))], c="black", linestyle ="--",linewidth=1)
# ax1.set_xlim((left_x, right_x))
# ax1.axvline(157)#left_x+(0.45*(right_x)))

# plt.savefig('figures/bic_curves.png', dpi=300, bbox_inches = 'tight',pad_inches = 0.01)

plt.show()


In [None]:
[90, left_x+(0.45*(right_x-left_x))]

In [None]:
GMmodel_dir = "{}/GMM_122comp_20d_alldata_model_2020-08-30_11-42-38.pkl".format(data_dir)
with open(GMmodel_dir, 'rb') as f:
    GMmodel122 = pickle.load(f)

reducer = umap.UMAP(random_state=42)
reducer.fit(GMmodel122.means_)


embedding = reducer.transform(GMmodel122.means_)


In [None]:
segments_UMAP_dir = "{}/UMAPmapper_20d_shape16_moments4_trainedonall_468202.pkl".format(data_dir)
with open(segments_UMAP_dir, 'rb') as f:
    segments_UMAP = pickle.load(f)

In [None]:
# load light curve segments
with open('{}/468202_len128_s2_4cad_counts_errorfix.pkl'.format(data_dir), 'rb') as f:
    segments_counts = pickle.load(f)
    
# load latent variables for light curve segments
weights_dir = "../../../model_weights/model_2020-08-30_11-42-38.h5"
segments_dir = '../../../data_GRS1915/468202_len128_s2_4cad_counts_errorfix.pkl'
segment_encoding_dir = '{}/segment_encoding_{}_segments_{}.pkl'.format(data_dir, weights_dir.split("/")[-1].split(".")[0], segments_dir.split("/")[-1].split(".")[0])
with open(segment_encoding_dir, 'rb') as f:
    segment_encoding = pickle.load(f)

# take latent variable means, i.e. 16 values per segment
segment_encoding_scaled_means = zscore(segment_encoding[:,0,:], axis=0).astype(np.float32)  # standardize per feature

# calculate statistical moments for the segments
desc_stats = np.zeros((len(segments_counts), 4)) #mean, std, skew, kurt
desc_stats[:,0] = np.mean(segments_counts, axis=1).flatten()
desc_stats[:,1] = np.std(segments_counts, axis=1).flatten()
desc_stats[:,2] = stats.skew(segments_counts, axis=1).flatten()
desc_stats[:,3] = stats.kurtosis(segments_counts, axis=1).flatten()
zscore_desc_stats = zscore(desc_stats, axis=0)

# merge the two types of features; shape of shape_moments is [468202, 20]
shape_moments = np.hstack((segment_encoding_scaled_means, zscore_desc_stats)) # every column is standardized

In [None]:
umap_embedding = segments_UMAP.transform(shape_moments)

In [None]:
plt.rcParams['figure.figsize'] = (3.32,3.32)
plt.scatter(umap_embedding[:,0], umap_embedding[:,1], c="black", s=0.1)
plt.show()

In [None]:
plt.rcParams['figure.figsize'] = (3.32,3.32)



zorders = [-1]*122
labels = ["Other"]*122

s=5

plt.scatter(embedding[:,0], embedding[:,1], c="black", s=s)


filter_indices = [6,10,16,21,22,24,37,38,41,45,56,57,59,63,79,81,82,83,90,94,96,105,108,115,119,121]# regular peaks

plt.scatter(embedding[filter_indices,0], embedding[filter_indices,1], c="orange", label= "periodic flare", s=s)


colour_filter = ["orange" if x in filter_indices else "black" for x in range(122)]
zorders = [1 if ind in filter_indices else x for ind, x in enumerate(zorders)]
labels = ["Periodic flare" if ind in filter_indices else x for ind, x in enumerate(zorders)]

filter_indices = [0,32,40,50,60,73,100,103,110,113,118]# mid-flats

plt.scatter(embedding[filter_indices,0], embedding[filter_indices,1], c="magenta", label= "mid random", s=s)


colour_filter = ["magenta" if ind in filter_indices else x for ind, x in enumerate(colour_filter)]
zorders = [1 if ind in filter_indices else x for ind, x in enumerate(zorders)]
labels = ["Mid random" if ind in filter_indices else x for ind, x in enumerate(zorders)]

filter_indices = [1,2,4,5,7,12,14,18,27, 30,31,47,49,51,54,67,71,76,80,99,106,109]# low-flats

plt.scatter(embedding[filter_indices,0], embedding[filter_indices,1], c="red", label= "low random", s=s)


filter_indices = [8,11, 52,66,68,117]# low-flats

plt.scatter(embedding[filter_indices,0], embedding[filter_indices,1], c="cyan", label= "+ve gradient", s=s)

filter_indices = [43,88,89,98,114]
plt.scatter(embedding[filter_indices,0], embedding[filter_indices,1], c="khaki", label= "\'square wave\'", s=s)

colour_filter = ["cyan" if ind in filter_indices else x for ind, x in enumerate(colour_filter)]
zorders = [1 if ind in filter_indices else x for ind, x in enumerate(zorders)]
labels = ["Low random" if ind in filter_indices else x for ind, x in enumerate(zorders)]

plt.xlabel("UMAP axis 1")
plt.ylabel("UMAP axis 2")

# filter_indices = [19,34,39,70]# irregular peaks

# colour_filter = ["red" if ind in filter_indices else x for ind, x in enumerate(colour_filter)]

# plt.scatter(embedding[:,0], embedding[:,1], c=colour_filter)
plt.legend()
# plt.savefig('figures/UMAP_component_means.png', dpi=300, bbox_inches = 'tight',pad_inches = 0.01)


plt.show()

In [None]:
# with open("UMAP_agglo_clustering_results4x3x100.pkl", 'wb') as f:
#     pickle.dump(results_list, f)

# test_this=results_arr[3,3,0,:]

# UMAP_n_components = [2,5,10,20]
# UMAP_n_neighbors = [15, 300, 1737]

fig, axes = plt.subplots(nrows=2, ncols=2)
axes = axes.flatten()

for comp_ind in range(4):
    for neigh_ind in range(3):
        mean_no_clusters = np.mean(results_arr[comp_ind,neigh_ind,:])
        axes[comp_ind].hist(results_arr[comp_ind,neigh_ind,:].flatten(), alpha=0.5, range=(30,90),
                            bins=20, label="{:.2f}".format(mean_no_clusters))#, range=(0,1737), bins=100 UMAP_n_neighbors[neigh_ind]
        axes[comp_ind].set_xlim((30,90))
        axes[comp_ind].set_ylim((0,40))
    axes[comp_ind].text(0.99,0.99,r"No. UMAP components {}".format(UMAP_n_components[comp_ind]),
                        ha='right', va='top', transform=axes[comp_ind].transAxes, size=6)
    axes[comp_ind].legend(loc='center right', title="mean")
    print( np.mean(results_arr[comp_ind,:,:]))
axes[2].set_ylabel("Population", size=6, y=1)
axes[2].set_xlabel("No. clusters at the knee point", size=6, x=1)
axes.reshape((2,2))    
plt.show()
    
# plt.hist(results_arr[:,:,:,:].flatten(), bins=100, range=(0,1737))
# plt.hist(test_this.flatten(), alpha=0.8, bins=100,range=(0,1737))



# plt.show()
# plt.hist(results_arr[:,:,:,:].flatten(), bins=100,range=(0,1737))
# plt.hist(test_this.flatten(), alpha=0.8, bins=100,range=(0,1737))
# plt.show()

In [None]:
# load light curve segments
with open('{}/468202_len128_s2_4cad_counts_errorfix.pkl'.format(data_dir), 'rb') as f:
    segments_counts = pickle.load(f)
    
# load latent variables for light curve segments
weights_dir = "../../../model_weights/model_2020-08-30_11-42-38.h5"
segments_dir = '../../../data_GRS1915/468202_len128_s2_4cad_counts_errorfix.pkl'
segment_encoding_dir = '{}/segment_encoding_{}_segments_{}.pkl'.format(data_dir, weights_dir.split("/")[-1].split(".")[0], segments_dir.split("/")[-1].split(".")[0])
with open(segment_encoding_dir, 'rb') as f:
    segment_encoding = pickle.load(f)

# take latent variable means, i.e. 16 values per segment
segment_encoding_scaled_means = zscore(segment_encoding[:,0,:], axis=0).astype(np.float32)  # standardize per feature

# calculate statistical moments for the segments
desc_stats = np.zeros((len(segments_counts), 4)) #mean, std, skew, kurt
desc_stats[:,0] = np.mean(segments_counts, axis=1).flatten()
desc_stats[:,1] = np.std(segments_counts, axis=1).flatten()
desc_stats[:,2] = stats.skew(segments_counts, axis=1).flatten()
desc_stats[:,3] = stats.kurtosis(segments_counts, axis=1).flatten()
zscore_desc_stats = zscore(desc_stats, axis=0)

# merge the two types of features; shape of shape_moments is [468202, 20]
shape_moments = np.hstack((segment_encoding_scaled_means, zscore_desc_stats)) # every column is standardized

In [None]:
# histograms of stats

fig, axes = plt.subplots(nrows=2, ncols=2)
axes = axes.flatten()


plt.rcParams['figure.figsize'] = (3.32, 3.32*(1/1))
plt.rcParams.update({'font.size': 6})

from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)

stat_names = np.array(["mean", "st.d.", "skew", "kurtosis"])

x_tick_space = [2500, 1000, 2, 5]
x_limits = [[0,8000],[0,3000], [-4,4], [-2,15]]

for plot_ind in range(4):
#     light_c = selected_lcs[plot_ind]
    axes[plot_ind].hist(desc_stats[:,plot_ind], bins=50, range = x_limits[plot_ind])
    y_vals = axes[plot_ind].get_yticks()
    axes[plot_ind].set_yticklabels(['{:3.0f}'.format(x / 10000) for x in y_vals])
    axes[plot_ind].xaxis.set_major_locator(MultipleLocator(x_tick_space[plot_ind]))
    axes[plot_ind].xaxis.set_major_formatter(FormatStrFormatter('%d'))
    axes[plot_ind].xaxis.set_minor_locator(AutoMinorLocator())
    axes[plot_ind].text(0.99,0.99,"{}".format(stat_names[plot_ind]), ha='right', va='top', transform=axes[plot_ind].transAxes, size=6)
    if plot_ind == 2:
        axes[plot_ind].set_ylabel("No. segments ($10^4$)", y=1, size=6)
        axes[plot_ind].set_xlabel("Value of the statistic", x=1, size=6)

axes.reshape((2,2))
# plt.savefig('figures/stat_distributions.png', dpi=300, bbox_inches = 'tight',pad_inches = 0)

plt.show()

In [None]:
component_no_list[np.argsort(np.mean(bics, axis=0))]

In [None]:
with open('{}/468202_len128_s2_4cad_start_times_errorfix.pkl'.format(data_dir), 'rb') as f:
    seg_id_start_time_df = pickle.load(f)
    
JD_times = np.sort(seg_id_start_time_df.Start_time.values)

    
with open('{}/GMM_500comp_labels_chronologically.pkl'.format(data_dir), 'rb') as f:
    chrono_clusters = pickle.load(f)
    
with open('{}/GMM_shape16_moments4_components500_alldata.pkl'.format(data_dir), 'rb') as f:
    GMM_500 = pickle.load(f)
    
from sklearn.cluster import OPTICS

clustering = OPTICS().fit(GMM_500.means_)
reordered_chrono_clusters = clustering.ordering_[chrono_clusters]



plt.rcParams['figure.figsize'] = [100,5]
for interval in np.arange(0,468202, 20000):
    plt.scatter(JD_times[interval:interval+20000], reordered_chrono_clusters[interval:interval+20000])
    plt.show()

In [None]:
with open('{}/shape_moments_GM114_labels.pkl'.format(data_dir), 'rb') as f:
    shape_moments_GM114_labels = pickle.load(f)

In [None]:
with open('{}/shape_moments_GM500_labels.pkl'.format(data_dir), 'rb') as f:
    shape_moments_GM500_labels = pickle.load(f)

In [None]:
argsort_segment_times = np.argsort(seg_id_start_time_df.Start_time.values)

segment_JD_times = seg_id_start_time_df.Start_time.values[argsort_segment_times]
segment_chrono_labels114 = shape_moments_GM114_labels[argsort_segment_times]

In [None]:
plt.rcParams['figure.figsize'] = [100,5]
for interval in np.arange(0,468202, 20000):
    plt.scatter(segment_JD_times[interval:interval+20000], segment_chrono_labels114[interval:interval+20000])
    plt.show()

In [None]:
segment_JD_times[1] - segment_JD_times[0]

In [None]:
from copy import deepcopy

argsort_segment_times = np.argsort(seg_id_start_time_df.Start_time.values)

segment_JD_times = seg_id_start_time_df.Start_time.values[argsort_segment_times]
segment_chrono_labels114 = shape_moments_GM114_labels[argsort_segment_times]

Edges_dict = {key:[] for key in np.unique(segment_chrono_labels114)}
for point_index, point_cluster in enumerate(segment_chrono_labels114[:-1]):
    point_time = segment_JD_times[point_index]
    next_time = segment_JD_times[point_index+1]
    next_cluster = segment_chrono_labels114[point_index+1]
    
    
    if (next_cluster != point_cluster) and (next_time-point_time == 8.):
        Edges_dict.setdefault(point_cluster, []).append(next_cluster)
        
Edges_dict_repetitions = deepcopy(Edges_dict)

for key, values in Edges_dict_repetitions.items():
    Edges_dict[key] = np.unique(values)
    print(len(Edges_dict[key]))

In [None]:
from copy import deepcopy

argsort_segment_times = np.argsort(seg_id_start_time_df.Start_time.values)

segment_JD_times = seg_id_start_time_df.Start_time.values[argsort_segment_times]
segment_chrono_labels500 = shape_moments_GM500_labels[argsort_segment_times]

Edges_dict = {key:[] for key in np.unique(segment_chrono_labels500)}
for point_index, point_cluster in enumerate(segment_chrono_labels500[:-1]):
    point_time = segment_JD_times[point_index]
    next_time = segment_JD_times[point_index+1]
    next_cluster = segment_chrono_labels500[point_index+1]
    
    
    if (next_cluster != point_cluster) and (next_time-point_time == 8.):
        Edges_dict.setdefault(point_cluster, []).append(next_cluster)
        
Edges_dict_repetitions = deepcopy(Edges_dict)

for key, values in Edges_dict_repetitions.items():
    Edges_dict[key] = np.unique(values)
    print(len(Edges_dict[key]))

In [None]:
Edges_dict

In [None]:
with open('../../../data_GRS1915/468202_len128_s2_4cad_observation90-10split.pkl', 'rb') as f:
    split_segment_indices = pickle.load(f)

In [None]:
bad = []
for done, index in enumerate(np.array(split_segment_indices[1])):
    if index in np.array(split_segment_indices[0]):
        bad.append(index)
    print(done)
    clear_output(wait=True)

In [None]:
print(bad)

In [None]:
with open('{}/468202_len128_s2_4cad_ids_errorfix.pkl'.format(data_dir), 'rb') as f:
    seg_ids = pickle.load(f)
    
ObID_per_sample = np.array([seg_id.split("_")[0] for seg_id in seg_ids])


needed_validation_segments = 468202/10
unique_ObIDs = np.unique(ObID_per_sample, return_counts=True)
ObIDs_no = len(unique_ObIDs[0])
shuffle_indices = np.array(range(ObIDs_no))
np.random.seed(seed=11)
np.random.shuffle(shuffle_indices)


valid_set_obs = []
valid_set_size = 0

for ob_index in shuffle_indices:
    valid_set_obs.append(unique_ObIDs[0][ob_index])
    valid_set_size += unique_ObIDs[1][ob_index]
    if valid_set_size > needed_validation_segments:
        break
        
valid_set_sample_indices = []
for valid_set_ob in np.array(valid_set_obs):
    valid_set_sample_indices.append(np.where(ObID_per_sample == valid_set_ob)[0])

valid_set_sample_indices = [item for sublist in valid_set_sample_indices for item in sublist]

train_set_sample_indices = []
for train_set_ob in shuffle_indices[len(valid_set_obs):]:
    train_set_sample_indices.append(np.where(ObID_per_sample == unique_ObIDs[0][train_set_ob])[0])
    
train_set_sample_indices = [item for sublist in train_set_sample_indices for item in sublist]


split_indices = [train_set_sample_indices, valid_set_sample_indices]

In [None]:
bad = []
for done, observationID in enumerate(np.unique(ObID_per_sample[valid_set_sample_indices])):
    if observationID in np.unique(ObID_per_sample[train_set_sample_indices]):
        bad.append(observationID)
    print(done)
    clear_output(wait=True)

In [None]:
print(bad)

In [None]:
np.unique(ObID_per_sample[train_set_sample_indices]).shape

In [None]:
np.unique(ObID_per_sample[valid_set_sample_indices]).shape

In [None]:
unique_ObIDs[0][shuffle_indices][:len(valid_set_obs)][:3]

In [None]:
valid_set_obs[:3]

# Feature Aggloclustering 

In [None]:
# load (shuffled) list of unique IDs for the 468202 light curve segments. Format: observationID_segmentIndex,
# i.e. ['96701-01-48-00_3','20402-01-02-02_122','70703-01-01-000_1420', ...]
with open('{}/468202_len128_s2_4cad_ids_errorfix.pkl'.format(data_dir), 'rb') as f:
    seg_ids = pickle.load(f)

# make a list of observation ids; for each of the 468202 segments, get the ID of the observation that it comes from
#i.e. ['96701-01-48-00','20402-01-02-02', '70703-01-01-000',
seg_ObIDs = [seg.split("_")[0] for seg in seg_ids] # gets rid of the within-observation segment indices and creates a degenerate list of observation IDs

In [None]:

with open("{}/shape_moments_GMM122_labels".format(data_dir), 'rb') as f: # output of LSTM autoencoder's decoder
    shape_moments_GMM122_labels = pickle.load(f)

In [None]:
# make a dict that groups indices of segments of the same observation 
# i.e. where each observation id can be found in seg_ObIDs
#i.e. ObID_SegIndices_dict == {'10258-01-01-00': [916, 949, 1046...467528, 467578], ....}
ObID_SegIndices_dict = {key:[] for key in np.unique(seg_ObIDs)}
for ID_index, ObID in enumerate(seg_ObIDs):
    ObID_SegIndices_dict.setdefault(ObID, []).append(ID_index)

In [None]:
# make a dictionary of Gaussian component labels instead of segment indices  
#i.e. ObID_GaussComps_dict_500 == {'10258-01-01-00': [401, 433, 382...101, 152], ....}
ObID_GaussComps_dict_122 = {}
for ObID, Indices in ObID_SegIndices_dict.items():
    ObID_GaussComps_dict_122[ObID] = [shape_moments_GMM122_labels[ind] for ind in Indices]

In [None]:
import pandas as pd

# make a data frame containing the counts of light curve segments in each of the Gaussian components, for each observation
obs_component_counts_df_122 = pd.DataFrame(np.zeros((len(ObID_GaussComps_dict_122),len(np.unique(shape_moments_GMM122_labels)))), index=np.unique(seg_ObIDs), columns=list(range(122)), dtype=int)

In [None]:
# populate the data frame
for ObID, GaussComps in ObID_GaussComps_dict_122.items():
    for comp_id, comp_count in np.array(np.unique(GaussComps, return_counts=True)).T:
        obs_component_counts_df_122.loc[ObID][comp_id] = comp_count
obs_component_counts_df_122

In [None]:
#sanity check; count of the number of segments in each Gaussian component, as returned by the Gaussian_mixture_model.predict() method,
#matches the  count in the columns of the data frame, so we transferred the counts from list to dataframe successfully
print((np.unique(shape_moments_GMM122_labels, return_counts= True)[1] == obs_component_counts_df_122.sum().values).all())
#total number of segment counts matches the size of our dataset
print(obs_component_counts_df_122.sum().sum())
#each component has at least 59 counts in it, so I didn't mis-name "shape_moments_GM114_labels" as "shape_moments_GM500_labels" 
print(obs_component_counts_df_122.sum().values.min())

In [None]:
# apply min-max normalisation to values in each row of the data frame
from sklearn.preprocessing import Normalizer
normalized_obs_component_counts_df_122 = pd.DataFrame(Normalizer(norm='max').fit_transform(obs_component_counts_df_122), index=np.unique(seg_ObIDs), columns=list(range(122)))
normalized_obs_component_counts_df_122

In [None]:
from sklearn import cluster
agglo = cluster.FeatureAgglomeration(distance_threshold=1, n_clusters=None, affinity="cosine", linkage="complete")
agglo.fit(obs_component_counts_df_122)

In [None]:
transformed_features = agglo.transform(normalized_obs_component_counts_df_122)

In [None]:
plt.scatter(transformed_features[:,0], transformed_features[:,1])
plt.show()

In [None]:
from scipy.cluster.hierarchy import dendrogram
import scipy.cluster.hierarchy as sch

def plot_dendrogram(model, **kwargs):
    # Create linkage matrix and then plot the dendrogram

    # create the counts of samples under each node
    counts = np.zeros(model.children_.shape[0])
    n_samples = len(model.labels_)
    for i, merge in enumerate(model.children_):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack([model.children_, model.distances_,
                                      counts]).astype(float)

    # Plot the corresponding dendrogram
    dendrogram(linkage_matrix, **kwargs)
    plt.show()
        # find no. clusters as a function of distance
    no_clusters = []
    for distance in np.linspace(0,1,1000):
        clusters = sch.fcluster(linkage_matrix, distance, criterion='distance')
        no_clusters.append(len(np.unique(clusters)))

    plt.rcParams['figure.figsize'] = [5,5]
    plt.axhline(13, c="cyan")
    plt.plot(np.linspace(0,1,1000), no_clusters, c="magenta")
    plt.xlim((0.8,1.))
    plt.title("No. of clusters as the function of distance for hierarchical clustering")
    plt.ylabel("No. clusters")
    plt.xlabel("Distance threshold")
    plt.show()
    print(np.array((np.linspace(0,1,1000), no_clusters)).T[:,-10:])

In [None]:
plot_dendrogram(agglo)#, truncate_mode='level', p=3)
plt.xlabel("Number of points in node (or index of point if no parenthesis).")
plt.show()

In [None]:
with open('{}/1776_light_curves_1s_bin_errorfix.pkl'.format(data_dir), 'rb') as f:
    lcs = pickle.load(f)
with open('{}/1776_light_curves_1s_bin_ids_errorfix.pkl'.format(data_dir), 'rb') as f:
    ids = pickle.load(f)

In [None]:
agglo = cluster.FeatureAgglomeration(distance_threshold=0.875, n_clusters=None, affinity="cosine", linkage="complete")
agglo.fit(obs_component_counts_df_122)
labels_13 = agglo.labels_

In [None]:
plt.rcParams['figure.figsize'] = (10.15, 4.15)

plot_cluster = 13
show_first_x_seconds = -1
lc_indices = np.where((labels_13==plot_cluster))[0]
lc_IDs = obs_component_counts_df_122.index.values[lc_indices]
for n_plot, lc_ID in enumerate(lc_IDs):
    lc_index_in_1776set = np.where(np.array(ids) == lc_ID)[0][0]
    plt.plot(lcs[lc_index_in_1776set][0][:show_first_x_seconds], lcs[lc_index_in_1776set][1][:show_first_x_seconds])
    plt.ylim((0,13000))
    plt.show()

In [None]:
ids_ar = np.array(segment_class)

# class_names = list(inv_ob_state.keys())


alpha = np.where(ids_ar == "alpha")[0][1]#1
beta= np.where(ids_ar == "beta")[0][7]#7
gamma=np.where(ids_ar == "gamma")[0][4]
delta=np.where(ids_ar == "delta")[0][5]
theta=np.where(ids_ar == "theta")[0][4]#4
kappa=np.where(ids_ar == "kappa")[0][3]
lambda1=np.where(ids_ar == "lambda")[0][2]
mu=np.where(ids_ar == "mu")[0][1]
nu=np.where(ids_ar == "nu")[0][1]
rho=np.where(ids_ar == "rho")[0][1]
phi=np.where(ids_ar == "phi")[0][1]
chi=np.where(ids_ar == "chi")[0][1]
eta=np.where(ids_ar == "eta")[0][4]
omega=np.where(ids_ar == "omega")[0][3]

In [None]:
np.where(ids_ar == "phi")[0]

In [None]:
phi_stats

In [None]:
phi_stats = np.take(desc_stats, np.where(ids_ar == "phi")[0], axis=0)
chi_stats = np.take(desc_stats, np.where(ids_ar == "chi")[0], axis=0)
gamma_stats = np.take(desc_stats, np.where(ids_ar == "gamma")[0], axis=0)

In [None]:
print(np.mean(phi_stats, axis=0), np.std(phi_stats, axis=0))
print(np.mean(chi_stats, axis=0), np.std(chi_stats, axis=0))
print(np.mean(gamma_stats, axis=0), np.std(gamma_stats, axis=0))

In [None]:
means_to_box = pd.DataFrame([phi_stats[:,0],chi_stats[:,0],gamma_stats[:,0]], index = ["phi", "chi", "gamma"]).T
std_to_box = pd.DataFrame([phi_stats[:,1],chi_stats[:,1],gamma_stats[:,1]], index = ["phi", "chi", "gamma"]).T
skew_to_box = pd.DataFrame([phi_stats[:,2],chi_stats[:,2],gamma_stats[:,2]], index = ["phi", "chi", "gamma"]).T
kurt_to_box = pd.DataFrame([phi_stats[:,3],chi_stats[:,3],gamma_stats[:,3]], index = ["phi", "chi", "gamma"]).T


In [None]:

# fake up some data
spread = np.random.rand(50) * 100
center = np.ones(25) * 50
flier_high = np.random.rand(10) * 100 + 100
flier_low = np.random.rand(10) * -100
data = np.concatenate((spread, center, flier_high, flier_low))

In [None]:
import matplotlib.pylab as pylab
import matplotlib.ticker as ticker

pylab.rcParams['figure.figsize'] = (6.97, 6.97*(21.0/29.7) ) # A4 size 210mm x 297mm

ids_ar = np.array(ids)

class_names = list(inv_ob_state.keys())


alpha = lcs[np.where(ids_ar == inv_ob_state["alpha"][0])[0][0]]
beta= lcs[np.where(ids_ar == inv_ob_state["beta"][5])[0][0]] #3
gamma=lcs[np.where(ids_ar == inv_ob_state["gamma"][0])[0][0]]
delta=lcs[np.where(ids_ar == inv_ob_state["delta"][9])[0][0]]
theta=lcs[np.where(ids_ar == inv_ob_state["theta"][13])[0][0]]#11
kappa=lcs[np.where(ids_ar == inv_ob_state["kappa"][6])[0][0]]#6
lambda1=lcs[np.where(ids_ar == inv_ob_state["lambda"][3])[0][0]] #3
mu=lcs[np.where(ids_ar == inv_ob_state["mu"][6])[0][0]]#6
nu=lcs[np.where(ids_ar == inv_ob_state["nu"][2])[0][0]]#0
rho=lcs[np.where(ids_ar == inv_ob_state["rho"][9])[0][0]]#9
phi=lcs[np.where(ids_ar == inv_ob_state["phi"][3])[0][0]]# 3,6
chi=lcs[np.where(ids_ar == inv_ob_state["chi"][27])[0][0]]# 1,17,27
eta=lcs[np.where(ids_ar == inv_ob_state["eta"][2])[0][0]]# 1
# omega=lcs[np.where(ids_ar == inv_ob_state["kappa"][-3])[0][0]]
omega=lcs[np.where(ids_ar == inv_ob_state["omega"][1])[0][0]]


selected_lcs = [alpha,beta,gamma,delta,theta,kappa,lambda1,mu,nu,rho,phi,chi,eta,omega]


fig, axes = plt.subplots(nrows=7, ncols=2)
axes = axes.flatten()

plt.subplots_adjust(hspace=0.05)
plt.subplots_adjust(wspace=0.01)

good_classes = ["delta", "mu", "rho", "phi"]
intervals = {}

for plot_ind in range(14):
    light_c = np.copy(selected_lcs[plot_ind])
    light_c[1] /=1000
    class_name = class_names[plot_ind]
    offset = light_c[0][0]
    axes[plot_ind].set_ylim([0, 12.5])
    
    if class_name == "alpha":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1
        axes[plot_ind].plot(light_c[0][:breaks[0]]-offset, light_c[1][:breaks[0]], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].set_xlim([0, 3500])
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)

    elif class_name == "beta":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start=breaks[0]
        end =breaks[1]
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "gamma":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start=breaks[0]
        end =breaks[1]
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "theta":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start=breaks[1]
        end =breaks[2]
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "kappa":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start=breaks[-1]
        end =-1
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "lambda":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start=breaks[-1]
        end =-1
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "nu":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start= breaks[0]
        end =breaks[1]
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "chi":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start= breaks[0]
        end =-1
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "eta":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start= breaks[2]
        end =breaks[3]
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "omega":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start= breaks[0]
        end =-1
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name in good_classes:
        axes[plot_ind].plot(light_c[0]-offset, light_c[1], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
    else:
        axes[plot_ind].plot(light_c[0]-offset, light_c[1])
        axes[plot_ind].plot(light_c[0][:3500]-offset, light_c[1][:3500])
    
    axes[plot_ind].set_xlim([0, 3500])
#     axes[plot_ind].tick_params(axis="x", which="major", length=5, width=1, labelsize=20, direction="in")
    
    if plot_ind%2 == 0:
        axes[plot_ind].tick_params(axis="y", which="major", length=2, width=1, labelsize=6, direction="in")
    else:
        axes[plot_ind].tick_params(axis="y", which="major", length=2, width=1, labelsize=0, direction="in")
        plt.setp(axes[plot_ind].get_yticklabels(), visible=False)
    if plot_ind == 6:
        axes[plot_ind].set_ylabel("Rate (kcounts/s)", size=6)
    if plot_ind == 12 or plot_ind == 13:
        axes[plot_ind].tick_params(axis="x", which="major", length=2, width=1, labelsize=6, direction="in")

    else:
        axes[plot_ind].tick_params(axis="x", which="major", length=2, width=1, labelsize=0, direction="in")
        plt.setp(axes[plot_ind].get_xticklabels(), visible=False)
    if plot_ind == 12:
        axes[plot_ind].set_xlabel("Time (s)", size=6, x=1)
    
#     axes[plot_ind].set_yticks([25, 50, 75, 100])
    axes[plot_ind].set_xticks([0,500, 1000,1500, 2000,2500, 3000])
    axes[plot_ind].set_yticks([0,2.5,5,7.5,10])
    axes[plot_ind].set_xticklabels([0,"", 1000,"", 2000,"", 3000])
    axes[plot_ind].set_yticklabels([0,"",5,"",10])
axes.reshape((7,2))


In [None]:
# axes = axes.flatten()


plt.rcParams['figure.figsize'] = (3.32, 8.4)#6.97)
plt.rcParams.update({'font.size': 6})

fig, axes = plt.subplots(nrows=4, ncols=1)

for plot_ind in range(4):
#     light_c = selected_lcs[plot_ind]
    stat_df = pd.DataFrame([phi_stats[:,plot_ind],chi_stats[:,plot_ind],gamma_stats[:,plot_ind]], index = [r"$\{}$".format(x) for x in ["phi", "chi", "gamma"]]).T #r"$\{}$"[r"$\{}$".format(x) for x in ["phi", "chi", "gamma"]]
    stat_df.boxplot(whis=[5,95], sym="", ax=axes[plot_ind], grid=False)
#     axes[plot_ind].tick_params(axis="x", which="major", length=0, width=1, labelsize=0, direction="in")
    axes[plot_ind].tick_params(axis="y", which="major", length=2, width=1, labelsize=6, direction="in")
    if plot_ind != 3:
        axes[plot_ind].set_xticks([])
        axes[plot_ind].set_xticklabels([])
    else:
#         axes[plot_ind].tick_params(axis="x", which="major", length=2, width=1, labelsize=6, direction="in")
        axes[plot_ind].tick_params(axis="x", which="major", length=0, width=1, labelsize=8, direction="in")
        
    axes[plot_ind].set_ylabel(["mean", "standard deviation", "skewness", "kurtosis"][plot_ind])
plt.subplots_adjust(hspace=0.01)
plt.subplots_adjust(wspace=0.01)
# axes.reshape((2,2))
plt.savefig('figures/stat_boxplot.png', dpi=300, bbox_inches = 'tight',pad_inches = 0.05)
# axes[1] = stat_df.boxplot(whis=[1,99], sym="")#plt.gca()
plt.show()

In [None]:
plot_ind

In [None]:

means_to_box.boxplot(whis=[1,99], sym="", ax=[0,0])
ax  = plt.gca()
ax.set_ylabel("mean")
plt.show()
std_to_box.boxplot(whis=[1,99], axes=[0,1])
ax  = plt.gca()
ax.set_ylabel("std")
plt.show()
skew_to_box.boxplot(whis=[1,99], axes=[1,0])
ax  = plt.gca()
ax.set_ylabel("skew")
plt.show()
kurt_to_box.boxplot(whis=[1,99], axes=[1,1])
ax  = plt.gca()
ax.set_ylabel("kurt")
plt.show()

In [None]:
plt.boxplot(phi_stats[:,0])
plt.boxplot(chi_stats[:,0])
plt.boxplot(gamma_stats[:,0])

plt.show()

In [None]:
np.array((phi_stats[:,0], chi_stats[:,0],gamma_stats[:,0])).shape

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2)
axes = axes.flatten()


plt.rcParams['figure.figsize'] = (3.32, 3.32*(1/1))
plt.rcParams.update({'font.size': 6})

from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)

stat_names = np.array(["mean", "st.d.", "skew", "kurtosis"])

x_tick_space = [2500, 1000, 1, 2]
x_limits = [[0,8000],[0,2500], [-2,3], [-2,8]]

for plot_ind in range(4):
#     light_c = selected_lcs[plot_ind]
    axes[plot_ind].hist(desc_stats[:,plot_ind], bins=50, range = x_limits[plot_ind],zorder=-4)
    y_vals = axes[plot_ind].get_yticks()
    axes[plot_ind].set_yticklabels(['{:3.0f}'.format(x / 10000) for x in y_vals])
    axes[plot_ind].xaxis.set_major_locator(MultipleLocator(x_tick_space[plot_ind]))
    axes[plot_ind].xaxis.set_major_formatter(FormatStrFormatter('%d'))
    axes[plot_ind].xaxis.set_minor_locator(AutoMinorLocator())
    axes[plot_ind].text(0.99,0.99,"{}".format(stat_names[plot_ind]), ha='right', va='top', transform=axes[plot_ind].transAxes, size=6)
    axes[plot_ind].axvline(desc_stats[chi,plot_ind], c="cyan", zorder=-3, alpha=0.5)
    axes[plot_ind].axvline(desc_stats[phi,plot_ind], c="magenta", zorder=-3, alpha=0.5)
    axes[plot_ind].axvline(desc_stats[gamma,plot_ind], c="black", zorder=-3, alpha=0.5)
    if plot_ind == 2:
        axes[plot_ind].set_ylabel("No. segments ($10^4$)", y=1, size=6)
        axes[plot_ind].set_xlabel("Value of the statistic", x=1, size=6)

axes.reshape((2,2))
# plt.savefig('figures/stat_distributions.png', dpi=300, bbox_inches = 'tight',pad_inches = 0)

plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import axes3d
from mpl_toolkits.mplot3d.art3d import Poly3DCollection  # appropriate import to draw 3d polygons
from matplotlib import style

# plt.figure('SPLTV',figsize=(10,5))
plt.rcParams['figure.figsize'] = (3.32, 3.32*(1/2))
plt.rcParams['axes.grid'] = False

custom=plt.subplot(121,projection='3d')

#x-2y+z=6
x1=np.array([1, 0, 0])
y1=np.array([0, 1, 0])
z1=np.array([0, 0, 1])  # z1 should have 3 coordinates, right?

x,y,z=np.zeros((3,3))

# custom.quiver(x+1,y+1,z,-x1*1.2,-y1*1.2,z1*1.2,arrow_length_ratio=0., color="black", zorder=-5)


points = np.random.random(size=300).reshape((100,3))
px, py, pz = np.divide(points.T, np.sum(points, axis=1))
custom.scatter(-px+1, -py+1, pz, s=0.5, c="magenta")

# 1. create vertices from points
verts = [list(zip(-x1+1, -y1+1, z1))]
# 2. create 3d polygons and specify parameters
srf = Poly3DCollection(verts, alpha=.25, facecolor='pink', zorder=20)
# 3. add polygon to the figure (current axes)
plt.gca().add_collection3d(srf)



# custom.set_xlabel('Feature 1')
# custom.set_ylabel("Feature 2")
# custom.set_zlabel('Feature 3')

# custom.set_yticks([0])
# custom.set_yticklabels([1])
# custom.set_xticks([0,1])
# custom.set_xticklabels([0,1])
# custom.set_zticks([0,1])
# custom.set_zticklabels([0,1])

custom.set_yticks([])
custom.set_yticklabels([])
custom.set_xticks([])
custom.set_xticklabels([])
custom.set_zticks([])
custom.set_zticklabels([])

custom.set_xlim([0,1])
custom.set_ylim([0,1])
custom.set_zlim([0,1])
# custom.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
# custom.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
# custom.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
# custom.grid(b=False)
custom.scatter((-x1+1)[1], (-y1+1)[1], z1[1], s=10, c="black", zorder=100 , depthshade=False)
custom.scatter((-x1+1)[2], (-y1+1)[2], z1[2], s=5, c="black", zorder=100 , depthshade=False)

custom.scatter((-x1+1)[0], (-y1+1)[0], z1[0], s=3, c="black", zorder=100 , depthshade=False)


custom.xaxis._axinfo['juggled'] = (0,0,0)
# custom.yaxis._axinfo['juggled'] = (1,1,1)
# custom.zaxis._axinfo['juggled'] = (2,2,2)

# plt.axis('off')
plt.savefig('figures/simplex.png', dpi=300, bbox_inches = 'tight',pad_inches = 0)

plt.show()

# Comparison 122 components Gaussian Mixture Model vs Belloni classifications

In [None]:
import pandas as pd

In [None]:
with open("{}/shape_moments_GMM122_labels.pkl".format(data_dir), 'rb') as f: # output of LSTM autoencoder's decoder
    shape_moments_GMM122_labels = pickle.load(f)

In [None]:
# load observation classifications from Huppenkothen 2017
clean_belloni = open('{}/1915Belloniclass_updated.dat'.format(data_dir))
lines = clean_belloni.readlines()
states = lines[0].split()
belloni_clean = {}
for h,l in zip(states, lines[1:]):
    belloni_clean[h] = l.split()
    #state: obsID1, obsID2...
ob_state = {}
for state, obs in belloni_clean.items():
    if state == "chi1" or state == "chi2" or state == "chi3" or state == "chi4": state = "chi"
    for ob in obs:
        ob_state[ob] = state
        
# make a dict of Tomaso's classifications  (Daniela's set) against observation ids,
# i.e. ob_state == {'20187-02-01-00': 'alpha', '20187-02-01-01': 'alpha', '20402-01-22-00': 'alpha', ...}
# lines = clean_belloni.readlines()
# states = lines[0].split()
# belloni_clean = {}
# for h,l in zip(states, lines[1:]):
#     belloni_clean[h] = l.split()
#     #state: obsID1, obsID2...
# ob_state = {}
# for state, obs in belloni_clean.items():
#     if state == "chi1" or state == "chi2" or state == "chi3" or state == "chi4": state = "chi"
#     for ob in obs:
#         ob_state[ob] = state
        
        
# load IDs of segmented light curves: observationsID_segmentIndex
with open('{}/468202_len128_s2_4cad_ids_errorfix.pkl'.format(data_dir), 'rb') as f:
    seg_ids = pickle.load(f)

        
seg_ObIDs = [seg.split("_")[0] for seg in seg_ids] # get rid of the within-observation segment indices and create a degenerate list of observation IDs

classes = np.array(["alpha", "beta", "gamma", "delta", "theta", "kappa", "lambda", "mu", "nu", "rho", "phi", "chi", "eta", "omega"])
scales = []
segment_class = []
for ob in seg_ObIDs:
    if ob in ob_state:
        segment_class.append(ob_state[ob])
    else:
        segment_class.append("Unknown")

In [None]:
new_classification = shape_moments_GMM122_labels

Belloni_classes = np.array(["alpha", "beta", "gamma", "delta", "theta", "kappa", "lambda", "mu", "nu", "rho", "phi", "chi", "eta", "omega", "Unknown"])

comparison_matrix = np.zeros((len(np.unique(new_classification)), len(Belloni_classes)), dtype=int)

comparison_matrix_df = pd.DataFrame(comparison_matrix, columns=Belloni_classes, index=np.unique(new_classification))

for n_Bc, Belloni_class in enumerate(Belloni_classes):
    Belloni_class_indices = np.where(np.array(segment_class) == Belloni_class)[0]
    count_clusters_for_class = np.unique(np.take(new_classification, Belloni_class_indices), return_counts=True)
    for cluster_ind, cluster in enumerate(count_clusters_for_class[0]):
        comparison_matrix_df[Belloni_class][cluster] = count_clusters_for_class[1][cluster_ind]

In [None]:
class_normalized_comparison_matrix_df=(comparison_matrix_df-comparison_matrix_df.min())/(comparison_matrix_df.max()-comparison_matrix_df.min())
known_comparison_matrix_df = comparison_matrix_df.drop(columns=['Unknown']).T
component_normalized_comparison_matrix_df = (known_comparison_matrix_df-known_comparison_matrix_df.min())/(known_comparison_matrix_df.max()-known_comparison_matrix_df.min())

In [None]:
# find particularly class-homogeneous components

good_comps = {}
for comp in range(122):
    comp_counts = comparison_matrix_df.iloc[comp,:].sort_values(ascending=False)
    comp_counts_nonzero = comp_counts.where(comp_counts>0).dropna().astype(int)
    if comp_counts_nonzero.index[0] == "Unknown":
        comp_class_proportion = comp_counts_nonzero[1:]/comp_counts_nonzero[1:].sum()
        if comp_class_proportion[0]>0.999:
            good_comps[comp] = comp_counts_nonzero
    else:
        good_comps[comp] = comp_counts_nonzero
print(len(good_comps))
# dominated_comps = []
# for i,v in good_comps.items():
#     if v.index[1] == "chi":
#         dominated_comps.append(i)
# print(dominated_comps)

In [None]:
# chi is >95% bar unknowns : (10), 3, 5, 7, 15, 24, 28, 37, 44, 53, 57, 60, 69, 76, 91, 103, 112
# kappa : (79),  4, 42, 78, 104 (87 has 90.8% kappa)
# theta : 25, 26, 71
# lambda : 66
# rho : 12, 18, 19, 20, 22, 23, 30, 38, 45, 49, 50, 54, 56, 63, 65, 74, 77, 86, 97, 107

print(good_comps)

In [None]:
comparison_matrix_df.iloc[98]#/(comparison_matrix_df.iloc[87].sum()-2628)

In [None]:
comp_list = []
for i,v in good_comps.items():
    if len(v)>1:
        comp_list.append(i)

In [None]:
print(comp_list)

In [None]:
comp_counts = comparison_matrix_df.iloc[7,:].sort_values(ascending=False)
comp_counts_nonzero = comp_counts.where(comp_counts>0).dropna().astype(int)
comp_counts_nonzero[1:]/comp_counts_nonzero[1:].sum()

In [None]:
comparison_matrix_df.sum(axis=0)

In [None]:
comparison_matrix_df.T.iloc[:,66]

In [None]:
(0.002524/1.002524)*7617

In [None]:
np.unique(shape_moments_GM114_labels, return_counts=True)[1][10]

In [None]:
np.where(comparison_matrix_df.T.idxmax().values != "Unknown")

In [None]:
class_normalized_comparison_matrix_df.T.iloc[:-1,:].index.values

In [None]:
class_normalized_comparison_matrix_df.T.iloc[:-1,:]

In [None]:
class_normalized_comparison_matrix_df.T.iloc[:-1,:].sort_index(ascending=True)

In [None]:
plt.rcParams['figure.figsize'] = (6.97, 6.97*(1/3))
plt.rcParams.update({'font.size': 6})

import seaborn as sns

from matplotlib.ticker import (MultipleLocator, FormatStrFormatter,
                               AutoMinorLocator)

from matplotlib.transforms import ScaledTranslation

ax = sns.heatmap(class_normalized_comparison_matrix_df.T.iloc[:-1,:].sort_index(ascending=True), xticklabels=True, yticklabels=True, cmap='coolwarm')#, linewidth=0.5)
# # ax.set_xticklabels(ax.get_xmajorticklabels(), fontsize = 6)
# ax.tick_params(axis="x", which="major", length=2, width=0.75, labelsize=6, direction="out")

ax.xaxis.set_major_locator(MultipleLocator(4))
ax.xaxis.set_major_formatter(FormatStrFormatter('%d'))

# For the minor ticks, use no labels; default NullFormatter.
# ax.xaxis.set_minor_locator(MultipleLocator(1.5))
ax.xaxis.set_minor_locator(AutoMinorLocator())

classes_names = class_normalized_comparison_matrix_df.T.iloc[:-1,:].sort_index(ascending=True).index.values
ax.set_yticklabels([r"$\{}$".format(c) for c in classes_names])

# offset = ScaledTranslation(0.5, 0, ax.transData)

# for label in ax.xaxis.get_majorticklabels():
#     label.set_transform(label.get_transform() + offset)
    
# for label in ax.xaxis.get_minorticklabels():
#     label.set_transform(label.get_transform() + offset)

# plt.title("Gaussian mixture components' populations in terms of classified data (component-wise min-maxed)")
# plt.savefig("figures/GMM122vsBelloni_heatmap.png", dpi=300)
plt.savefig('figures/GMM122vsBelloni_heatmap.png', dpi=300, bbox_inches = 'tight',pad_inches = 0)

plt.show()

In [None]:
import matplotlib.pylab as pylab
import matplotlib.ticker as ticker

pylab.rcParams['figure.figsize'] = (6.97, 6.97*(21.0/29.7) ) # A4 size 210mm x 297mm

ids_ar = np.array(ids)

class_names = list(inv_ob_state.keys())


alpha = lcs[np.where(ids_ar == inv_ob_state["alpha"][0])[0][0]]
beta= lcs[np.where(ids_ar == inv_ob_state["beta"][5])[0][0]] #3
gamma=lcs[np.where(ids_ar == inv_ob_state["gamma"][0])[0][0]]
delta=lcs[np.where(ids_ar == inv_ob_state["delta"][9])[0][0]]
theta=lcs[np.where(ids_ar == inv_ob_state["theta"][13])[0][0]]#11
kappa=lcs[np.where(ids_ar == inv_ob_state["kappa"][6])[0][0]]#6
lambda1=lcs[np.where(ids_ar == inv_ob_state["lambda"][3])[0][0]] #3
mu=lcs[np.where(ids_ar == inv_ob_state["mu"][6])[0][0]]#6
nu=lcs[np.where(ids_ar == inv_ob_state["nu"][2])[0][0]]#0
rho=lcs[np.where(ids_ar == inv_ob_state["rho"][9])[0][0]]#9
phi=lcs[np.where(ids_ar == inv_ob_state["phi"][3])[0][0]]# 3,6
chi=lcs[np.where(ids_ar == inv_ob_state["chi"][27])[0][0]]# 1,17,27
eta=lcs[np.where(ids_ar == inv_ob_state["eta"][2])[0][0]]# 1
# omega=lcs[np.where(ids_ar == inv_ob_state["kappa"][-3])[0][0]]
omega=lcs[np.where(ids_ar == inv_ob_state["omega"][1])[0][0]]


selected_lcs = [alpha,beta,gamma,delta,theta,kappa,lambda1,mu,nu,rho,phi,chi,eta,omega]


fig, axes = plt.subplots(nrows=7, ncols=2)
axes = axes.flatten()

plt.subplots_adjust(hspace=0.05)
plt.subplots_adjust(wspace=0.01)

good_classes = ["delta", "mu", "rho", "phi"]
intervals = {}

for plot_ind in range(14):
    light_c = np.copy(selected_lcs[plot_ind])
    light_c[1] /=1000
    class_name = class_names[plot_ind]
    offset = light_c[0][0]
    axes[plot_ind].set_ylim([0, 12.5])
    
    if class_name == "alpha":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1
        axes[plot_ind].plot(light_c[0][:breaks[0]]-offset, light_c[1][:breaks[0]], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].set_xlim([0, 3500])
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)

    elif class_name == "beta":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start=breaks[0]
        end =breaks[1]
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "gamma":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start=breaks[0]
        end =breaks[1]
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "theta":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start=breaks[1]
        end =breaks[2]
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "kappa":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start=breaks[-1]
        end =-1
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "lambda":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start=breaks[-1]
        end =-1
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "nu":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start= breaks[0]
        end =breaks[1]
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "chi":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start= breaks[0]
        end =-1
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "eta":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start= breaks[2]
        end =breaks[3]
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "omega":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start= breaks[0]
        end =-1
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name in good_classes:
        axes[plot_ind].plot(light_c[0]-offset, light_c[1], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
    else:
        axes[plot_ind].plot(light_c[0]-offset, light_c[1])
        axes[plot_ind].plot(light_c[0][:3500]-offset, light_c[1][:3500])
    
    axes[plot_ind].set_xlim([0, 3500])
#     axes[plot_ind].tick_params(axis="x", which="major", length=5, width=1, labelsize=20, direction="in")
    
    if plot_ind%2 == 0:
        axes[plot_ind].tick_params(axis="y", which="major", length=2, width=1, labelsize=6, direction="in")
    else:
        axes[plot_ind].tick_params(axis="y", which="major", length=2, width=1, labelsize=0, direction="in")
        plt.setp(axes[plot_ind].get_yticklabels(), visible=False)
    if plot_ind == 6:
        axes[plot_ind].set_ylabel("Rate (kcounts/s)", size=6)
    if plot_ind == 12 or plot_ind == 13:
        axes[plot_ind].tick_params(axis="x", which="major", length=2, width=1, labelsize=6, direction="in")

    else:
        axes[plot_ind].tick_params(axis="x", which="major", length=2, width=1, labelsize=0, direction="in")
        plt.setp(axes[plot_ind].get_xticklabels(), visible=False)
    if plot_ind == 12:
        axes[plot_ind].set_xlabel("Time (s)", size=6, x=1)
    
#     axes[plot_ind].set_yticks([25, 50, 75, 100])
    axes[plot_ind].set_xticks([0,500, 1000,1500, 2000,2500, 3000])
    axes[plot_ind].set_yticks([0,2.5,5,7.5,10])
    axes[plot_ind].set_xticklabels([0,"", 1000,"", 2000,"", 3000])
    axes[plot_ind].set_yticklabels([0,"",5,"",10])
axes.reshape((7,2))


# axes[0][0].tick_params(axis="x", which="major", length=5, width=1, labelsize=5, direction="in")
# axes[0][1].tick_params(axis="x", which="major", length=5, width=1, labelsize=5, direction="in")
# axes[1][0].tick_params(axis="x", which="major", length=5, width=1, labelsize=25, direction="in")
# axes[1][1].tick_params(axis="x", which="major", length=5, width=1, labelsize=25, direction="in")

# axes[0][0].tick_params(axis="y", which="major", length=5, width=1, labelsize=25, direction="in")
# axes[1][0].tick_params(axis="y", which="major", length=5, width=1, labelsize=25, direction="in")
# axes[0][1].tick_params(axis="y", which="major", length=5, width=1, labelsize=5, direction="in")
# axes[1][1].tick_params(axis="y", which="major", length=5, width=1, labelsize=5, direction="in")

# plt.setp(axes[0][1].get_yticklabels(), visible=False)
# plt.setp(axes[1][1].get_yticklabels(), visible=False)
# plt.setp(axes[0][1].get_xticklabels(), visible=False)
# plt.setp(axes[0][0].get_xticklabels(), visible=False)

# axes[1][0].set_xticks([500,1000,1500,2000])
# axes[1][1].set_xticks([500,1000,1500,2000])


# axes[0][0].yaxis.set_label_coords(-0.2, 0)
# axes[1][0].xaxis.set_label_coords(1, -0.1)

# plt.suptitle("14 classes of activity of x-ray black hole binary GRS1915+105", fontsize=40, y=0.92)
# plt.rcParams.update({'font.size': 6})


# plt.savefig('all_classes_of_GRS1915.png', dpi=300)
plt.savefig('figures/light_curves.png', dpi=300, bbox_inches = 'tight',pad_inches = 0.01)

plt.show()

In [None]:
with open('{}/1776_light_curves_1s_bin_errorfix.pkl'.format(data_dir), 'rb') as f:
    lcs = pickle.load(f)
with open('{}/1776_light_curves_1s_bin_ids_errorfix.pkl'.format(data_dir), 'rb') as f:
    ids = pickle.load(f)
    
    
clean_belloni = open('{}/1915Belloniclass_updated.dat'.format(data_dir))
lines = clean_belloni.readlines()
states = lines[0].split()
belloni_clean = {}
for h,l in zip(states, lines[1:]):
    belloni_clean[h] = l.split()
    #state: obsID1, obsID2...
ob_state = {}
for state, obs in belloni_clean.items():
    if state == "chi1" or state == "chi2" or state == "chi3" or state == "chi4": state = "chi"
    for ob in obs:
        ob_state[ob] = state


        

inv_ob_state = {}
for k, v in ob_state.items():
    inv_ob_state[v] = inv_ob_state.get(v, [])
    inv_ob_state[v].append(k)

In [None]:
list(inv_ob_state.keys())

In [None]:
import matplotlib.pylab as pylab
import matplotlib.ticker as ticker

pylab.rcParams['figure.figsize'] = (6.97, 6.97*(21.0/29.7) ) # A4 size 210mm x 297mm

ids_ar = np.array(ids)

class_names = np.sort(list(inv_ob_state.keys()))


alpha = lcs[np.where(ids_ar == inv_ob_state["alpha"][0])[0][0]]
beta= lcs[np.where(ids_ar == inv_ob_state["beta"][5])[0][0]] #3
gamma=lcs[np.where(ids_ar == inv_ob_state["gamma"][0])[0][0]]
delta=lcs[np.where(ids_ar == inv_ob_state["delta"][9])[0][0]]
theta=lcs[np.where(ids_ar == inv_ob_state["theta"][13])[0][0]]#11
kappa=lcs[np.where(ids_ar == inv_ob_state["kappa"][6])[0][0]]#6
lambda1=lcs[np.where(ids_ar == inv_ob_state["lambda"][3])[0][0]] #3
mu=lcs[np.where(ids_ar == inv_ob_state["mu"][6])[0][0]]#6
nu=lcs[np.where(ids_ar == inv_ob_state["nu"][2])[0][0]]#0
rho=lcs[np.where(ids_ar == inv_ob_state["rho"][9])[0][0]]#9
phi=lcs[np.where(ids_ar == inv_ob_state["phi"][3])[0][0]]# 3,6
chi=lcs[np.where(ids_ar == inv_ob_state["chi"][27])[0][0]]# 1,17,27
eta=lcs[np.where(ids_ar == inv_ob_state["eta"][2])[0][0]]# 1
# omega=lcs[np.where(ids_ar == inv_ob_state["kappa"][-3])[0][0]]
omega=lcs[np.where(ids_ar == inv_ob_state["omega"][1])[0][0]]


selected_lcs = [alpha,beta,chi,delta,eta,gamma,kappa,lambda1,mu,nu,omega,phi,rho,theta]


fig, axes = plt.subplots(nrows=7, ncols=2)
axes = axes.flatten()

plt.subplots_adjust(hspace=0.05)
plt.subplots_adjust(wspace=0.01)

good_classes = ["delta", "mu", "phi", "rho"]
intervals = {}

for plot_ind in range(14):
    light_c = np.copy(selected_lcs[plot_ind])
    light_c[1] /=1000
    class_name = class_names[plot_ind]
    offset = light_c[0][0]
    axes[plot_ind].set_ylim([0, 12.5])
    
    if class_name == "alpha":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1
        axes[plot_ind].plot(light_c[0][:breaks[0]]-offset, light_c[1][:breaks[0]], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].set_xlim([0, 3500])
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)

    elif class_name == "beta":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start=breaks[0]
        end =breaks[1]
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "chi":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start= breaks[0]
        end =-1
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        

    elif class_name == "eta":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start= breaks[2]
        end =breaks[3]
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "gamma":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start=breaks[0]
        end =breaks[1]
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "kappa":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start=breaks[-1]
        end =-1
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name == "lambda":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start=breaks[-1]
        end =-1
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])

    elif class_name == "nu":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start= breaks[0]
        end =breaks[1]
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])

    elif class_name == "omega":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start= breaks[0]
        end =-1
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
        
    elif class_name == "theta":
        breaks = np.where((light_c[0][1:]-light_c[0][:-1]) != 1.)[0]+1 # [ 279 3584 6652]
        start=breaks[1]
        end =breaks[2]
        offset = light_c[0][start]
        axes[plot_ind].plot(light_c[0][start:end]-offset, light_c[1][start:end], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
        
    elif class_name in good_classes:
        axes[plot_ind].plot(light_c[0]-offset, light_c[1], c="blue", linewidth=0.5, zorder=-5)
        axes[plot_ind].text(0.99,0.95,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind].transAxes, size=10)
        axes[plot_ind].set_xlim([0, 3500])
    else:
        axes[plot_ind].plot(light_c[0]-offset, light_c[1])
        axes[plot_ind].plot(light_c[0][:3500]-offset, light_c[1][:3500])
    
    axes[plot_ind].set_xlim([0, 3500])
#     axes[plot_ind].tick_params(axis="x", which="major", length=5, width=1, labelsize=20, direction="in")
    
    if plot_ind%2 == 0:
        axes[plot_ind].tick_params(axis="y", which="major", length=2, width=1, labelsize=6, direction="in")
    else:
        axes[plot_ind].tick_params(axis="y", which="major", length=2, width=1, labelsize=0, direction="in")
        plt.setp(axes[plot_ind].get_yticklabels(), visible=False)
    if plot_ind == 6:
        axes[plot_ind].set_ylabel("Rate (kcounts/s)", size=6)
    if plot_ind == 12 or plot_ind == 13:
        axes[plot_ind].tick_params(axis="x", which="major", length=2, width=1, labelsize=6, direction="in")

    else:
        axes[plot_ind].tick_params(axis="x", which="major", length=2, width=1, labelsize=0, direction="in")
        plt.setp(axes[plot_ind].get_xticklabels(), visible=False)
    if plot_ind == 12:
        axes[plot_ind].set_xlabel("Time (s)", size=6, x=1)
    
#     axes[plot_ind].set_yticks([25, 50, 75, 100])
    axes[plot_ind].set_xticks([0,500, 1000,1500, 2000,2500, 3000])
    axes[plot_ind].set_yticks([0,2.5,5,7.5,10])
    axes[plot_ind].set_xticklabels([0,"", 1000,"", 2000,"", 3000])
    axes[plot_ind].set_yticklabels([0,"",5,"",10])
axes.reshape((7,2))


# axes[0][0].tick_params(axis="x", which="major", length=5, width=1, labelsize=5, direction="in")
# axes[0][1].tick_params(axis="x", which="major", length=5, width=1, labelsize=5, direction="in")
# axes[1][0].tick_params(axis="x", which="major", length=5, width=1, labelsize=25, direction="in")
# axes[1][1].tick_params(axis="x", which="major", length=5, width=1, labelsize=25, direction="in")

# axes[0][0].tick_params(axis="y", which="major", length=5, width=1, labelsize=25, direction="in")
# axes[1][0].tick_params(axis="y", which="major", length=5, width=1, labelsize=25, direction="in")
# axes[0][1].tick_params(axis="y", which="major", length=5, width=1, labelsize=5, direction="in")
# axes[1][1].tick_params(axis="y", which="major", length=5, width=1, labelsize=5, direction="in")

# plt.setp(axes[0][1].get_yticklabels(), visible=False)
# plt.setp(axes[1][1].get_yticklabels(), visible=False)
# plt.setp(axes[0][1].get_xticklabels(), visible=False)
# plt.setp(axes[0][0].get_xticklabels(), visible=False)

# axes[1][0].set_xticks([500,1000,1500,2000])
# axes[1][1].set_xticks([500,1000,1500,2000])


# axes[0][0].yaxis.set_label_coords(-0.2, 0)
# axes[1][0].xaxis.set_label_coords(1, -0.1)

# plt.suptitle("14 classes of activity of x-ray black hole binary GRS1915+105", fontsize=40, y=0.92)
# plt.rcParams.update({'font.size': 6})


# plt.savefig('all_classes_of_GRS1915.png', dpi=300)
plt.savefig('figures/light_curves.png', dpi=300, bbox_inches = 'tight',pad_inches = 0.01)

plt.show()

In [None]:
class_names

In [None]:
n=9
plt.plot(selected_lcs[n][0], selected_lcs[n][1])
plt.show()

# 4/1 second paper

In [None]:
# load observation classifications from Huppenkothen 2017
clean_belloni = open('{}/1915Belloniclass_updated.dat'.format(data_dir))
lines = clean_belloni.readlines()
states = lines[0].split()
belloni_clean = {}
for h,l in zip(states, lines[1:]):
    belloni_clean[h] = l.split()
    #state: obsID1, obsID2...
ob_state = {}
for state, obs in belloni_clean.items():
    if state == "chi1" or state == "chi2" or state == "chi3" or state == "chi4": state = "chi"
    for ob in obs:
        ob_state[ob] = state
        
# load IDs of segmented light curves: observationsID_segmentIndex
with open('{}/468202_len128_stride8_4sec_cad_ids_sum_bin.pkl'.format(data_dir), 'rb') as f:
    seg_ids = pickle.load(f)

        
seg_ObIDs = [seg.split("_")[0] for seg in seg_ids] # get rid of the within-observation segment indices and create a degenerate list of observation IDs

classes = np.array(["alpha", "beta", "gamma", "delta", "theta", "kappa", "lambda", "mu", "nu", "rho", "phi", "chi", "eta", "omega"])
scales = []
segment_class = []
for ob in seg_ObIDs:
    if ob in ob_state:
        segment_class.append(ob_state[ob])
    else:
        segment_class.append("Unknown")

In [None]:
segments_dir = '{}/468202_len128_stride8_4sec_cad_countrates_sum_bin.pkl'.format(data_dir)
errors_dir = '{}/468202_len128_stride8_4sec_cad_errors_sum_bin.pkl'.format(data_dir)
recos_dir = "{}/reconstructions_model_model_2020-12-21_20-11-39_segments_468202_len128_stride8_4sec_cad_countrates_sum_bin.pkl".format(data_dir)

with open(segments_dir, 'rb') as f:
    segments = pickle.load(f)
with open(errors_dir, 'rb') as f:
    errors = pickle.load(f)
with open(recos_dir, 'rb') as f:
    recos = pickle.load(f)

recos= recos*np.std(segments, axis=1) + np.mean(segments, axis=1)

# errors = ((errors)/np.expand_dims(np.std(segments, axis=1), axis=1)).astype(np.float32)
# segments = zscore(segments, axis=1).astype(np.float32)  # standardize per segment

In [None]:
with open('{}/lightcurve1738_train70_val10_test20.pkl'.format(data_dir), 'rb') as f:
    split_ob_ids = pickle.load(f)
test_ids = [seg.split("_")[0] for seg in split_ob_ids[2]]

In [None]:
import matplotlib.pylab as pylab
import matplotlib.ticker as ticker
#https://stackoverflow.com/questions/8389636/creating-over-20-unique-legend-colors-using-matplotlib


# pylab.rcParams['figure.figsize'] = (3.15, 8.4) # A4 size 210mm x 297mm
plt.rcParams['figure.figsize'] = (6.97, 8.4)
plt.rcParams.update({'font.size': 6})


# colors = matplotlib.colors.CSS4_COLORS.keys()
# colors = np.array(list(colors))


# NUM_COLORS = 14
# cm = plt.get_cmap("jet")#('gist_rainbow')
# colors = [cm(1.*i/NUM_COLORS) for i in range(NUM_COLORS)]

ids_ar = np.array(segment_class)

# class_names = list(inv_ob_state.keys())


alpha = np.where(ids_ar == "alpha")[0][0]#0
beta= np.where(ids_ar == "beta")[0][778]#7
gamma=np.where(ids_ar == "gamma")[0][0]
delta=np.where(ids_ar == "delta")[0][2]#5
theta=np.where(ids_ar == "theta")[0][0]#4
kappa=np.where(ids_ar == "kappa")[0][5]
lambda1=np.where(ids_ar == "lambda")[0][220]
mu=np.where(ids_ar == "mu")[0][1]
nu=np.where(ids_ar == "nu")[0][2]
rho=np.where(ids_ar == "rho")[0][1]
phi=np.where(ids_ar == "phi")[0][14]
chi=np.where(ids_ar == "chi")[0][13]
eta=np.where(ids_ar == "eta")[0][1]
omega=np.where(ids_ar == "omega")[0][3]


# selected_lcs = [alpha,beta,gamma,delta,theta,kappa,lambda1,mu,nu,rho,phi,chi,eta,omega]
selected_lcs = [alpha,beta,chi,delta,eta,gamma,kappa,lambda1,mu,nu,omega,phi,rho,theta]


fig, axes = plt.subplots(nrows=14, ncols=2)
# axes = axes.flatten()

plt.subplots_adjust(hspace=0.05)
plt.subplots_adjust(wspace=0.01)

# good_classes = ["delta", "mu", "rho", "phi"]
intervals = {}

class_names = np.array(["alpha", "beta", "gamma", "delta", "theta", "kappa", "lambda", "mu", "nu", "rho", "phi", "chi", "eta", "omega"])
class_names = np.sort(class_names)

for plot_ind in range(14):
    light_c = selected_lcs[plot_ind]
    class_name = class_names[plot_ind]    
    axes[plot_ind, 0].plot(np.linspace(0,512, num=128), segments[light_c], c="black", linewidth=0.5, zorder=1, label="Input")
    axes[plot_ind, 0].plot(np.linspace(0,512, num=128), recos[light_c], c="magenta", linewidth=0.5, zorder=2, label="Reconstruction")
    axes[plot_ind, 1].plot(np.linspace(0,512, num=128), segments[light_c]/1000, c="black", linewidth=0.5, zorder=1, label="Input")
    axes[plot_ind, 1].plot(np.linspace(0,512, num=128), recos[light_c]/1000, c="magenta", linewidth=0.5, zorder=2, label="Reconstruction")
    axes[plot_ind, 1].yaxis.tick_right()
    axes[plot_ind, 0].set_ylim([0, 15000])
    axes[plot_ind, 0].set_xlim([0, 512])
    axes[plot_ind, 1].set_xlim([0, 512])

    axes[plot_ind, 0].text(0.99,0.99,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind, 0].transAxes, size=10)
        
    axes[plot_ind, 0].tick_params(axis="y", which="major", length=2, width=0.75, labelsize=6, direction="in")
    axes[plot_ind, 1].tick_params(axis="y", which="major", length=2, width=0.75, labelsize=6, direction="in")
#     axes[plot_ind].tick_params(axis="y", which="minor", length=2, width=0.75, labelsize=8, direction="in")


#     if plot_ind%2 == 0:
#         axes[plot_ind].tick_params(axis="y", which="major", length=5, width=1, labelsize=8, direction="in")
#     else:
#         axes[plot_ind].tick_params(axis="y", which="major", length=5, width=1, labelsize=0, direction="in")
#         plt.setp(axes[plot_ind].get_yticklabels(), visible=False)

    if plot_ind == 6:
        axes[plot_ind, 0].set_ylabel("Rate (kcts/s)", size=6)
    if plot_ind == 13:#plot_ind == 12 or plot_ind == 13:
        axes[plot_ind, 0].tick_params(axis="x", which="major", length=2, width=0.75, labelsize=6, direction="in")
        axes[plot_ind, 1].tick_params(axis="x", which="major", length=2, width=0.75, labelsize=6, direction="in")
        axes[plot_ind, 0].set_xlabel("Time (s)", size=6, x=1)

    else:
        axes[plot_ind, 0].tick_params(axis="x", which="major", length=2, width=0.75, labelsize=0, direction="in")
        plt.setp(axes[plot_ind, 0].get_xticklabels(), visible=False)
        axes[plot_ind, 1].tick_params(axis="x", which="major", length=2, width=0.75, labelsize=0, direction="in")
        plt.setp(axes[plot_ind, 1].get_xticklabels(), visible=False)
    
#     axes[plot_ind].set_yticks([0, 2000, 4000, 6000, 8000, 10000, 12000, 14000])
#     axes[plot_ind].set_yticklabels([0, "", "", 6, "", "", 12, ""])
    axes[plot_ind, 0].set_yticks([0, 2500, 5000, 7500, 10000, 12500])
    axes[plot_ind, 0].set_yticklabels([0, "", 5, "", 10, "",])
    axes[plot_ind, 0].set_xticks([0, 100, 200,300, 400, 500])
    axes[plot_ind, 0].set_xticklabels([0, "", 200,"", 400, ""])
    axes[plot_ind, 1].set_xticks([0, 100, 200,300, 400, 500])
    axes[plot_ind, 1].set_xticklabels([0, "", 200,"", 400, ""])
    
# axes.reshape((14,1))


# axes[0][0].tick_params(axis="x", which="major", length=5, width=1, labelsize=5, direction="in")
# axes[0][1].tick_params(axis="x", which="major", length=5, width=1, labelsize=5, direction="in")
# axes[1][0].tick_params(axis="x", which="major", length=5, width=1, labelsize=25, direction="in")
# axes[1][1].tick_params(axis="x", which="major", length=5, width=1, labelsize=25, direction="in")

# axes[0][0].tick_params(axis="y", which="major", length=5, width=1, labelsize=25, direction="in")
# axes[1][0].tick_params(axis="y", which="major", length=5, width=1, labelsize=25, direction="in")
# axes[0][1].tick_params(axis="y", which="major", length=5, width=1, labelsize=5, direction="in")
# axes[1][1].tick_params(axis="y", which="major", length=5, width=1, labelsize=5, direction="in")

# plt.setp(axes[0][1].get_yticklabels(), visible=False)
# plt.setp(axes[1][1].get_yticklabels(), visible=False)
# plt.setp(axes[0][1].get_xticklabels(), visible=False)
# plt.setp(axes[0][0].get_xticklabels(), visible=False)

# axes[1][0].set_xticks([500,1000,1500,2000])
# axes[1][1].set_xticks([500,1000,1500,2000])


# axes[0][0].yaxis.set_label_coords(-0.2, 0)
# axes[1][0].xaxis.set_label_coords(1, -0.1)

# plt.suptitle("14 classes of activity of x-ray black hole binary GRS1915+105", fontsize=40, y=0.92)

# plt.gca().set_axis_off()
# plt.subplots_adjust(top = 1, bottom = 0, right = 1, left = 0, 
#             hspace = 0, wspace = 0)
# plt.margins(0,0)
# plt.gca().xaxis.set_major_locator(plt.NullLocator())
# plt.gca().yaxis.set_major_locator(plt.NullLocator())
plt.savefig('figures/segments_fit_alphabetical_4s_test_only.png', dpi=300, bbox_inches = 'tight',pad_inches = 0)
# plt.legend()
plt.show()

In [None]:
plt.rcParams['figure.figsize'] = (4, 1)


for i in range(500):

    alpha = np.where(ids_ar == "alpha")[0][0]#0
    beta= np.where(ids_ar == "beta")[0][i]#7
    gamma=np.where(ids_ar == "gamma")[0][0]
    delta=np.where(ids_ar == "delta")[0][0]#5
    theta=np.where(ids_ar == "theta")[0][i]#4
    kappa=np.where(ids_ar == "kappa")[0][0]
    lambda1=np.where(ids_ar == "lambda")[0][0]
    mu=np.where(ids_ar == "mu")[0][0]
    nu=np.where(ids_ar == "nu")[0][0]
    rho=np.where(ids_ar == "rho")[0][0]
    phi=np.where(ids_ar == "phi")[0][0]
    chi=np.where(ids_ar == "chi")[0][0]
    eta=np.where(ids_ar == "eta")[0][0]
    omega=np.where(ids_ar == "omega")[0][0]


    selected_lcs = [alpha,beta,chi,delta,eta,gamma,kappa,lambda1,mu,nu,omega,phi,rho,theta]
# , ,  omega
    light_c = selected_lcs[1]
    
    
    if seg_ObIDs[light_c] in test_ids:

        plt.plot(np.linspace(0,128, num=128), segments[light_c], c="black", linewidth=0.5, zorder=1, label="Input")
        plt.show()
        print(i)
        print(seg_ObIDs[light_c])

In [None]:
# load observation classifications from Huppenkothen 2017
clean_belloni = open('{}/1915Belloniclass_updated.dat'.format(data_dir))
lines = clean_belloni.readlines()
states = lines[0].split()
belloni_clean = {}
for h,l in zip(states, lines[1:]):
    belloni_clean[h] = l.split()
    #state: obsID1, obsID2...
ob_state = {}
for state, obs in belloni_clean.items():
    if state == "chi1" or state == "chi2" or state == "chi3" or state == "chi4": state = "chi"
    for ob in obs:
        ob_state[ob] = state
        
# load IDs of segmented light curves: observationsID_segmentIndex
with open('{}/474471_len128_stride10_1sec_cad_ids_sum_bin.pkl'.format(data_dir), 'rb') as f:
    seg_ids = pickle.load(f)

        
seg_ObIDs = [seg.split("_")[0] for seg in seg_ids] # get rid of the within-observation segment indices and create a degenerate list of observation IDs

classes = np.array(["alpha", "beta", "gamma", "delta", "theta", "kappa", "lambda", "mu", "nu", "rho", "phi", "chi", "eta", "omega"])
scales = []
segment_class = []
for ob in seg_ObIDs:
    if ob in ob_state:
        segment_class.append(ob_state[ob])
    else:
        segment_class.append("Unknown")

In [None]:
segments_dir = '{}/474471_len128_stride10_1sec_cad_countrates_sum_bin.pkl'.format(data_dir)
errors_dir = '{}/474471_len128_stride10_1sec_cad_errors_sum_bin.pkl'.format(data_dir)
recos_dir = "{}/reconstructions_model_model_2020-12-24_13-14-02_segments_474471_len128_stride10_1sec_cad_countrates_sum_bin.pkl".format(data_dir)

with open(segments_dir, 'rb') as f:
    segments = pickle.load(f)
with open(errors_dir, 'rb') as f:
    errors = pickle.load(f)
with open(recos_dir, 'rb') as f:
    recos = pickle.load(f)

recos= recos*np.std(segments, axis=1) + np.mean(segments, axis=1)

# errors = ((errors)/np.expand_dims(np.std(segments, axis=1), axis=1)).astype(np.float32)
# segments = zscore(segments, axis=1).astype(np.float32)  # standardize per segment

In [None]:
with open('{}/lightcurve1738_train70_val10_test20.pkl'.format(data_dir), 'rb') as f:
    split_ob_ids = pickle.load(f)
test_ids = [seg.split("_")[0] for seg in split_ob_ids[2]]

In [None]:
test_ids = [seg.split("_")[0] for seg in split_ob_ids[2]]

In [None]:
test_ids

In [None]:
import matplotlib.pylab as pylab
import matplotlib.ticker as ticker
#https://stackoverflow.com/questions/8389636/creating-over-20-unique-legend-colors-using-matplotlib


# pylab.rcParams['figure.figsize'] = (3.15, 8.4) # A4 size 210mm x 297mm
plt.rcParams['figure.figsize'] = (6.97, 8.4)
plt.rcParams.update({'font.size': 6})


# colors = matplotlib.colors.CSS4_COLORS.keys()
# colors = np.array(list(colors))


# NUM_COLORS = 14
# cm = plt.get_cmap("jet")#('gist_rainbow')
# colors = [cm(1.*i/NUM_COLORS) for i in range(NUM_COLORS)]

ids_ar = np.array(segment_class)

# class_names = list(inv_ob_state.keys())


alpha = np.where(ids_ar == "alpha")[0][24]#
beta= np.where(ids_ar == "beta")[0][72]#
gamma=np.where(ids_ar == "gamma")[0][9] #
delta=np.where(ids_ar == "delta")[0][5]#
theta=np.where(ids_ar == "theta")[0][30]#
kappa=np.where(ids_ar == "kappa")[0][13] #
lambda1=np.where(ids_ar == "lambda")[0][110] # 
mu=np.where(ids_ar == "mu")[0][0] # 
nu=np.where(ids_ar == "nu")[0][6] # 
rho=np.where(ids_ar == "rho")[0][0] # 
phi=np.where(ids_ar == "phi")[0][4] # 
chi=np.where(ids_ar == "chi")[0][4] # 
eta=np.where(ids_ar == "eta")[0][3] # 
omega=np.where(ids_ar == "omega")[0][164] # 


selected_lcs = [alpha,beta,chi,delta,eta,gamma,kappa,lambda1,mu,nu,omega,phi,rho,theta]



fig, axes = plt.subplots(nrows=14, ncols=2)
# axes = axes.flatten()

plt.subplots_adjust(hspace=0.05)
plt.subplots_adjust(wspace=0.01)

# good_classes = ["delta", "mu", "rho", "phi"]
intervals = {}

class_names = np.array(["alpha", "beta", "gamma", "delta", "theta", "kappa", "lambda", "mu", "nu", "rho", "phi", "chi", "eta", "omega"])
class_names = np.sort(class_names)

for plot_ind in range(14):
    light_c = selected_lcs[plot_ind]
    class_name = class_names[plot_ind]    
    axes[plot_ind, 0].plot(np.linspace(0,128, num=128), segments[light_c], c="black", linewidth=0.5, zorder=1, label="Input")
    axes[plot_ind, 0].plot(np.linspace(0,128, num=128), recos[light_c], c="magenta", linewidth=0.5, zorder=2, label="Reconstruction")
    axes[plot_ind, 1].plot(np.linspace(0,128, num=128), segments[light_c]/1000, c="black", linewidth=0.5, zorder=1, label="Input")
    axes[plot_ind, 1].plot(np.linspace(0,128, num=128), recos[light_c]/1000, c="magenta", linewidth=0.5, zorder=2, label="Reconstruction")
    axes[plot_ind, 1].yaxis.tick_right()
    axes[plot_ind, 0].set_ylim([0, 15000])
    axes[plot_ind, 0].set_xlim([0, 128])
    axes[plot_ind, 1].set_xlim([0, 128])

    axes[plot_ind, 0].text(0.99,0.99,r"$\{}$".format(class_name), ha='right', va='top', transform=axes[plot_ind, 0].transAxes, size=10)
        
    axes[plot_ind, 0].tick_params(axis="y", which="major", length=2, width=0.75, labelsize=6, direction="in")
    axes[plot_ind, 1].tick_params(axis="y", which="major", length=2, width=0.75, labelsize=6, direction="in")
#     axes[plot_ind].tick_params(axis="y", which="minor", length=2, width=0.75, labelsize=8, direction="in")


#     if plot_ind%2 == 0:
#         axes[plot_ind].tick_params(axis="y", which="major", length=5, width=1, labelsize=8, direction="in")
#     else:
#         axes[plot_ind].tick_params(axis="y", which="major", length=5, width=1, labelsize=0, direction="in")
#         plt.setp(axes[plot_ind].get_yticklabels(), visible=False)

    if plot_ind == 6:
        axes[plot_ind, 0].set_ylabel("Rate (kcts/s)", size=6)
    if plot_ind == 13:#plot_ind == 12 or plot_ind == 13:
        axes[plot_ind, 0].tick_params(axis="x", which="major", length=2, width=0.75, labelsize=6, direction="in")
        axes[plot_ind, 1].tick_params(axis="x", which="major", length=2, width=0.75, labelsize=6, direction="in")
        axes[plot_ind, 0].set_xlabel("Time (s)", size=6, x=1)

    else:
        axes[plot_ind, 0].tick_params(axis="x", which="major", length=2, width=0.75, labelsize=0, direction="in")
        plt.setp(axes[plot_ind, 0].get_xticklabels(), visible=False)
        axes[plot_ind, 1].tick_params(axis="x", which="major", length=2, width=0.75, labelsize=0, direction="in")
        plt.setp(axes[plot_ind, 1].get_xticklabels(), visible=False)
    
#     axes[plot_ind].set_yticks([0, 2000, 4000, 6000, 8000, 10000, 12000, 14000])
#     axes[plot_ind].set_yticklabels([0, "", "", 6, "", "", 12, ""])
    axes[plot_ind, 0].set_yticks([0, 2500, 5000, 7500, 10000, 12500])
    axes[plot_ind, 0].set_yticklabels([0, "", 5, "", 10, "",])
    axes[plot_ind, 0].set_xticks([0, 25, 50,75, 100, 125])
    axes[plot_ind, 0].set_xticklabels([0, "", 50,"", 100, ""])
    axes[plot_ind, 1].set_xticks([0, 25, 50,75, 100, 125])
    axes[plot_ind, 1].set_xticklabels([0, "", 50,"", 100, ""])

plt.savefig('figures/segments_fit_alphabetical_1s_test_only.png', dpi=300, bbox_inches = 'tight',pad_inches = 0)
# plt.legend()
plt.show()

In [None]:
i = 10408-01
j = 20187-02
k = 20402-01 

In [None]:
inv_ob_state["lambda"]

In [None]:
plt.rcParams['figure.figsize'] = (4, 1)


for i in range(500):

    alpha = np.where(ids_ar == "alpha")[0][0]#0
    beta= np.where(ids_ar == "beta")[0][i]#7
    gamma=np.where(ids_ar == "gamma")[0][0]
    delta=np.where(ids_ar == "delta")[0][0]#5
    theta=np.where(ids_ar == "theta")[0][0]#4
    kappa=np.where(ids_ar == "kappa")[0][0]
    lambda1=np.where(ids_ar == "lambda")[0][0]
    mu=np.where(ids_ar == "mu")[0][0]
    nu=np.where(ids_ar == "nu")[0][0]
    rho=np.where(ids_ar == "rho")[0][0]
    phi=np.where(ids_ar == "phi")[0][0]
    chi=np.where(ids_ar == "chi")[0][0]
    eta=np.where(ids_ar == "eta")[0][i]
    omega=np.where(ids_ar == "omega")[0][0]


    # selected_lcs = [,,,,,,,lambda1,mu,nu,omega,phi,rho,theta]
    selected_lcs = [alpha,beta,chi,delta,eta,gamma,kappa,lambda1,mu,nu,omega,phi,rho,theta]
# , ,  omega
    light_c = selected_lcs[4]
    
    
    if seg_ObIDs[light_c] in test_ids:

        plt.plot(np.linspace(0,128, num=128), segments[light_c], c="black", linewidth=0.5, zorder=1, label="Input")
        plt.show()
        print(i)
        print(seg_ObIDs[light_c])

In [None]:
segments_dir = '{}/474471_len128_stride10_1sec_cad_countrates_sum_bin.pkl'.format(data_dir)
errors_dir = '{}/474471_len128_stride10_1sec_cad_errors_sum_bin.pkl'.format(data_dir)
recos_dir = "{}/reconstructions_model_model_2020-12-24_13-14-02_segments_474471_len128_stride10_1sec_cad_countrates_sum_bin.pkl".format(data_dir)

with open(segments_dir, 'rb') as f:
    segments = pickle.load(f)
with open(errors_dir, 'rb') as f:
    errors = pickle.load(f)
with open(recos_dir, 'rb') as f:
    recos = pickle.load(f)

# recos= recos*np.std(segments, axis=1) + np.mean(segments, axis=1)

errors = ((errors)/np.expand_dims(np.std(segments, axis=1), axis=1)).astype(np.float32)
segments = zscore(segments, axis=1).astype(np.float32)  # standardize per segment

In [None]:
errors.shape

In [None]:
94951/474471

In [None]:
chi_square = np.mean(((segments.reshape((474471,128))-recos)/errors.reshape((474471,128)))**2, axis=1)

In [None]:
with open('{}/474471_len128_stride10_1sec_cad_ids_sum_bin.pkl'.format(data_dir), 'rb') as f:
    seg_ids = pickle.load(f)

        
seg_ObIDs = [seg.split("_")[0] for seg in seg_ids]

In [None]:
np.array(seg_ObIDs) 

In [None]:
np.isin(np.array(seg_ObIDs), np.array(split_ob_ids[2])).sum()

In [None]:
np.mean(chi_square[np.isin(np.array(seg_ObIDs), np.array(split_ob_ids[2]))])

In [None]:
np.median(chi_square[np.isin(np.array(seg_ObIDs), np.array(split_ob_ids[2]))])

In [None]:
plt.hist(chi_square[np.isin(np.array(seg_ObIDs), np.array(split_ob_ids[2]))], bins=20)
# plt.show()