In [None]:
import torch
import matplotlib.pyplot as plt

torch.manual_seed(24)

pos = torch.rand(4, 2)
rot = torch.rand(4, 1) * torch.pi * 2
d = torch.cat([rot.cos(), rot.sin()], dim=-1)
x, y = pos.T
u, v = d.T

plt.quiver(x, y, u, v)
plt.xlim(-0.2, 1.2)
plt.ylim(-0.2, 1.2)

pos

In [None]:
def separation(p0, p1, p1_d):
    rel_pos = rel_pos =  p1.unsqueeze(0) - p0.unsqueeze(1)
    z_distance = (rel_pos * p1_d).sum(-1, keepdim=True)
    z_displacement = z_distance * p1_d

    r_displacement = rel_pos - z_displacement
    r_distance = torch.norm(r_displacement, dim=-1, keepdim=True)
    return z_distance, r_distance
    
def downwash(p0, p1, p1_d, kr=2, kz=1):
    """
    p0: [n, d]
    p1: [m, d]
    p1: [m, d]
    """
    z, r = separation(p0, p1, p1_d)
    z = torch.clip(z, 0)
    v = torch.exp(-0.5 * torch.square(kr * r / z)) / (1 + kz * z)**2
    f = v * - p1_d
    return f

def off_diag(a: torch.Tensor) -> torch.Tensor:
    assert a.shape[0] == a.shape[1]
    n = a.shape[0]
    return a.flatten(0, 1)[1:].unflatten(0, (n-1, n+1))[:, :-1].reshape(n, n-1, *a.shape[2:])

f = downwash(pos, pos, d)
f.shape

In [None]:
xx = torch.linspace(0, 1, 20)
to = torch.stack(torch.meshgrid(xx, xx), dim=-1).flatten(0, 1)
f = downwash(to, pos[[1]], d[[1]], kr=2).squeeze()

plt.quiver(*to.T, *f.T)
plt.xlim(-0.2, 1.2)
plt.ylim(-0.2, 1.2)
f.norm(dim=-1).max()

In [None]:
x = torch.linspace(0, 3, 30)
plt.plot(x, 1/(x+1))
plt.plot(x, 1/(x*x+1))
plt.plot(x, 1/(x+1)**2)