In [1]:
import pandas as pd
import numpy as np
import pickle

from sklearn.preprocessing import LabelEncoder, StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split

import ugtm
from ugtm import eGTM
import altair as alt


### Put together dataframe to generate dataset.

In [2]:
label_scheme = 'labels_2'

pd.options.display.max_columns = None

dataset = pd.read_csv(f'../processed_data/dataset_allfeatures_inc_labels.csv', low_memory=False)
dataset = dataset[dataset[label_scheme].notna()].reset_index(drop=True)
print(dataset[label_scheme].value_counts())

# Features and labels.
X = dataset.iloc[:,14:]
X_cols = X.columns.tolist()
y_names = dataset[label_scheme]

with open('../results/encoder/encoder_rf_sampling.pkl', 'rb') as f:
    enc = pickle.load(f)

# enc = LabelEncoder().fit(y_names)
y = enc.transform(y_names)

# Train, test, split.
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=1)
dataset.head()


labels_2
dwarf_nova_SU_UMa    630
dwarf_nova_Z_Cam     174
nova_like            144
nova_like_VY_Scl     120
dwarf_nova_U_Gem     116
polar                114
int_polar             49
AMCVn                 46
nova                  46
Name: count, dtype: int64


Unnamed: 0,oid_ztf,oid_aavso,type_aavso,Eclipsing,CV_Types,CV_subtypes,CV_subsubtypes,eclipse_clear,manual_label,Clarity,labels_1,labels_2,labels_3,labels_4,Amplitude_g,AndersonDarling_g,Autocor_length_g,Beyond1Std_g,CAR_mean_g,CAR_sigma_g,CAR_tau_g,Con_g,Eta_e_g,FluxPercentileRatioMid20_g,FluxPercentileRatioMid35_g,FluxPercentileRatioMid50_g,FluxPercentileRatioMid65_g,FluxPercentileRatioMid80_g,Freq1_harmonics_amplitude_0_g,Freq1_harmonics_amplitude_1_g,Freq1_harmonics_amplitude_2_g,Freq1_harmonics_amplitude_3_g,Freq1_harmonics_rel_phase_1_g,Freq1_harmonics_rel_phase_2_g,Freq1_harmonics_rel_phase_3_g,Freq2_harmonics_amplitude_0_g,Freq2_harmonics_amplitude_1_g,Freq2_harmonics_amplitude_2_g,Freq2_harmonics_amplitude_3_g,Freq2_harmonics_rel_phase_1_g,Freq2_harmonics_rel_phase_2_g,Freq2_harmonics_rel_phase_3_g,Freq3_harmonics_amplitude_0_g,Freq3_harmonics_amplitude_1_g,Freq3_harmonics_amplitude_2_g,Freq3_harmonics_amplitude_3_g,Freq3_harmonics_rel_phase_1_g,Freq3_harmonics_rel_phase_2_g,Freq3_harmonics_rel_phase_3_g,Gskew_g,LinearTrend_g,MaxSlope_g,Mean_g,Meanvariance_g,MedianAbsDev_g,MedianBRP_g,PairSlopeTrend_g,PercentAmplitude_g,PercentDifferenceFluxPercentile_g,PeriodLS_g,Period_fit_g,Psi_CS_g,Psi_eta_g,Q31_g,Rcs_g,Skew_g,SlottedA_length_g,SmallKurtosis_g,Std_g,StetsonK_g,StetsonK_AC_g,StructureFunction_index_21_g,StructureFunction_index_31_g,StructureFunction_index_32_g,Amplitude_r,AndersonDarling_r,Autocor_length_r,Beyond1Std_r,CAR_mean_r,CAR_sigma_r,CAR_tau_r,Con_r,Eta_e_r,FluxPercentileRatioMid20_r,FluxPercentileRatioMid35_r,FluxPercentileRatioMid50_r,FluxPercentileRatioMid65_r,FluxPercentileRatioMid80_r,Freq1_harmonics_amplitude_0_r,Freq1_harmonics_amplitude_1_r,Freq1_harmonics_amplitude_2_r,Freq1_harmonics_amplitude_3_r,Freq1_harmonics_rel_phase_1_r,Freq1_harmonics_rel_phase_2_r,Freq1_harmonics_rel_phase_3_r,Freq2_harmonics_amplitude_0_r,Freq2_harmonics_amplitude_1_r,Freq2_harmonics_amplitude_2_r,Freq2_harmonics_amplitude_3_r,Freq2_harmonics_rel_phase_1_r,Freq2_harmonics_rel_phase_2_r,Freq2_harmonics_rel_phase_3_r,Freq3_harmonics_amplitude_0_r,Freq3_harmonics_amplitude_1_r,Freq3_harmonics_amplitude_2_r,Freq3_harmonics_amplitude_3_r,Freq3_harmonics_rel_phase_1_r,Freq3_harmonics_rel_phase_2_r,Freq3_harmonics_rel_phase_3_r,Gskew_r,LinearTrend_r,MaxSlope_r,Mean_r,Meanvariance_r,MedianAbsDev_r,MedianBRP_r,PairSlopeTrend_r,PercentAmplitude_r,PercentDifferenceFluxPercentile_r,PeriodLS_r,Period_fit_r,Psi_CS_r,Psi_eta_r,Q31_r,Rcs_r,Skew_r,SlottedA_length_r,SmallKurtosis_r,Std_r,StetsonK_r,StetsonK_AC_r,StructureFunction_index_21_r,StructureFunction_index_31_r,StructureFunction_index_32_r,Q31_color,StetsonJ,StetsonL,median_g,min_mag_g,max_mag_g,n_obs_g,dif_min_mean_g,dif_min_median_g,dif_max_mean_g,dif_max_median_g,dif_max_min_g,temporal_baseline_g,kurtosis_g,pwr_max_g,freq_pwr_max_g,FalseAlarm_prob_g,pwr_maxovermean_g,npeaks_pt5to1_g,rrate_pt5to1_g,drate_pt5to1_g,amp_pt5to1_g,npeaks_1to2_g,rrate_1to2_g,drate_1to2_g,amp_1to2_g,npeaks_2to5_g,rrate_2to5_g,drate_2to5_g,amp_2to5_g,npeaks_above5_g,rrate_above5_g,drate_above5_g,amp_above5_g,rollstd_ratio_t20s10_g,stdstilllev_t20s10_g,rollstd_ratio_t10s5_g,stdstilllev_t10s5g,pnts_leq_rollMedWin20-1mag_g,pnts_leq_rollMedWin20-2mag_g,pnts_leq_rollMedWin20-5mag_g,pnts_geq_rollMedWin20+1mag_g,pnts_geq_rollMedWin20+2mag_g,pnts_geq_rollMedWin20+3mag_g,pnts_leq_median-1mag_g,pnts_leq_median-2mag_g,pnts_leq_median-5mag_g,pnts_geq_median+1mag_g,pnts_geq_median+2mag_g,pnts_geq_median+3mag_g,median_r,min_mag_r,max_mag_r,n_obs_r,dif_min_mean_r,dif_min_median_r,dif_max_mean_r,dif_max_median_r,dif_max_min_r,temporal_baseline_r,kurtosis_r,pwr_max_r,freq_pwr_max_r,FalseAlarm_prob_r,pwr_maxovermean_r,npeaks_pt5to1_r,rrate_pt5to1_r,drate_pt5to1_r,amp_pt5to1_r,npeaks_1to2_r,rrate_1to2_r,drate_1to2_r,amp_1to2_r,npeaks_2to5_r,rrate_2to5_r,drate_2to5_r,amp_2to5_r,npeaks_above5_r,rrate_above5_r,drate_above5_r,amp_above5_r,rollstd_ratio_t20s10_r,stdstilllev_t20s10_r,rollstd_ratio_t10s5_r,stdstilllev_t10s5r,pnts_leq_rollMedWin20-1mag_r,pnts_leq_rollMedWin20-2mag_r,pnts_leq_rollMedWin20-5mag_r,pnts_geq_rollMedWin20+1mag_r,pnts_geq_rollMedWin20+2mag_r,pnts_geq_rollMedWin20+3mag_r,pnts_leq_median-1mag_r,pnts_leq_median-2mag_r,pnts_leq_median-5mag_r,pnts_geq_median+1mag_r,pnts_geq_median+2mag_r,pnts_geq_median+3mag_r,clr_mean,clr_median,clr_std,clr_bright,clr_faint,ra,dec,ra_error,dec_error,parallax,parallax_error,pm,pmra_error,pmdec_error,nu_eff_used_in_astrometry,astrometric_sigma5d_max,phot_g_n_obs,phot_g_mean_flux,phot_g_mean_flux_error,phot_g_mean_mag,phot_bp_n_obs,phot_bp_mean_flux,phot_bp_mean_flux_error,phot_bp_mean_mag,phot_rp_n_obs,phot_rp_mean_flux,phot_rp_mean_flux_error,phot_rp_mean_mag,bp_rp,bp_g,g_rp,l,b,ecl_lon,ecl_lat,distance,absmag_g,absmag_bp,absmag_rp
0,ZTF18abryuah,ASASSN-19dp,AM,0,polar,,,0.0,AM_Her,1.0,polar,polar,polar,magnetic,1.332255,1.0,31.0,0.328814,0.8399,0.252236,21.584233,0.0,1869.996451,0.257652,0.434712,0.605825,0.755212,0.874876,0.997426,0.505215,0.339324,0.484018,-0.600325,0.401968,-1.025766,0.581499,0.049171,0.218349,0.156956,1.208792,2.612223,1.16429,0.285483,0.114756,0.074438,0.032343,0.15093,-0.433048,-1.709862,0.807321,-0.000105,30.830754,18.12859,0.044738,0.612558,0.271186,0.1,0.10924,0.134467,63.297317,0.0,0.421883,0.256133,1.458057,0.421883,0.316666,8.0,-1.173994,0.81103,0.839603,0.876159,1.912595,2.758041,1.473054,1.525589,1.0,27.0,0.313158,63.131571,1.954631,0.104404,0.005291,26327.378818,0.138095,0.404874,0.537797,0.694851,0.861921,0.860887,0.351262,0.169339,0.523886,0.194579,1.224606,-0.479172,0.457272,0.151486,0.072937,0.147139,0.066141,-0.615173,1.114636,0.330389,0.068003,0.148562,0.067203,-0.957594,-0.67599,1.568794,0.496268,-0.000153,172.486888,17.714298,0.050951,0.760635,0.3,-0.1,0.165693,0.157677,58.970956,0.0,0.367095,0.604717,1.483988,0.367095,0.372961,11.0,-0.564214,0.902558,0.90492,0.801733,1.578056,1.922599,1.336618,0.638181,8.538768,7.957973,17.893334,16.544905,19.848006,295.0,1.583685,1.348428,1.719416,1.954673,3.303101,1629.756366,-1.186561,0.694003,0.00092,1.076016e-71,80.776209,24.0,0.061514,0.779111,0.871705,12.0,0.050872,0.080337,1.927747,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,9.169765,0.649583,16.998108,0.650764,0.0,0.0,0.0,2.0,0.0,0.0,4.0,0.0,0.0,75.0,0.0,0.0,17.672762,16.113612,20.601009,380.0,1.600686,1.55915,2.886711,2.928246,4.487397,1623.791701,-0.577142,0.481344,0.000924,2.0657709999999998e-50,53.782559,32.0,0.868543,0.276155,0.993821,41.0,0.375136,0.84464,1.922408,1.0,0.014205,0.007751,3.681235,0.0,0.0,0.0,0.0,14.305113,0.435391,54.475168,0.800782,3.0,0.0,0.0,14.0,4.0,1.0,47.0,0.0,0.0,67.0,5.0,0.0,0.633883,0.553139,0.494722,0.471238,1.060625,35.745917,43.653639,0.174089,0.183429,2.705361,0.176508,7.80863,0.332211,0.359794,1.457178,0.536487,179.0,292.473554,18.738121,19.52215,19.0,202.238499,44.369358,19.573883,17.0,349.148984,61.556385,18.39037,1.183514,0.051733,1.131781,139.96794,-16.169142,48.441976,27.723366,369.636457,11.683276,11.735009,10.551495
1,ZTF18abtrvgp,BMAM-V789,AM,0,polar,,,0.0,AM_Her,1.0,polar,polar,polar,magnetic,0.714916,0.466196,1.0,0.303571,6.72318,0.29408,2.91073,0.0,8566.617616,0.138306,0.335229,0.427281,0.574874,0.729827,0.319827,0.101678,0.093589,0.136091,2.169127,0.871884,2.623909,0.228019,0.048791,0.036217,0.152055,-0.034098,-0.807468,0.344017,0.229103,0.059155,0.057695,0.106019,-0.578028,-1.035542,1.433269,0.061647,7.7e-05,30.830754,19.569362,0.019939,0.297931,0.267857,-0.033333,0.044537,0.069247,8.927349,1.0,0.217042,1.518778,0.579314,0.167864,0.09127,2.0,-0.518765,0.390195,0.814326,0.690661,1.551979,1.935767,1.341881,1.789813,1.0,3.0,0.1625,0.649302,0.268176,29.886054,0.089744,104430.255608,0.060065,0.121724,0.171309,0.337565,0.909486,1.20964,0.617273,0.48028,0.268577,0.573496,0.346147,-0.51876,0.673643,0.384837,0.132118,0.156886,2.088083,1.016367,-0.575305,0.523602,0.331903,0.177305,0.343568,-0.825428,1.261461,0.604871,2.135367,-0.00045,145.874805,19.405084,0.053637,0.287708,0.6125,0.033333,0.153743,0.176095,1.001825,0.001777,0.18996,0.970269,0.566743,0.248854,1.511615,8.0,1.080571,1.040834,0.840262,0.698896,1.927601,2.756006,1.476504,0.797858,0.969203,0.930821,19.542219,18.671871,20.367321,56.0,0.897491,0.870347,0.797959,0.825102,1.69545,1514.9711,-0.612637,0.197295,0.002838,0.6816253,4.408452,7.0,0.050678,0.046268,0.894904,2.0,0.218368,0.033067,1.669533,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.829709,0.349326,4.396222,0.673015,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,19.051129,18.282168,21.980108,80.0,1.122916,0.768961,2.575024,2.928979,3.69794,1551.926551,0.976278,0.342859,0.004704,8.188138e-05,6.581204,7.0,0.148288,0.039508,0.980322,1.0,0.038511,1.096651,1.076622,2.0,0.202491,0.013476,3.652952,0.0,0.0,0.0,0.0,10.703135,0.82078,16.124018,0.814912,1.0,0.0,0.0,13.0,9.0,0.0,0.0,0.0,0.0,13.0,13.0,0.0,0.689769,0.931386,0.886252,0.471734,-0.109275,38.492292,41.623972,0.209687,0.23109,1.389675,0.202936,9.820421,0.302572,0.351561,1.599227,0.539709,179.0,537.58142,28.596745,18.861256,19.0,396.328598,63.694065,18.843403,20.0,417.576534,55.68394,18.196054,0.647348,-0.017853,0.665201,142.72671,-17.273221,49.748682,25.109881,719.592798,9.575822,9.557969,8.91062
2,ZTF17aaaehby,CSS 091026:002637+242916,AM,0,polar,,,0.0,AM_Her,1.0,polar,polar,polar,magnetic,1.644433,1.0,4.0,0.432292,1.736803,0.381532,11.311616,0.0,8566.617616,0.141683,0.239059,0.564077,0.652928,0.804496,0.715037,0.110132,0.225704,0.240926,-0.633752,2.070518,1.152802,0.418947,0.141479,0.229677,0.074199,1.089334,0.817529,0.903133,0.402646,0.018368,0.09808,0.139186,1.170299,0.691873,1.306343,-0.572981,7.6e-05,30.830754,19.646054,0.045227,0.473756,0.447917,0.033333,0.112177,0.137228,0.996259,8.983593e-08,0.166041,1.264095,1.527154,0.21849,-0.293967,1.0,-0.769495,0.888527,0.85939,0.798442,1.616912,2.035787,1.336415,1.663687,1.0,7.0,0.4,63.131571,1.954631,0.056446,0.0,51432.187925,0.188853,0.309676,0.563304,0.67593,0.803789,0.716571,0.190795,0.104668,0.313132,0.427267,-0.193512,-0.102752,0.348398,0.19948,0.080148,0.131079,-1.830988,-1.67044,-1.41259,0.3805,0.038135,0.182268,0.04484,2.610277,0.102692,2.41883,-0.408817,-0.000186,172.486888,19.122359,0.047421,0.615351,0.323077,0.1,0.10492,0.148467,58.970956,0.0,0.27022,0.788899,1.59161,0.27022,-0.194543,11.0,-0.828884,0.90681,0.891792,0.747857,1.562151,1.913389,1.412423,0.740399,5.601237,4.464399,19.915599,17.681528,21.602343,192.0,1.964526,2.234071,1.956289,1.686744,3.920815,1612.644502,-0.793165,0.274387,0.001178,1.874381e-10,25.871578,14.0,0.107039,0.196604,0.938706,5.0,0.057807,0.090047,1.84449,7.0,0.032119,0.091828,3.669366,0.0,0.0,0.0,0.0,8.840653,0.337017,18.087568,0.749636,28.0,7.0,0.0,11.0,0.0,0.0,59.0,4.0,0.0,10.0,0.0,0.0,19.26769,17.246121,21.260795,260.0,1.876238,2.02157,2.138437,1.993105,4.014675,1519.850058,-0.845822,0.28059,0.001119,1.741459e-15,21.65184,21.0,4.136798,0.140842,0.977942,18.0,0.111792,1.570404,1.934902,10.0,0.085642,0.216807,3.876295,0.0,0.0,0.0,0.0,35.36237,0.852854,48.56939,0.865017,20.0,1.0,0.0,11.0,1.0,0.0,71.0,1.0,0.0,20.0,0.0,0.0,0.312823,0.328959,0.65967,0.038225,0.367259,6.654417,24.487694,0.220313,0.148709,1.715388,0.292919,8.879793,0.275921,0.18228,,0.41132,351.0,312.668643,14.354109,19.449656,39.0,202.137355,30.217966,19.574427,41.0,394.688025,44.847103,18.257261,1.317165,0.124771,1.192394,115.757357,-38.038067,16.142642,19.77612,582.958395,10.621468,10.746239,9.429073
3,ZTF18abgjgiq,MGAB-V3453,AM,0,polar,,,0.0,AM_Her,1.0,polar,polar,polar,magnetic,1.469023,1.0,35.0,0.342007,0.762847,0.178939,25.120046,0.0,1199.127984,0.135518,0.257628,0.531476,0.716239,0.875099,0.994543,0.513242,0.587822,0.566582,-0.903428,-2.325182,-1.730958,0.752402,0.370057,0.231342,0.30128,0.541601,-1.017727,-1.278409,0.316263,0.038331,0.077325,0.064529,0.75046,-0.784758,0.658175,0.097987,0.00053,30.830754,19.162745,0.044069,0.651916,0.345725,0.233333,0.091102,0.140238,63.297317,0.0,0.395208,0.261884,1.43387,0.395208,0.005002,8.0,-0.860362,0.844488,0.910462,0.843821,1.89157,2.75648,1.475144,1.475417,1.0,42.0,0.388235,0.319283,0.13497,59.475019,0.0,24096.221078,0.345741,0.447988,0.569201,0.674468,0.844324,1.024625,0.584921,0.482062,0.736598,-0.745682,-2.402736,-2.178879,0.916486,0.675287,0.377092,0.600802,1.108338,1.303831,0.253936,0.870857,0.28068,0.204769,0.007151,0.683393,-1.038048,0.740042,1.407689,0.001071,172.486888,18.989363,0.045286,0.673766,0.305882,-0.1,0.134273,0.139325,58.970956,0.0,0.419146,0.266166,1.474571,0.419146,0.418753,11.0,-1.016049,0.859953,0.866704,0.843202,1.851645,2.634486,1.457728,0.338333,6.363495,5.767676,19.194971,17.52907,20.943669,269.0,1.633675,1.665901,1.780924,1.748699,3.414599,1665.689468,-0.876492,0.683346,0.00078,4.2387179999999995e-63,73.040614,17.0,0.108944,0.895331,0.887451,10.0,0.060061,0.062806,1.93772,2.0,0.016128,0.044149,3.402725,0.0,0.0,0.0,0.0,10.175502,0.591032,19.1631,0.876593,1.0,0.0,0.0,7.0,0.0,0.0,56.0,0.0,0.0,34.0,0.0,0.0,18.618959,17.705365,21.118974,255.0,1.283998,0.913593,2.129611,2.500016,3.413609,1647.668762,-1.031859,0.71929,0.000789,2.7097400000000004e-66,59.855095,11.0,0.145649,0.099606,0.935359,10.0,0.044667,0.135435,1.632509,1.0,0.004384,0.006779,3.130703,0.0,0.0,0.0,0.0,15.796832,0.753808,26.232004,0.929777,11.0,0.0,0.0,4.0,0.0,0.0,0.0,0.0,0.0,81.0,9.0,0.0,-0.009772,0.019568,0.481754,-0.097633,0.057774,8.972375,43.561528,0.313521,0.28773,0.962491,0.429397,11.67697,0.347751,0.291654,1.567939,0.518935,262.0,274.39266,8.464073,19.591436,27.0,140.773483,12.621718,19.96724,29.0,194.991571,18.285467,19.022856,0.944384,0.375803,0.568581,119.949817,-19.220222,27.826676,35.966564,1038.970729,9.50842,9.884223,8.939839
4,ZTF18abumlux,MGAB-V3769,AM,0,polar,,,0.0,AM_Her,1.0,polar,polar,polar,magnetic,0.886477,1.0,2.0,0.205882,0.195829,0.065789,100.0,0.0,32.527039,0.042661,0.070393,0.132222,0.778954,0.863251,4.72074,2.33685,1.83589,1.552995,-0.341776,-0.686713,-1.036843,4.989103,2.572036,0.412972,2.009564,1.371481,-0.582348,-0.541316,3.526632,2.907736,1.261553,2.761998,1.613119,1.130058,0.823072,1.205204,-0.000292,0.273689,19.58292,0.028624,0.111504,0.647059,0.1,0.079451,0.092954,63.297317,0.0,0.3938,0.473249,0.20408,0.3938,1.368331,1.0,0.652253,0.560546,0.710142,0.616771,1.786215,2.535633,1.451763,0.886477,1.0,2.0,0.205882,0.195829,0.065789,100.0,0.0,32.527039,0.042661,0.070393,0.132222,0.778954,0.863251,4.274413,2.169682,1.673386,1.472945,-0.341776,-0.686713,-1.036843,4.989103,2.572036,0.412972,1.822577,1.371481,-0.582348,-0.541316,3.526632,2.907736,1.261553,2.596602,1.613119,1.130058,0.823072,1.205204,-0.000292,0.273689,,0.028624,0.111504,0.647059,0.1,0.079451,0.092954,58.970956,0.0,0.3938,0.473249,0.20408,0.3938,1.368331,1.0,0.652253,0.560546,0.710142,0.616771,1.786215,2.535633,1.451763,,,,19.366703,19.033199,20.905411,34.0,0.549721,0.333504,1.322491,1.538708,1.872212,1393.09522,0.423572,0.806959,7.2e-05,5.791654e-09,6.665263,0.0,0.0,0.0,0.0,1.0,0.001532,0.006944,1.674918,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.978181,0.833021,11.373401,0.819208,0.0,0.0,0.0,7.0,0.0,0.0,0.0,0.0,0.0,7.0,0.0,0.0,,,,0.0,0.549721,0.333504,1.322491,1.538708,1.872212,,0.423572,0.806959,7.2e-05,5.791654e-09,6.665263,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,7.978181,0.833021,11.373401,0.819208,0.0,0.0,0.0,7.0,0.0,0.0,0.0,0.0,0.0,7.0,0.0,0.0,,,,,,353.638167,40.430611,0.306753,0.224506,1.798189,0.350782,17.624914,0.337015,0.300345,1.601431,0.498788,431.0,235.25747,6.584115,19.758509,47.0,161.057391,14.952202,19.82109,47.0,177.085964,16.932121,19.127436,0.693655,0.062582,0.631073,107.453119,-20.118888,13.425801,38.944087,556.115012,11.032686,11.095268,10.401613


