In [485]:
import torch
import numpy as np
import matplotlib.pyplot as plt

In [486]:
torch.manual_seed(8)
# model setting
weight = torch.tensor([2/3, 1/3])
mu = torch.tensor([-1.7, 1.7])
std = 0.6
# init setting
init = torch.tensor([-1.5, 1.3, 1.4, 1.5, 1.6, 1.7])
iteration = 200
x_axis = torch.linspace(-5, 5, 201)
interval = 0.1
lr = 0.03
alpha = 1.0
alpha_dk = 0.03
size = 20
size_ = 100
bw = 0.6
linewidth = 2

In [487]:
def pdf_kernel(x, particles, mass, bandwidth):
    results = torch.exp( - (x - particles).pow(2) / bandwidth) / np.sqrt(np.pi * bandwidth)
    results = torch.sum(results * mass, dim = 1)
    return results

def pdf_calc(particles): # N * 1 array
    results = torch.exp(- (particles - mu).pow(2) / 2 / std**2) / np.sqrt(2 * np.pi * std**2)
    results = torch.sum(results * weight, dim = 1)
    return results
    

def potential_calc(particles):
    return -torch.log(pdf_calc(particles))
    

def grad_calc(particles):
    p1 =  - (particles - mu) / std**2
    p2 = torch.exp( - (particles - mu).pow(2) / 2 / std**2) * weight
    results = torch.sum(p1 * p2, dim = 1, keepdim = True) / pdf_calc(particles).view(-1, 1)
    return  - results

def duplicate_kill_particles(prob_list, kill_list: torch.Tensor, particles: torch.Tensor, noise_amp, mode = 'parallel'):
    # will modify the input particles
    rand_number = torch.rand(particles.shape[0], device = particles.device)
    index_list = torch.linspace(0, particles.shape[0] - 1, particles.shape[0], dtype = torch.int, device = particles.device)
    if mode == 'sequential':
        rand_index = torch.randint(0, particles.shape[0] - 1, (particles.shape[0],), device = particles.device)
        for k in range(particles.shape[0]):
            if kill_list[k]: # kill particle k, duplicate with random noise 
                if rand_number[k] < prob_list[k]:
                    particles[k] = particles[index_list != k][rand_index[k]].clone() + torch.randn(particles.shape[1], device = particles.device) * noise_amp
            else: # duplicate particle k, duplicate with random noise
                if rand_number[k] < prob_list[k]:
                    particles[index_list != k][rand_index[k]] = particles[k].clone() + torch.randn(particles.shape[1], device = particles.device) * noise_amp
        return particles
    elif mode == 'parallel':
        unchange_particles = particles[(rand_number >= prob_list)]
        duplicate_particles = particles[torch.bitwise_and(rand_number < prob_list, torch.logical_not(kill_list))]
        new_particles = torch.cat([unchange_particles, duplicate_particles, duplicate_particles + torch.randn_like(duplicate_particles) * noise_amp], dim = 0)
        if new_particles.shape[0] == particles.shape[0]:
            pass
        elif new_particles.shape[0] < particles.shape[0]: # duplicate randomly
            rand_index = torch.randint(0, new_particles.shape[0], (particles.shape[0] - new_particles.shape[0], ), device = new_particles.device)
            new_particles = torch.cat([new_particles, new_particles[rand_index] + torch.randn_like(new_particles[rand_index]) * noise_amp], dim = 0)
        else: # kill randomly
            rand_index = torch.randperm(new_particles.shape[0], device = new_particles.device)
            new_particles = new_particles[rand_index][:particles.shape[0]].clone()
        assert new_particles.shape[0] == particles.shape[0], 'change the particle number!'
        return new_particles
    else:
        raise NotImplementedError

