In [1]:
# import stuff
# my code
import model_no_uncert as model
import my_toolbox as tb
from pars_shocks_and_wages import Pars, Shocks
import plot_lc
import simulate
# import solver
import run

# general
import numpy as np

In [28]:
"""
solver.py

This file contains the solver module for the project.

Author: Ben Boyajian
Date: 2024-05-31 11:42:26
"""
import model_no_uncert as model
from pars_shocks_and_wages import Pars
import my_toolbox as tb

from numba import njit, prange, float64
import numpy as np
import csv
from math import inf
from typing import Tuple
from interpolation.splines import eval_linear


#solve the whole lifecycle for the given parameters return a dictionary of solutions
def solve_lc(myPars: Pars, path: str = None )-> dict:
    # Start status csv
    if path is None:
        path = myPars.path
    fullpath = path + "/status.csv"
    with open(fullpath, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow([f'solve_lc started'])

    # Initialize solution shells
    var_list = ['c', 'lab', 'a_prime']
    ### **NOTE:** DO NOT CHANGE ORDER OF vlist W/O CHANGING ORDER IN simulate.sim_lc_jit
    state_sols = {var: np.empty(myPars.state_space_shape) for var in var_list} 
    if myPars.print_screen >= 1:
        print("state_sols shape", state_sols['c'].shape)
    
    # Set initial mat_c_prime = mat_c to a large number will be replaced anyways
    mat_c = inf * np.ones(myPars.state_space_shape_no_j) #this is a very big number
    
    # Iterate over periods
    for j in reversed(range(myPars.J)): #could maybe make this inner loop a seperate function and jit and parallelize it with prange
        
        # Set age-specific parameters, values
        mat_c_prime = mat_c
        last_per = (j >= myPars.J - 1)
        # retired = (j >= par.JR)
        
        # Get period solutions
        per_sol_list = solve_per_j(myPars, j, last_per, mat_c_prime)
        
        #Store period solutions
        for var,sol in zip(state_sols.keys(), per_sol_list):
            state_sols[var][:, :, :, :, j] = sol

        # Update mat_c with the solution for consumption from this period
        mat_c = per_sol_list[0] #this means we must always return the consumption first in the solve_per_j function
        
        # Print status of life-cycle solution both to the terminal and store in the status.csv file
        if myPars.print_screen >= 2:
            print(f'solved period {j} of {myPars.J}')
        fullpath = myPars.path + '/status.csv'
        with open(fullpath, mode='a', newline='') as file:
            writer = csv.writer(file, quoting=csv.QUOTE_NONE, escapechar='\\')
            writer.writerow([f'solved period {j} of {myPars.J}'])
            
            if myPars.print_screen >= 2:
                for state in range(np.prod(myPars.state_space_shape_no_j)):
                    a_ind, lab_FE_ind, H_ind, nu_ind = tb.D4toD1(state, myPars.a_grid_size, myPars.lab_FE_grid_size, myPars.H_grid_size, myPars.nu_grid_size)
                    ind_tuple = (a_ind, lab_FE_ind, H_ind, nu_ind, j) # also incorporate j in the tuple
                    # Create row elements without using f-strings
                    state_row = ['state:', state, 
                                'a:', myPars.a_grid[a_ind], 
                                'lab_FE:', myPars.lab_FE_grid[lab_FE_ind], 
                                'H:', myPars.H_grid[H_ind], 
                                'nu:', myPars.nu_grid[nu_ind], 
                                'j:', j]
                    writer.writerow(state_row)
                    
                    # Create solution row elements without using f-strings
                    solution_row = ['c:', round(state_sols["c"][ind_tuple], 3), 'lab:', round(state_sols["lab"][ind_tuple], 3), 'a_prime:', round(state_sols["a_prime"][ind_tuple], 3)]
                    writer.writerow(solution_row)

    
    return state_sols
    
# Solve the individual period problem given the parameters and the period sates
# this may need to be not jitted
# we must always return the consumption first in the solve_per_j function
# @njit
def solve_per_j( myPars: Pars, j: int, last_per: bool, mat_c_prime: np.ndarray)-> list:
    """
    solve for c, lab, and a_prime for a given period j
    """
    #Initialie shell for period solutions and asset grid
    shell_shape = myPars.state_space_shape_no_j
    shell_a_prime =  -inf * np.ones(shell_shape)
    shell_a = np.zeros(shell_shape)

    mat_c_ap, mat_lab_ap, mat_a_prime_ap = solve_per_j_iter(myPars, j, shell_a_prime, mat_c_prime, last_per)

    ## Transform variables z(a, kk) to variables z(a, k) using k(a, kk) or something like that?
    mat_c, mat_lab, mat_a_prime = transform_ap_to_a(myPars, shell_a, mat_c_ap, mat_lab_ap, mat_a_prime_ap, last_per)

    #mat_c, mat_lab, mat_a_prime = tb.create_increasing_array(myPars.state_space_shape_no_j), tb.create_increasing_array(myPars.state_space_shape_no_j), tb.create_increasing_array(myPars.state_space_shape_no_j) 
    return [mat_c, mat_lab, mat_a_prime]

# Iterate over individual states
#@njit(parallel=True)
#@njit
def solve_per_j_iter(myPars: Pars, j: int, shell_a_prime: np.ndarray, mat_c_prime: np.ndarray, last_per: bool)-> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Iterate over individual states.
    """
    # Initialize solution matrices
    mat_c_ap, mat_lab_ap, mat_a_ap = np.copy(shell_a_prime), np.copy(shell_a_prime), np.copy(shell_a_prime)
    
    # Iterate over states
    for state in prange(myPars.state_space_no_j_size): #can be parellized with prange messes up order of any printings in the loop
        
        # Get state indices and values
        a_prime_ind, lab_FE_ind, H_ind, nu_ind = tb.D4toD1(state, myPars.a_grid_size, myPars.lab_FE_grid_size, myPars.H_grid_size, myPars.nu_grid_size)
        a_prime, lab_FE, H, nu = myPars.a_grid[a_prime_ind], myPars.lab_FE_grid[lab_FE_ind], myPars.H_grid[H_ind], myPars.nu_grid[nu_ind]
        ind_tuple = (a_prime_ind, lab_FE_ind, H_ind, nu_ind)

        # Get current wage ***AND FUTURE WAGE IF WAGE VARIES?***
        curr_wage = model.wage(myPars, j, lab_FE_ind, H_ind, nu_ind)
        
        #fullpath = myPars.path + '/status.csv'

        # Get  period solutions
        if last_per: 
            #lab = myPars.lab_min
            last_per_a_prime = 0
            a = myPars.a_grid[a_prime_ind]
            c = a*(1 + myPars.r)
            lab = myPars.lab_min
            mb_lab = model.mb_lab(myPars, c, curr_wage, lab, H)
            mc_lab = model.mc_lab(myPars, c, lab, H)
            my_fe_ind = 0
            my_h_ind = 1
            my_nu_ind = 0
            my_ap_ind = int(myPars.a_grid_size/2)-1
            if a_prime_ind == my_ap_ind and lab_FE_ind == my_fe_ind and H_ind == my_h_ind and nu_ind == my_nu_ind: 
                print(f"For j= {j}: a={round(a,3)}, c={round(c,4)}, labor={round(lab, 3)}, mb_lab={round(mb_lab,3)}, mc_lab={round(mc_lab,3)}, a_prime={round(a_prime, 3)}")
            #scenario 1: c >= c_min
            if c >= myPars.c_min:
                if a_prime_ind == my_ap_ind and lab_FE_ind == my_fe_ind and H_ind == my_h_ind and nu_ind == my_nu_ind: 
                    print("Case 1")
                # mb_lab = model.mb_lab(myPars, c, curr_wage, lab, H)
                # mc_lab = model.mc_lab(myPars, c, lab, H)
                if mb_lab >= mc_lab:
                    c = model.c_star(myPars, last_per_a_prime, a, H, curr_wage)
                    lab = model.solve_lab_a(myPars, c, last_per_a_prime, curr_wage, H_ind)[0]
                else:
                    c = max(myPars.c_min, c) 
            #scenario 2: c + lab_max * curr_wage >= c_min but c < c_min
            elif c + myPars.lab_max * curr_wage >= myPars.c_min:
                if a_prime_ind == my_ap_ind and lab_FE_ind == my_fe_ind and H_ind == my_h_ind and nu_ind == my_nu_ind: 
                    print("Case 2")
                # mb_lab = model.mb_lab(myPars, myPars.c_min, curr_wage, lab, H)
                # mc_lab = model.mc_lab(myPars, myPars.c_min, lab, H)
                if mb_lab >= mc_lab:
                    c = model.c_star(myPars, last_per_a_prime, a, H, curr_wage)
                    lab = model.solve_lab_a(myPars, c, last_per_a_prime, curr_wage, H_ind)[0]
                else:
                    lab = myPars.lab_max
                    c = c + lab * curr_wage
            #scenario 3: c + lab_max * curr_wage < c_min
            else:
                if a_prime_ind == my_ap_ind and lab_FE_ind == my_fe_ind and H_ind == my_h_ind and nu_ind == my_nu_ind: 
                    print("Case 3")
                c = myPars.c_min
                lab = myPars.lab_min
        else: #not last period
            
            c_prime = mat_c_prime[ind_tuple]
            c = model.infer_c(myPars, curr_wage, j, lab_FE_ind, H_ind, nu_ind, c_prime)
            lab = myPars.lab_min
            mb_lab = model.mb_lab(myPars, c, curr_wage, lab, H)
            mc_lab = model.mc_lab(myPars, c, lab, H)
            if mb_lab >= mc_lab:
                lab, a = model.solve_lab_a(myPars, c, a_prime, curr_wage, H_ind)
            else:
                a = (c + a_prime - curr_wage*lab)/(1 + myPars.r)
        my_fe_ind = 0
        my_h_ind = 1
        my_nu_ind = 0
        my_ap_ind = int(myPars.a_grid_size/2)-1
        if a_prime_ind == my_ap_ind and lab_FE_ind == my_fe_ind and H_ind == my_h_ind and nu_ind == my_nu_ind: 
            print(f"For j= {j}: a={round(a,3)}, c={round(c,4)}, labor={round(lab, 3)}, mb_lab={round(mb_lab,3)}, mc_lab={round(mc_lab,3)}, a_prime={round(a_prime, 3)}")
        
        # Store state specific solutions
        mat_c_ap[ind_tuple], mat_lab_ap[ind_tuple], mat_a_ap[ind_tuple] = c, lab, a
        
    return mat_c_ap, mat_lab_ap, mat_a_ap

@njit
def solve_j_indiv( myPars: Pars, a_prime: float, curr_wage: float, j: int, lab_fe_ind: int, H_ind: int, nu_ind: int, c_prime: float)-> Tuple[float, float, float]:
    #c, lab, a = 1,2,3

    # Compute implied c given cc = c_prime
    # dVV_dkk = (1 + r) * model.du_dc(cc, par)
    # c = model.invert_c(dVV_dkk, par)
    c = model.infer_c(myPars, curr_wage, j, lab_fe_ind, H_ind, nu_ind, c_prime)

    lab, a = model.solve_lab_a(myPars, c, a_prime, curr_wage, H_ind)

    return c, lab, a

@njit
def transform_ap_to_a(myPars : Pars, shell_a, mat_c_ap, mat_lab_ap, mat_a_ap, last_per) :
  
    mat_ap_a, mat_c_a, mat_lab_a = np.copy(shell_a), np.copy(shell_a), np.copy(shell_a)

    evals = np.copy(myPars.a_grid)
    evals = evals.reshape(myPars.a_grid_size, 1)
    state_size_no_aj =  myPars.lab_FE_grid_size * myPars.H_grid_size * myPars.nu_grid_size 
    
    for state in range(state_size_no_aj) :
        lab_fe_ind, H_ind, nu_ind = tb.D3toD1(state, myPars.lab_FE_grid_size, myPars.H_grid_size, myPars.nu_grid_size)
        #convert soltuions from functions of a_prime to functions of a
        points = (mat_a_ap[:, lab_fe_ind, H_ind, nu_ind],)
        mat_c_a[:, lab_fe_ind, H_ind, nu_ind] = eval_linear(points, mat_c_ap[:, lab_fe_ind, H_ind, nu_ind], evals)
        mat_lab_a[:, lab_fe_ind, H_ind, nu_ind] = eval_linear(points, mat_lab_ap[:, lab_fe_ind, H_ind, nu_ind], evals)

        if not last_per: # default value is zero from shell_a, which is correct for last period
            mat_ap_a[:, lab_fe_ind, H_ind, nu_ind] = eval_linear(points, myPars.a_grid, evals)
 
    sol_a = [mat_c_a, mat_lab_a, mat_ap_a]

    return sol_a



In [30]:
main_path = "C:/Users/Ben/My Drive/PhD/PhD Year 3/3rd Year Paper/Model/My Code/Main_Git_Clone/Model/My Code/my_model_2/output/"

 # my_lab_FE_grid = np.array([10.0, 20.0, 30.0, 40.0])
# my_lab_FE_grid = np.array([10.0, 20.0, 30.0])
my_lab_FE_grid = np.array([10.0])
# my_lab_FE_grid = np.array([10.0, 10.0])
lin_wage_coeffs = [0.0, 1.0, 1.0, 1.0]
quad_wage_coeffs = [-0.000, -0.020, -0.020, -0.020] 
cub_wage_coeffs = [0.0, 0.0, 0.0, 0.0]

num_FE_types = len(my_lab_FE_grid)
w_coeff_grid = np.zeros([num_FE_types, 4])

w_coeff_grid[0, :] = [my_lab_FE_grid[0], lin_wage_coeffs[0], quad_wage_coeffs[0], cub_wage_coeffs[0]]
# w_coeff_grid[1, :] = [my_lab_FE_grid[1], lin_wage_coeffs[1], quad_wage_coeffs[1], cub_wage_coeffs[1]]
# w_coeff_grid[2, :] = [my_lab_FE_grid[2], lin_wage_coeffs[2], quad_wage_coeffs[2], cub_wage_coeffs[2]]
#w_coeff_grid[3, :] = [my_lab_FE_grid[3], lin_wage_coeffs[3], quad_wage_coeffs[3], cub_wage_coeffs[3]]


myPars = Pars(main_path, J=45, a_grid_size=100, a_min= -100.0, a_max = 100.0, lab_FE_grid = my_lab_FE_grid,
                H_grid=np.array([0.0, 1.0]), nu_grid_size=1, alpha = 0.45, sim_draws=10, sigma_util = 0.9999,
                wage_coeff_grid = w_coeff_grid,
                print_screen=0)

myShocks = Shocks(myPars)
state_sols = solve_lc(myPars)


For j= 44: a=-1.01, c=-1.0303, labor=0.0, mb_lab=nan, mc_lab=nan, a_prime=-1.01
Case 2
For j= 44: a=-1.01, c=7.8586, labor=0.889, mb_lab=nan, mc_lab=nan, a_prime=-1.01
For j= 43: a=6.714, c=7.8586, labor=0.0, mb_lab=0.573, mc_lab=0.619, a_prime=-1.01
For j= 42: a=1.569, c=5.1749, labor=0.256, mb_lab=0.87, mc_lab=0.619, a_prime=-1.01
For j= 41: a=0.5, c=4.6842, labor=0.316, mb_lab=0.961, mc_lab=0.619, a_prime=-1.01
For j= 40: a=0.052, c=4.4782, labor=0.342, mb_lab=1.005, mc_lab=0.619, a_prime=-1.01
For j= 39: a=-0.195, c=4.3649, labor=0.355, mb_lab=1.031, mc_lab=0.619, a_prime=-1.01
For j= 38: a=-0.352, c=4.2932, labor=0.364, mb_lab=1.048, mc_lab=0.619, a_prime=-1.01
For j= 37: a=-0.459, c=4.2437, labor=0.37, mb_lab=1.06, mc_lab=0.619, a_prime=-1.01
For j= 36: a=-0.538, c=4.2076, labor=0.375, mb_lab=1.07, mc_lab=0.619, a_prime=-1.01
For j= 35: a=-0.598, c=4.18, labor=0.378, mb_lab=1.077, mc_lab=0.619, a_prime=-1.01
For j= 34: a=-0.645, c=4.1583, labor=0.381, mb_lab=1.082, mc_lab=0.619, 

In [31]:
asset_ind = int(myPars.a_grid_size/2)
asset_ind = asset_ind - 1
print(f"My asset gird = {myPars.a_grid}")
print(f"My asset_ind = {asset_ind}")
print(f"My assets = {myPars.a_grid[asset_ind]}")
fe_ind = 0
h_ind = 1
nu_ind = 0
for j in range(myPars.J):
    ind_tuple = asset_ind, fe_ind, h_ind, nu_ind, j
    print(f"For j={j}: a_prime= {round(state_sols['a_prime'][ind_tuple],4)}, lab = {round(state_sols['lab'][ind_tuple],4)}, c = {round(state_sols['c'][ind_tuple],4)}, wage = {round(model.wage(myPars, j, fe_ind, h_ind, nu_ind),2)}")

sim_lc = simulate.sim_lc(myPars, myShocks, state_sols)
sim_ind = 0
print(f"for simulation {sim_ind}:")
for j in range(myPars.J):
    ind_tuple = fe_ind, h_ind, nu_ind, sim_ind, j
    print(f"For j={j}: a= {round(sim_lc['a'][ind_tuple],4)}, lab = {round(sim_lc['lab'][ind_tuple],4)}, c = {round(sim_lc['c'][ind_tuple],4)}, wage = {round(sim_lc['wage'][ind_tuple],2)}")
print('The average across all simulations:')
print(np.mean(sim_lc['lab'][fe_ind, h_ind, 0, :], axis=0))

My asset gird = [-100.          -97.97979798  -95.95959596  -93.93939394  -91.91919192
  -89.8989899   -87.87878788  -85.85858586  -83.83838384  -81.81818182
  -79.7979798   -77.77777778  -75.75757576  -73.73737374  -71.71717172
  -69.6969697   -67.67676768  -65.65656566  -63.63636364  -61.61616162
  -59.5959596   -57.57575758  -55.55555556  -53.53535354  -51.51515152
  -49.49494949  -47.47474747  -45.45454545  -43.43434343  -41.41414141
  -39.39393939  -37.37373737  -35.35353535  -33.33333333  -31.31313131
  -29.29292929  -27.27272727  -25.25252525  -23.23232323  -21.21212121
  -19.19191919  -17.17171717  -15.15151515  -13.13131313  -11.11111111
   -9.09090909   -7.07070707   -5.05050505   -3.03030303   -1.01010101
    1.01010101    3.03030303    5.05050505    7.07070707    9.09090909
   11.11111111   13.13131313   15.15151515   17.17171717   19.19191919
   21.21212121   23.23232323   25.25252525   27.27272727   29.29292929
   31.31313131   33.33333333   35.35353535   37.37373737   39

In [32]:
run.output(myPars, state_sols, sim_lc, no_tex = True)

In [6]:
c = 100.0
fut_wage = 10.0
cur_wage = 1000.0
util_c = model.util_c(myPars, c, fut_wage)
inv = model.util_c_inv(myPars, util_c, cur_wage)
print(inv)
print(c - inv)

99.97467223889932
0.025327761100683688


In [19]:
"""
My Model 2 - basic model no uncertainty

Contains the model equations and derivatives to be used by solver.py and others

Author: Ben Boyajian
Date: 2024-05-29 20:16:01
"""

# Import packages
import time
import numpy as np
from pars_shocks_and_wages import Pars
import my_toolbox as tb
from numba import njit, guvectorize, prange 
from interpolation import interp
import sys

#convert labor to leisure within period
@njit
def leis_giv_lab(myPars: Pars, labor: float, health: float) -> float:
    """
    encodes the time endowment constraint
    """
    leisure = 1.0 - labor*myPars.phi_n - (1.0-health)*myPars.phi_H
    # print("leisure: ", leisure)
    leisure = min(myPars.leis_max, leisure)
    # print("leisure min: ", leisure)
    return max(myPars.leis_min, leisure)

#convert leisure to labor within period
@njit
def lab_giv_leis(myPars: Pars, leisure: float, health: float) -> float:
    """
    encodes the time endowment constraint
    """
    labor = (1.0 - leisure - (1.0-health)*myPars.phi_H) / myPars.phi_n #this denom should never be zero, phi_n != 0
    return labor
    # labor = min(myPars.lab_max, labor)
    # return max(myPars.lab_min, labor)

#convert labor to consumption within period
@njit
def leis_giv_c(myPars: Pars, c: float, wage: float) -> float:
    """
    To do this we want to leverage the static equation:
        phi_n * util_leis = wage * util_c
        manipulating this equations gives us a relatively simple equation for leisure given current period consumption
    """

    constant = (myPars.phi_n * (1 - myPars.alpha)) / (wage * myPars.alpha) #this denom should !=0, wage is a product of exp != 0  
    leis = constant * c
    #leis = min(myPars.leis_max, leis)
    #return max(myPars.leis_min, leis)
    return leis
   
    

#converty leisure to consumption within period
@njit
def c_giv_leis(myPars: Pars,  leis: float, wage: float) -> float:
    """
    To do this we want to leverage the static equation:
        phi_n * util_leis = wage * util_c
        manipulating this equations gives us a relatively simple equation for consumption given current period leisure
    """
    constant = (wage * myPars.alpha) / (myPars.phi_n * (1 - myPars.alpha)) #this denom should never be zero, alpha != 1 and phi_n != 0
    return constant * leis

#utility function given leisure and consumption
@njit
def util_giv_leis(myPars: Pars, c: float, leis: float) -> float:
    """
    utility function
    """
    sig = myPars.sigma_util
    alpha = myPars.alpha
    return (1/(1-sig)) * ((c**alpha) * (leis**(1-alpha))) ** (1-sig)

#derivative of utility function with respect to consumption given consumption and leisure
@njit
def util_c_giv_leis(myPars: Pars, c: float, leis: float) -> float:
    """
    derivative of utility function with respect to consumption
    """
    sig = myPars.sigma_util
    alpha = myPars.alpha
    return alpha*c**(alpha - 1)*leis**(1 - alpha)/(c**alpha*leis**(1 - alpha))**sig #this denom is 0 if c or leis is 0

#deriveative of utility function with respect to consumption given consumption and health
@njit
def util_c(myPars: Pars, c: float, wage: float) -> float:
    """
    derivative of utility function with respect to consumption
    """
    leis = leis_giv_c(myPars, c, wage) #this can also be done explicitly in one function
    return util_c_giv_leis(myPars, c, leis)

@njit
def mb_lab(myPars: Pars, c: float, wage: float, labor: float, health: float) -> float:
    """
    marginal benefit of labor
    """
    leis = leis_giv_lab(myPars, labor, health)
    return wage * util_c_giv_leis(myPars, c, leis)

@njit
def mc_lab(myPars: Pars, c: float, labor: float, health: float) -> float:
    """
    marginal cost of labor
    """
    leis = leis_giv_lab(myPars, labor, health)
    return myPars.phi_n * util_leis_giv_c(myPars, leis, c)

@njit
def util_leis_giv_c(myPars: Pars, leis: float, c: float) -> float:
    """
    derivative of utility function with respect to leisure
    """
    return (1-myPars.alpha)*c**(myPars.alpha)*leis**(-myPars.alpha)/(c**myPars.alpha*leis**(1-myPars.alpha))**myPars.sigma_util

#inverse of the derivative of the utility function with respect to consumption
@njit
def util_c_inv(myPars: Pars, u: float, wage: float) ->float:
    """
    given a marginal utility u and a current wage return the consumption that yields that utility
    """
    alpha = myPars.alpha
    sigma = myPars.sigma_util

    const =(myPars.phi_n * (1 - alpha)) / (wage* alpha) #this denom should never be zero, wage is a product of exp != 0
    inner_exponent =(alpha*(-sigma)+alpha+sigma-1)

    c = ((u*const**inner_exponent) / alpha)**(-1/sigma)
    return c

# infer what current consumption should be given future consumption, curr wage, and the curr state space
@njit
def infer_c(myPars: Pars, curr_wage: float, age: int, lab_fe_ind: int, health_ind: int, nu_ind: int, c_prime: float ) -> float: 
    """
    calculated expectation on rhs of euler, calc the rest of the rhs, then invert util_c to get the curr c on the lhs
    """
    fut_wage = wage(myPars, age+1, lab_fe_ind, health_ind, nu_ind)    
    #fut_wage = curr_wage
    util_c_prime = util_c(myPars, c_prime, fut_wage)
    
    expect = util_c_prime
    rhs = myPars.beta *(1 + myPars.r) * expect
    
    c = util_c_inv(myPars, rhs, curr_wage)
    return max(myPars.c_min, c)  

# given current choice of c and a_prime, as well the state's wage and health 
@njit
def solve_lab_a(myPars: Pars, c: float, a_prime: float,  curr_wage: float, health_ind: float) -> float:
    """
    solve for labor and assets given consumption and wage
    """
    leis = leis_giv_c(myPars, c, curr_wage) 
    leis = min(myPars.leis_max, leis)
    leis = max(myPars.leis_min, leis)

    health = myPars.H_grid[health_ind]
    lab = lab_giv_leis(myPars, leis, health)
    #lab = invert_lab(myPars, c, curr_wage, health)
    lab = min(myPars.lab_max, lab)
    lab = max(myPars.lab_min, lab)

    a = (c + a_prime - curr_wage*lab)/(1 + myPars.r)
    # a = min(myPars.a_max, a)
    # a = max(myPars.a_min, a)
    return lab, a

@njit
def invert_lab (myPars : Pars, c: float, curr_wage: float, health: float) -> float:
    """
    invert the foc to get labor given consumption and wage
    """
    rhs = (curr_wage/myPars.phi_n) * util_c(myPars, c, curr_wage)
    leis = util_leis_inv(myPars, rhs, c)
    lab = lab_giv_leis(myPars, leis, health)
    return lab

@njit
def util_leis_inv(myPars: Pars, u: float, c: float) -> float:
    """
    invert the utility function with respect to leisure
    """
    alpha = myPars.alpha
    sigma = myPars.sigma_util
    phi_n = myPars.phi_n
    phi_H = myPars.phi_H

    out_exp = 1 / (alpha*sigma - alpha - sigma)
    inside =(u * c ** (-alpha * (1-sigma)))/(1 - alpha)
    return inside ** out_exp
    


# return the optimal labor decision given an asset choice a_prime and a current asset level, health status, and wage
# @njit
# def lab_star(myPars: Pars, a_prime: float, a: float, health: float, wage: float)-> float:
#     lab =  ((myPars.alpha/myPars.phi_n)*(1 - myPars.phi_H*(1-health))
#             + ((myPars.alpha - 1)/wage)*((1 + myPars.r)*a - a_prime))
#     lab = min(myPars.lab_max, lab)
#     return max(myPars.lab_min, lab)

@njit
def lab_star(myPars: Pars, a_prime: float, a: float, health: float, wage: float)-> float:
    numerator = 1 - myPars.phi_H*(1.0-health) - (myPars.phi_n/wage)*((1-myPars.alpha)/myPars.alpha)*(1 + myPars.r)*a
    denominator = myPars.phi_n + wage
    return max(myPars.lab_min, min(myPars.lab_max, numerator/denominator))

@njit
def c_star(myPars: Pars, a_prime: float, a: float, health: float, wage: float) -> float:
    """
    return the optimal consumption given an asset choice a_prime and a current asset level, health status, and wage
    """
    c_star = myPars.alpha*((wage/myPars.phi_n)*(1-myPars.phi_H*(1.0-health)) + (1 + myPars.r)*a - a_prime)
    return max(myPars.c_min, c_star)

#calulate deterministic part of the wage given health and age 
@njit
def det_wage(myPars: Pars, health: float, age: int) -> float:
    """
    deterministic part of the wage process
    """
    age_comp = myPars.w_age*age + myPars.w_age_2*age**2 + myPars.w_age_3*age**3
    health_comp = myPars.w_good_health*health + myPars.w_good_health_age*health*age
    return np.exp(myPars.w_determ_cons + age_comp + health_comp)

#calculate the wage given health, age, lab_fe, and nu i.e. the shocks
@njit
def wage(myPars: Pars,  age: int, lab_fe_ind: int, h_ind: int,  nu_ind: int) -> float:
    """
    wage process
    """
    return max(0.0001, tb.cubic(age, myPars.wage_coeff_grid[lab_fe_ind]))
    
    #det_wage = det_wage(myPars, health, age)
    # det_wage = 1.0
    # nu = 0.0
    # return  det_wage* np.exp(lab_fe) * np.exp(nu)
@njit
def gen_wages(myPars: Pars) -> np.ndarray:
    """
    generate the wage grid
    """
    #initialize the wage grid
    wage_grid = np.zeros((myPars.lab_FE_grid_size, myPars.H_grid_size, myPars.nu_grid_size, myPars.J))
    #loop through the wage grid
    for j in range(myPars.J):
        for h_ind in range(myPars.H_grid_size):
            for nu_ind in range(myPars.nu_grid_size):
                for lab_fe_ind in range(myPars.lab_FE_grid_size):
                    wage_grid[lab_fe_ind, h_ind, nu_ind, j] = wage(myPars, j, lab_fe_ind, h_ind, nu_ind)
    return wage_grid

@njit
def recover_wage(myPars: Pars, c: float, lab: float, a_prime: float, a: float) -> float: #this will divide by zero if lab = 0
    """
    recover the wage given consumption, labor, and assets
    """
    return (c + a_prime - (1 + myPars.r)*a) / lab


In [23]:
wage = 10
c = myPars.c_min
lab = 0.0
health = 1.0

leis = leis_giv_lab(myPars, lab, health)
print(f"leisure given labor = {leis}")
util = util_c_giv_leis(myPars, c, leis)
print(f"utility given consumption and leisure = {util}")
print(f"wage * util = {wage * util}")
my_mb_lab = mb_lab(myPars, c, wage, lab, health)
print(f"marginal benefit of labor = {my_mb_lab}")

leisure given labor = 1.0
utility given consumption and leisure = 2786750193.874479
wage * util = 27867501938.74479
marginal benefit of labor = 27867501938.74479