### Load Model

In [44]:
# Load ml model
with open('../results/model/model_xgb_weights.pkl', 'rb') as f:
    mod = pickle.load(f)

# from tensorflow.keras.models import load_model
# mod = load_model(f'../results/model/model_NN_weights.h5')

  If you are loading a serialized model (like pickle in Python, RDS in R) generated by
  older XGBoost, please export the model by calling `Booster.save_model` from that version
  first, then load it back in current version. See:

    https://xgboost.readthedocs.io/en/latest/tutorials/saving_model.html

  for more details about differences between saving model and serializing.



### Using class probabilities as input

In [40]:
# Preliminaries
# If you try to create a plot that will directly embed a dataset with more than 5000 rows, you will see a MaxRowsError:
# This is not because Altair cannot handle larger datasets, but it is because it is important for the user to think carefully 
# about how large datasets are handled. As noted above in Why does Altair lead to such extremely large notebooks?, 
# it is quite easy to end up with very large notebooks if you make many visualizations of a large dataset, and this error 
# is a way of preventing that.
# If you are certain you would like to embed your dataset within the visualization specification, you can disable the 
# MaxRows check with the following:

alt.data_transformers.disable_max_rows()


# These are all our class probabilities from our original model.
labels2 = enc.classes_.tolist()
preds_train = mod.predict_proba(X_train)
preds_test = mod.predict_proba(X_test)