In [488]:
# ----------------------------------------------- BLOB -----------------------------------------------------------------------------
blob = init.clone().view(-1, 1)
mass_blob = torch.ones(len(blob)) / len(blob)
for i in range(iteration):
    grads = grad_calc(blob)
    assert grads.max() == grads.max(), 'nan'
    sq_dist = torch.cdist(blob, blob, p = 2)**2
    bandwidth_h = sq_dist + torch.diag(torch.diag(sq_dist) + sq_dist.max())
    bandwidth_h = bandwidth_h.min(dim = 1)[0].median()
    kernel = torch.exp( - sq_dist / bandwidth_h)
    blob = blob - lr * grads + lr * 2 * (
        (blob * kernel.sum(1, keepdim = True) - torch.matmul(kernel, blob)) / kernel.sum(1, keepdim = True) + \
        blob * (kernel / kernel.sum(1)).sum(1, keepdim = True) - torch.matmul(kernel / kernel.sum(1), blob)
    ) / bandwidth_h
# ---------------------------------------------- BLOBBD --------------------------------------------------------------------------
blobbd = init.clone().view(-1, 1)
mass_blobbd = torch.ones(len(blobbd)) / len(blobbd)
for i in range(iteration):
    grads = grad_calc(blobbd)
    sq_dist = torch.cdist(blobbd, blobbd, p = 2)**2
    bandwidth_h = sq_dist + torch.diag(torch.diag(sq_dist) + sq_dist.max())
    bandwidth_h = bandwidth_h.min(dim = 1)[0].median()
    kernel = torch.exp( - sq_dist / bandwidth_h)
    blobbd = blobbd - lr * grads + lr * 2 * (
        (blobbd * torch.matmul(kernel,mass_blobbd.view(-1,1)) - torch.matmul(kernel * mass_blobbd, blobbd)) / torch.matmul(kernel, mass_blobbd.view(-1,1)) + \
        blobbd * (kernel * mass_blobbd / torch.matmul(kernel, mass_blobbd)).sum(1, keepdim = True) - torch.matmul(kernel * mass_blobbd / torch.matmul(kernel, mass_blobbd), blobbd)
    ) / bandwidth_h
    potential = potential_calc(blobbd) 
    beta = torch.log((mass_blobbd * kernel).sum(1) + 1e-6) + ((kernel * mass_blobbd) / torch.matmul(mass_blobbd,kernel)).sum(1) + potential
    beta_bar = beta - (beta * mass_blobbd).sum()
    mass_blobbd = mass_blobbd * (1 - beta_bar * lr * alpha)
    mass_blobbd = mass_blobbd / mass_blobbd.sum()
# ---------------------------------------------- BLOBDK --------------------------------------------------------------------------
blobdk = init.clone().view(-1, 1)
mass_blobdk = torch.ones(len(blobdk)) / len(blobdk)
for i in range(iteration):
    grads = grad_calc(blobdk)
    sq_dist = torch.cdist(blobdk, blobdk, p = 2)**2
    bandwidth_h = sq_dist + torch.diag(torch.diag(sq_dist) + sq_dist.max())
    bandwidth_h = bandwidth_h.min(dim = 1)[0].median()
    kernel = torch.exp( - sq_dist / bandwidth_h)
    blobdk = blobdk - lr * grads + lr * 2 * (
        (blobdk * torch.matmul(kernel,mass_blobdk.view(-1,1)) - torch.matmul(kernel * mass_blobdk, blobdk)) / torch.matmul(kernel, mass_blobdk.view(-1,1)) + \
        blobdk * (kernel * mass_blobdk / torch.matmul(kernel, mass_blobdk)).sum(1, keepdim = True) - torch.matmul(kernel * mass_blobdk / torch.matmul(kernel, mass_blobdk), blobdk)
    ) / bandwidth_h
    potential = potential_calc(blobdk) 
    beta = torch.log((mass_blobdk * kernel).sum(1) + 1e-6) + ((kernel * mass_blobdk) / torch.matmul(mass_blobdk,kernel)).sum(1) + potential
    beta_bar = beta - (beta * blobdk).sum()
    prob_list = 1 - torch.exp( - beta_bar.abs() * alpha_dk * lr)
    blobdk = duplicate_kill_particles(prob_list, prob_list > 0, blobdk, noise_amp = np.sqrt(2 * lr), mode = 'parallel')
