In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import numpy as np
from agents import Agent
import pandas as pd

def gradient_check(agent, epsilon=1e-5):
    """
    Checks the correctness of analytical gradients using numerical gradient approximation.

    Args:
    - agent (Agent): An instance of the Agent class with initialized parameters.
    - epsilon (float): Small perturbation value for numerical gradients.

    Returns:
    - discrepancies (dict): Difference between analytical and numerical gradients for each parameter.
    """
    # Get initial loss and analytical gradients
    original_loss, analytical_grad, *_ = agent.compute_loss()
    discrepancies = {}

    # Check each parameter's gradient
    for param_name, param_value in agent.params.items():
        if isinstance(param_value, (float, int)):  # Single scalar parameter
            # Perturb the parameter
            agent.params[param_name] += epsilon
            loss_plus = agent.compute_loss()[0]
            agent.params[param_name] -= 2 * epsilon
            loss_minus = agent.compute_loss()[0]
            agent.params[param_name] += epsilon  # Restore original value

            # Compute numerical gradient
            numerical_grad = (loss_plus - loss_minus) / (2 * epsilon)
            analytical_value = analytical_grad[f"d{param_name}"]

            # Store the discrepancy
            discrepancies[param_name] = {
                "numerical": numerical_grad,
                "analytical": analytical_value,
                "difference": np.abs(numerical_grad - analytical_value),
            }

        elif isinstance(param_value, np.ndarray):  # Parameter is an array
            discrepancies[param_name] = []
            for i in range(len(param_value)):
                # Perturb each element of the parameter
                param_value[i] += epsilon
                loss_plus = agent.compute_loss()[0]
                param_value[i] -= 2 * epsilon
                loss_minus = agent.compute_loss()[0]
                param_value[i] += epsilon  # Restore original value

                # Compute numerical gradient
                numerical_grad = (loss_plus - loss_minus) / (2 * epsilon)
                analytical_value = analytical_grad[f"d{param_name}"][i]

                # Store the discrepancy
                discrepancies[param_name].append({
                    "index": i,
                    "numerical": numerical_grad,
                    "analytical": analytical_value,
                    "difference": np.abs(numerical_grad - analytical_value),
                })

    return discrepancies


# Example usage
# Assuming `Agent` is defined and instantiated:
# Initialize agent parameters and input data
params = {
    'HDP': 0.5,
    'HGP': 0.2,
    'wh': 0.1,
    'w0': 1.0,
    'hs': 0.5,
    'B': np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]),
}

N = 10000
info = {
    'weekday': np.random.randint(0, 7, size=N),
    'dfrw': np.random.uniform(0, 100, size=N),
    'income': np.random.uniform(30, 100, size=N),
    'shelter-in-place': np.random.randint(0, 2, size=N),
    'pre-vax': np.random.randint(0, 2, size=N),
    'during': np.random.randint(0, 2, size=N),
    'post-vax': np.random.randint(0, 2, size=N),
    'is_weekend': np.random.randint(0, 2, size=N),
    'is_holiday_extra': np.random.randint(0, 2, size=N),
    'commute_filled_diff_dist': np.random.randint(0, 2, size=N),
}

init_habit_strength = np.random.uniform(0, 1, size=(7, 2))
param_constraint = []

# Create an agent instance
agent = Agent(init_habit_strength, pd.DataFrame(info), param_constraint)
agent.params = params


# Perform gradient check
discrepancies = gradient_check(agent)

# Print results
for param, discrepancy in discrepancies.items():
    print(f"Parameter: {param}")
    if isinstance(discrepancy, list):  # For array parameters
        for entry in discrepancy:
            print(f"  Index {entry['index']}: Numerical={entry['numerical']:.6f}, "
                  f"Analytical={entry['analytical']:.6f}, "
                  f"Difference={entry['difference']:.6e}")
    else:  # For scalar parameters
        print(f"  Numerical={discrepancy['numerical']:.6f}, "
              f"Analytical={discrepancy['analytical']:.6f}, "
              f"Difference={discrepancy['difference']:.6e}")


Parameter: HDP
  Numerical=-11.399054, Analytical=-11.442229, Difference=4.317514e-02
Parameter: HGP
  Numerical=-109.499240, Analytical=-109.492785, Difference=6.454804e-03
Parameter: wh
  Numerical=-300.611683, Analytical=-300.593280, Difference=1.840290e-02
Parameter: w0
  Numerical=2551.037529, Analytical=2551.080338, Difference=4.280961e-02
Parameter: hs
  Numerical=8.432300, Analytical=8.447128, Difference=1.482798e-02
Parameter: B
  Index 0: Numerical=2873.525056, Analytical=2873.571516, Difference=4.645949e-02
  Index 1: Numerical=14.294331, Analytical=14.294583, Difference=2.527343e-04
  Index 2: Numerical=11858.880966, Analytical=11859.075634, Difference=1.946682e-01
  Index 3: Numerical=1498.236779, Analytical=1498.261248, Difference=2.446896e-02
  Index 4: Numerical=1533.222328, Analytical=1533.244714, Difference=2.238609e-02
  Index 5: Numerical=3048.328242, Analytical=3048.384500, Difference=5.625812e-02
  Index 6: Numerical=1568.639891, Analytical=1568.669646, Difference