In [None]:
def gwl_Wsto_vectorized_2(z, pF, Ksat=None, root=False):
    r""" Forms interpolated function for soil column ground water dpeth, < 0 [m], as a
    function of water storage [m] and vice versa + others

    Args:
        - pF (np.ndarray):
            - dict
                - 'ThetaS' (np.ndarray): saturated water content [m\ :sup:`3` m\ :sup:`-3`\ ]
                - 'ThetaR' (np.ndarray): residual water content [m\ :sup:`3` m\ :sup:`-3`\ ]
                - 'alpha' (np.ndarray): air entry suction [cm\ :sup:`-1`]
                - 'n' (np.ndarray): pore size distribution [-]
        - z (np.ndarrays): soil compartment thichness, node in center [m]
    Returns:
        - (np.ndarray):
            - dict
                - 'to_gwl' (np.ndarray): interpolated function for gwl(Wsto)
                - 'to_wsto' (np.ndarray): interpolated function for Wsto(gwl)
                - 'to_C' (np.ndarray): interpolated function for C(Wsto)
                - 'to_Tr' (np.ndarray): interpolated function for Tr(gwl)
    """
    dz = abs(z)
    #dz[:,1:] = z[:,:-1] - z[:,1:]
    dz = np.hstack((dz[:, :1], np.diff(dz, axis=1)))
    step = -0.05
    z_min = np.nanmin(z, axis=1)
    max_len = int(abs(np.nanmin(z_min)) / abs(step)) + 1
    z_fine = np.tile(np.arange(0, step * max_len, step), (z.shape[0], 1)) + step
    dz_fine = z_fine*0.0 - step
    z_mid_fine = dz_fine / 2 - np.cumsum(dz_fine, axis=1)
    ix = np.full((z_fine.shape), np.nan)
    # Expand z along the second axis to match z_fine's shape (broadcasting)
    z_expanded = np.expand_dims(z, axis=1)  # Shape: (rows, 1, cols)
    z_fine_expanded = np.expand_dims(z_fine, axis=2)  # Shape: (rows, fine_steps, 1)

    # Compute mask using broadcasting (row-wise comparison)
    mask = (z_fine_expanded < z_expanded) & ~np.isclose(z_fine_expanded, z_expanded, atol=1e-9)

    # Sum along the depth dimension to count how many times z_fine falls below z
    ix = np.sum(mask, axis=2).astype(np.float64)  # Convert to float to retain NaN compatibility

    pF_fine = {}

    for key in pF[0].keys():  # Iterate over each parameter in `pF`
        # Convert pF into an array ensuring consistent shapes
        try:
            pF_array = np.vstack([p[key] for p in pF])  # Ensures (rows, depths) shape
        except ValueError:  # If rows have different lengths, handle it gracefully
            max_depth = max(len(p[key]) for p in pF)  # Find the longest row
            pF_array = np.full((len(pF), max_depth), np.nan)  # Initialize padded array

            # Fill rows with actual values
            for i, p in enumerate(pF):
                pF_array[i, :len(p[key])] = p[key]

        # Ensure `ix` values are within valid range (clip to prevent indexing errors)
        ix_valid = np.clip(ix.astype(int), 0, pF_array.shape[1] - 1)

        # Assign values using vectorized indexing
        pF_fine[key] = np.take_along_axis(pF_array, ix_valid, axis=1)  # Shape: (rows, fine_steps)

    # --------- connection between gwl and Wsto, Tr, C------------
    gwl = np.arange(1.0, min(z_min)-5, step)

    Wsto_deep = np.stack([h_to_cellmoist_vectorized(pF_fine, g - z_mid_fine, dz_fine) + max(0.0, g) for g in gwl]).T

    #if root:
    #    Wsto_deep = [sum(h_to_cellmoist(pF_fine, g - z_mid_fine, dz_fine) * dz_fine) for g in gwl]
    #    Wsto_deep = Wsto_deep/sum(dz)
    #    GwlToWsto = interp1d(np.array(gwl), np.array(Wsto_deep), fill_value='extrapolate')
    #    return {'to_rootmoist': GwlToWsto}

    #step = -0.5
    gwl = np.arange(1.0, min(z_min)-5, step)
    Tr1 = np.stack([transmissivity_vectorized(dz, deep_ksats, g) * 86400. for g in gwl]).T

    # Generate interpolators for each row of Wsto_deep and Tr1 while keeping gwl the same
    WstoToGwl = [interp1d(wsto_row, gwl, kind='linear', fill_value='extrapolate') for wsto_row in Wsto_deep]
    GwlToWsto = [interp1d(gwl, wsto_row, kind='linear', fill_value='extrapolate') for wsto_row in Wsto_deep]
    GwlToC = [interp1d(gwl, np.gradient(wsto_row) / np.gradient(gwl), kind='linear', fill_value='extrapolate') for wsto_row in Wsto_deep]
    GwlToTr = [interp1d(gwl, tr_row, kind='linear', fill_value='extrapolate') for tr_row in Tr1]
    
    #plt.figure(1)
    #plt.plot(np.array(gwl), np.array(np.gradient(Wsto_deep[0]/np.gradient(gwl))), linestyle='--')
    #plt.figure(2)
    #plt.plot(np.array(gwl), np.log10(np.array(Tr1[0])), linestyle='--')
    #plt.plot(np.array(gwl), np.array(Tr1[0]), linestyle='--')
    #plt.figure(3)
    #plt.plot(np.array(gwl), np.array(Wsto_deep[0]), linestyle='--')

    return {'to_gwl': WstoToGwl, 'to_wsto': GwlToWsto, 'to_C': GwlToC, 'to_Tr': GwlToTr}