# Let's add some column names.
preds_train_df = pd.DataFrame()
for count, name in enumerate(labels2):
    preds_train_df[name] = preds_train[:,count]

preds_test_df = pd.DataFrame()
for count, name in enumerate(labels2):
    preds_test_df[name] = preds_test[:,count]


# Just alter the scalerfit to change the data that is used to scale the data.
scaler = StandardScaler()
scaler.fit(preds_train_df)
# preds_train_df = scaler.fit_transform(preds_train_df)
# preds_test_df = scaler.transform(preds_test_df)

# Change train depending on whether you want ugtm to use the training or test data for the latent space.
# Change test, and labels depending on whether you want to see the projections of the training or test data onto the latent space.
# Change X set which dataset you want to see the features for, this will be the same test.
train = preds_train_df # preds_train_df, preds_test_df, preds_train_orig_df
test = preds_test_df # preds_train_df, preds_test_df, preds_train_orig_df
labels = y_train # y_train_fnl, y_test_fnl, y_train
X_set = X_train # X_train_fnl, X_test_fnl, X_train_imp

gtm_model = ugtm.runGTM(train,verbose=False,k=10)

# Use the following if you want to see the projections of the test data onto the latent space.
# transformed=ugtm.transform(optimizedModel=gtm_model,train=train,test=test)

