In [None]:
import numpy as np
import pynbody
import pynbody.units as units
import matplotlib.pyplot as plt
import sys, os, glob, pickle, pylab as plt, struct

In [None]:
def load_halos_pickle(pickle_path):
    
    '''
    Returns pickle as dictionary
    '''
        
    data = pickle.load( open( pickle_path , "rb" ))
    
    
    output = dict([(str(k),np.zeros(len(data))) for k in data[1]])
    
    for i in range(len(data)):
        gal_dict = data[i]
        
        if gal_dict is not None:
            for key, value in output.items():
                try:
                    value[i] = gal_dict[str(key)] + 1e-12
                except:
                    pass
    return output

In [None]:
def density_estimate(x, y):
    
    '''
    Calculates the plot density given function m2(m1).
    Returns X, Y, Z
    '''

    from scipy import stats
    xmin = x.min()
    xmax = x.max()
    ymin = y.min()
    ymax = y.max()

    X, Y = np.mgrid[xmin:xmax:50j, ymin:ymax:50j]
    positions = np.vstack([X.ravel(),Y.ravel()])
    values = np.vstack([x, y])

    kernel = stats.gaussian_kde(values)

    Z = np.reshape(kernel(positions).T, X.shape)
    
    return X, Y, Z

In [None]:
#path = 'data/60/vol_halo_center_z0.205.p'
path = '/scratch/hc2347/pickles/60/main_preload_z0.205.p'

#path = 'data/60/1e5_vol_halo_metals_z0.205.p'
data = load_halos_pickle(path)


In [None]:
print(data.keys())

In [None]:
def clean_arrays(array_iter, notzero = False):
    cur_mask = np.isfinite(array_iter[0])
    for array in array_iter[1:]:
        array_mask = np.isfinite(array)
        cur_mask = np.where(array_mask==True | cur_mask == True)
    if notzero:
        for array in array_iter:
            cur_mask == np.where(arra)
    return cur_mask

In [None]:
massf = np.where(data['mvir'] > 10**8)

In [None]:
def is_valid(elm):
    return elm > 1 and np.isfinite(elm)

def do_filter(a, b):
    #print(len(b))
    for idx in range(len(b) - 1, -1, -1):
        #print(idx)
        if not is_valid(b[idx]):
                a = np.delete(a,idx)
                b = np.delete(b,idx)
    return a,b

oxh = data['oxh']
mstar = data['mstar']
oxh, mstar = do_filter(oxh,mstar)
print(np.where(np.isinf(oxh)==True))

In [None]:
# Oxygen Abundance

def plot_oxh(oxh,mstar):
    
    oxh = [oxh[idx] for idx in range(0, len(oxh)) if is_valid(oxh[idx])]
    mstar = [mstar[idx] for idx in range(0, len(oxh)) if is_valid(oxh[idx])]
    
    tremonti = np.genfromtxt('/scratch/hc2347/references/obs/Tremonti_2004_mzr.csv',unpack=True,skip_header=2,delimiter=',')
    
    median_tr = tremonti[3]
    sixteen_tr = tremonti[2]
    eightyfour_tr = tremonti[4]
    
    logmstar_tr =tremonti[0]
    
    fig, ax = plt.subplots(figsize = (7,7))
    
    print("Beginning density estimate for " + str(len(oxh)) + " number of points.")
    
    #X, Y, Z = density_estimate(oxh[idx], np.log10(mstar[idx]))
    
    import multiprocessing
    import concurrent.futures
    with concurrent.futures.ProcessPoolExecutor(max_workers=20) as executor:
        result = executor.map(density_estimate, oxh, mstar) 
        flat_results = list(result)
        print(flat_results)
    
    #ax.scatter(oxh[idx],np.log10(mstar[idx]))
    plt.contourf(X,Y,Z, 20, cmap = 'Blues')
    plt.colorbar()
    
    ax.plot(logmstar_tr, median_tr, marker='+')
    ax.fill_between(logmstar_tr, sixteen_tr, y2 = eightyfour_tr, alpha = 0.1, color='orange', label = "Tremonti 2004")
    ax.set_ylabel('$12+log_{10}(O/H)\odot)$',fontsize=18)
    ax.set_xlabel('$ log M_{*}/M_\odot$',fontsize=18)
    ax.legend()
    
    ax.set_ylim(8,9.5)
    ax.set_xlim(8,12)

                              
plot_oxh(data['oxh'][massf],data['mstar'][massf])

In [None]:
# Stellar Metallicity

z_star = data['z_star']
z_gas = data['z_gas']

def plot_stellar_metallicity(z_star, m_star):

    z_sol = 0.013 # primordial Solar metallicity


    x = np.log10(data['mstar'])
    y = np.log10(z_star/z_sol)
    gallazzi = np.genfromtxt('/scratch/hc2347/references/obs/Gallazzi_2005_zstar.csv',unpack=True,skip_header=2,delimiter=',')
    logmstar_tr = gallazzi[0]

    median_tr = gallazzi[1]
    sixteen_tr = gallazzi[2]
    eightyfour_tr = gallazzi[3]

    fig, ax = plt.subplots(figsize=(10,8))

    X, Y, Z = density_estimate(x,y)
    plt.contourf(X,Y,Z, 20, cmap = 'Blues')
    plt.colorbar()

    # plot the observation.
    ax.plot(logmstar_tr, median_tr, marker='+')
    ax.fill_between(logmstar_tr, sixteen_tr, y2 = eightyfour_tr, alpha = 0.1, color='orange', label = "Gallazzi 2005")
    ax.set_ylabel('$log(Z_{*}/Z_\odot)$',fontsize=18)
    ax.set_xlabel('$ log M_{*}/M_\odot$',fontsize=18)
    ax.legend()

    #plt.ylim(7.6,10)
    ax.set_ylim(-1.5,0.5)
    ax.set_xlim(7,12)