In [None]:
from __future__ import print_function, division

%matplotlib inline
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

## Boilerplate path hack to give access to full clustered_SNe package
import sys, os
if __package__ is None:
    if os.pardir not in sys.path[0]:
        file_dir = os.getcwd()
        sys.path.insert(0, os.path.join(file_dir, 
                                        os.pardir, 
                                        os.pardir))
        

from clustered_SNe.analysis.constants import m_proton, pc, yr, M_solar, \
                                   metallicity_solar
from clustered_SNe.analysis.parse import Overview, RunSummary, \
                                         Inputs, parse_into_scientific_notation
    
from clustered_SNe.analysis.database_helpers import session, \
                                                Simulation, \
                                                Simulation_Inputs, \
                                                Simulation_Status
            
from clustered_SNe.analysis.fit_helpers import Aggregated_Results, \
                                               Momentum_Model
                                         

In [None]:
results = Aggregated_Results()

## Example: visualize surface (fixed metallicity)

In [None]:
%matplotlib notebook

metallicity_index = np.argmax(results.metallicities_1D==metallicity_solar)

from mpl_toolkits.mplot3d import axes3d
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

with sns.plotting_context("poster"):
    surf = ax.plot_wireframe(np.log10(results.masses_3D[   metallicity_index,:,:]),
                             np.log10(results.densities_3D[metallicity_index,:,:]),
                             results.momenta_3D[           metallicity_index,:,:],
                             rstride=1, cstride=1, linewidth=1)

    plt.xlabel("log Mass")
    plt.ylabel("log density")
    plt.show()

## Example: visualize surface (fixed density)

In [None]:
%matplotlib notebook

density_index = np.argmax(results.densities_1D==1.33 * m_proton)

from mpl_toolkits.mplot3d import axes3d
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

with sns.plotting_context("poster"):
    surf = ax.plot_wireframe(np.log10(results.masses_3D[       :, density_index, :]),
                             np.log10(results.metallicities_3D[:, density_index, :]),
                             results.momenta_3D[               :, density_index, :],
                             rstride=1, cstride=1, linewidth=1)

    plt.xlabel("log Mass")
    plt.ylabel("log Z / Z_sun")
    plt.show()

## Test: call the model 

In [None]:
%matplotlib inline

metallicity_index = np.argmax(results.metallicities_1D==metallicity_solar)
density_index = 2
metallicity = results.metallicities_1D[metallicity_index]
density     = results.densities_1D[density_index]
mask = np.isclose(results.densities, density, atol=0) \
     & np.isclose(results.metallicities, metallicity, atol=0) \
     & results.usable & (results.momenta>0)

tmp_model = Momentum_Model(1e3, 1e4, 
                           0, 0, 
                           0, 0, 
                           2, -.2)

print(tmp_model(metallicity, density, results.masses[mask]))

with sns.plotting_context("poster", font_scale=2):

    plt.plot(results.num_SNe[mask],
             results.momenta[mask] / (results.num_SNe[mask] * 100 * M_solar),
             marker= "o", linestyle="",
             label="data")
    
    
    x_fit = np.logspace(-.5,3.5, num=100)
    y_fit = tmp_model(metallicity, density, x_fit)
    plt.plot(x_fit, y_fit,
            label="chi-by-eye")
    
    plt.xscale("log")
    plt.xlabel("N_SNe")
    plt.ylabel("Momentum / (100 M$_\odot$ * N$_\mathrm{SNe}$) [km / s]")
    plt.legend(loc="best")

## Test: fit the model (using 1D slice -- fixed density, metallicity)

In [None]:
%matplotlib inline

metallicity_index = np.argmax(results.metallicities_1D==metallicity_solar)
density_index = 2
metallicity = results.metallicities_1D[metallicity_index]
density     = results.densities_1D[density_index]
mask = np.isclose(results.densities, density, atol=0) \
     & np.isclose(results.metallicities, metallicity, atol=0) \
     & results.usable & (results.momenta>0)

tmp_model = Momentum_Model(1e3, 1e4, 
                           0, 0, 
                           0, 0, 
                           2, -.12)

fixed = np.array([False, False, 
                  True, True, 
                  True, True, 
                  False, True])


y_init = tmp_model(metallicity, density, results.num_SNe[mask])
print("y_init: ", y_init)

x = (results.metallicities[mask],
     results.densities[mask],
     results.num_SNe[mask])
y = results.momenta[mask] / (results.num_SNe[mask] * 100 * M_solar)
popt, pcov = tmp_model.fit(x, y, fixed=fixed)

print("params_0: ", tmp_model.params_0)
print("params:   ", tmp_model.params)

with sns.plotting_context("poster", font_scale=2):
    plt.plot(results.num_SNe[mask], 
             results.momenta[mask] / (results.num_SNe[mask] * 100 * M_solar),
             marker= "o", linestyle="",
             label="data")
    
    x_fit = np.logspace(-.5,3.5, num=100)
    y_fit = tmp_model(metallicity, density, x_fit)
    plt.plot(x_fit, y_fit,
             label="fit")
    plt.xscale("log")
    plt.legend(loc="best")
    plt.xlabel("N_SNe ")
    plt.ylabel("Momentum / (100 * M$_\odot$ * N$_\mathrm{SNe}$) [km / s]")

## Test: fit the model (using 2D slice of solar metallicity)

In [None]:
%matplotlib inline

metallicity_index = np.argmax(np.isclose(results.metallicities_1D, metallicity_solar, atol=0))

with sns.plotting_context("poster", font_scale=2):
    for density_index in range(len(results.densities_1D)):
        results.plot_slice(metallicity_index, density_index, verbose=True)
        plt.title("density = {0:.2e} g cm^-3".format(results.densities_1D[density_index]))
        plt.show()

## Test: fit the model (using 2D slice -- fixed density)

In [None]:
%matplotlib inline

density_index = np.argmax(np.isclose(results.densities_1D, 1.33e-1 * m_proton, atol=0, rtol=1e-4))

with sns.plotting_context("poster", font_scale=2):
    for metallicity_index in range(len(results.metallicities_1D)):
        results.plot_slice(metallicity_index, density_index,verbose=True)
        plt.title("log Z / Z_solar = {0:.1f}".format(np.log10(results.metallicities_1D[metallicity_index] / metallicity_solar)))
        plt.show()