In [5]:
import pickle
import os
import seaborn as sn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# Loading in ...

In [6]:
os.chdir('Data')
pkl_file = open('isochrones.pkl', 'rb')
stacked_isochrones = pickle.load(pkl_file)

pkl_file = open('columns.pkl', 'rb')
x_columns = pickle.load(pkl_file)

pkl_file = open('x_values.pkl', 'rb')
x_values = pickle.load(pkl_file)

pkl_file = open('isoc_cols.pkl', 'rb')
isoc_columns = pickle.load(pkl_file)

x_input=pd.read_csv('x_input')
x_input_err=pd.read_csv('x_input_err')

os.chdir('..')

In [3]:
x_columns

array(['ra', 'dec', 'parallax', 'phot_g_mean_mag', 'phot_g_mean_flux',
       'phot_rp_mean_flux', 'phot_bp_mean_flux', 'phot_bp_mean_mag',
       'phot_rp_mean_mag', 'bp_rp', 'mean_absolute_mag_g_band',
       'ra_error', 'dec_error', 'parallax_error',
       'phot_g_mean_flux_error', 'phot_bp_mean_flux_error',
       'phot_rp_mean_flux_error', 'phot_g_mean_mag_error', 'bp_error',
       'rp_error', 'bp_rp_error', 'G'], dtype=object)

In [4]:
def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx],idx

def isochrone_selector(feh,age):
    '''if(feh<-4 or feh>0.5):
        raise NotImplementedError
    if(age<5 or age>10.3):
        raise NotImplementedError
    else:'''
    logagegrid = np.linspace(5,10.3,105)
    fehgrid = np.linspace(-4,0.5,90)
    feh,feh_idx=find_nearest(fehgrid,feh)
    age,age_idx=find_nearest(logagegrid,age)

    return feh_idx*len(logagegrid)+age_idx

In [5]:
def column_index(name):
    if name in x_columns:
        return np.where(np.array(x_columns)==name)[0][0]
    else:
        return np.where(np.array(isoc_columns)==name)[0][0] +len(x_columns)

# Loaded.

In [6]:
x_values[0,0,:]

