In [None]:
%load_ext autoreload
%autoreload 2

import astropy
from astropy.io import fits
import numpy as np
import matplotlib.pyplot as plt

import inspect
import random


import apogee.tools.read as apread
import apogee.tools.path as apogee_path
from apogee.tools import bitmask
import collections

from apoNN.src.occam import Occam
from apoNN.src.datasets import ApogeeDataset,AspcapDataset
from apoNN.src.utils import get_mask_elem,dump,load
from apoNN.src.plotting import summarize_representation,get_intracluster_distances,get_intercluster_
import apoNN.src.vectors as vector


import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn as nn
from sklearn.decomposition import PCA,KernelPCA

from sklearn.preprocessing import PolynomialFeatures
poly = PolynomialFeatures(2,interaction_only=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

apogee_path.change_dr(16)

[(['TEFF', 'LOGG', 'LOG10VDOP', 'METALS', 'C', 'N', 'O Mg Si S Ca Ti'], ['C', 'N', 'O', 'Na', 'Mg', 'Al', 'Si', 'S', 'K', 'Ca', 'Ti', 'V', 'Mn', 'Fe', 'Ni'], ['[C/M]', '[N/M]', '[O/M]', '[Na/H]', '[Mg/M]', '[Al/H]', '[Si/M]', '[S/M]', '[K/H]', '[Ca/M]', '[Ti/M]', '[V/H]', '[Mn/H]', '[Fe/H]', '[Ni/H]'], [0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1])]


In [None]:
#allStar =  load("allStar_training_clean")
#dataset=  AspcapDataset(filename="aspcap_training_clean",tensor_type=torch.FloatTensor,recenter=True)
allStar =  load("allStar_occamlike")
dataset=  AspcapDataset(filename="aspcap_occamlike",tensor_type=torch.FloatTensor,recenter=True)
n_data=10000

In [None]:
feh_outliercut = allStar["Fe_H"]>-5
o_outliercut = allStar["O_FE"]>-5
c_outliercut = allStar["C_FE"]>-5
na_outliercut = allStar["Na_FE"]>-5
mg_outliercut = allStar["Mg_FE"]>-5
si_outliercut = allStar["Si_FE"]>-5
al_outliercut = allStar["Al_FE"]>-5
s_outliercut = allStar["S_FE"]>-5
p_outliercut = allStar["P_FE"]>-5
ti_outliercut = allStar["Ti_FE"]>-5
cr_outliercut = allStar["Cr_FE"]>-5
si_outliercut = allStar["Si_FE"]>-5
n_outliercut = allStar["N_FE"]>-5
ni_outliercut = allStar["Ni_FE"]>-5
mn_outliercut = allStar["Mn_FE"]>-5


combined_cut = feh_outliercut & o_outliercut &  c_outliercut & na_outliercut & mg_outliercut & si_outliercut & al_outliercut & p_outliercut & s_outliercut & ti_outliercut & cr_outliercut & si_outliercut & n_outliercut & ni_outliercut & mn_outliercut


In [None]:
considered_parameters = ["Teff","logg","Fe_H","O_FE","C_FE","Na_FE","Mg_FE","Si_FE","S_FE","Al_FE","Ni_FE","N_FE","Cr_FE"] 
#considered_parameters = ["Fe_H"]
y = vector.Vector(np.array([allStar[:n_data][combined_cut[:n_data]][param] for param in considered_parameters])[:,:n_data].T)
y_astronn = vector.AstroNNVector(allStar[:n_data][combined_cut[:n_data]],considered_parameters)

In [None]:
occam = load("occam")
allStar_occam = occam["allStar"]
dataset_occam = AspcapDataset(filename="aspcap_occam",recenter=True,tensor_type=torch.FloatTensor,filling_dataset=dataset.dataset["aspcap"])
occam_cluster_idxs = occam["cluster_idxs"]


# Get the z's

In [None]:
compressor = PCA(n_components=20,whiten=False)#z.raw.shape[1],whiten=True)
compressor.fit(dataset.dataset["aspcap_interpolated"])

In [None]:
compressor.fit(dataset.dataset["aspcap_interpolated"])

In [None]:
z_pca = compressor.transform(dataset.dataset["aspcap_interpolated"])
#z_pca= z_pca[:n_data]
z = vector.Vector(z_pca,order=1,interaction_only=False)
#z = vector.Vector(z_pca[combined_cut[:n_data]],order=1,interaction_only=False)

In [None]:
plt.plot(compressor.inverse_transform(z_pca)[450])
plt.plot(dataset.dataset["aspcap_interpolated"][450])
plt.xlim(1000,1200)

In [None]:
z_raw_occam = compressor.transform(dataset_occam.dataset["aspcap_interpolated"])
z_occam = vector.OccamLatentVector(occam_cluster_idxs,raw=z_raw_occam,order=1,interaction_only=False)

### Use the fitter object

In [None]:
fitter = vector.Fitter(z,z_occam)
#fitter = vector.Fitter(z,z_occam.without("NGC 6819"))

In [None]:
fitter.scaling_factor

In [None]:
v_centered_occam = fitter.transform(z_occam.centered.only("NGC 6819"))
v = fitter.transform(fitter.z.centered)

In [None]:
summarize_representation(v[:,20],v_centered_occam[:,4],0.,0.)

In [None]:
summarize_representation(v[:,-1],v_centered_occam[:,-1],0.1,0.1)

In [None]:
def get_combinations(len_cluster):
    combinations = []
    for idx1 in np.arange(len_cluster):
        for idx2 in np.delete(np.arange(len_cluster),idx1):
            if sorted([idx1,idx2]) not in combinations:
                combinations.append(sorted([idx1,idx2]))
    return combinations

In [None]:
def get_intracluster_distances(z,z_occam):
    distances = []
    for cluster in list(z_occam.registry.keys()):
        fitter = vector.Fitter(z,z_occam.without(cluster))
        #fitter = vector.Fitter(z,z_occam)
        v_centered_occam = fitter.transform(z_occam.centered.only(cluster))
        combinations = get_combinations(len(v_centered_occam))
        distances_cluster = []
        distances_cluster_random = []
        for combination in combinations:
            distances_cluster.append(np.linalg.norm(v_centered_occam[combination[0]]-v_centered_occam[combination[1]]))

        distances.append(distances_cluster)
    return distances

In [None]:
def get_intercluster_distances(z,z_occam,n_random = 200):
    distances = []
    for cluster in list(z_occam.registry.keys()):
        fitter = vector.Fitter(z,z_occam.without(cluster))
        v_centered_occam = fitter.transform(z_occam.centered.only(cluster))
        #v = fitter.transform(z_occam.centered.without(cluster))
        v = fitter.transform(fitter.z.centered)
        n_v = len(v)
        distances_cluster = []
        for idx in np.arange(len(v_centered_occam)):
            for _ in np.arange(n_random):
                random_idx = random.randint(0,n_v-1)
                distances_cluster.append(np.linalg.norm(v_centered_occam[idx]-v[random_idx]))

        distances.append(distances_cluster)
    return distances

In [None]:
#distances = get_intracluster_distances(z,z_occam.without("NGC 6791"))
#random_distances = get_intercluster_distances(z,z_occam.without("NGC 6791"))
distances = get_intracluster_distances(z,z_occam)
random_distances = get_intercluster_distances(z,z_occam)

In [None]:
mean_distances = [np.mean(i) for i in distances]
mean_random_distances = [np.mean(i) for i in random_distances]

In [None]:
for i in range(len(distances)):
    plt.title(len(distances[i]))
    plt.hist(distances[i],alpha=0.5,density=True,bins=20)
    plt.hist(random_distances[i],alpha=0.5,density=True,bins=50)
    plt.xlim(0,100)
    plt.show()

### Visualizing cluster directions in spectra space


Want to find A such that AV=X

In [None]:
x = dataset.dataset["aspcap_interpolated"]

In [None]:
v_ones = np.concatenate((v,np.ones((v.shape[0],1))),axis=1)

In [None]:
v_ones.shape

In [None]:
dataset=  AspcapDataset(filename="aspcap_occamlike",tensor_type=torch.FloatTensor,recenter=True)


In [None]:
dataset.dataset["aspcap_interpolated"].shape

In [None]:
z().shape

In [None]:
np.dot(v,a)=x

In [None]:
a = np.dot(np.linalg.pinv(v_ones),x)

In [None]:
plt.plot(a[-2],alpha=0.5)
plt.plot(a[1],alpha=0.5)

In [None]:
a.shape

In [None]:
x_pred2 = np.dot(v_ones,a)

In [None]:
plt.plot(x[0])
plt.plot(x_pred2[0])
plt.xlim(1000,1200)

In [None]:
x[0].shape

In [None]:
plt.plot(dataset.dataset["aspcap_interpolated"][0])
plt.plot(compressor.inverse_transform(z_pca)[0])
plt.xlim(1000,1200)

In [None]:
v[:50]

In [None]:
fitter.scaling_factor

In [None]:
v_scaled = v[:,:50]/fitter.scaling_factor

In [None]:
plt.plot((a[:50]/fitter.scaling_factor.T)[-1])

### I think this sortoff gives us the information of each pixel scaled by open-clusters

In [None]:
plt.plot(np.sum(a[:50]/fitter.scaling_factor.T,axis=0))
#plt.plot(x[0])


### Without scaling

In [None]:
plt.plot(np.sum(a[:50],axis=0))

In [None]:
x_fe = x[:n_data][combined_cut[:n_data]]

In [None]:
y_fe = y.centered()[:,2:3]

In [None]:
x[:n_data].shape

In [None]:
combined_cut.shape

In [None]:
y_fe.shape

In [None]:
x_fe.shape

In [None]:
np.dot(x_fe,w)

In [None]:
w_fe = np.dot(np.linalg.pinv(x_fe),y_fe)

In [None]:
np.linalg.pinv(x_fe).shape

In [None]:
#z = vector.Vector(z_pca[combined_cut[:n_data]],order=1,interaction_only=False)

In [None]:
x_fe.shape

### compare metallicity

In [None]:
v[:,-2:-1].shape

In [None]:
plt.plot(v,)

In [None]:
v()[:n_data]#[].shape

In [None]:
olt.hist(v[:n_data][combined_cut[:n_data]][:,-1],y.centered

In [None]:
v[:n_data][combined_cut[:n_data]][:,-1].shape

In [None]:
y.centered()[:,2:3].shape

In [None]:
plt.scatter(v[:n_data][combined_cut[:n_data]][:,19],y.centered()[:,2],s=0.5)