In [1]:
%env CUDA_VISIBLE_DEVICES=6

import math
from IPython.display import clear_output
import torch
import matplotlib.pyplot as plt
from tqdm.auto import tqdm, trange

STEPS = 200
ALPHA = 0.5

env: CUDA_VISIBLE_DEVICES=6


In [2]:
def sample_points_in_ball(n_samples=10000, dim=2, radius=10):
    """
    Sample points uniformly within a circle centered at (0, 0) with the given radius.

    :param n_samples: Number of points to sample.
    :param radius: Radius of the circle.
    :return: Torch tensor of sampled points (x1, x2, ...xdim).
    """
    points = torch.rand(n_samples, dim, device="cuda") * 2 * radius - radius
    points = points[torch.norm(points, dim=1) < radius]
    return points


def normal_density(points):
    """
    Compute the Gaussian weights of points under a 2D Normal distribution N(0, 1) in a vectorized manner.
    """
    return torch.exp(-0.5 * torch.sum(points**2, dim=1)) / (2 * torch.pi) ** (points.shape[1] / 2)


def compute_centers_of_mass_with_gaussian_weights(grid, sampled_points):
    """
    Compute centers of mass for Voronoi cells using points weighted by a Gaussian distribution, in a vectorized manner.
    """
    # Calculate Gaussian weights for all sampled points
    weights = normal_density(sampled_points)
    
    # Find the Voronoi cell each point belongs to
    closest_indices = torch.empty(sampled_points.shape[0], dtype=torch.long, device=sampled_points.device)
    for i in range(0, sampled_points.shape[0], 4096):
        closest_indices[i:i+4096] = torch.argmax(2 * sampled_points[i:i+4096] @ grid.T - torch.norm(grid, dim=1)**2, dim=1)
    
    # Initialize sums and weights sum arrays
    sums = torch.zeros_like(grid)
    weights_sum = torch.zeros(grid.shape[0], dtype=grid.dtype, device=grid.device)
    
    # Accumulate weighted points and total weights for each Voronoi cell
    for i in range(grid.shape[0]):
        cell_mask = closest_indices == i
        sums[i] = torch.sum(sampled_points[cell_mask] * weights[cell_mask, None], dim=0)
        weights_sum[i] = torch.sum(weights[cell_mask])

    # Calculate centers of mass
    valid_cells = weights_sum > 0
    
    centers_of_mass = torch.clone(grid)
    centers_of_mass[valid_cells] = sums[valid_cells] / weights_sum[valid_cells, None]

    return centers_of_mass

# Example usage
@torch.no_grad()
def get_grid(dim, bits, limit, steps):
    n_points = math.floor(2 ** (bits * dim))
    centers_of_mass = torch.empty(n_points, dim, device="cuda").normal_(0, 1)
    centers_of_mass[torch.norm(centers_of_mass, dim=1) > limit] = 0


    for i in trange(steps, leave=False):
        # Sample points within the [-13, 13] x [-13, 13] square
        sampled_points = sample_points_in_ball(n_samples=10000000, dim=dim, radius=limit)

        # Compute centers of mass
        centers_of_mass = ALPHA * compute_centers_of_mass_with_gaussian_weights(centers_of_mass, sampled_points) + (1 - ALPHA) * centers_of_mass

        # Visualization
        # if i % 10 == 0 and centers_of_mass.shape[1] > 1:
        #     fig, ax = plt.subplots(figsize=(10,10))
        #     clear_output()
        #     for com in centers_of_mass[:1000]:
        #         ax.plot(com[0].item(), com[1].item(), 'ro')  # Centers of mass marked in red
        #     plt.xlim(-limit, limit)
        #     plt.ylim(-limit, limit)
        #     plt.show()
        #     print(f"{i}/{steps}")
    return centers_of_mass


In [3]:
import os
import numpy as np

for dim in trange(1, 11, leave=False):
    for bits in tqdm(np.linspace(0.1, 3, num=20)):
        n_points = math.floor(2 ** (bits * dim))
        
        if bits * dim < 1:
            print(f"Skipping {dim}D with {bits} bits")
            continue
        
        if bits * dim > 12:
            print(f"Skipping {dim}D with {bits} bits")
            continue
        
        if os.path.isfile(f"./grids/EDEN{dim}-{n_points}.pt"):
            print(f"Skipping {dim}D with {bits} bits")
            continue
        
        if bits < 1.5:
            limit = 4
        elif bits < 2.5:
            limit = 5
        else:
            limit = 6
        
        grid = get_grid(dim, bits, limit, STEPS)
        assert n_points == grid.shape[0]
        
        torch.save(grid, f"./grids/EDEN{dim}-{n_points}.pt")

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/20 [00:00<?, ?it/s]

