<a href="https://colab.research.google.com/github/ketanp23/sit-neuralnetworks-class/blob/main/Hessian_Matrix_Calculation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch.autograd.functional import hessian

# --- 1. Define the Scalar Function ---
# We define a simple function f(x, y) = x^3 + 2y^2 + 5xy
# The input to this function will be a 1D tensor [x, y].
def simple_scalar_function(input_tensor):
    """
    Computes f(x, y) where input_tensor = [x, y]
    f(x, y) = x^3 + 2y^2 + 5xy
    """
    x = input_tensor[0]
    y = input_tensor[1]

    # Calculation: x^3 + 2*y^2 + 5*x*y
    return x.pow(3) + 2 * y.pow(2) + 5 * x * y

# --- 2. Define the Point for Calculation ---
# We want to calculate the Hessian at the point (x=1.0, y=2.0)
point_of_interest = torch.tensor([1.0, 2.0], dtype=torch.float32, requires_grad=True)

# --- 3. Compute the Hessian Matrix ---
# The hessian function computes the square matrix of second-order partial derivatives.
# The matrix H will be:
# [[ d^2f/dx^2, d^2f/dxdy ],
#  [ d^2f/dydx, d^2f/dy^2 ]]
try:
    # Compute the Hessian using PyTorch's automatic differentiation
    hessian_matrix = hessian(simple_scalar_function, point_of_interest)

    # --- 4. Print Output and Explanation ---
    print(f"--- Hessian Matrix Calculation Example ---")
    print(f"Function f(x, y) = x^3 + 2y^2 + 5xy")
    print(f"Point of Interest (x, y): {point_of_interest.tolist()}")
    print("-" * 40)

    # Analytical Derivatives for Verification:
    # First derivatives:
    # df/dx = 3x^2 + 5y
    # df/dy = 4y + 5x

    # Second derivatives (The Hessian components):
    # d^2f/dx^2 = 6x
    # d^2f/dy^2 = 4
    # d^2f/dxdy = 5
    # d^2f/dydx = 5 (due to Clairaut's Theorem, this must be 5)

    # At point (1.0, 2.0):
    # d^2f/dx^2 = 6 * 1.0 = 6.0
    # d^2f/dy^2 = 4
    # d^2f/dxdy = 5
    # d^2f/dydx = 5

    print("Expected Hessian Matrix (Analytical):")
    print("[[ 6.0, 5.0 ],")
    print(" [ 5.0, 4.0 ]]")
    print("-" * 40)

    print("Calculated Hessian Matrix (PyTorch):")
    print(hessian_matrix)

except Exception as e:
    print(f"An error occurred during Hessian computation: {e}")

--- Hessian Matrix Calculation Example ---
Function f(x, y) = x^3 + 2y^2 + 5xy
Point of Interest (x, y): [1.0, 2.0]
----------------------------------------
Expected Hessian Matrix (Analytical):
[[ 6.0, 5.0 ],
 [ 5.0, 4.0 ]]
----------------------------------------
Calculated Hessian Matrix (PyTorch):
tensor([[6., 5.],
        [5., 4.]])
