In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt

In [2]:
potential = lambda x: torch.linalg.norm(x, axis=-1)
is_valid = lambda x: torch.linalg.norm(x, axis=-1) < 10

In [3]:
num_particles = 10
dim = 4

particles = np.round(np.random.normal(0, 5, (num_particles, dim)))
particles = torch.tensor(particles, dtype=torch.float64, requires_grad=True)
v = torch.ones(num_particles, 2 * dim, dtype=torch.float64) * 0.1

In [4]:
vals = potential(particles)
vals.sum().backward()

In [5]:
v.shape

torch.Size([10, 8])

In [18]:
def step(particles, v, potential):
    vals = potential(particles)
    vals.sum().backward()
    
    mask = (-particles.grad) > 0
    
    # Get probabilities to flip an index up or down 1
    w = v.clone() + abs(torch.concatenate([mask * particles.grad, (~mask) * particles.grad], axis=1))
    w = w.cumsum(axis=1)
    w /= w[:, -1][:, None]
    
    # Draw rand float and get index to flip
    xi = torch.rand(num_particles)
    mask = xi[:, None] < w

    temp = torch.arange(2 * dim) * mask
    temp[~mask] = 3 * dim
    
    # The index to be flipped
    sampled_index = torch.argmin(temp, axis=1)

    mask = sampled_index < dim

    sign = 1 * mask - 1 * (~mask)
    index = sampled_index % dim

    with torch.no_grad():
        particles[torch.arange(num_particles), index] += sign
        
    particles.grad.zero_();

In [7]:
for i in range(10):
    step(particles, v, potential)
    
    is_particle_valid = is_valid(particles)
    num_valid = is_particle_valid.sum().item()
    if num_valid == 0:
        break
    elif num_valid == num_particles:
        continue
    else:
        support = is_particle_valid.to(torch.float64).clone().detach().numpy()
        p = support / support.sum()
        
        resampled_particles = np.random.choice(np.arange(num_particles), num_particles - num_valid, 
                                               replace=True, p=p)
        with torch.no_grad():
            particles[~is_particle_valid] = particles[resampled_particles]

In [8]:
torch.linalg.norm(particles, axis=1)

tensor([2.6458, 1.0000, 1.0000, 4.2426, 1.7321, 1.0000, 3.8730, 3.1623, 2.0000,
        6.4031], dtype=torch.float64, grad_fn=<LinalgVectorNormBackward0>)

In [152]:
torch.linalg.norm(particles, axis=1)

tensor([14.0357,  3.1623, 11.7473,  0.0000,  9.2736,  0.0000,  0.0000,  1.0000,
        12.7279,  2.4495], dtype=torch.float64,
       grad_fn=<LinalgVectorNormBackward0>)

In [147]:
particles

tensor([[  6.,   3.,  -7.,  13.],
        [  3.,  13.,  -2.,  -6.],
        [ 13.,   4.,  11.,  -2.],
        [ -1.,   0.,  -3.,   2.],
        [  5.,   1., -10.,  -6.],
        [  0.,   0.,   0.,   0.],
        [  1.,  -2.,  -1.,   2.],
        [  4.,   0.,   0.,   1.],
        [ -7.,   1.,   2.,  18.],
        [  2.,   1.,   0.,   1.]], dtype=torch.float64, requires_grad=True)

In [28]:
n1 = 5
Q1 = np.array([
    [8, 2, 3, 4, 5],
    [2, 7, 2, 3, 4],
    [3, 2, 6, 2, 3],
    [4, 3, 2, 5, 2],
    [5, 4, 3, 2, 9]
])
weights1 = np.array([2, 3, 4, 5, 9])
capacity1 = 10

Q = torch.tensor(Q1, dtype=torch.float64)
weights = torch.tensor(weights1, dtype=torch.float64)
capacity = capacity1
dim = n1

weights_torch = weights / capacity

In [29]:
potential = lambda x: -torch.einsum('ni,ij,nj->n', x, Q, x)
isvalid = lambda x: (x @ weights_torch <= 1) & (x >= 0).all(axis=1) & (x <= 1).all(axis=1)

In [50]:
num_particles = 100
particles = np.round(np.random.rand(num_particles, dim))
particles = torch.tensor(particles, dtype=torch.float64, requires_grad=True)
v = torch.ones(num_particles, 2 * dim, dtype=torch.float64) * 0.1

is_particle_valid = isvalid(particles)
support = is_particle_valid.to(torch.float64).clone().detach().numpy()
p = support / support.sum()

num_valid = is_particle_valid.sum().item()
resampled_particles = np.random.choice(np.arange(num_particles), num_particles - num_valid, 
                                       replace=True, p=p)
with torch.no_grad():
    particles[~is_particle_valid] = particles[resampled_particles]

In [57]:
for i in range(1):
    step(particles, v, potential)
    
    is_particle_valid = isvalid(particles)
    num_valid = is_particle_valid.sum().item()
    if num_valid == 0:
        break
    elif num_valid == num_particles:
        continue
    else:
        support = is_particle_valid.to(torch.float64).clone().detach().numpy()
        p = support / support.sum()
        
        resampled_particles = np.random.choice(np.arange(num_particles), num_particles - num_valid, 
                                               replace=True, p=p)
        with torch.no_grad():
            particles[~is_particle_valid] = particles[resampled_particles]

In [58]:
potential(particles)

tensor([-69., -63., -67., -63., -58., -64., -58., -69., -69., -69., -63., -64.,
        -68., -69., -64., -74., -68., -64., -69., -74., -69., -69., -67., -58.,
        -67., -63., -58., -58., -58., -63., -64., -69., -58., -67., -69., -68.,
        -58., -67., -58., -69., -58., -69., -63., -64., -74., -69., -58., -74.,
        -68., -64., -58., -64., -69., -68., -68., -64., -69., -74., -69., -58.,
        -74., -69., -64., -69., -69., -69., -67., -67., -69., -69., -69., -69.,
        -64., -69., -74., -69., -69., -58., -69., -69., -64., -74., -67., -69.,
        -69., -68., -67., -68., -69., -63., -68., -69., -67., -69., -63., -68.,
        -63., -69., -64., -69.], dtype=torch.float64, grad_fn=<NegBackward0>)

In [59]:
isvalid(particles)

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False])