In [1]:
%env CUDA_VISIBLE_DEVICES=0

import math
from IPython.display import clear_output
import torch
import matplotlib.pyplot as plt
from tqdm.auto import tqdm, trange

STEPS = 1000
ALPHA = 0.2

env: CUDA_VISIBLE_DEVICES=0


In [2]:
def sample_points_in_ball(n_samples=10000, dim=2, radius=10):
    """
    Sample points uniformly within a circle centered at (0, 0) with the given radius.

    :param n_samples: Number of points to sample.
    :param radius: Radius of the circle.
    :return: Torch tensor of sampled points (x1, x2, ...xdim).
    """
    points = torch.rand(n_samples, dim, device="cuda") * 2 * radius - radius
    points = points[torch.norm(points, dim=1) < radius]
    return points


def normal_density(points):
    """
    Compute the Gaussian weights of points under a 2D Normal distribution N(0, 1) in a vectorized manner.
    """
    return torch.exp(-0.5 * torch.sum(points**2, dim=1)) / (2 * torch.pi) ** (points.shape[1] / 2)


def compute_centers_of_mass_with_gaussian_weights(grid, sampled_points):
    """
    Compute centers of mass for Voronoi cells using points weighted by a Gaussian distribution, in a vectorized manner.
    """
    # Calculate Gaussian weights for all sampled points
    weights = normal_density(sampled_points)
    
    # Find the Voronoi cell each point belongs to
    closest_indices = torch.empty(sampled_points.shape[0], dtype=torch.long, device=sampled_points.device)
    for i in range(0, sampled_points.shape[0], 4096):
        closest_indices[i:i+4096] = torch.argmax(2 * sampled_points[i:i+4096] @ grid.T - torch.norm(grid, dim=1)**2, dim=1)
    
    # Initialize sums and weights sum arrays
    sums = torch.zeros_like(grid)
    weights_sum = torch.zeros(grid.shape[0], dtype=grid.dtype, device=grid.device)
    
    # Accumulate weighted points and total weights for each Voronoi cell
    for i in range(grid.shape[0]):
        cell_mask = closest_indices == i
        sums[i] = torch.sum(sampled_points[cell_mask] * weights[cell_mask, None], dim=0)
        weights_sum[i] = torch.sum(weights[cell_mask])

    # Calculate centers of mass
    valid_cells = weights_sum > 0
    
    centers_of_mass = torch.clone(grid)
    centers_of_mass[valid_cells] = sums[valid_cells] / weights_sum[valid_cells, None]

    return centers_of_mass

# Example usage
@torch.no_grad()
def get_grid(dim, n_points, limit, steps):
    centers_of_mass = torch.empty(n_points, dim, device="cuda").normal_(0, 1)
    centers_of_mass[torch.norm(centers_of_mass, dim=1) > limit] = 0


    for i in trange(steps, leave=False):
        # Sample points within the [-13, 13] x [-13, 13] square
        sampled_points = sample_points_in_ball(n_samples=10000000, dim=dim, radius=limit)

        # Compute centers of mass
        centers_of_mass = ALPHA * compute_centers_of_mass_with_gaussian_weights(centers_of_mass, sampled_points) + (1 - ALPHA) * centers_of_mass

        # Visualization
        # if i % 10 == 0 and centers_of_mass.shape[1] > 1:
        #     fig, ax = plt.subplots(figsize=(10,10))
        #     clear_output()
        #     for com in centers_of_mass[:1000]:
        #         ax.plot(com[0].item(), com[1].item(), 'ro')  # Centers of mass marked in red
        #     plt.xlim(-limit, limit)
        #     plt.ylim(-limit, limit)
        #     plt.show()
        #     print(f"{i}/{steps}")
    return centers_of_mass


In [3]:
import os
import numpy as np

for dim in tqdm([1]):
    for n_points in tqdm([256], leave=False):
        bits = np.log2(n_points) / dim
        
        # if bits > 6:
        #     print(f"Skipping {dim=},{n_points=}: bitwidth {np.log2(n_points) / dim} too large")
        #     continue
        
        if bits < 0.25:
            print(f"Skipping {dim=},{n_points=}: bitwidth {np.log2(n_points) / dim} too small")
            continue
        
        if os.path.isfile(f"./grids/EDEN{dim}-{n_points}.pt"):
            print(f"Skipping {dim=},{n_points=}: already exists")
            continue
        
        
        if bits < 1.5:
            limit = 4
        elif bits < 2.5:
            limit = 5
        else:
            limit = 6
        
        print(f"Computing {dim=},{n_points=}...")
        grid = get_grid(dim, n_points, limit, STEPS)
        assert n_points == grid.shape[0]
        
        torch.save(grid, f"./grids/EDEN{dim}-{n_points}.pt")

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Skipping dim=1,n_points=256: already exists
