In [72]:
import os
os.environ["DEBUG"] = "1"

import torch
import numpy as np

from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes
from tinygrad.ops import Device
Device.DEFAULT = "CPU"

## Psuedocode

```
- input
    - input tensor (ndim=1 or ndim=2)
    - n is the number of samples to draw
    - replacement (bool) whether to draw with replacement or not
- validate input
    - make sure input is a tensor of ndim=1 or ndim=2, else raise error
    - make sure there's no negative values in the input tensor, else raise error
    - make sure there are no null values in the input tensor, else raise error
    - make sure the sum of the input tensor is 1, else raise error (or normalize it - pytorch does not automatically normalize i think)
- errors
    - if replacement is False, and n > k (len of input tensor), raise error 
```

```
function multinomial(input_tensor, n_samples, replacement)
    validate input
    check replacement logic, raise error if necessary

    if 2D tensor, iterate over each row:
        normalize if necessary
        calculate the CDF
        get the random numbers
        convert the random numbers to indices based on the CDF
        return the indices
    else if 1D tensor:
        normalize if necessary
        calculate the CDF
        get the random numbers
        convert the random numbers to indices based on the CDF
        return the indices
```

```
multinomial(input_tensor, n_samples, replacement)
    validation logic
    replacment logic
    def multinomial_1d(input_tensor of ndim=1, n_samples, replacement)
        normalize if necessary
        calculate the CDF
        get the random numbers
        convert the random numbers to indices based on the CDF
        return the indices

    if 2d:
        for each row:
            multinomial_1d(row, n_samples, replacement)
    else:
        multinomial_1d(input_tensor, n_samples, replacement)

    return the indices
```

In [None]:
@staticmethod
def multinomial(input: Tensor, num_samples: int, replacement: bool = False):
    assert isinstance(input, Tensor) and (input.ndim == 1 or input.ndim == 2), "input must be a 1 or 2-dimensional tensor"
    assert isinstance(num_samples, int) and num_samples > 0, "num_samples must be a positive integer"
    if replacement == False:
        assert num_samples <= input.numel(), f"num_samples: {num_samples} must be less than or equal to the number of elements in input tensor: {input.numel()}"
        # TODO: implement without replacement
        raise NotImplementedError("multinomial without replacement not implemented")
    
    def multinomial_1d(input: Tensor, num_samples: int, replacement: bool = False):
        assert input.ndim == 1, "input must be a 1-dimensional tensor"
        assert num_samples <= input.numel(), f"num_samples: {num_samples} must be less than or equal to the number of elements in input tensor: {input.numel()}"
        
        cumsum = input.cumsum()
        # generate random numbers 
        rand = Tensor.rand(num_samples, 1)
        # compute indices
        indices = (cumsum.unsqueeze(1) < rand).sum(axis=0)
        return indices
    

In [43]:
t_1d = Tensor([0.1, 0.2, 0.3, 0.4]) # 1-dimensional tensor
num_samples = 4 # number of samples to draw

cdf = t_1d.cumsum()
# print(f"cdf: {cdf.numpy()}")
assert cdf == (expected_cdf := Tensor([0.1, 0.3, 0.6, 1.0])), f"Expected cdf: {expected_cdf.numpy()}, got cdf: {cdf.numpy()}"


rand = Tensor([[0.5488], [0.7152], [0.6028], [0.5449]]) # mock random numbers for testing equal to the number of samples
# print(f"rand: {rand.numpy()}")
# get first index where cdf is greater than random number
indices = (rand.expand(num_samples, t_1d.numel()) >= cdf).sum(1) # why does this work???
# print(f"indices: {indices.numpy()}")
expected_indices = Tensor([2, 3, 3, 2])
assert indices == expected_indices, f"Expected indices: {expected_indices.numpy()}, got indices: {indices.numpy()}"

