In [None]:
import numpy as np
import matplotlib.pyplot as plt
from TheCannon import dataset
from TheCannon import model
from tqdm import tqdm

import pymem
from figsave import savefig as savefig

In [None]:
plt.rcParams.update({
    "text.usetex": True,
    "font.family": 'serif',
    "font.serif": ["Computer Modern"],
    "font.size": 10,
    "figure.dpi": 200
})

In [None]:
import pickle as pkl

with open("data_large.pkl", "rb") as f:
    data = pkl.load(f)
    
data = np.array(data, dtype=object)

In [None]:
label_select = np.array(range(0,20))
#label_select = np.array(range(0,8))

In [None]:
wl_full = np.array(data[0])
targets = data[1]
flux_full = np.array(data[2], dtype=float)
ivar_full = np.array(data[3])
labels = np.array(data[4])[:,label_select]

In [None]:
bad_pix = flux_full > 1.1
flux_full[bad_pix] = 0
ivar_full[bad_pix] = 0

bad_pix = flux_full < 0
flux_full[bad_pix] = 0
ivar_full[bad_pix] = 0

In [None]:
blacklist = [
    'HD26965A_HARPS.2006-11-09T07:00:52.905_s1d_A',
    'HD72374_HARPS.2009-04-01T00:41:51.307_s1d_A',
    'HD131183_HARPS.2011-08-04T23:58:04.312_s1d_A',
    'HD26965A_HARPS.2011-12-04T03:37:00.086_s1d_A',
    'HD10647_HARPS.2011-12-29T00:42:06.216_s1d_A',
    'HD26965A_HARPS.2011-12-06T04:46:32.270_s1d_A',
    'HD151933_HARPS.2012-08-31T01:35:01.979_s1d_A',
    'HD114613_HARPS.2004-02-14T06:42:56.756_s1d_A',
    'HD65907A_HARPS.2015-12-05T06:52:13.596_s1d_A',
    'HD19994_HARPS.2016-01-10T03:16:38.232_s1d_A',
    'HD114613_HARPS.2017-02-14T07:46:36.285_s1d_A',
    'HD82114_HARPS.2018-04-02T02:52:02.509_s1d_A',
    'HD93932_HARPS.2018-04-02T05:24:31.488_s1d_A', 
    'HD82114_HARPS.2018-04-02T03:33:04.344_s1d_A',
    'HD85725_HARPS.2018-04-02T04:51:10.528_s1d_A',
    'HD93932_HARPS.2018-04-02T05:58:14.379_s1d_A'
]

In [None]:
blacklist_idx = np.array([ np.where(targets == bad)[0][0] for bad in blacklist ], dtype=int)
print(blacklist_idx)

In [None]:
targets = np.delete(targets, blacklist_idx)
flux_full = np.delete(flux_full, blacklist_idx, axis=0)
ivar_full = np.delete(ivar_full, blacklist_idx, axis=0)
labels = np.delete(labels, blacklist_idx, axis=0)

In [None]:
num_stars = len(targets)

num_test = int(num_stars * 0.2)

print("%d stars in dataset." %num_stars)
print("%d stars used in testing." %num_test)

In [None]:
ranges = np.linspace(0, 1, 17) * len(wl_full)

ranges = np.array(ranges, dtype=int)

range_num = 5
wl_range = slice(ranges[range_num], ranges[range_num+1])

print(wl_range)

In [None]:
wl = wl_full[wl_range]
flux = flux_full[:,wl_range]
ivar = ivar_full[:,wl_range]

print("WL range: %.2f to %.2f" %(wl[0], wl[-1]))

In [None]:
targets_train = targets[num_test:num_stars]
flux_train = flux[num_test:num_stars]
ivar_train = ivar[num_test:num_stars]
labels_train = labels[num_test:num_stars]

targets_test = targets[:num_test]
flux_test = flux[:num_test]
ivar_test = ivar[:num_test]
labels_test = labels[:num_test]

In [None]:
label_names = np.array(('T_{eff}', '\log g', 'v \sin i', '[Fe/H]', '[Na/H]', '[Mg/H]', '[Al/H]', '[Si/H]', '[Ca/H]', '[V/H]', '[Mn/H]', '[Co/H]', '[O/H]', '[Ni/h]', '[C/H]', '[ScI/H]', '[TiI/H]', '[CrI/H]', '[YII/H]', '[S/H]'))[label_select]
label_units = np.array(['K', 'dex', 'km/s'] + ['dex']*17)
print(','.join(label_names))

In [None]:
ds = dataset.Dataset(
    wl,
    targets_train, flux_train, ivar_train, labels_train,
    targets_test, flux_test, ivar_test
)
ds.set_label_names(label_names)