Skipping 1D with 0.1 bits
Skipping 1D with 0.25263157894736843 bits
Skipping 1D with 0.4052631578947369 bits
Skipping 1D with 0.5578947368421052 bits
Skipping 1D with 0.7105263157894737 bits
Skipping 1D with 0.8631578947368421 bits
Skipping 1D with 1.0157894736842106 bits
Skipping 1D with 1.168421052631579 bits
Skipping 1D with 1.3210526315789475 bits
Skipping 1D with 1.473684210526316 bits
Skipping 1D with 1.6263157894736844 bits
Skipping 1D with 1.7789473684210528 bits
Skipping 1D with 1.931578947368421 bits
Skipping 1D with 2.0842105263157893 bits
Skipping 1D with 2.236842105263158 bits
Skipping 1D with 2.3894736842105266 bits
Skipping 1D with 2.542105263157895 bits
Skipping 1D with 2.694736842105263 bits
Skipping 1D with 2.8473684210526318 bits
Skipping 1D with 3.0 bits


  0%|          | 0/20 [00:00<?, ?it/s]

Skipping 2D with 0.1 bits
Skipping 2D with 0.25263157894736843 bits
Skipping 2D with 0.4052631578947369 bits
Skipping 2D with 0.5578947368421052 bits
Skipping 2D with 0.7105263157894737 bits
Skipping 2D with 0.8631578947368421 bits
Skipping 2D with 1.0157894736842106 bits
Skipping 2D with 1.168421052631579 bits
Skipping 2D with 1.3210526315789475 bits
Skipping 2D with 1.473684210526316 bits
Skipping 2D with 1.6263157894736844 bits
Skipping 2D with 1.7789473684210528 bits
Skipping 2D with 1.931578947368421 bits
Skipping 2D with 2.0842105263157893 bits
Skipping 2D with 2.236842105263158 bits
Skipping 2D with 2.3894736842105266 bits
Skipping 2D with 2.542105263157895 bits
Skipping 2D with 2.694736842105263 bits
Skipping 2D with 2.8473684210526318 bits
Skipping 2D with 3.0 bits


  0%|          | 0/20 [00:00<?, ?it/s]

Skipping 3D with 0.1 bits
Skipping 3D with 0.25263157894736843 bits
Skipping 3D with 0.4052631578947369 bits
Skipping 3D with 0.5578947368421052 bits
Skipping 3D with 0.7105263157894737 bits
Skipping 3D with 0.8631578947368421 bits
Skipping 3D with 1.0157894736842106 bits
Skipping 3D with 1.168421052631579 bits
Skipping 3D with 1.3210526315789475 bits
Skipping 3D with 1.473684210526316 bits
Skipping 3D with 1.6263157894736844 bits
Skipping 3D with 1.7789473684210528 bits
Skipping 3D with 1.931578947368421 bits
Skipping 3D with 2.0842105263157893 bits
Skipping 3D with 2.236842105263158 bits
Skipping 3D with 2.3894736842105266 bits
Skipping 3D with 2.542105263157895 bits
Skipping 3D with 2.694736842105263 bits
Skipping 3D with 2.8473684210526318 bits
Skipping 3D with 3.0 bits


  0%|          | 0/20 [00:00<?, ?it/s]

Skipping 4D with 0.1 bits
Skipping 4D with 0.25263157894736843 bits
Skipping 4D with 0.4052631578947369 bits
Skipping 4D with 0.5578947368421052 bits
Skipping 4D with 0.7105263157894737 bits
Skipping 4D with 0.8631578947368421 bits
Skipping 4D with 1.0157894736842106 bits
Skipping 4D with 1.168421052631579 bits
Skipping 4D with 1.3210526315789475 bits
Skipping 4D with 1.473684210526316 bits
Skipping 4D with 1.6263157894736844 bits
Skipping 4D with 1.7789473684210528 bits
Skipping 4D with 1.931578947368421 bits
Skipping 4D with 2.0842105263157893 bits
Skipping 4D with 2.236842105263158 bits
Skipping 4D with 2.3894736842105266 bits
Skipping 4D with 2.542105263157895 bits
Skipping 4D with 2.694736842105263 bits
Skipping 4D with 2.8473684210526318 bits
Skipping 4D with 3.0 bits


  0%|          | 0/20 [00:00<?, ?it/s]

Skipping 5D with 0.1 bits
Skipping 5D with 0.25263157894736843 bits
Skipping 5D with 0.4052631578947369 bits
Skipping 5D with 0.5578947368421052 bits
Skipping 5D with 0.7105263157894737 bits
Skipping 5D with 0.8631578947368421 bits
Skipping 5D with 1.0157894736842106 bits
Skipping 5D with 1.168421052631579 bits
Skipping 5D with 1.3210526315789475 bits
Skipping 5D with 1.473684210526316 bits
Skipping 5D with 1.6263157894736844 bits
Skipping 5D with 1.7789473684210528 bits
Skipping 5D with 1.931578947368421 bits
Skipping 5D with 2.0842105263157893 bits
Skipping 5D with 2.236842105263158 bits
Skipping 5D with 2.3894736842105266 bits
Skipping 5D with 2.542105263157895 bits
Skipping 5D with 2.694736842105263 bits
Skipping 5D with 2.8473684210526318 bits
Skipping 5D with 3.0 bits


  0%|          | 0/20 [00:00<?, ?it/s]