def h_to_cellmoist_vectorized_2(pF, gwl, z_mid, dz):
    r""" Fully vectorized cell moisture calculation using van Genuchten-Mualem model.

    Args:
        pF (dict): Soil parameters
            'ThetaS' (np.ndarray): saturated water content [m³/m³]
            'ThetaR' (np.ndarray): residual water content [m³/m³]
            'alpha' (np.ndarray): air entry suction [cm⁻¹]
            'n' (np.ndarray): pore size distribution [-]
        gwl (np.ndarray): Groundwater levels below surface, shape (n_gwl,)
        z_mid (np.ndarray): Depth of layer midpoints, shape (n_cells, n_layers)
        dz (np.ndarray): Layer thickness, shape (n_cells, n_layers)

    Returns:
        theta (np.ndarray): Total volumetric water content of cell for each gwl, shape (n_cells, n_gwl)
    """

    # Extract van Genuchten parameters
    Ts = pF['ThetaS']  # Shape: (n_cells, n_layers)
    Tr = pF['ThetaR']
    alpha = pF['alpha']
    n = pF['n']
    m = 1.0 - 1.0 / n

    # Expand gwl for broadcasting
    gwl = gwl[:, None, None]  # Shape: (n_gwl, 1, 1)

    # Compute pressure head at each layer midpoint
    # here h becomes 3D matrix...
    h = gwl - z_mid  # Shape: (n_gwl, n_cells, n_layers)

    # Compute moisture based on pressure head using van Genuchten model
    h_clipped = np.minimum(h, 0)  # Apply only to negative pressure heads
    theta = Tr + (Ts - Tr) / (1 + np.abs(alpha * 100 * h_clipped) ** n) ** m  # Shape: (n_gwl, n_cells, n_layers)

    # Identify partially saturated cells
    half_dz = dz / 2
    partially_saturated = np.abs(h) < half_dz  # Shape: (n_gwl, n_cells, n_layers)

    # Compute weighted moisture only where needed
    if np.any(partially_saturated):  # Check to avoid unnecessary calculations
        x_unsat = -(half_dz - h) / 2  # Unsaturated thickness
        theta_unsat = Tr + (Ts - Tr) / (1 + np.abs(alpha * 100 * x_unsat) ** n) ** m

        weighted_theta = (theta_unsat * (half_dz - h) + Ts * (half_dz + h)) / dz
        theta[partially_saturated] = weighted_theta[partially_saturated]

    # Convert to water storage per layer
    theta *= dz  # Shape: (n_gwl, n_cells, n_layers)

    # Sum across layers to get total water storage
    Wsto = theta.sum(axis=2)  # Shape: (n_gwl, n_cells)

    return Wsto.T  # Shape: (n_cells, n_gwl)


