Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unifying shapes of tensors throughout the library #63

Closed
shuheng-liu opened this issue Oct 12, 2020 · 1 comment
Closed

Unifying shapes of tensors throughout the library #63

shuheng-liu opened this issue Oct 12, 2020 · 1 comment

Comments

@shuheng-liu
Copy link
Member

Rethink the diff function

At the core of the neurodiffeq library is the diff(x, t) function, which computes the partial derivative ∂x/∂t evaluated at t. Usually, both tensor t and x have shapes of (n_samples, 1). When either x.shape or t.shape is malformed, however, there are cases where things could go wrong due to broadcasting. Such cases are so subtle that they have gone unnoticed for a long time.

All our generators (as defined in neurodiffeq.generator) currently return tensors with shapes (n_samples,) instead of (n_samples, 1). Efforts should be put into unifying the tensor shapes everywhere.

Here are two simple cases for review.

Case 1: Shapes don't matter

In this case, we try different combinations of x.shape and t.shape and check the shape of the output ∂x/∂t, namely:

  • [n, 1] and [n] --> [n]
  • [n] and [n]--> [n]
  • [n, 1] and [n, 1]--> [n,1]
  • [n] and [n, 1]--> [n,1]

To see this, run the following code. Note that d1, d2, d3, and d4, while having different shapes, hold the same values. This is the reason why we incorrectly believed in the soundness of the diff() function.

n = 10

t = torch.rand(n, requires_grad=True)
x = torch.sin(t)
d1 = diff(x.reshape(-1, 1), t)
d2 = diff(x.reshape(-1), t)

t = t.reshape(-1, 1)
x = torch.sin(t)
d3 = diff(x.reshape(-1, 1), t)
d4 = diff(x.reshape(-1), t)

Case 2: Shapes matter

In this second case, we examine two new operators – div and curl in spherical coordinates – and show that only when x.shape and t.shape are both (n, 1) will the vector identity div(curl(...)) == 0 hold.

Here is the definition of curl and divergence in spherical coordinates

# these two operators have been recently implemented in neurodiffeq.operators
def spherical_curl(u_r, u_theta, u_phi, r, theta, phi):
    d_r = lambda u: diff(u, r)
    d_theta = lambda u: diff(u, theta)
    d_phi = lambda u: diff(u, phi)

    curl_r = (d_theta(u_phi * sin(theta)) - d_phi(u_theta)) / (r * sin(theta))
    curl_theta = (d_phi(u_r) / sin(theta) - d_r(u_phi * r)) / r
    curl_phi = (d_r(u_theta * r) - d_theta(u_r)) / r

    return curl_r, curl_theta, curl_phi


def spherical_div(u_r, u_theta, u_phi, r, theta, phi):
    div_r = diff(u_r * r ** 2, r) / r ** 2
    div_theta = diff(u_theta * sin(theta), theta) / (r * sin(theta))
    div_phi = diff(u_phi, phi) / (r * sin(theta))
    return div_r + div_theta + div_phi

Here we define a vector field q by specifying the rule to compute q given coordinates (r, theta, phi)

def compute_q(r, theta, phi):
    r_theta_phi = torch.stack([r.flatten(), theta.flatten(), phi.flatten()], dim=1)
    W = torch.tensor([
        [.01, .04, .07],
        [.02, .05, .08],
        [.03, .06, .09],
    ])
    q = torch.matmul(r_theta_phi, W)
    q = torch.tanh(q)
    return q[:, 0], q[:, 1], q[:, 2]

We then test the vector identity div(curl(q)) == 0 for q

n = 10

# create r, theta, and phi with shape (n, 1)
r = torch.rand(n, 1, requires_grad=True) + 0.1
theta = torch.rand(n, 1, requires_grad=True) * np.pi
phi = torch.rand(n, 1, requires_grad=True)  * np.pi * 2
q_r, q_theta, q_phi = compute_q(r, theta, phi)

# bind the operators to the r, theta, phi created above
div = lambda u_r, u_theta, u_phi: spherical_div(u_r, u_theta, u_phi, r, theta, phi)
curl = lambda u_r, u_theta, u_phi: spherical_curl(u_r, u_theta, u_phi, r, theta, phi)

div_curl_q1 = div(*curl(q_r.reshape(-1, 1), q_theta.reshape(-1, 1), q_phi.reshape(-1, 1)))
div_curl_q2 = div(*curl(q_r.reshape(-1), q_theta.reshape(-1), q_phi.reshape(-1)))

# create r, theta, and phi with shape (n,)
r = r.reshape(-1)
theta = r.reshape(-1)
phi = r.reshape(-1)
q_r, q_theta, q_phi = compute_q(r, theta, phi)

# bind the operators to the r, theta, phi created above
div = lambda u_r, u_theta, u_phi: spherical_div(u_r, u_theta, u_phi, r, theta, phi)
curl = lambda u_r, u_theta, u_phi: spherical_curl(u_r, u_theta, u_phi, r, theta, phi)

div_curl_q3 = div(*curl(q_r.reshape(-1, 1), q_theta.reshape(-1, 1), q_phi.reshape(-1, 1)))
div_curl_q4 = div(*curl(q_r.reshape(-1), q_theta.reshape(-1), q_phi.reshape(-1)))

print(div_curl_q1, div_curl_q2, div_curl_q3, div_curl_q4, sep="\n")

Printing all four div_curl_qs will show that, only div_curl_q1 is (approximately) equal to 0, which means both the dependent and independent variables must have shape (n, 1) for the differentiation to go correctly.

@shuheng-liu
Copy link
Member Author

After discussing with David, we decide to change the default behavior of the diff function.

  • Before v0.2.0, we'll be using these (python) functions for differentiation:
    • safe_diff: performs shape checking
    • unsafe_diff: doesn't perform shape checking; behaves exactly like the original diff
    • diff: defaults to unsafe_diff; but issues a one-time warning that, in the future, it will default to safe_diff
  • Starting from v0.2.0
    • safe_diff: performs shape checking
    • unsafe_diff: doesn't perform shape checking; behaves exactly like the original diff
    • diff: defaults to safe_diff, no warning will be issued

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant