In [None]:
import numpy as np

def rhsProductivities(A, L, w, d, R, D, sigma, vbar):
    """
    Returns an array 'labIncomeHat' of shape (N,), where:
      labIncomeHat[i] = sum_n [
        L[i]*(d[n,i]*w[i]/A[i])^(1-sigma)
        / sum_k [ L[k]*(d[n,k]*w[k]/A[k])^(1-sigma) ]
      ] * [vbar[n]*R[n] + D[n]].

    Parameters
    ----------
    A : array of shape (N,)
        Current guess of A^(sigma-1) or A (depending on usage).
        Make sure A > 0.
    L, w : arrays of shape (N,)
        Employment and wages for each region i.
    d    : (N x N) array with d[n,i] = distance factor from n to i.
    R, D, vbar : arrays of shape (N,)
        R[n], D[n], vbar[n] for region n.
    sigma : float
        Elasticity (> 1).
    
    Returns
    -------
    labIncomeHat : array of shape (N,)
        Model-implied labor income for each i, i.e. the sum over n
        of fraction_{n,i} * (vbar[n]*R[n] + D[n]).
    """
    N = len(A)
    # M[n,i] = L[i]*(d[n,i]*w[i]/A[i])^(1-sigma)
    M = L * np.power(d * (w / A), 1.0 - sigma)  # shape (N, N), numerator
    
    # denom[n] = sum_k M[n,k]
    denom = np.sum(M, axis=1)                   # shape (N,)
    
    # X[n] = vbar[n]*R[n] + D[n]
    X = vbar * R + D                            # shape (N,)
    
    # fraction_{n,i} = M[n,i] / denom[n].
    # partial_term[i,n] = fraction_{n,i} * X[n] => we use transpose.
    M_T = M.T  # shape (N, N), M_T[i,n] = M[n,i]
    summand = (M_T / denom) * X  # broadcast X[n], denom[n]
    labIncomeHat = np.sum(summand, axis=1) # sum over n => shape (N,)
    
    return labIncomeHat


def productivities(
    workemp,      # L_i
    workwage,     # w_i
    resExp,       # dictionary or struct with R, D, vbar
    distmatrix,   # NxN
    sigma,        # elasticity
    psi,          # distance exponent
    detailsYN=True,
    max_iter=10000,
    tol=1e-6
):
    """
    Python version of the Mathematica-style tâtonnement for solving
      w_i L_i == sum_n fraction_{n,i} * [vbar[n]*R[n] + D[n]],
    where fraction_{n,i} = M[n,i]/sum_k M[n,k], M[n,i]=(L_i*(d_{n,i}*w_i/A_i)^(1-sigma)).
    
    We iterate on A^(sigma-1) until the implied w_i L_i matches the actual w_i L_i.
    
    Returns
    -------
    A_final : array of shape (N,) = A_i^(1/(sigma-1)), i.e. the final productivities.
    """

    ncounties = len(workemp)
    
    #------------------------------------------------
    # 1) Preprocess / scaling as per your code
    #    You can adjust these scalings to match your
    #    original Mathematica approach exactly.
    #------------------------------------------------
    
    
    L = workemp  # employees in hundreds
    w = workwage   #wage in tens of thousands thousands

    # Let d = distmatrix^(psi)
    d = distmatrix**psi

    # Unpack R, D, vbar from resExp dict (adjust as needed)
    R = resExp["R"]       # array of shape (N,)
    D_ = resExp["D"]      # array of shape (N,)
    vbar = resExp["vbar"] # array of shape (N,)

    # We'll compute actual "labIncome" = w_i L_i, but note
    # we used w=workwage/1e4 => actual w_i L_i is
    #    (workwage[i]/1e4)*workemp[i] ...
    # If your code wants it in "millions" => divide by 1e6
    # or do whichever scaling you used in Mathematica:
    labIncome = (L * w)  # in
    
    # Step size (partial adjustment)
    lam = 0.9990
    
    #------------------------------------------------
    # 2) Initialize A0 as A^(sigma-1) = 1
    #------------------------------------------------
    A0 = np.ones(ncounties)
    
    #------------------------------------------------
    # 3) Iteration
    #------------------------------------------------
    error = 1.0  # track max gap
    c = 0
    
    while error > tol and c < max_iter:
        c += 1
        
        # Model-implied income under current guess A0
        # Note we pass D_ (because "D" is a python builtin),
        # but just rename it for clarity.
        labIncomeHat = rhsProductivities(A0, L, w, d, R, D_, sigma, vbar)
        print("lab income")
        print(labIncome)
        print("total labor income true")
        print(np.sum(labIncome))
        print("lab income hat")
        print(labIncomeHat)
        print("total lab income hat")
        print(np.sum(labIncomeHat))
        # Compare to actual labIncome => gap[i] = actual[i]/hat[i]
        gap = labIncome / labIncomeHat
       
        # "dist" is the maximum relative difference from 1
        error = np.max(np.abs(gap - 1.0))
        distmin = np.min(np.abs(gap - 1.0))
        
        # Partial update: A1 = [ lam + (1-lam)*gap ] * A0
        A1 = (lam + (1.0 - lam)*gap) * A0
        
        # Rescale so that mean(A1) = 1 (like your code)
        A1_mean = np.mean(A1) 
        A0 = A1 / A1_mean # normalize to mean 1.
        print("mean A0")
        print(np.mean(A0))
        print("A0")
        print(A0)
        # Optional: print diagnostics
        if detailsYN and c % 50 == 1 and error > 0.01:
            print(f"Iteration {c}, max gap={error:8.5f}, min gap={distmin:8.5f}, "
                  f"A0 in [{A0.min():.3e}, {A0.max():.3e}]")
            

    if detailsYN:
        print(f"End of iteration {c}, max gap={error:8.5f}, "
              f"A0 in [{A0.min():.3e}, {A0.max():.3e}]")
    
    #------------------------------------------------
    # 4) Exponentiate back to get true A
    #    A^(1/(sigma-1)) and rescale so mean(...)=1
    #------------------------------------------------
    exponent = 1.0 / (sigma - 1.0)
    A_linear = A0**exponent
    A_linear /= np.mean(A_linear)
    
    return A_linear