Skipping 6D with 0.1 bits
Skipping 6D with 0.25263157894736843 bits
Skipping 6D with 0.4052631578947369 bits
Skipping 6D with 0.5578947368421052 bits
Skipping 6D with 0.7105263157894737 bits
Skipping 6D with 0.8631578947368421 bits
Skipping 6D with 1.0157894736842106 bits
Skipping 6D with 1.168421052631579 bits
Skipping 6D with 1.3210526315789475 bits
Skipping 6D with 1.473684210526316 bits
Skipping 6D with 1.6263157894736844 bits
Skipping 6D with 1.7789473684210528 bits
Skipping 6D with 1.931578947368421 bits
Skipping 6D with 2.0842105263157893 bits
Skipping 6D with 2.236842105263158 bits
Skipping 6D with 2.3894736842105266 bits
Skipping 6D with 2.542105263157895 bits
Skipping 6D with 2.694736842105263 bits
Skipping 6D with 2.8473684210526318 bits
Skipping 6D with 3.0 bits


  0%|          | 0/20 [00:00<?, ?it/s]

Skipping 7D with 0.1 bits
Skipping 7D with 0.25263157894736843 bits
Skipping 7D with 0.4052631578947369 bits
Skipping 7D with 0.5578947368421052 bits
Skipping 7D with 0.7105263157894737 bits
Skipping 7D with 0.8631578947368421 bits
Skipping 7D with 1.0157894736842106 bits
Skipping 7D with 1.168421052631579 bits
Skipping 7D with 1.3210526315789475 bits
Skipping 7D with 1.473684210526316 bits


  0%|          | 0/200 [00:00<?, ?it/s]

Skipping 7D with 1.7789473684210528 bits
Skipping 7D with 1.931578947368421 bits
Skipping 7D with 2.0842105263157893 bits
Skipping 7D with 2.236842105263158 bits
Skipping 7D with 2.3894736842105266 bits
Skipping 7D with 2.542105263157895 bits
Skipping 7D with 2.694736842105263 bits
Skipping 7D with 2.8473684210526318 bits
Skipping 7D with 3.0 bits


  0%|          | 0/20 [00:00<?, ?it/s]

Skipping 8D with 0.1 bits
Skipping 8D with 0.25263157894736843 bits
Skipping 8D with 0.4052631578947369 bits
Skipping 8D with 0.5578947368421052 bits
Skipping 8D with 0.7105263157894737 bits
Skipping 8D with 0.8631578947368421 bits
Skipping 8D with 1.0157894736842106 bits
Skipping 8D with 1.168421052631579 bits


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

Skipping 8D with 1.6263157894736844 bits
Skipping 8D with 1.7789473684210528 bits
Skipping 8D with 1.931578947368421 bits
Skipping 8D with 2.0842105263157893 bits
Skipping 8D with 2.236842105263158 bits
Skipping 8D with 2.3894736842105266 bits
Skipping 8D with 2.542105263157895 bits
Skipping 8D with 2.694736842105263 bits
Skipping 8D with 2.8473684210526318 bits
Skipping 8D with 3.0 bits


  0%|          | 0/20 [00:00<?, ?it/s]

Skipping 9D with 0.1 bits
Skipping 9D with 0.25263157894736843 bits
Skipping 9D with 0.4052631578947369 bits


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

Skipping 9D with 1.473684210526316 bits
Skipping 9D with 1.6263157894736844 bits
Skipping 9D with 1.7789473684210528 bits
Skipping 9D with 1.931578947368421 bits
Skipping 9D with 2.0842105263157893 bits
Skipping 9D with 2.236842105263158 bits
Skipping 9D with 2.3894736842105266 bits
Skipping 9D with 2.542105263157895 bits
Skipping 9D with 2.694736842105263 bits
Skipping 9D with 2.8473684210526318 bits
Skipping 9D with 3.0 bits


  0%|          | 0/20 [00:00<?, ?it/s]

Skipping 10D with 0.1 bits
Skipping 10D with 0.25263157894736843 bits
Skipping 10D with 0.4052631578947369 bits


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

Skipping 10D with 1.3210526315789475 bits
Skipping 10D with 1.473684210526316 bits
Skipping 10D with 1.6263157894736844 bits
Skipping 10D with 1.7789473684210528 bits
Skipping 10D with 1.931578947368421 bits
Skipping 10D with 2.0842105263157893 bits
Skipping 10D with 2.236842105263158 bits
Skipping 10D with 2.3894736842105266 bits
Skipping 10D with 2.542105263157895 bits
Skipping 10D with 2.694736842105263 bits
Skipping 10D with 2.8473684210526318 bits
Skipping 10D with 3.0 bits
