**notebook objective**:
* Experiment with how to exclude zero'd out features in APOGEE


In [1]:
import apogee.tools.read as apread
import matplotlib.pyplot as plt
import apogee.tools.path as apogee_path
from apogee.tools import bitmask

import random
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
import torch.nn as nn

from torchcontrib.optim import SWA

from apoNN.src.datasets import ApogeeDataset

from tagging.src.networks import ConditioningAutoencoder,Embedding_Decoder,Feedforward,ParallelDecoder,Autoencoder

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
apogee_path.change_dr(16)

## Create dataset

In [2]:
allStar= apread.allStar(rmcommissioning=True,main=False,ak=True, akvers='targ',adddist=False)

upper_temp_cut = allStar["Teff"]<5000
lower_temp_cut = allStar["Teff"]>4000
lower_g_cut = allStar["logg"]>1.5
upper_g_cut = allStar["logg"]<3
snr_cut = allStar["SNR"]>100
snr_highcut = allStar["SNR"]<500
feh_outliercut = allStar["Fe_H"]>-5
o_outliercut = allStar["O_FE"]>-5
c_outliercut = allStar["C_FE"]>-5
na_outliercut = allStar["Na_FE"]>-5
mg_outliercut = allStar["Mg_FE"]>-5
si_outliercut = allStar["Si_FE"]>-5



combined_cut = lower_g_cut & upper_g_cut & lower_temp_cut & upper_temp_cut & snr_cut & snr_highcut & feh_outliercut & o_outliercut &  c_outliercut & na_outliercut & mg_outliercut & si_outliercut
cut_allStar = allStar[combined_cut]



### Parameters


In [3]:
n_batch = 128
n_z = 80
n_bins = 8575
lr = 0.0001
n_datapoints = 10000

### Training

In [None]:
#dataset = ApogeeDataset(cut_allStar[:n_datapoints],outputs = ["apstar","physical","idx"])
#dataset = ApogeeDataset(cut_allStar[:n_datapoints],outputs = ["aspcap","physical","idx"])

dataset = ApogeeDataset(cut_allStar[:n_datapoints],outputs = ["aspcap","mask","physical","idx"])

loader = torch.utils.data.DataLoader(dataset = dataset,
                                     batch_size = n_batch,
                                     shuffle= False,
                                     drop_last=True)
plt.plot(dataset[0][0])

### Training of the neural network

New training of neural network

In [None]:
encoder = Feedforward([n_bins,2048,512,n_z],activation=nn.SELU()).to(device)
decoder = Feedforward([n_z,512,2048,8192,n_bins],activation=nn.SELU()).to(device)

autoencoder = Autoencoder(encoder,decoder,n_bins=n_bins).to(device)
optimizer_autoencoder = torch.optim.Adam(autoencoder.parameters(), lr=lr)

we use swa in order to find better minima (or at leasat in theory)

In [None]:
#autoencoder = torch.load("/share/splinter/ddm/taggingProject/apogeeFactory/outputs/pretrained/ae1")
optimizer_autoencoder = torch.optim.Adam(autoencoder.parameters(), lr=lr)
opt_swa = SWA(optimizer_autoencoder, swa_start=10, swa_freq=5, swa_lr=0.0001)

In [None]:
loss = nn.L1Loss()

In [None]:
def generate_loss_with_masking(loss):
    def loss_with_masking(x_pred,x_true):
        non_zero = x_true!=0
        return loss(x_pred[non_zero],x_true[non_zero])
    return loss_with_masking
    

In [None]:
masked_loss = generate_loss_with_masking(loss)

In [None]:
for i in range(20000):
    for j,(x,mask,u,idx) in enumerate(loader):
        opt_swa.zero_grad()
        x_pred,z = autoencoder(x.to(device))

        err_pred = masked_loss(x_pred,x.to(device))

        err_tot = err_pred
        err_tot.backward()
        opt_swa.step()
        if j%100==0:
            print(f"err:{err_tot},err_pred:{err_pred}")


## Latent visualization

In [4]:
def get_z(idx,dataset):
    _,z = autoencoder(dataset[idx][0].to(device).unsqueeze(0))
    return z

def get_v(idx,dataset,feedforward):
    _,z = autoencoder(dataset[idx][0].to(device).unsqueeze(0))
    v = feedforward(z)
    return v