# ------------------------------------------------------- plot_figure --------------------------------------------------------------------
ax = plt.figure(figsize=(13.8, 2.8))
plt.subplot(1,3,1)
plt.plot(x_axis.numpy(), pdf_calc(x_axis.view(-1,1)).numpy(), c = 'black', linestyle = '--', label = 'target', zorder = 2)
plt.plot(x_axis.numpy(), pdf_kernel(x_axis.view(-1,1), blob.view(-1), mass_blob, bw).numpy(), c = 'red', linestyle = '-', label = 'kernel density estimation', zorder = 3)
plt.scatter(blob.view(-1).numpy(), np.ones(len(init)) * 0, s = size, c = 'purple', marker = 's', edgecolor = 'purple', label = 'particle', zorder = 4)
for i in range(len(init)):
    if i == 0:
        plt.vlines(blob.view(-1).numpy()[i], 0, mass_blob.numpy()[i], linewidth = linewidth, label = 'weight of particle', linestyle = '-', zorder = 1)
        plt.scatter(blob.view(-1).numpy()[i], mass_blob.numpy()[i], c = '#1f77b4', s = size_, marker = '_', zorder = 4)
    else:
        plt.vlines(blob.view(-1).numpy()[i], 0, mass_blob.numpy()[i], linewidth = linewidth, linestyle = '-', zorder = 1)
        plt.scatter(blob.view(-1).numpy()[i], mass_blob.numpy()[i], c = '#1f77b4', s = size_, marker = '_', zorder = 4)
plt.xticks([])
plt.tight_layout()
plt.tick_params(labelsize = 16)
plt.grid()
plt.title('Blob', fontdict = {'size': 16})
ax.legend(fontsize = 18, ncol = 4, bbox_to_anchor=(0.04, 1.0), loc=3, borderaxespad = 0)


plt.subplot(1,3,2)
plt.plot(x_axis.numpy(), pdf_calc(x_axis.view(-1,1)).numpy(), c = 'black', linestyle = '--', label = 'target', zorder = 2)
plt.plot(x_axis.numpy(), pdf_kernel(x_axis.view(-1,1), blobbd.view(-1), mass_blobbd, bw).numpy(), c = 'red', linestyle = '-', zorder = 3)
size_list = mass_blobbd.numpy() * len(init) * size
plt.scatter(blobbd.view(-1).numpy(), np.ones(len(init)) * 0, s = size, c = 'purple', marker = 's', edgecolor = 'purple', label = 'particle', zorder = 4)
for i in range(len(init)):
    plt.vlines(blobbd.view(-1).numpy()[i], 0, mass_blobbd.numpy()[i], linewidth = linewidth, linestyle = '-', zorder = 1)
    plt.scatter(blobbd.view(-1).numpy()[i], mass_blobbd.numpy()[i], c = '#1f77b4', s = size_, marker = '_', zorder = 4)
plt.xticks([])
plt.tight_layout()
plt.tick_params(labelsize = 16)
plt.grid()
plt.title('D-Blob-CA', fontdict = {'size': 16})


plt.subplot(1,3,3)
plt.plot(x_axis.numpy(), pdf_calc(x_axis.view(-1,1)).numpy(), c = 'black', linestyle = '--', label = 'target', zorder = 2)
plt.plot(x_axis.numpy(), pdf_kernel(x_axis.view(-1,1), blobdk.view(-1), mass_blobdk, bw).numpy(), c = 'red', linestyle = '-', zorder = 3)
size_list = mass_blobdk.numpy() * len(init) * size
plt.scatter(blobdk.view(-1).numpy(), np.ones(len(init)) * 0, s = size, c = 'purple', marker = 's', edgecolor = 'purple', label = 'particle', zorder = 4)
for i in range(len(init)):
    plt.vlines(blobdk.view(-1).numpy()[i], 0, mass_blobdk.numpy()[i], linewidth = linewidth, linestyle = '-', zorder = 1)
    plt.scatter(blobdk.view(-1).numpy()[i], mass_blobdk.numpy()[i], c = '#1f77b4', s = size_, marker = '_', zorder = 4)
plt.xticks([])
plt.tight_layout()
plt.tick_params(labelsize = 16)
plt.grid()
plt.title('D-Blob-DK', fontdict = {'size': 16})

plt.savefig('./figures/demo2.pdf', bbox_inches = 'tight', dpi = 300)
plt.close()
plt.show()