# mean projection
# mean position of each data point in latent space.
# Further information located here: https://ugtm.readthedocs.io/en/latest/ugtm.html?highlight=ugtm.matY#module-ugtm.ugtm_classes/
# For projection of test data use: mean_u = transformed.matMeans, otherwise use:
mean_u = gtm_model.matMeans
# mean_u = transformed.matMeans
#
mean_u = pd.DataFrame(mean_u, columns=['U1','U2'])

# Add the labels to the latent space. Since we are using the training data, we use y_train_fnl, otherwise use y_test_fnl.
mean_u_labels = mean_u.copy()
mean_u_labels['y'] = enc.inverse_transform(labels)
# If using either the non-resampled training data or the test data, use the following to append the original index.
mean_u_labels['index'] = X_train.index

# May also use. Actually you can't as this is just a scikit learn wrapper and does not contain the methods of the above.
gtm_model2 = eGTM(k=10,verbose=False).fit(train).transform(train)
mean_u2 = pd.DataFrame()
mean_u2['U1'] = gtm_model2[:,0]
mean_u2['U2'] = gtm_model2[:,1]

# Plot the latent space. But do so with a combination of different shapes and colours.

selection = alt.selection_point(fields=['series'], bind='legend')

alt.Chart(mean_u_labels, width=500, height=500).mark_point(size=100).encode(
    x='U1', 
    y='U2',
    color= 'y',
    shape='y',
    tooltip=['y', 'index']
    # opacity=alt.condition(selection, alt.value(1), alt.value(0.2))
    )




