## Identifiability check

We check if abundances are recoverable up to a linear transformation (the answer is more or less but more than less yes).

## Setup

In [None]:
import pandas as pd
import sqlite3
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy import spatial
import sys
import pickle

from tagging.src.datasets import ApogeeDataset
from tagging.src.networks import ConditioningAutoencoder,Embedding_Decoder,Feedforward
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

"""
sys.path.insert(0,'/share/splinter/ddm/taggingProject/taggingClean/')
from src.datasets import ApogeeDataset
from src.networks import ConditioningAutoencoder,Embedding_Decoder,Feedforward
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
"""


In [None]:
data = pd.read_pickle("/share/splinter/ddm/taggingProject/taggingRepo/data/processed/spectra_noiseless.pd")
noisy_data = pd.read_pickle("/share/splinter/ddm/taggingProject/taggingClean/data/final/train/spectra_SN_100.pd")
val_data = pd.read_pickle("/share/splinter/ddm/taggingProject/taggingRepo/data/processed/spectra_noiseless_val.pd")


In [None]:
n_bins = 7751
n_batch=100

In [None]:
dataset = ApogeeDataset(data,n_bins)
noisy_dataset = ApogeeDataset(noisy_data,n_bins)
val_dataset = ApogeeDataset(val_data,n_bins)

loader = torch.utils.data.DataLoader(dataset = dataset,
                                     batch_size = n_batch,
                                     shuffle = False,
                                     drop_last=True)

we load the model we want to analyse

In [None]:
#model_file = "../../outputs/results_fader/run5/adN7214I1600"
model_file = "/share/splinter/ddm/taggingProject/taggingRepo/outputs/results_fader/expandedLatent/adN7214I1600"
#model_file = "/share/splinter/ddm/taggingProject/taggingClean/models/wasDis/runs/0/wganI5400"

conditioning_autoencoder = torch.load(model_file)

In [None]:
def get_z(idx,dataset):
    _,z = conditioning_autoencoder(dataset[idx][0].unsqueeze(0),dataset[idx][1][0:2].unsqueeze(0))
    return z

## Extracting the abundances and latents

In [None]:
data.params.values[0][2:]

In [None]:
abundances_array = np.array([row for row in data.abundances.values])
abundances_val_array = np.array([row for row in val_data.abundances.values])

In [None]:
params_array =  np.array([row[2:] for row in data.params.values])
abundances_array = np.concatenate((abundances_array,params_array),axis=1)
params_val_array =  np.array([row[2:] for row in val_data.params.values])
abundances_val_array = np.concatenate((abundances_val_array,params_val_array),axis=1)

In [None]:
z_array = np.array([get_z(i,dataset).detach().cpu().numpy() for i in range(2000)]).squeeze()
z_val_array = np.array([get_z(i,val_dataset).detach().cpu().numpy() for i in range(2000)]).squeeze()

## Calculation

We recenter (make mean=0) both the latents ```z_array``` and the  abundances ```abundances_arrat```

In [None]:
z_calibrated = (z_array- np.mean(z_array,axis=0)).T
z_val_calibrated = (z_val_array- np.mean(z_array,axis=0)).T
abundances_calibrated = (abundances_array-np.mean(abundances_array,axis=0))[:2000].T

we learn a matrix corresponding to a linear transformation between both spaces

In [None]:
z_calibrated.shape

In [None]:
W_est = np.linalg.pinv(z_calibrated.T).dot(abundances_calibrated.T).T

In [None]:
W_est.shape

In [None]:
W_est =np.dot(abundances_calibrated,np.linalg.pinv(z_calibrated))

In [None]:
abundances_calibrated_est = np.dot(W_est,z_calibrated)
abundances_est = (abundances_calibrated_est.T+np.mean(abundances_array,axis=0)).T #re-add the mean

In [None]:
abundances_calibrated_val_est = np.dot(W_est,z_val_calibrated)
abundances_val_est = (abundances_calibrated_val_est.T+np.mean(abundances_array,axis=0)).T #re-add the mean

In [None]:
elements= ["[N/Fe]","[O/Fe]","[Na/Fe]","[Mg/Fe]","[Al/Fe]","[Si/Fe]","[S/Fe]","[K/Fe]","[Ca/Fe]","[Ti/Fe]","[V/Fe]","[Mn/Fe]","[Ni/Fe]","[P/Fe]","[Cr/Fe]","[Co/Fe]","[Rb/Fe]","[Fe/H]",r"[$\alpha$\Fe]","[C/Fe]"]

