<a href="https://colab.research.google.com/github/montest/stochastic-methods-optimal-quantization/blob/pytorch_implentation_dim_1/Lloyd_with_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import sys

import torch
import torch.nn.functional as F
import numpy as np
from tqdm import trange

np.set_printoptions(threshold=np.inf, linewidth=10_000)
torch.set_printoptions(profile="full", linewidth=10_000)

In [None]:
def lloyd_method(N: int, M: int, nbr_iter: int, device:str):
    """
    Perform scalar quantization using the Lloyd algorithm for a Gaussian random variable.

    x: input signal
    k: number of quantization levels

    Returns: quantized signal and quantization levels
    """
    with torch.no_grad():

      mean, sigma = 0, 1
      x = torch.tensor(torch.randn(M) * sigma + mean, dtype=torch.float32)

      # Initialize quantization centroids randomly
      centroids = torch.randn(N) * sigma + mean
      print(centroids)
      centroids, indices = centroids.sort()
      print(centroids)

      x = x.to(device)
      centroids = centroids.to(device)
      # Repeat the quantization process until convergence
      with trange(nbr_iter, desc=f'Lloyd method (pytorch: {device})') as t:
          for step in t:
              # x = torch.tensor(torch.randn(M) * sigma + mean, dtype=torch.float32)
              # # quick method
              # # Compute the thresholds that separate the quantization levels
              thresholds = 0.5 * (centroids[:-1] + centroids[1:])

              #
              # # Assign each sample to the closest quantization level
              indices = torch.sum(x[:, None] >= thresholds[None, :], dim=1).long()

              # slow one
              #dist_centroids_points = torch.norm(centroids - x.view(M, 1, 1), dim=1)
              #indices = dist_centroids_points.argmin(dim=1)

              # Compute the new quantization levels as the mean of the samples assigned to each level
              centroids = torch.tensor([torch.mean(x[indices == i]) for i in range(N)]).to(device)

              # Check if the quantization levels have converged
              # if torch.allclose(centroids, new_quantization_levels, rtol=epsilon):
              #     break

              # Update the quantization levels
              # centroids = new_quantization_levels

              # compute probas
              # probabilities = torch.bincount(indices).numpy() / float(M)
              # probabilities = np.array([x[indices == i].size()[0] for i in range(N)])/float(M)

              # compute distortion
              # quantized_signal = centroids[indices]
              # distortion = torch.mean((x - quantized_signal)**2) * 0.5
              # t.set_postfix(distortion=distortion.item())

              # x = torch.randn(n) * sigma + mean
              # x = torch.tensor(x, dtype=torch.float32)

      # Compute the probability of each centroid
      thresholds = 0.5 * (centroids[:-1] + centroids[1:])
      indices = torch.sum(x[:, None] >= thresholds[None, :], dim=1).long()
      # probabilities = np.array([x[indices == i].size()[0] for i in range(N)])/float(M)
      probabilities = torch.bincount(indices).to('cpu').numpy()/float(M)

      return centroids.to('cpu').numpy(), probabilities

In [None]:
# Generate a sample of a Gaussian random variable
torch.manual_seed(0)

# Apply the Lloyd-Max algorithm with 4 quantization levels
#device = 'cuda' if torch.cuda.is_available() else 'cpu'
centroids, probas = lloyd_method(M=100000, N=50, nbr_iter=500, device='cuda')
torch.manual_seed(0)

print(centroids)
print(probas)
print(probas.sum())
centroids, probas = lloyd_method(M=100000, N=50, nbr_iter=500, device='cpu')

print(centroids)
print(probas)
print(probas.sum())

  x = torch.tensor(torch.randn(M) * sigma + mean, dtype=torch.float32)


