In [None]:
%matplotlib inline
from json import load
from collections import OrderedDict

import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import minimize
from scipy.stats import binned_statistic


import sys
sys.path.insert(0, '../')
from kelp import Filter, Planet, PhaseCurve, Model

In [None]:
names = sorted(['WASP-18', 'WASP-103', 'WASP-43', 'WASP-12', 'KELT-9', 'HD 189733'])

In [None]:
with open('data/initp_lmax2.json', 'r') as w: 
    results_dict = load(w)
    
for i, n in enumerate(names): 
    for j, channel in enumerate([1, 2]):
        if not (n == 'KELT-9' and channel == 1):
            # Load phase curve
            
            fig, ax = plt.subplots(2, 1, figsize=(6, 8), sharex=True)
            pc = PhaseCurve.from_name(n, channel, year=2010 if n=='WASP-12' else None)
            p = Planet.from_name(n)
            filt = Filter.from_name(f"IRAC {channel}")
            filt.bin_down()
            pc.plot(ax=ax[0], mask=np.abs(pc.xi) < 2.7, alpha=0.8)
            
            # Construct xi grid
            xi_grid = np.linspace(-2.7, 2.7, 50)
            xi_bin_size = xi_grid[1] - xi_grid[0]
            xi_bins = np.concatenate([xi_grid, [xi_grid[-1] + xi_bin_size]]
                                     ) - xi_bin_size / 2
            
            # Bin observations to xi grid
            obs_bins = binned_statistic(pc.xi, pc.flux, bins=xi_bins, 
                                        statistic=np.nanmedian).statistic
            eclipse_full = p.eclipse_model(pc.xi)
            eclipse = binned_statistic(pc.xi, eclipse_full, bins=xi_bins, 
                                        statistic=np.nanmedian).statistic
            
            
            # Minimize the chi^2 for a phase curve model
            
            def pc_model(theta, xi_grid, obs_bins, eclipse):
                """
                Phase curve model function for optimization
                """

                hotspot_offset, ln_omega_drag, ln_c1, c4 = theta

                C_ml = np.array([[0, 0.0, 0.0], 
                                 [0, np.exp(ln_c1), 0], 
                                 [0, 0, c4]])
                model = Model(hotspot_offset=hotspot_offset, 
                              alpha=0.6, omega_drag=np.exp(ln_omega_drag), 
                              A_B=0, C_ml=C_ml, lmax=2, 
                              a_rs=p.a, rp_a=p.rp_a, T_s=p.T_s, filt=filt)
                phase_curve = 1e6 * eclipse * model.phase_curve(xi_grid)

                not_nan = np.logical_not(np.isnan(obs_bins) | 
                                         np.isnan(phase_curve) | 
                                         np.isinf(phase_curve))

                if np.count_nonzero(not_nan) > 0:         
                    constant = np.linalg.lstsq(phase_curve[not_nan, None], 
                                               obs_bins[not_nan], rcond=-1)[0][0]
                else: 
                    constant = 1
                return phase_curve * constant
            
            
            def chi2(theta, xi_grid, obs_bins, eclipse): 
                """
                Function to be minimized, returns the chi^2
                """
                model = pc_model(theta, xi_grid, obs_bins, eclipse)
        
                return np.nansum((model - obs_bins)**2)
  
            bounds = [[-1, 1], [0., 100], [-20, 1], 
                      [-10, 10]]

            keys = ['hotspot_offset', 'ln_omega_drag', 'ln_c1', 'c4'] # 'c2', 'c3', 
            initp = [results_dict[n][f"ch{channel}"][k] for k in keys]
            result = minimize(chi2, initp, args=(xi_grid, obs_bins, eclipse), 
                              bounds=bounds, 
                              method='powell')

            for k, v in zip(keys, result.x):
                results_dict[n][f"ch{channel}"][k] = v

            print(result)
            best_model = pc_model(result.x, xi_grid, obs_bins, eclipse)

            ax[0].plot(xi_grid, obs_bins, 'sk')
            ax[0].plot(xi_grid, best_model, color='r')
            ax[0].set_title(f'{n} Ch {channel}')
            
            ax[1].plot(xi_grid, obs_bins - best_model, 'sk')
            ax[1].axhline(0, ls='--', color='silver')
            fig.tight_layout()
            plt.show()

In [None]:
from json import dump

with open('data/initp_lmax2.json', 'w') as w: 
    dump(results_dict, w, indent=4, sort_keys=True)