# Abundance Estimation

This notebook relates to directly obtaining abundances from the disentangled neural networks

## Setup

In [None]:
import torch.nn as nn
import torch
import torch.nn.functional as F


from tagging.paths import path_dataset
from tagging.src.datasets import ApogeeDataset
from tagging.src.networks import ConditioningAutoencoder,Embedding_Decoder,Feedforward,ParallelDecoder
import pandas as pd
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [None]:
n_bins = 1000
n_batch = 64
n_z = 20
n_cat = 30
n_hidden = 10
lr = 0.0001
n_conditioned = 2
loss_ratio = 10e-4

In [None]:
data = pd.read_pickle(path_dataset)
dataset = ApogeeDataset(data[:50000],n_bins)
loader = torch.utils.data.DataLoader(dataset = dataset,
                                     batch_size = n_batch,
                                     shuffle= True,
                                     drop_last=True)

In [None]:
encoder = Feedforward([n_bins+n_conditioned,512,128,32,n_z],activation=nn.SELU()).to(device)
decoder = ParallelDecoder(n_bins=n_bins,n_hidden=n_hidden,n_latent=n_z+n_conditioned,activation=nn.SELU()).to(device)
autoencoder = ConditioningAutoencoder(encoder,decoder,n_bins=n_bins).to(device)


loss = nn.MSELoss()
optimizer_autoencoder = torch.optim.Adam(autoencoder.parameters(), lr=lr)

In [None]:
loss2 = nn.CrossEntropyLoss()


pred_u0_given_v = Feedforward([n_z+1,512,256,n_cat],activation=nn.SELU()).to(device)
pred_u1_given_v = Feedforward([n_z+1,512,256,n_cat],activation=nn.SELU()).to(device)


optimizer_u0 = torch.optim.Adam(pred_u0_given_v.parameters(), lr=lr)
optimizer_u1 = torch.optim.Adam(pred_u1_given_v.parameters(), lr=lr)


In [None]:
zeros = torch.full((n_batch,2), 0.0, device=device)
ones = torch.full((n_batch,2),1.0,device=device)
noise = 100
noise_matrix = torch.empty(50000,n_bins).normal_(mean=0,std=1/noise).to(device)*4 #We initialize one noisy version of every datapoint and always use the same noise. This was found to work better (but not fully understood)


In [None]:

for i in range(100):
    for j,(x,u,v,idx) in enumerate(loader):

        u_cat = ((u+1)*n_cat/2).long()
        u_cat[u_cat==n_cat]=n_cat-1

        optimizer_autoencoder.zero_grad()
        
        x_pred,z = autoencoder(x,u[:,0:2].detach())
        err_pred = loss(x_pred,x)  
        z0 = torch.cat((z,u[:,1:2]),1)
        z1 = torch.cat((z,u[:,0:1]),1)

        
        z0 = torch.cat((z,u[:,1:2]),1)
        z1 = torch.cat((z,u[:,0:1]),1)
        u0_pred = pred_u0_given_v(z0)  
        u1_pred = pred_u1_given_v(z1)  
      
        
        
        err_u0 = loss2(u0_pred,u_cat[:,0])
        err_u1 = loss2(u1_pred,u_cat[:,1])
        err_tot = err_pred-loss_ratio*err_u0-loss_ratio*err_u1 #agrregated loss
        

        err_tot.backward(retain_graph=True)
        optimizer_autoencoder.step()        
        optimizer_u0.zero_grad()
        err_u0.backward(retain_graph=True)
        optimizer_u0.step()
        optimizer_u1.zero_grad()
        err_u1.backward()
        optimizer_u1.step()
        if j%10==0:
            print("epoch:{},tot:{},err:{},err_u0:{},er_u1:{}".format(i,err_tot,err_pred,err_u0,err_u1))
            
torch.save(autoencoder.state_dict(), "conditional_parallel_decoder.p")

In [None]:
torch.save(autoencoder.state_dict(), "conditional_parallel_decoder.p")

In [None]:
torch.softmax(u0_pred,1)[0]

In [None]:
torch.save(autoencoder, "conditional_parallel_decoder.p")

In [None]:
autoencoder.load_state_dict(torch.load("conditional_parallel_decoder.p"))