In [43]:
# matY has shape n_dimensions (number of features) * n_nodes in latent space. 
# Manifold in n-dimensional space (projection of matX in data space); 
# A point matY[:,i] is a center of a Gaussian component i on the manifold in data space. Y=WÎ¦T
# Location of each node in the high dimensional space.
# It is related to the actual probability space through Y=W*phi(transpose).
# We therefore need to normalise the matrix to show the strength of a node's association to a given class relative to the other nodes.
refvect = gtm_model.matY

# refvect = MinMaxScaler().fit_transform(refvect.T).T
refvect = scaler.inverse_transform(refvect.T).T


# Here we are plotting matX, the coordinates of the nodes in the 2D space. We will colour code them based on 'label'. So this will be a
# grid of pixels in 2D space that is colour coded by some parameter.
def plot_ref_vect(gtm_matX,label,title,fig_size=(200,200)):
    dfmap = pd.DataFrame(gtm_matX, columns=["x1", "x2"])
    dfmap['label'] = label
    map = alt.Chart(dfmap).mark_square().encode(
        x='x1',
        y='x2',
        color=alt.Color('label:Q',
                        #scale=alt.Scale(scheme='viridis')),
                        scale=alt.Scale(scheme='turbo')),
        size=alt.value(350),
        tooltip=['x1','x2', 'label:Q'],
        #opacity='density'
    ).properties(title = title, width = fig_size[0], height = fig_size[1])
    return map
