In [5]:
import torch
import numpy as np
from math import sqrt
import gc

DEVICE = 'cuda:0'

In [6]:

class AmsSketch:
    """
    AMS Sketch class for approximate second moment estimation in PyTorch.
    """

    def __init__(self, depth=3, width=500):
        
        self.depth = depth
        self.width = width
        
        self.epsilon = 1. / sqrt(width)

        self.F = torch.randint(0, (1 << 31) - 1, (6, depth), dtype=torch.int32)

        # Dictionary to store precomputed results
        self.precomputed_dict = {}

    def precompute(self, d):
        pos_tensor = self.tensor_hash31(torch.arange(d), self.F[0], self.F[1]) % self.width  # shape=(d, depth)
        four = self.tensor_fourwise(torch.arange(d)).float()  # shape=(d, depth)
        self.precomputed_dict[('pos_tensor', d)] = pos_tensor.to(DEVICE)  # shape=(d, depth)
        self.precomputed_dict[('four', d)] = four.to(DEVICE)  # shape=(d, depth)

    @staticmethod
    def hash31(x, a, b):
        r = a * x + b
        fold = torch.bitwise_xor(r >> 31, r)
        return fold & 2147483647

    @staticmethod
    def tensor_hash31(x, a, b):
        """ Assumed that x is tensor shaped (d,) , i.e., a vector (for example, indices, i.e., torch.arange(d)) """
        x_reshaped = x.unsqueeze(-1)
        r = a * x_reshaped + b
        fold = torch.bitwise_xor(r >> 31, r)
        return fold & 2147483647

    def tensor_fourwise(self, x):
        """ Assumed that x is tensor shaped (d,) , i.e., a vector (for example, indices, i.e., torch.arange(d)) """
        in1 = self.tensor_hash31(x, self.F[2], self.F[3])  # shape = (`x_dim`, `depth`)
        in2 = self.tensor_hash31(x, in1, self.F[4])  # shape = (`x_dim`, `depth`)
        in3 = self.tensor_hash31(x, in2, self.F[5])  # shape = (`x_dim`, `depth`)

        in4 = in3 & 32768  # shape = (`x_dim`, `depth`)
        return 2 * (in4 >> 15) - 1  # shape = (`x_dim`, `depth`)

    def sketch_for_vector(self, v):
        """ Efficient computation of sketch using PyTorch tensors.

        Args:
        - v (torch.Tensor): Vector to sketch. Shape=(d,).

        Returns:
        - torch.Tensor: An AMS Sketch. Shape=(`depth`, `width`).
        """
        d = v.shape[0]

        if ('four', d) not in self.precomputed_dict:
            self.precompute(d)

        four, pos_tensor = self.precomputed_dict[('four', d)], self.precomputed_dict[('pos_tensor', d)]
        
        sketch = self._sketch_for_vector(v, four, pos_tensor)
        
        gc.collect()
        
        return sketch

    def _sketch_for_vector(self, v, four, pos_tensor):
        """
        PyTorch translation of the TensorFlow function using a simple for loop.

        Args:
        - v (torch.Tensor): Vector to sketch. Shape=(d,).
        - four (torch.Tensor): Precomputed fourwise tensor. Shape=(d, depth).
        - indices (torch.Tensor): Precomputed indices for scattering. Shape=(d, depth, 2).

        Returns:
        - sketch (torch.Tensor): The AMS sketch tensor. Shape=(depth, width).
        """

        # Expand the input vector v to match dimensions for element-wise multiplication
        v_expand = v.unsqueeze(-1).to(DEVICE)  # shape=(d, 1)

        # Element-wise multiply v_expand and four to get deltas
        deltas_tensor = four * v_expand  # shape=(d, depth)

        # Initialize the sketch tensor with zeros
        sketch = torch.zeros((self.depth, self.width), dtype=torch.float32).to(DEVICE)

        # Loop over each depth and scatter the corresponding values
        for i in range(self.depth):
            # Compute the width indices on the fly
            width_indices = pos_tensor[:, i]  # shape=(d,), indices for the width dimension
            
            deltas = deltas_tensor[:, i]

            # Add the deltas_tensor[:, i] (shape=(d,)) into the correct rows
            # using index_add on the width dimension
            sketch[i].index_add_(0, width_indices, deltas)

        return sketch

    @staticmethod
    def estimate_euc_norm_squared(sketch):
        """ Estimate the Euclidean norm squared of a vector using its AMS sketch.

        Args:
        - sketch (torch.Tensor): AMS sketch of a vector. Shape=(`depth`, `width`).

        Returns:
        - float: Estimated squared Euclidean norm.
        """
        norm_sq_rows = torch.sum(sketch ** 2, dim=1)
        return torch.median(norm_sq_rows).item()