# ------------------------------------------------------------------------
# EXAMPLE USAGE (dummy example)
if __name__ == "__main__":
    np.random.seed(0)
    ncounties = 3000
    
    # Fake data
    workemp   = np.random.rand(ncounties)*10000
    workwage  = np.random.rand(ncounties)*100000
    distm     = np.random.rand(ncounties, ncounties)
    
    # Suppose R, D, vbar are also known (random example):
    R     = workemp # same for now
    D     = np.zeros(ncounties)       # no deficits, for instance
    vbar  = workwage  # same for now
     
    # Put them in resExp
    resExp = {"R": R, "D": D, "vbar": vbar}
    
    sigma = 4.0
    psi   = 0.43
    
    A_solution = productivities(
        workemp=workemp,
        workwage=workwage,
        resExp=resExp,
        distmatrix=distm,
        sigma=sigma,
        psi=psi,
        detailsYN=True,
        max_iter=20000,
        tol=1e-8
    )
    print("\nFinal productivities A:", A_solution)


lab income
[2.27188208e+08 4.50296315e+08 4.69302080e+08 ... 2.25404351e+08
 3.93508971e+08 2.36802808e+08]
total labor income true
734281281374.0294
lab income hat
[284.08561847  42.41169458  91.05546707 ...  14.29385875 756.91387833
 228.00937784]
total lab income hat
734281281374.0298
mean A0
1.0000000000000002
A0
[0.10709362 1.42016402 0.6894702  ... 2.1092361  0.06966695 0.13903898]
Iteration 1, max gap=95905331.97540, min gap= 0.01119, A0 in [1.336e-04, 1.283e+01]
lab income
[2.27188208e+08 4.50296315e+08 4.69302080e+08 ... 2.25404351e+08
 3.93508971e+08 2.36802808e+08]
total labor income true
734281281374.0294
lab income hat
[3.58137530e+05 2.60795134e+08 5.02870980e+07 ... 1.72232239e+08
 4.45477303e+05 8.41496323e+05]
total lab income hat
734281281374.0295
mean A0
1.0
A0
[0.08662393 0.7037952  0.34427983 ... 1.04484433 0.06494081 0.08816108]
lab income
[2.27188208e+08 4.50296315e+08 4.69302080e+08 ... 2.25404351e+08
 3.93508971e+08 2.36802808e+08]
total labor income true
73428