array([282.60587141, 282.60587141, 282.60587141, 282.60587141,
       282.60587141, 282.60587141, 282.60587141, 282.60587141,
       282.60587141, 282.60587141, 282.60587141, 282.60587141,
       282.60587141, 282.60587141, 282.60587141, 282.60587141,
       282.60587141, 282.60587141, 282.60587141, 282.60587141,
       282.60587141, 282.60587141, 282.60587141, 282.60587141,
       282.60587141, 282.60587141, 282.60587141, 282.60587141,
       282.60587141, 282.60587141, 282.60587141, 282.60587141,
       282.60587141, 282.60587141, 282.60587141, 282.60587141,
       282.60587141, 282.60587141, 282.60587141, 282.60587141,
       282.60587141, 282.60587141, 282.60587141, 282.60587141,
       282.60587141, 282.60587141, 282.60587141, 282.60587141,
       282.60587141, 282.60587141, 282.60587141, 282.60587141,
       282.60587141, 282.60587141, 282.60587141, 282.60587141,
       282.60587141, 282.60587141, 282.60587141, 282.60587141,
       282.60587141, 282.60587141, 282.60587141, 282.60

In [7]:
isoc_columns

array(['logg', 'logteff', 'logl', 'mass', 'logage', 'feh', 'phase',
       'Gaia_RP_EDR3', 'Gaia_BP_EDR3', 'Gaia_G_EDR3', 'BPRP', 'p_slopes',
       'slopes', 'low_c', 'high_c'], dtype='<U12')

# Data

The data we have loaded in is as follows:

1. x_values is a Numpy array of size (n_samples,n_features,longest_isochrone_tang_length). Each sample has n_features which are copied into the 3rd axis the same number of times as the longest isochrone is.
2. Stacked_isochrones is a Numpy array of size (n_isochrones,n_features,largest_tangent_numb_size). Each isochrone will have a certain number of slopes and p_slopes depending on the isochrone. These values extend out into the third axis, however they are padded with NaN values.
3. x_input and err are easier access versions, used for input

In [8]:
stacked_isochrones[0][column_index('p_slopes')-len(x_columns)]


array([-1.73925433e+01, -1.75523825e+01, -1.74907631e+01, -1.59839254e-01,
        7.07219486e-02, -1.41398819e+01, -1.76695745e+01, -1.74794589e+01,
        1.99925974e+00,  4.14005443e+00,  9.40709167e+00,  1.93193193e+00,
        5.00000000e+00, -4.00000000e+00,  0.00000000e+00, -1.75908508e+01,
       -1.77373116e+01, -1.76799325e+01, -1.46460821e-01,  6.14792768e-02,
       -1.62656435e+01, -1.78577022e+01, -1.76709282e+01,  1.97464221e+00,
        4.12802174e+00,  9.45416162e+00,  1.95195195e+00,  5.00000000e+00,
       -4.00000000e+00,  0.00000000e+00, -1.77842444e+01, -1.79192657e+01,
       -1.78660032e+01, -1.35021331e-01,  6.15102157e-02, -1.62574621e+01,
       -1.80437757e+01, -1.78576980e+01,  1.95036799e+00,  4.11602791e+00,
        9.50104221e+00,  1.97197197e+00,  5.00000000e+00, -4.00000000e+00,
        0.00000000e+00, -1.79769199e+01, -1.81005387e+01, -1.80513795e+01,
       -1.23618792e-01,  5.92113304e-02, -1.68886595e+01, -1.82252834e+01,
       -1.80440599e+01,  

# VAESTAR

In [9]:
import torch
import torch.nn as nn
from torch.autograd import Variable

In [10]:
device=torch.device("mps")
torch.backends.mps.is_available()

True

In [11]:
x_input

Unnamed: 0.1,Unnamed: 0,ra,dec,parallax,phot_g_mean_mag,bp_rp,G
0,0,,,,,,
1,1,,,,,,
2,2,,,,,,
3,3,,,,,,
4,4,,,,,,
...,...,...,...,...,...,...,...
1096,1096,,,,,,
1097,1097,,,,,,
1098,1098,,,,,,
1099,1099,,,,,,


In [None]:

class encoder(nn.Module): #q(z|x)
    def __init__(self,input_dim,hidden_dims,z_dim):
        super().__init__()
        # Shapes
        self.sample_size=32
        self.input_dim=1
        self.n_layers=1
        self.lstm_hidden_dim=5

        self.z_dim=z_dim

        # Model Definition

        #the shape will be (batch_size,sequencelength=1,input_dim=1)
        self.dist_lstm=nn.LSTM(self.input_dim,self.lstm_hidden_dim,self.n_layers,batch_first=True)
        self.lstm_activation=nn.Tanh()

        self.input_dense=nn.Linear(in_features=input_dim,out_features=hidden_dims[0])
        self.hidden_dense=nn.Linear(in_features=hidden_dims[0],out_features=hidden_dims[1])
        self.input_activation=nn.ReLU()
        self.hidden_activation=nn.ReLU()

        self.concat_dense=nn.Linear(in_features=hidden_dims[1]+self.lstm_hidden_dim,out_features=z_dim*2)
        self.z_activation=nn.ReLU() #this will mean that extinction cant be negative (this is actually a part of the prior i suppose), could also just do linear

        self.N=torch.distributions.Normal(0,1) #prior on extinction

        
    def forward(self, x):
        x_dist=x[:,-1]
        x_not_dist=x[:,:-1] # the last element in each row is the actual distance, and so we need to figure out what to do about the nans
        
        #everything but distance channel
        x_not_dist=self.input_activation(self.input_dense(x_not_dist))
        x_not_dist=self.hidden_activation(self.hidden_dense(x_not_dist))

        # distance channel LSTM
        h_0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) #hidden state
        c_0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size) #internal state
        # Propagate input through LSTM
        x_dist, (hn,cn) =self.lstm_activation(self.dist_lstm(x_dist.reshape(x_dist[0],1,x_dist[1])),(h_0,c_0))

        #concatenate channel
        x=torch.concat([x_not_dist,x_dist.reshape(x_dist.shape[0],x_dist.shape[2])],axis=1)
        output=self.concat_dense(x)

        z_mu=output[:,:self.z_dim]
        z_sigma=torch.exp(output[:,self.z_dim:])

        #sample a z value now

        #get the average first though
        print('tag')
        average_sample_length=self.ave_sample
        assert len(z_mu)%average_sample_length==0
        z_mu=torch.cumsum(z_mu,0)[average_sample_length-1::average_sample_length,:]/average_sample_length
        z_mu[1:,:]=z_mu[1:,:]-z_mu[:-1]

        z_sigma=torch.cumsum(z_sigma,0)[average_sample_length-1::average_sample_length,:]/average_sample_length
        z_sigma[1:,:]=z_sigma[1:,:]-z_sigma[:-1]

        z_mu[:,:2]=self.z_activation(z_mu[:,:2])
        z_sigma[:,:2]=self.z_activation(z_sigma[:,:2])


        q=torch.distributions.multivariate_normal.MultivariateNormal(loc=z_mu,covariance_matrix=z_sigma**2)
        p=torch.distributions.multivariate_normal.MultivariateNormal(torch.tensor(0,0,-2,7.5),torch.diag(1,1,1.5,1.5))

        z_ext=z_mu[:,:2]+z_sigma[:,:2]*self.N_ext.sample(z_mu[:,:2].shape)
        z_feh=z_mu[:,2]+z_sigma[:,2]*self.N_feh.sample(z_mu[:,2].shape)
        z_age=z_mu[:,3]+z_sigma[:,3]*self.N_age.sample(z_mu[:,3].shape)
        
        # variances need to be done
        


        return z_ext,z_feh,z_age, torch.distributions.kl_divergence(q,p)



In [None]:

class decoder(nn.Module):
    def __init__(self):
        super(decoder,self).__init__()
        
    def forward(self,z_ext,z_feh,z_age,all_isochrones,x_values):
        
        # should i have a log scale on the age or will it know it
        #NEED TO STACK ALL OF THE ISOCHRONES AND WILL MAKE EAZ
        log_prob=[]
        log_l=0.0
        for i in range(len(z_age)):
            isochrone=torch.cat([x_values[i],all_isochrones[isochrone_selector(z_feh[i],z_age[i])]],dim=0)

            
            #can i apply this that easy or does a lot need to be changed to make it more tensor frienly


            #now we need to do row by row the calculation we did before in the cube.
            # find out where the extinction needs to be added to or taken away from in the cluster_df
            # im going to let the normal prior say everything about the z values, and take away the implementation error.

            # the next is very specific to my isochrone file df i created
            
            #x.shape
            #e=torch.tensor([i for i in range(92)]).repeat(len(x),1)
            #x[:,column_index('G'),:]+e
            truth_1=(isochrone[:,column_index('G'),:]-isochrone[:,column_index('bp_rp'),:]*isochrone[:,column_index('p_slopes'),:]<=isochrone[:,column_index('high_c'),:])

            truth_1=truth_1.reshape(truth_1.shape[0],1,truth_1.shape[1])

            truth_2=(isochrone[:,column_index('low_c'),:]<=isochrone[:,column_index('G'),:]-isochrone[:,column_index('bp_rp'),:]*isochrone[:,column_index('p_slopes'),:])
            truth_2=truth_2.reshape(truth_2.shape[0],1,truth_2.shape[1])

            truth=truth_1*truth_2


            print('tag2')

            x=((1/torch.sqrt(1+isochrone[:,column_index('slopes'),:]**2))*(isochrone[:,column_index('G'),:]-isochrone[:,column_index('bp_rp'),:]*isochrone[:,column_index('slopes'),:]-isochrone[:,column_index('Gaia_G_EDR3'),:] + isochrone[:,column_index('slopes'),:]*isochrone[:,column_index('BPRP'),:]))



            idx=torch.argmin(torch.abs(x/truth.reshape(x.shape)).nan_to_num(nan=torch.inf),1)

            x=x.gather(1,idx.view(-1,1))
            #error needs to be corrected
            x_err=(1/(1+isochrone[:,column_index('slopes'),:]**2))*isochrone[:,column_index('phot_g_mean_mag_error'),:]**2+(isochrone[:,column_index('slopes'),:]*isochrone[:,column_index('bp_rp_error'),:])**2
            x_err=x_err.gather(1,idx.view(-1,1))

            isochrone=torch.cat((isochrone,x.reshape(x.shape[0],1,x.shape[1]).repeat(1,1,(isochrone).shape[-1]),x_err.reshape(x_err.shape[0],1,x_err.shape[1]).repeat(1,1,(isochrone).shape[-1])),1)

            #((1/(1+isochrone['slopes'][i]**2))*test['phot_g_mean_mag_error']**2+(isochrone['slopes'][i]*test['bp_rp_error'])**2)
                                
            dist=torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros_like(x),torch.eye(len(x))+torch.diag(x_err**2))
            log_l+=dist.log_prob(x)
            log_prob.append(dist.log_prob(x))
        return log_prob,z_ext,z_feh,z_age
        


class VaeStar(nn.Module):
    def __init__(self,input_dim,hidden_dims,z_dim,av_sample):
        super(VaeStar, self).__init__()

        self.encoder=encoder(input_dim,hidden_dims,z_dim,av_sample)
        self.decoder=decoder()
    
    def forward(self,x_resampled, x_values, all_isochrones):
        #already averaged
        z_ext,z_feh,z_age, kl=self.encoder(x_resampled,)

        #after averaging
        return kl, self.decoder(z_ext,z_feh,z_age,all_isochrones,x_values)






In [None]:
from sklearn.preprocessing import MinMaxScaler
scaler=MinMaxScaler()

In [None]:
x_input=torch.tensor(scaler.fit_transform(x_resample.values),requires_grad=True)

In [None]:
lr=1e-3
model=VaeStar(input_dim=x_input.shape[1],hidden_dims=[10,10],z_dim=4,av_sample=draw_size)

optimizer=torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
model.to(device)

In [None]:
batch_size = len(x_input)#draw_size #need to make sure everything adds up
lr = 1e-3
epochs = 50

In [None]:
from torch.utils.data import DataLoader
x_input=DataLoader(x_input.float(),batch_size=batch_size,drop_last=False) 

In [None]:
model.train()
for epoch in range(epochs):
    overall_loss=0.0
    for batch_idx,x in enumerate(x_input):
        x=x.view(batch_size,len(x[0]))
        x=x.to(device)
        
        optimizer.zero_grad()
        kl, log_prob,z_ext,z_feh,z_age = model(x,x_values,stacked_isochrones)

        loss=-kl.sum()-torch.sum(log_prob)
        overall_loss+=loss.item()
        loss.backward()
        optimizer.step()
        print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", overall_loss / (batch_idx*batch_size))
        print("Overall Loss: ", overall_loss)