# %%

# Here we save into variables plots for each of the classes. The colour coding is based on matY. matY defines the central
# position of each Gaussian (node) in feature, or in this case class probability, space. 
gtm_refvect0 = plot_ref_vect(gtm_model.matX,label=refvect[0,:],title=labels2[0])
gtm_refvect1 = plot_ref_vect(gtm_model.matX,label=refvect[1,:],title=labels2[1])
gtm_refvect2 = plot_ref_vect(gtm_model.matX,label=refvect[2,:],title=labels2[2])
gtm_refvect3 = plot_ref_vect(gtm_model.matX,label=refvect[3,:],title=labels2[3])
gtm_refvect4 = plot_ref_vect(gtm_model.matX,label=refvect[4,:],title=labels2[4])
gtm_refvect5 = plot_ref_vect(gtm_model.matX,label=refvect[5,:],title=labels2[5])
gtm_refvect6 = plot_ref_vect(gtm_model.matX,label=refvect[6,:],title=labels2[6])
gtm_refvect7 = plot_ref_vect(gtm_model.matX,label=refvect[7,:],title=labels2[7])
gtm_refvect8 = plot_ref_vect(gtm_model.matX,label=refvect[8,:],title=labels2[8])

gtm_refs_top = gtm_refvect0 |  gtm_refvect1 |  gtm_refvect2 
gtm_refs_middle = gtm_refvect3 | gtm_refvect4 | gtm_refvect5
gtm_refs_bottom = gtm_refvect6 | gtm_refvect7 | gtm_refvect8 
class_maps = alt.vconcat(gtm_refs_top, gtm_refs_middle, gtm_refs_bottom)
class_maps = class_maps.configure_title(fontSize=20,fontWeight='normal')
class_maps

