In [None]:
import math
import matplotlib.pyplot as plt
import numpy as np

class Firm():
    @property
    def wage(self):
        return self.wage_premium + self.model.subsistence_wage
    
    def __init__(self, model,
             init_wage_premium,
             alpha_F, beta_F, Z,
             price_of_output, cost_of_capital,
             firm_adjustment_parameter,
             wage_adjustment_parameter,
             initial_population):
        self.N = initial_population # TODO temp to show N rising seperately
        self.model           = model
        self.wage_premium    = init_wage_premium # omega
        self.alpha_F         = alpha_F
        self.beta_F          = beta_F
        self.Z               = Z

        self.price_of_output = price_of_output
        self.r               = cost_of_capital

        self.firm_adjustment_parameter = firm_adjustment_parameter
        self.wage_adjustment_parameter = wage_adjustment_parameter

        n_R           = self.model.workforce_rural_firm
        self.n        = n_R # workforce_urban_firm is initally same as urban firm

        self.F        = 5 # TODO INITIALIZE, CHECK IN OVERLEAF  #### Should we set F set with baseline popualtion?
        # self.no_firms = self.model.baseline_population/self.model.workforce_rural_firm

        # Calculate scale factor A for a typical urban firm
        psi      = self.model.subsistence_wage
        Y_R      = n_R * psi / beta_F
        Y_U      = self.n * self.wage / beta_F
        k_R      = alpha_F * Y_R / self.r
        self.k   = alpha_F * Y_U / self.r
        self.A_F = 3500 # Y_R/(k_R**alpha_F * n_R * psi**beta_F)

    def step(self):
        # Calculate wage, capital, and firm count given number of urban workers
        self.n = self.N/self.F
        self.y = self.output(self.N, self.k, self.n)

        self.MPL = self.beta_F  * self.y / self.n
        self.MPK = self.alpha_F * self.y / self.k

        self.n_target = self.beta_F * self.y / self.wage
        self.y_target = self.output(self.N, self.k, self.n_target)
        self.k_target = self.alpha_F * self.y_target / self.r

        # N_target_exist = n_target/self.n * self.N
        adj_f = self.firm_adjustment_parameter # TODO repeats
        self.F_target = self.n_target/self.n * self.F
        self.F_next = (1 - adj_f) * self.F + adj_f * self.F_target
        self.N_target_total = self.F_next * self.n_target
        self.F_next_total = self.N_target_total / self.n_target

        # adj_l = 1.25 # TODO self.labor_adjustment_parameter
        # N_target_total = adj_l * n_target/self.n * self.N
        # N_target_new = n_target * self.Z * (MPL - self.wage)/self.wage * self.F # TODO - CHECK IS THIS F-NEXT?

        c = self.model.transport_cost_per_dist
        self.wage_premium_target = c * math.sqrt(self.N_target_total/(2*self.model.density))        
        
        k_next = self.k_target # TODO fix

        adj_w = self.wage_adjustment_parameter
#       self.wage_premium = self.wage_premium_target # TODO add back in wage adjusment process
        self.wage_premium = (1-adj_w) * self.wage_premium + adj_w * self.wage_premium_target
#         if self.model.time_step < 3:
#             self.wage_premium = (1-adj_w)*self.wage_premium + adj_w * self.wage_premium_target
#         else:
#             self.wage_premium += 100
        self.k = k_next
        self.F = self.F_next_total # OR use F_total

    def output(self, N, k, n):
        A_F     = self.A_F
        alpha_F = self.alpha_F
        beta_F  = self.beta_F
        gamma   = self.model.gamma

        return A_F * N**gamma * k**alpha_F * n**beta_F