In [None]:
memlogger = pymem.MemLogger(0.1)
memlogger.start()

md = model.CannonModel(2, useErrors=False)
md.fit(ds)

In [None]:
md.infer_labels(ds)

memlogger.stop()

In [None]:
memlog = np.array(memlogger.get_log())

fig = plt.figure(figsize=(4,3))
                
ax = fig.add_subplot()

ax.plot(memlog[:,0] - memlog[0,0], memlog[:,1]/2**30)

ax.set_ylabel(r'RAM Usage ($GiB$)')
ax.set_xlabel(r'Time ($s$)')

In [None]:
print(np.max(memlog) / 2**30 * 16)

In [None]:
residuals = ds.test_label_vals - labels_test

residuals = np.sort(residuals, axis=0)

lower_limits = residuals[int(0.16*num_test),:]
upper_limits = residuals[int(0.84*num_test),:]

robust_stds = upper_limits - lower_limits

print(upper_limits - lower_limits)

In [None]:
from datetime import datetime

with open('range_results_robust.csv', 'a') as file:
    timestamp = datetime.now().strftime("%Y-%m-%d-%H%M%S")
    data = np.array([timestamp, range_num] + robust_stds.tolist(), dtype=str)
    file.write(','.join(data) + '\n')

In [None]:
from figsave import savefig

#fig, ax = plt.subplots(5,4)

fig = plt.figure(figsize=(10, 5))
gs = fig.add_gridspec(3,4,hspace=0.2, wspace=0.15, left=0.08, right=0.95, bottom=0.1, top=0.95)

ax = []

for i in range(0, 12):
    ax.append(fig.add_subplot(gs[int(i/4), i % 4]))

for i in range(0,len(ax)):
    ax[i].plot(labels_test[:,i], ds.test_label_vals[:,i], '.')
    ax[i].plot((np.min(labels_test[:,i]), np.max(labels_test[:,i])),(np.min(labels_test[:,i]), np.max(labels_test[:,i])))
    ax[i].legend(title=r'$%s$' %(label_names[i]), loc=0)
    
fig.supylabel(r'Residuals')
fig.supxlabel(r'Literature Value')


savefig(fig, 'results')

In [None]:
from figsave import savefig

#fig, ax = plt.subplots(5,4)

fig = plt.figure()
gs = fig.add_gridspec(5,4,hspace=0.15, wspace=0.2, left=0.06, right=0.95, bottom=0.05, top=0.95)

ax = []

for i in range(0, 20):
    ax.append(fig.add_subplot(gs[int(i/4), i % 4]))

print(ax[0])

for i in range(0,len(label_select)):
    ax[i].plot(labels_test[:,i], residuals[:,i], '.')
    ax[i].plot(labels_test[:,i], [0]*len(residuals))
    ax[i].fill_between(
        (np.min(labels_test[:,i]), np.max(labels_test[:,i])),
        [-2*np.std(residuals[:,i])]*2,
        [2*np.std(residuals[:,i])]*2,
        color='orange',
        alpha=0.3
    )
    ax[i].legend(title=r'$%s$ ($\sigma=%.2f$)' %(label_names[i], np.std(residuals[:,i])), loc=3)
    
fig.supylabel(r'Residuals')
fig.supxlabel(r'Literature Value')

fig.set_figwidth(15)
fig.set_figheight(12)

#savefig(fig, 'residuals_range-%d' %range_num)

In [None]:
with open('data_large_labels.pkl', 'wb') as f:
    pkl.dump((targets_train, targets_test, labels_train, labels_test, ds.test_label_vals), f)

In [None]:
fig = plt.figure(figsize=(15, 9))

gs = fig.add_gridspec(4, 5, hspace=0.3, wspace=0.25, left=0.07, right=0.93, bottom=0.07, top=0.93)

gs_spots = []

for y in range(0, 4):
    for x in range(0, 5):
        gs_spots += [gs[y, x]]


ls = range(20)

for i in range(len(gs_spots)):
    ax = fig.add_subplot(gs_spots[i])
    ax.plot(labels_test[:,ls[i]], ds.test_label_vals[:,ls[i]], '.', ms=2)
    min_l = np.min(labels_test[:,ls[i]])
    max_l = np.max(labels_test[:,ls[i]])
    ax.plot((min_l, max_l), (min_l, max_l))
    ax.text(
        0.04, 0.9,
        r'$\sigma=%.2f %s$' %(robust_stds[ls[i]], label_units[ls[i]]),
        transform=ax.transAxes
    )
    ax.set_title('$%s$ ($%s$)' %(label_names[ls[i]], label_units[ls[i]]), fontsize=9)

fig.supxlabel('Literature Labels')
fig.supylabel('Cannon-Derived Labels')

savefig(fig, 'robust-residuals-range4')