In [None]:
import torch
import umap
import numpy as np
import matplotlib.pyplot as plt
import itertools

In [None]:
from hubersed.style import *
from hubersed.paths import PATHS

DATA_PATH = PATHS['DATA']
RESULTS_PATH = PATHS['RESULTS']

In [None]:
device = torch.device('cpu')

In [None]:
latent_space_dict = torch.load(DATA_PATH / 'desi_noise_spender_10latent_space.pt', map_location=device)

In [None]:
latents = latent_space_dict['latents']

A = latent_space_dict['A']
A = A.squeeze()

z = latent_space_dict['zs']

In [None]:
reducer = umap.UMAP(random_state=14, n_neighbors=5).fit(latents)

In [None]:
embeddings = reducer.transform(latents)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(6., 3.5), dpi=300)
ax = ax.ravel()

# both ax should be equal aspect
for a in ax:
    a.set_aspect('equal', 'box')

scatter = ax[0].scatter(
    embeddings[:, 0], embeddings[:, 1],
    c=z, cmap='viridis',
    s=0.005, alpha=0.7,
    rasterized=True,
    marker='.'
)

# colorbar on top
cbar = fig.colorbar(
    scatter,
    ax=ax[0],
    orientation='horizontal',
    pad=0.02,          # distance between plot and colorbar
    fraction=0.05,     # thickness of colorbar
    location='top',
)
# range of colorbar 0 to 0.6
cbar.set_label('z')


import matplotlib as mpl
norm = mpl.colors.LogNorm(vmin=1, vmax=1000)
cmap = mpl.cm.ScalarMappable(norm=norm, cmap='viridis')
scatter.set_cmap(cmap.cmap)

scatter = ax[1].scatter(embeddings[:, 0], embeddings[:, 1], c=A, s=.005, alpha=0.7, norm=norm, rasterized=True, marker='.')

cbar = plt.colorbar(scatter, ax=ax[1], orientation='horizontal', pad=0.02, fraction=0.05, location='top')
cbar.set_label('A')

# no ticks
for a in ax:
    a.set_xticks([])
    a.set_yticks([])

fig.tight_layout()

fig.savefig(RESULTS_PATH / 'desi_noise_spender_10latent_space_umap.pdf', dpi=300)


In [None]:
# do for spectra

spender_spec = torch.load(DATA_PATH / 'spender_spec_6latent')
s_l = spender_spec['latents'].to(device)

prospector_spec = torch.load(DATA_PATH / 'prospector_noise_spec_6latent')
p_l = prospector_spec['latents'].to(device)

In [None]:
import corner

# make corner plot of all prospector latents and non-overlapping desi latents in grey and orange
fig = plt.figure(figsize=(10, 10))
labels = [f'Latent {k+1}' for k in range(s_l.shape[1])]

corner.corner(
    s_l.detach().cpu().numpy(),
    labels=labels,
    color='C0',
    plot_datapoints=True,
    plot_density=False,  
    plot_contours=False, 
    corner_mask=False,
    fig=fig,
    label_kwargs={"fontsize": 16}
)


corner.corner(
    p_l.detach().cpu().numpy(),
    labels=labels,
    color='grey',
    plot_datapoints=True,
    plot_density=False,  
    plot_contours=False, 
    corner_mask=False,
    fig=fig,
    label_kwargs={"fontsize": 16}
)

for ax in fig.get_axes():
    ax.set_xticks([])
    ax.set_yticks([])

plt.legend(
    ['DESI', 'Prospector'],
    loc='upper center',
    markerscale=80,
    fontsize=20,
    ncol=1,
    bbox_to_anchor=(0.1, 3)
)

fig.tight_layout()
fig.savefig(RESULTS_PATH / 'desi_prospector_latent_space_corner.pdf', dpi=300)