# %%


In [35]:
alt.Chart(mean_u_labels, width=250, height=250).mark_point(size=100).encode(
    x='U1', 
    y='U2',
    color= 'y',
    shape='y',
    tooltip=['y']
    # opacity=alt.condition(selection, alt.value(1), alt.value(0.2))
    )

In [36]:
# Each example in data space is asociated with a location in 2D space. The mapping depends on using either
# matMeans, or matModes. For example if a dataspace example is associated with a particular node, the mean position
# or median position of the datapoints associated with that Gaussian are used to map it to a point in 2D space. During training,
# the Gaussian center is alocated a position in data space based on the mean or median of points associated with it based on 
# responsibility. It is this mean or median that is used in combination with the wieghted contributions of all Gaussians
# that is used to to map the data point to 2D space. Subsequently, we can create a histogram of locations of each example in 2D space. Therefore,
# matMeans and matModes contains the positions of example in 2D space. 
# Create a 2D histogram, H contains the histogram values in a 2D array, this can be used to generate our plot with altair.
# This is like the number of examples associated with each cluster.

H, xedges, yedges = np.histogram2d(gtm_model.matMeans[:,0],gtm_model.matMeans[:,1],[10,10])
# H, xedges, yedges = np.histogram2d(transformed.matMeans[:,0],transformed.matMeans[:,1],[10,10])

dfmap = pd.DataFrame(gtm_model.matX, columns=["x1", "x2"])
# dfmap = pd.DataFrame(transformed.matX, columns=["x1", "x2"])

dfmap['size'] = H.flatten()

alt.Chart(dfmap).mark_square().encode(
    x='x1',
    y='x2',
    #color=alt.Color('label:Q',
    #                scale=alt.Scale(scheme='viridis')),
    size='size',
    tooltip=['x1','x2','size'],
    #opacity='density'
).properties(title = "GTM Membership Map",width = 300, height = 300)

In [37]:
from sklearn.preprocessing import minmax_scale

def factor_map(gtm_model, Xfact):
    # Generate a dataframe where the first column is the feature value
    dfclus = pd.DataFrame(Xfact).rename({Xfact.name: 'scale'}, axis=1)
    # matR contains the responsibilities - the posterior probability that a data point
    # belongs to a particular Gaussian - the probability that that Gaussian is responsible for
    # the data point. It has dimensions n_examples x n_nodes.
    # We difine in the membership column the node most responsible for the data point.
    dfclus['membership'] = np.argmax(gtm_model.matR,axis=1)
    # Now group examples by their membership node and then combine the feature values 
    # for each node using their mean value.
    dfclus = dfclus.groupby('membership', as_index=False).agg(np.mean)
    # Now scale the grouped and meaned feature values from 0 to 1.
    dfclus.scale = minmax_scale(dfclus.scale)
    # We now use matX which contains the location in 2D space of all our neurons.
    # This is just a 2D grid. Place in a dataframe.
    df_map = pd.DataFrame(gtm_model.matX, columns=["x1", "x2"])
    # Now assign to each node its node number. This is the node memebership number
    # to which each example was assigned earlier. Some nodes were associated with examples,
    # others were not because examples had greater associations with other Gaussians.
    df_map['membership'] = np.arange(0,100)
    # We now wish to merge the above dataframes to produce a grid on neurons (coordinate values), each one with
    # the mean value of the feature for the examples associated with that neuron (Gaussian).
    df_map = df_map.merge(dfclus,how='left',on='membership')
    # Those nodes without any assigned examples will be given a value of 0
    # df_map.fillna(0,inplace=True)
    # size column is added to assign the size of the square in the plot.
    df_map['size'] = 1
    return df_map