In [109]:
def muntinomial(input: Tensor, num_samples: int, replacement=False):
    assert isinstance(input, Tensor) and (input.ndim == 1 or input.ndim == 2), "input must be a 1 or 2-dimensional tensor"
    assert isinstance(num_samples, int) and num_samples > 0, "num_samples must be a positive integer"
    if replacement == False:
        assert num_samples <= input.numel(), f"num_samples: {num_samples} must be less than or equal to the number of elements in input tensor: {input.numel()}"
        # TODO: implement without replacement
        raise NotImplementedError("multinomial without replacement not implemented")
    
    def _multinomial_1D(input: Tensor, num_samples: int, replacement=replacement):
        assert input.ndim == 1, "input must be a 1-dimensional tensor"
        cdf = input.cumsum()
        cdf /= cdf[-1] # normalize the cdf (you could normalize the input tensor instead) (this needs to be across rows)
        unif_samples = Tensor.rand(num_samples, 1)
        indices = (unif_samples.expand(num_samples, input.numel()) >= cdf).sum(1) # expand is O(1), comparison is O(n * m), sum is O(m * n)
        return indices
    
    return Tensor.stack([_multinomial_1D(row, num_samples, replacement) for row in input])\
        if input.ndim == 2 else _multinomial_1D(input, num_samples, replacement)


In [126]:
def multinomial_optim(input: Tensor, num_samples, replacement=False):
    assert isinstance(input, Tensor) and (input.ndim in [1, 2]), "input must be a 1 or 2-dimensional tensor"
    assert isinstance(num_samples, int) and num_samples > 0, "num_samples must be a positive integer"

    if not replacement:
        assert num_samples <= input.numel(), f"num_samples: {num_samples} must be less than or equal to the number of elements in input tensor: {input.numel()}"
        raise NotImplementedError("multinomial without replacement not implemented")

    if input.ndim == 1: input = input.reshape(1, -1) # Reshape input to 2D if it's 1D
    input = input / input.sum(1, keepdim=True) # Normalize each row of the input tensor
    cdf = input.cumsum(1) # Compute CDF for each row
    unif_samples = Tensor.rand(num_samples, input.shape[0]) # Generate uniform random samples
    indices = (unif_samples.unsqueeze(2) >= cdf.unsqueeze(0)).sum(1) # Determine indices by comparing samples with CDF
    return indices


In [128]:
t = Tensor([[0.1, 0.2, 0.3, 0.4], [0.1, 0.2, 0.3, 0.4], [1, 2, 3, 3]]) # 2-dimensional tensor

optim = multinomial_optim(t, 5, replacement=True)
non_optim = muntinomial(t, 5, replacement=True)

In [129]:
print(f"optim: {optim.numpy()}")
print(f"non_optim: {non_optim.numpy()}")

optim: [[3. 3. 2. 0.]
 [3. 2. 0. 0.]
 [3. 2. 2. 0.]
 [3. 2. 2. 0.]
 [3. 2. 1. 0.]]
non_optim: [[1. 1. 2. 3. 3.]
 [1. 1. 3. 1. 1.]
 [3. 3. 2. 3. 3.]]


In [110]:
arr_1d = np.random.rand(1000)
arr_2d = np.random.rand(1000, 1000)

t_1d = Tensor(arr_1d)
t_2d = Tensor(arr_2d)

pt_1d = torch.tensor(arr_1d)
pt_2d = torch.tensor(arr_2d)

assert np.allclose(pt_1d.numpy(), t_1d.numpy()), f"pt_1d: {pt_1d.numpy()} != t_1d: {t_1d.numpy()}"

num_samples = 1000

# test pytorchs multinomial
pt_indices_1d = torch.multinomial(pt_1d, num_samples, replacement=True).numpy()
pt_indices_2d = torch.multinomial(pt_2d, num_samples, replacement=True).numpy()

# test tinygrads multinomial
indices_1d = muntinomial(t_1d, num_samples, replacement=True).numpy()
indices_2d = muntinomial(t_2d, num_samples, replacement=True).numpy()

print(np.mean(indices_1d))
print(np.mean(pt_indices_1d))

print(np.mean(indices_2d))
print(np.mean(pt_indices_2d))

476.379
490.792
498.95505
499.520523