In [7]:
ams_sketch = AmsSketch()

In [8]:
ams_sketch.F

tensor([[1077708736,  775553956,  197980044],
        [1378570210,  759419985,  667263297],
        [ 186700019,  431933018,  500035991],
        [1822604758,  446099739, 1236368238],
        [ 820040940, 1133144505,  723769476],
        [ 871766692, 1792924851, 1434257607]], dtype=torch.int32)

In [11]:
ams_sketch.F

tensor([[1608637542, 1273642420, 1935803229],
        [ 787846414,  996406379, 1201263688],
        [ 423734973,  415968277,  670094950],
        [1914837113,  669991378,  429389014],
        [ 249467210, 1972458954, 1572714584],
        [1433267572,  434285668,  613608295]], dtype=torch.int32)

In [3]:
for _ in range(20):
    v1 = torch.rand(100_000_000)
    sk1 = ams_sketch.sketch_for_vector(v1)
    est = ams_sketch.estimate_euc_norm_squared(sk1)
    print(abs(est - torch.dot(v1,v1)) / est)

tensor(0.0670)
tensor(0.0705)
tensor(0.0449)
tensor(0.0664)
tensor(0.0067)
tensor(0.1256)
tensor(0.0938)
tensor(0.0893)
tensor(0.0602)
tensor(0.0464)
tensor(0.0781)
tensor(0.0488)
tensor(0.0454)
tensor(0.0262)
tensor(0.0615)
tensor(0.1158)
tensor(0.1045)
tensor(0.0785)
tensor(0.0740)
tensor(0.0133)


In [9]:
v1 = torch.rand(1000)
v2 = torch.rand(1000)

In [10]:
sk1 = ams_sketch.sketch_for_vector(v1)
sk2 = ams_sketch.sketch_for_vector(v2)

In [11]:
ams_sketch.estimate_euc_norm_squared(sk1)

368.45062255859375

In [12]:
torch.dot(v1,v1)

tensor(357.6067)

In [13]:
ams_sketch.estimate_euc_norm_squared(sk2)

334.3333740234375

In [14]:
torch.dot(v2,v2)

tensor(332.4880)

In [15]:
sk_l = ams_sketch.sketch_for_vector(v1+v2)

In [16]:
sk1+sk2

tensor([[-0.7485,  0.3308, -0.9262,  ...,  2.5623, -1.4802, -0.2253],
        [-0.7550, -0.5428,  2.0598,  ...,  1.9245,  0.5743,  1.6090],
        [-4.0223,  0.0695,  1.3860,  ...,  0.4399,  1.7191,  2.7875],
        [-1.6184,  2.8533, -1.1031,  ...,  5.1887,  2.1836,  5.7894],
        [-5.6956, -1.1338, -6.5732,  ..., -1.7909, -0.3551, -3.2422]],
       device='cuda:0')

In [17]:
sk_l

tensor([[-0.7485,  0.3308, -0.9262,  ...,  2.5623, -1.4802, -0.2253],
        [-0.7550, -0.5428,  2.0598,  ...,  1.9245,  0.5743,  1.6090],
        [-4.0223,  0.0695,  1.3860,  ...,  0.4399,  1.7191,  2.7875],
        [-1.6184,  2.8533, -1.1031,  ...,  5.1887,  2.1836,  5.7894],
        [-5.6956, -1.1338, -6.5732,  ..., -1.7909, -0.3551, -3.2422]],
       device='cuda:0')