In [None]:
from jaxtyping import Float
import numpy as np

from scipy.stats import spearmanr

def steerability_slope(
    multipliers: Float[np.ndarray, "n_multipliers"],
    propensities: Float[np.ndarray, "batch n_multipliers"],
) -> Float[np.ndarray, "batch"]:
    
    # NOTE: Assumes multipliers are the same for all datasets
    slope, _ = np.polyfit(multipliers, propensities.T, 1)
    return slope

def steerabilty_spearman(
    multipliers: Float[np.ndarray, "n_multipliers"],
    propensities: Float[np.ndarray, "batch n_multipliers"],
) -> Float[np.ndarray, "batch"]:
    """ Compute the Spearman correlation between multipliers and propensities """
    # batch_size = propensities.shape[0]
    # multipliers = multipliers[np.newaxis, :].repeat(batch_size, axis=0)
    result = spearmanr(multipliers, propensities, axis=1)
    return result.statistic[0, 1:]

In [None]:
multipliers = np.arange(100)
propensities = np.random.rand(10, 100)

print(steerability_slope(multipliers, propensities))
print(steerabilty_spearman(multipliers, propensities))

In [None]:
multipliers = np.arange(100)
propensities = np.arange(100)[np.newaxis, :].repeat(10, axis=0) / 100

print(steerability_slope(multipliers, propensities))
print(steerabilty_spearman(multipliers, propensities))