def transmissivity_vectorized_2(dz, Ksat, gwl):
    r""" Fully vectorized transmissivity function for 3D inputs.

    Args:
       dz (np.ndarray):  Soil compartment thickness, node in center [m], shape (n_cells, n_layers)
       Ksat (np.ndarray): Horizontal saturated hydraulic conductivity [m/s], shape (n_cells, n_layers)
       gwl (np.ndarray): Groundwater levels below surface, shape (n_gwl,)

    Returns:
       Tr (np.ndarray): Transmissivity for each cell and groundwater level [m²/s], shape (n_cells, n_gwl)
    """

    # Reshape gwl to (n_gwl, 1, 1) for broadcasting
    gwl = gwl[:, np.newaxis, np.newaxis]  # Shape: (n_gwl, 1, 1)

    # Compute midpoints of layers
    z = dz / 2 - np.cumsum(dz, axis=1)  # Shape: (n_cells, n_layers)

    # Total soil thickness (impermeable boundary depth)
    ib = np.sum(dz, axis=1, keepdims=True)  # Shape: (n_cells, 1)

    # Expand dimensions of ib to match gwl broadcasting
    ib = ib[np.newaxis, :, :]  # Shape: (1, n_cells, 1)

    # Saturated layer thickness for each groundwater level
    Hdr = np.maximum(0, gwl + ib)  # Shape: (n_gwl, n_cells, 1)

    # Mask for contributing layers (True if part of the layer is saturated)
    mask = ((z - dz / 2) - gwl < 0) & ((z + dz / 2) > -ib)  # Shape: (n_gwl, n_cells, n_layers)

    # Compute saturated thickness for each layer
    dz_sat = np.maximum(gwl - (z - dz / 2), 0)  # Shape: (n_gwl, n_cells, n_layers)

    # Compute transmissivity of each layer
    Trans = Ksat * dz_sat  # Shape: (n_gwl, n_cells, n_layers)

    # Find last contributing layer index
    last_layer_ix = np.argmax(mask[:, :, ::-1], axis=2)  # Indices in reversed order
    last_layer_ix = mask.shape[2] - 1 - last_layer_ix  # Convert to correct index

    # Adjust last layer's saturated thickness
    valid_cells = np.any(mask, axis=2)  # Shape: (n_gwl, n_cells) - True where any layer contributes

    row_idx, col_idx = np.where(valid_cells)  # Get valid cell indices
    last_layer_idx = last_layer_ix[row_idx, col_idx]  # Get last contributing layer indices

    dz_sat[row_idx, col_idx, last_layer_idx] += z[col_idx, last_layer_idx] - dz[col_idx, last_layer_idx] / 2 + ib[0, col_idx, 0]

    # Recalculate transmissivity for last layer
    Trans[row_idx, col_idx, last_layer_idx] = Ksat[col_idx, last_layer_idx] * dz_sat[row_idx, col_idx, last_layer_idx]

    # Sum transmissivity across layers
    Tr = np.where(valid_cells, np.sum(Trans * mask, axis=2), 1e-4 / 86400)  # Shape: (n_gwl, n_cells)

    return Tr.T  # Shape: (n_cells, n_gwl) for easier plotting