we can now plot the abundances estimated from the latent ```abundance_est``` and compare them to the true abundances

In [None]:
for idx in range(20):
    plt.title(f"element:{elements[idx]}")
    plt.scatter(abundances_est[idx,0:2000],abundances_array.T[idx,0:2000])
    plt.xlabel("estimated")
    plt.ylabel("true")

    plt.show()

In [None]:
for idx in range(20):
    plt.title(f"element:{elements[idx]}")
    plt.scatter(abundances_val_est[idx,0:2000],abundances_val_array.T[idx,0:2000])
    plt.xlabel("estimated")
    plt.ylabel("true")

    plt.show()

## Estimating the actual information content

In [None]:
def train_network(loader,v_index,train_u=False):
    """
    loader: 
        pytorch dataset loader
    v_index: int
        index of input array to train
    train_v: bool
        whether to train using the v_index (True) or the u_index (False)
    """
    n_z = z_calibrated.shape[0]
    feedforward = Feedforward([n_z,512,256,128,1],activation=nn.SELU()).to(device)
    loss = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(feedforward.parameters(),lr=0.0001)
    for i in range(6):
        for j,(x,u,v,idx) in enumerate(loader):
            optimizer.zero_grad()
            _,z = conditioning_autoencoder(x,u[:,0:2])
            pred = feedforward(z.detach())
            if train_u:
                err = loss(pred,u[:,v_index:v_index+1])    
            else:
                err = loss(pred,v[:,v_index:v_index+1])
            err.backward()
            optimizer.step()
            if j%100==0:
                print(f"epoch:{i},err:{err}")
    return feedforward

In [None]:
networks = []
for i in range(17):
    networks.append(train_network(loader,i))

In [None]:
for i in range(2,5):
    networks.append(train_network(loader,i,train_u=True))

In [None]:
def get_v(z,network):
    z_tensor = torch.tensor(z).to(device)
    v_tensor = network(z_tensor)
    return v_tensor.detach().cpu().numpy()

In [None]:
for idx in range(20):
    plt.title(f"element:{elements[idx]}")
    v_net_array = get_v(z_array,networks[idx])
    v_net_array = (max(abundances_array[:,idx])-min(abundances_array[:,idx]))*(v_net_array+1)/2+min(abundances_array[:,idx])
    plt.scatter(v_net_array,abundances_array.T[idx,0:2000],s=0.5,alpha=0.5,label="optimal")
    plt.scatter(abundances_est[idx,0:2000],abundances_array.T[idx,0:2000],s=0.5,alpha=0.5,label="linear")
    plt.legend()
    plt.xlabel("estimated")
    plt.ylabel("true")

    plt.show()

In [None]:
import matplotlib.gridspec as gridspec

def draw_figure(ax,idx):
    ax.set_title(f"{elements[idx]}")
    v_net_array = get_v(z_array,networks[idx])
    v_net_array = (max(abundances_array[:,idx])-min(abundances_array[:,idx]))*(v_net_array+1)/2+min(abundances_array[:,idx])
    ax.scatter(abundances_array.T[idx,0:2000],v_net_array,s=0.5,alpha=0.5,label="non-linear")
    ax.scatter(abundances_array.T[idx,0:2000],abundances_est[idx,0:2000],s=0.5,alpha=0.5,label="linear")
    lgnd = ax.legend()
    lgnd.legendHandles[0]._sizes = [30]
    lgnd.legendHandles[1]._sizes = [30]
    ax.set_ylabel("estimated (dex)")
    ax.set_xlabel("true (dex)")
    
def make_canvas():
    fig = plt.figure(constrained_layout=True,figsize=[14,17.5])
    spec = gridspec.GridSpec(ncols=4, nrows=5, figure=fig)
    for i in range(4):
        for j in range(4):
            fig_ax = fig.add_subplot(spec[i, j])
            #fig_ax.set_axis_off()
            draw_figure(fig_ax,i+j*4)
    for i,j in enumerate([0,2]): 
        fig_ax = fig.add_subplot(spec[4, i])
        draw_figure(fig_ax,17+j)

    return fig
            
fig = make_canvas()
fig.savefig("latent_interpretabilty.pdf",format="pdf")

In [None]:
! pwd

In [None]:
? plt.save

In [None]:
nd

In [None]:
len(networks)