def plot_factor_map(df_map, title='Factor Map',fig_size=(115,115),node_size=1):
    #df_map['size']=df_map['size']*node_size
    return alt.Chart(df_map).mark_square().encode(
        x=alt.X('x1',axis=None),
        y=alt.Y('x2',axis=None),
        color=alt.Color('scale:Q',
                        scale=alt.Scale(scheme='turbo')),
        size=alt.value(140),
        tooltip=['x1','x2','scale:Q'],
        #opacity='density'
    ).properties(title=title, width=fig_size[0], height=fig_size[1])

In [38]:
varnames = X_cols
gtm_model_for_plot = gtm_model # transformed, gtm_model
# X_set = X_test_fnl # X_train_fnl, X_test_fnl, X_train_imp
# varnames = selected
var = 0

# %%
chart = alt.vconcat()
new_line = '\n'
for rr in range(100):
    row = alt.hconcat()
    for cc in range(6):
        if var<len(varnames):
            idx_X_set = X_cols.index(varnames[var])
            row |= plot_factor_map(
                factor_map(gtm_model_for_plot, X_set.iloc[:,idx_X_set]),
                title=[varnames[var][0:17], varnames[var][17:34]])
            var = var + 1
    chart &= row

chart




In [16]:
X_cols

['Amplitude_g',
 'AndersonDarling_g',
 'Autocor_length_g',
 'Beyond1Std_g',
 'CAR_mean_g',
 'CAR_sigma_g',
 'CAR_tau_g',
 'Con_g',
 'Eta_e_g',
 'FluxPercentileRatioMid20_g',
 'FluxPercentileRatioMid35_g',
 'FluxPercentileRatioMid50_g',
 'FluxPercentileRatioMid65_g',
 'FluxPercentileRatioMid80_g',
 'Freq1_harmonics_amplitude_0_g',
 'Freq1_harmonics_amplitude_1_g',
 'Freq1_harmonics_amplitude_2_g',
 'Freq1_harmonics_amplitude_3_g',
 'Freq1_harmonics_rel_phase_1_g',
 'Freq1_harmonics_rel_phase_2_g',
 'Freq1_harmonics_rel_phase_3_g',
 'Freq2_harmonics_amplitude_0_g',
 'Freq2_harmonics_amplitude_1_g',
 'Freq2_harmonics_amplitude_2_g',
 'Freq2_harmonics_amplitude_3_g',
 'Freq2_harmonics_rel_phase_1_g',
 'Freq2_harmonics_rel_phase_2_g',
 'Freq2_harmonics_rel_phase_3_g',
 'Freq3_harmonics_amplitude_0_g',
 'Freq3_harmonics_amplitude_1_g',
 'Freq3_harmonics_amplitude_2_g',
 'Freq3_harmonics_amplitude_3_g',
 'Freq3_harmonics_rel_phase_1_g',
 'Freq3_harmonics_rel_phase_2_g',
 'Freq3_harmonics_rel_

In [9]:
def plot_factor_map2(df_map, title='Factor Map',fig_size=(150,150),node_size=1):
    #df_map['size']=df_map['size']*node_size
    return alt.Chart(df_map).mark_square().encode(
        x=alt.X('x1',axis=None),
        y=alt.Y('x2',axis=None),
        color=alt.Color('scale:Q',
                        scale=alt.Scale(scheme='turbo')),
        size=alt.value(250),
        tooltip=['x1','x2','scale:Q'],
        #opacity='density'
    ).properties(title=title, width=fig_size[0], height=fig_size[1])

# Select subset of columns

varnames_g = ['Amplitude_g','dif_min_median_g','dif_min_median_g','npeaks_1to2_g','npeaks_2to5_g',
            'npeaks_above5_g','Eta_e_g','CAR_sigma_g','Freq1_harmonics_amplitude_0_g','Skew_g',
            'LinearTrend_g','freq_pwr_max_g','Std_g','MedianAbsDev_g','stdstilllev_t20s10_g',
            'Mean_g','min_mag_g','n_obs_g'
            ]
# Remove _g and add _r
varnames_r = [v.replace('_g','_r') for v in varnames_g]

# Colour and Gaia features
varnames_gaia = ['parallax','pm','clr_mean','clr_bright','bp_rp','bp_g','g_rp','StetsonJ','StetsonL']

# NAM varnames
varnames_nam = ['bp_rp','nu_eff_used_in_astrometry','parallax','absmag_g',
                'PeriodLS_g','npeaks_1to2_g','npeaks_2to5_g','npeaks_above5_g',
                'pnts_leq_rollMedWin20-5mag_r','stdstilllev_t20s10_g','Gskew_g','n_obs_g',
                'dif_min_median_g','kurtosis_r','LinearTrend_r','StetsonJ']


# varnames = selected
var = 0
feature_list = varnames_nam

# %%
chart = alt.vconcat().configure_title(fontSize=16,fontWeight='bold')
count = 0
count = 0
new_line = '\n'
for rr in range(100):
    row = alt.hconcat()
    for cc in range(4):
        if var<len(feature_list):
            idx_X_set = X_cols.index(feature_list[var])
            row |= plot_factor_map2(
                factor_map(gtm_model, X_set.iloc[:,idx_X_set]),
                title=[feature_list[var][0:20], feature_list[var][20:34]])
            var = var + 1
    chart &= row

chart