tensor([ 1.5558, -1.7621, -2.0133,  0.2595, -0.4927, -0.4737, -0.2074, -1.9427, -0.6407,  0.5961, -0.2810, -0.3851,  0.1829, -0.0854, -0.3732, -0.3063, -0.8163,  0.1983, -0.0392, -0.4143,  1.1076,  0.0795, -0.9777,  0.4277, -1.1232,  1.4520, -1.1108,  0.7190, -0.1672, -1.8211,  1.7581,  0.3415, -1.7858,  0.2259,  0.7186,  0.6501,  0.3941,  1.6776, -1.7012, -0.6319, -1.4278,  1.3123,  1.4974, -1.2202, -0.5113,  1.0046,  1.1069,  1.4470, -1.7761, -0.7340])
tensor([-2.0133, -1.9427, -1.8211, -1.7858, -1.7761, -1.7621, -1.7012, -1.4278, -1.2202, -1.1232, -1.1108, -0.9777, -0.8163, -0.7340, -0.6407, -0.6319, -0.5113, -0.4927, -0.4737, -0.4143, -0.3851, -0.3732, -0.3063, -0.2810, -0.2074, -0.1672, -0.0854, -0.0392,  0.0795,  0.1829,  0.1983,  0.2259,  0.2595,  0.3415,  0.3941,  0.4277,  0.5961,  0.6501,  0.7186,  0.7190,  1.0046,  1.1069,  1.1076,  1.3123,  1.4470,  1.4520,  1.4974,  1.5558,  1.6776,  1.7581])


Lloyd method (pytorch: cuda): 100%|██████████| 500/500 [00:02<00:00, 198.97it/s]


[-3.528091   -2.9802203  -2.6384788  -2.3809927  -2.1708677  -1.9945942  -1.8397757  -1.698688   -1.5707275  -1.446537   -1.3321846  -1.2279814  -1.1247361  -1.0251892  -0.92976886 -0.84033996 -0.75436914 -0.6713103  -0.5892103  -0.5082337  -0.43156746 -0.35479936 -0.27882677 -0.2024892  -0.1273823  -0.05128891  0.02491156  0.10145945  0.17912774  0.25875446  0.33940196  0.42119375  0.50701874  0.5962884   0.6863401   0.7788686   0.8765815   0.97787935  1.082488    1.1912245   1.3125548   1.4431299   1.5924487   1.7509642   1.9247192   2.11428     2.3218226   2.594847    2.9499176   3.5267856 ]
[0.00065 0.00194 0.00366 0.00542 0.00707 0.00897 0.0108  0.01264 0.01542 0.01714 0.01753 0.0194  0.02232 0.02317 0.02396 0.02456 0.02525 0.02541 0.02744 0.02737 0.02717 0.02848 0.02917 0.0291  0.03075 0.03019 0.03133 0.03044 0.03013 0.02983 0.0304  0.03076 0.0303  0.03046 0.0296  0.02789 0.02734 0.02554 0.02377 0.02299 0.02153 0.01937 0.01669 0.01374 0.01153 0.00816 0.00614 0.00416 0.00226 0.000

Lloyd method (pytorch: cpu): 100%|██████████| 500/500 [00:24<00:00, 20.62it/s]


[-3.528091   -2.9802203  -2.638479   -2.3809931  -2.1708677  -1.994594   -1.8397756  -1.6986881  -1.5707275  -1.446537   -1.3321844  -1.2279813  -1.1247361  -1.0251892  -0.9297689  -0.8403401  -0.754369   -0.67131025 -0.5892103  -0.5082338  -0.43156746 -0.35479942 -0.27882677 -0.20248918 -0.1273823  -0.05128891  0.02491157  0.10145946  0.17912774  0.25875446  0.33940193  0.4211937   0.5070188   0.5962884   0.6863401   0.7788686   0.8765814   0.97787917  1.0824881   1.1912245   1.3125548   1.4431298   1.5924487   1.7509639   1.9247191   2.11428     2.3218224   2.5948465   2.9499178   3.5267851 ]
[0.00065 0.00194 0.00366 0.00542 0.00707 0.00897 0.0108  0.01264 0.01542 0.01714 0.01753 0.0194  0.02232 0.02317 0.02396 0.02456 0.02525 0.02541 0.02744 0.02737 0.02717 0.02848 0.02917 0.0291  0.03075 0.03019 0.03133 0.03044 0.03013 0.02983 0.0304  0.03076 0.0303  0.03046 0.0296  0.02789 0.02734 0.02554 0.02377 0.02299 0.02153 0.01937 0.01669 0.01374 0.01153 0.00816 0.00614 0.00416 0.00226 0.000