class City:
    def __init__(self, **parameters):
        default_parameters = {
                    'run_notes': 'Debugging model.',
                    'subfolder': None,
                    'width': 50,
                    'height': 1,
                    'init_city_extent': 10.,  # f CUT OR CHANGE?
                    'seed_population': 10,
                    'density': 300,
                    'subsistence_wage': 40000.,  # psi
                    'init_wage_premium_ratio': 0.2,
                    'workforce_rural_firm': 100,
                    'price_of_output': 1.,  # TODO CUT?
                    'alpha_F': 0.18,
                    'beta_F': 0.72,  # beta and was lambda, workers_share of aglom surplus
                    'beta_city': 1.12,
                    'gamma': 0.02,  # FIX value
                    'Z': 0.5,  # CUT? Scales new entrants
                    'firm_adjustment_parameter': 0.25,
                    'wage_adjustment_parameter': 0.5,
                    'mortgage_period': 5.0,  # T, in years
                    'working_periods': 40,  # in years
                    'savings_rate': 0.3,
                    'r_prime': 0.05,  # 0.03
                    'discount_rate': 0.07, # 1/delta
                    'r_margin': 0.01,
                    'property_tax_rate': 0.04,  # tau, annual rate, was c
                    'housing_services_share': 0.3,  # a
                    'maintenance_share': 0.2,  # b
                    'max_mortgage_share': 0.9,
                    'ability_to_carry_mortgage': 0.28,
                    'wealth_sensitivity': 0.1,
                    'initial_population': 1000,
                }

        # Merge default parameters with provided parameters
        if parameters is not None:
            self.params = {**default_parameters, **parameters}
        else:
            self.params = default_parameters

        self.time_step = 1.
        self.seed_population         = self.params['seed_population']
        self.density                 = self.params['density'] # Coarse grain population
        self.transport_cost_per_dist = self.params['init_wage_premium_ratio'] * self.params['subsistence_wage'] / self.params['init_city_extent'] # c
        # self.baseline_population   = density*width*height + self.seed_population 

        # People
        self.working_periods  = self.params['working_periods']
        self.savings_per_step = self.params['subsistence_wage'] * self.params['savings_rate']

        # Production model
        self.subsistence_wage = self.params['subsistence_wage'] # psi
        self.workforce_rural_firm = self.params['workforce_rural_firm']
        self.gamma = self.params['gamma']
        self.beta_city = self.params['beta_city']
        self.workers_share = self.params['beta_F'] # lambda

        # Housing market model
        self.mortgage_period        = self.params['mortgage_period']
        self.housing_services_share = self.params['housing_services_share'] # a
        self.maintenance_share      = self.params['maintenance_share'] # b
        self.r_prime  = self.params['r_prime']
        self.r_margin = self.params['r_margin']
        self.delta    = 1/self.params['discount_rate'] # TODO divide by zero error checking
        self.max_mortgage_share        = self.params['max_mortgage_share']
        self.ability_to_carry_mortgage = self.params['ability_to_carry_mortgage']
        self.wealth_sensitivity        = self.params['wealth_sensitivity']
        
        init_wage_premium    = self.params['init_wage_premium_ratio'] * self.params['subsistence_wage'] # omega
        firm_cost_of_capital = self.params['r_prime']
        
        # Initial values for ploting in unit
        self.initial_population = self.params['initial_population']
        self.initial_wage = init_wage_premium + self.params['subsistence_wage']
        
        self.wage_premium_list = []
        self.wage_list = []
        self.population_list   = []
        
        # Add firm
        self.firm            = Firm(self,
                                    init_wage_premium,
                                    self.params['alpha_F'], self.params['beta_F'], self.params['Z'],
                                    self.params['price_of_output'], firm_cost_of_capital,
                                    self.params['firm_adjustment_parameter'],
                                    self.params['wage_adjustment_parameter'],
                                    self.params['initial_population'])

    def step(self):
        self.firm.step()
        self.firm.N += 100 # update population somehow..

        self.wage_list.append(self.firm.wage_premium + self.subsistence_wage)
        self.wage_premium_list.append(self.firm.wage_premium)
        self.population_list.append(self.firm.N)
        
city = City()
for i in range(5):
    city.step()


In [None]:
    # Define parameter ranges
    a_name   = 'beta_F'
    a_values = np.linspace(.9, .7, 4)   # A_values min, max, steps

    b_name   = 'gamma'
    b_values = np.linspace(1.0, 1.4, 4) # min, max, steps

    # Initialize variables
    time_steps = 50
    t = np.arange(time_steps)
    data_matrix = np.zeros((len(a_values), len(b_values), time_steps, 2))

    # Sweep over parameter combinations and generate data
    for i, a in enumerate(a_values):
        for j, b in enumerate(b_values):
#             wage_premium_list = []
#             population_list = []
#             MPL_list = []
#             P_list   = []

            test_parameters = {
                a_name: a,
                b_name: b,
            }

            city = City(**test_parameters)
            for time_step in range(time_steps):
                city.step()
####           P = population/100  ####my addition
    #             wage_premium = gamma *A * population**(gamma-1)#A * population**gamma
    #             if wage_premium < 1e-10:  # Avoid division by zero
    #                 wage_premium = 1e-10
    #             population = 2 * (wage_premium / transport_cost_per_dist)**2 * density + seed_population
    #             population = max(1, min(population, 1e5))  # Limit population values
    #             MPL = gamma *A * population**(gamma-1)

            data_matrix[i, j, :, 1] = city.wage_list
            data_matrix[i, j, :, 0] = city.population_list

    # Create subplots grid
    fig, axs = plt.subplots(len(a_values), len(b_values), figsize=(15, 10))
    

    # Plot results for each parameter combination
    for i in range(len(a_values)):
        for j in range(len(b_values)):
            axs[i, j].plot(t, data_matrix[i, j, :, 0], label='MPL')
            axs[i, j].plot(t, data_matrix[i, j, :, 1], label='population')
            axs[i, j].set_xlabel('Time')
            axs[i, j].set_ylabel('Value')
            axs[i, j].set_title(f'{a_name}={a_values[i]:.2f}, {b_name}={b_values[j]:.2f}')
#             axs[i, j].set_title(f'A={a_values[i]:.2f}, gamma={b_values[j]:.2f}')
            axs[i, j].legend()

    plt.tight_layout()
    plt.show()
    # plt.savefig('')