In [1]:
%load_ext line_profiler
%load_ext autoreload
%autoreload 2
%env PYTORCH_ENABLE_MPS_FALLBACK=1

env: PYTORCH_ENABLE_MPS_FALLBACK=1


In [4]:
import numbers
import logging

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import hyclib as lib

In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(device)

lib.logging.basic_config()
logger = logging.getLogger()

mps


In [5]:
def as_tuple(out):
    return out if isinstance(out, tuple) else (out,)

def test_unique(M, shape, O, dim, sorted, return_index, return_inverse, return_counts, device):
    kwargs = {
        'return_index': return_index,
        'return_inverse': return_inverse,
        'return_counts': return_counts,
    }
    np_kwargs = kwargs.copy()
    np_kwargs['return_index'] = True # always use return_index=True in order to have stable sort
    
    t = torch.randint(M, size=shape)
    if O > 0:
        t = t.float()
        indices = [torch.randint(D, size=(O,)) for D in shape]
        t[tuple(indices)] = torch.nan
    t = t.to(device)
    a = t.cpu().numpy()

    torch_results = [t for t in as_tuple(lib.pt.unique(t, dim=dim, sorted=sorted, **kwargs))]
    np_results = as_tuple(np.unique(t.cpu().numpy(), axis=dim, equal_nan=False, **np_kwargs))
    if not kwargs['return_index']:
        np_results = list(np_results)
        del np_results[1]

    if not sorted:
        if dim is None:
            sort_idx = torch_results[0].argsort(stable=True)
        else:
            sort_idx = lib.pt.lexsort(torch_results[0].movedim(dim, -1).flip(0))
        sort_idx_inv = lib.pt.inv_perm(sort_idx)

    keys = ['x'] + [k for k, v in kwargs.items() if v]

    for key, torch_result, np_result in zip(keys, torch_results, np_results):
        print(key, torch_result)
        if sorted:
            torch.testing.assert_close(torch_result, torch.from_numpy(np_result).to(device), equal_nan=True)
        else:
            if key == 'return_inverse':
                torch.testing.assert_close(sort_idx_inv[torch_result], torch.from_numpy(np_result).to(device), equal_nan=True)
            else:
                if dim is None or key != 'x':
                    torch.testing.assert_close(torch_result[sort_idx], torch.from_numpy(np_result).to(device), equal_nan=True)
                else:
                    torch.testing.assert_close(torch_result.movedim(dim, 0)[sort_idx].movedim(0, dim), torch.from_numpy(np_result).to(device), equal_nan=True)

In [6]:
M, shape, O, dim = (3, (2, 10), 5, 1)
sorted = False
return_index = False
return_inverse = True
return_counts = False
device = 'cpu'
test_unique(M, shape, O, dim, sorted, return_index, return_inverse, return_counts, device)

x tensor([[0., 0., 1., 2., 2., nan, 0., 2., nan],
        [0., 2., 2., 1., 2., 2., nan, nan, 0.]])
return_inverse tensor([5, 0, 1, 4, 6, 3, 0, 2, 7, 8])


In [75]:
M, D, shape, O, dim = (4, 6, (99,100), 60_000, -1)
t = torch.randint(M, size=(D,*shape))
if O > 0:
    t = t.float()
    indices = [torch.randint(D, size=(O,))]
    indices += [torch.randint(N, size=(O,)) for N in shape]
    t[tuple(indices)] = torch.nan
t = t.to(device)
a = t.cpu().numpy()

pt_idx = lib.pt.lexsort(t, dim=dim)
np_idx = np.lexsort(a, axis=dim)

print(t.numel())
print(t[~t.isnan()].numel())

torch.testing.assert_close(pt_idx, torch.from_numpy(np_idx).to(device), equal_nan=True)

59400
21560


In [64]:
# M, N, D, O = 2, 100_000, 20, 1_000
# M, N, D, O = 3, 100, 2, 10
M, N, D, O = 5, 10_000, 5, 1_000
t = torch.randint(M, size=(N,D)).float()
idx_0 = torch.randint(N, size=(O,))
idx_1 = torch.randint(D, size=(O,))
t[idx_0, idx_1] = torch.nan
t = t.to(device)
t_cpu = t.cpu()
a = t.cpu().numpy()

In [60]:
dim = 0
kwargs = {
    'return_index': True,
    'return_inverse': True,
    'return_counts': True,
}

torch_results = [t for t in as_tuple(lib.pt.unique(t, dim=dim, **kwargs))]
np_results = as_tuple(np.unique(t.cpu().numpy(), axis=dim, equal_nan=False, **kwargs))

for torch_result, np_result in zip(torch_results, np_results):
    torch.testing.assert_close(torch_result, torch.from_numpy(np_result).to(device), equal_nan=True)

In [65]:
dim = 0
kwargs = {
    'return_index': True,
    'return_inverse': True,
    'return_counts': True,
}

def as_tuple(out):
    return out if isinstance(out, tuple) else (out,)

torch_results = [t for t in as_tuple(lib.pt.unique(t, dim=dim, **kwargs))]
np_results = as_tuple(np.unique(t.cpu().numpy(), axis=dim, equal_nan=False, **kwargs))
print(len(torch_results[0]))

sort_idx = lib.pt.lexsort(torch_results[0].t().flip(0))
sort_idx_inv = lib.pt.inv_perm(sort_idx)

keys = ['x'] + [k for k, v in kwargs.items() if v]

for key, torch_result, np_result in zip(keys, torch_results, np_results):
    if key == 'return_inverse':
        torch.testing.assert_close(sort_idx_inv[torch_result], torch.from_numpy(np_result).to(device), equal_nan=True)
    else:
        torch.testing.assert_close(torch_result[sort_idx], torch.from_numpy(np_result).to(device), equal_nan=True)

3902


In [9]:
%lprun -f lib.pt.unique -f lib.pt._unique_sorted lib.pt.unique(t, dim=dim, **kwargs)

Timer unit: 1e-09 s

Total time: 0.098428 s
File: /Users/hoyinchau/local_documents/research/hyclib/hyclib/pt/core.py
Function: _unique_sorted at line 30

Line #      Hits         Time  Per Hit   % Time  Line Contents
    30                                           def _unique_sorted(x, dim=None, return_index=False, return_inverse=False, return_counts=False):
    31         1          0.0      0.0      0.0      if dim is None:
    32                                                   x = x.reshape(-1)
    33                                                   dim = 0
    34                                                   
    35                                               # reshape to a 2D tensor where we compute unique rows
    36         1     100000.0 100000.0      0.1      x = x.movedim(dim, 0)
    37         1       3000.0   3000.0      0.0      shape = x.shape
    38         1      20000.0  20000.0      0.0      x = x.reshape(shape[0], -1)
    39                                 

In [28]:
lib.pt.unique(t, dim=dim, **kwargs)[0].shape

torch.Size([95389, 20])

In [34]:
%timeit lib.pt.unique(t, dim=dim, **kwargs)
%timeit lib.pt.unique(t_cpu, dim=dim, sorted=False, **kwargs)
%timeit lib.pt.unique(t_cpu, dim=dim, sorted=True, **kwargs)

58 ms ± 219 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
167 ms ± 1.34 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
313 ms ± 8.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [35]:
%timeit np.unique(a, axis=dim, equal_nan=False, **kwargs)

318 ms ± 9.39 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [30]:
%timeit lib.pt.unique(t, dim=dim, **kwargs)
%timeit lib.pt.unique(t_cpu, dim=dim, sorted=False, **kwargs)
%timeit lib.pt.unique(t_cpu, dim=dim, sorted=True, **kwargs)
%timeit np.unique(a, axis=dim, equal_nan=False, **kwargs)
%timeit lib.np.unique_rows(a, **kwargs)

24.6 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
144 ms ± 300 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
150 ms ± 386 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
226 ms ± 5.65 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
111 ms ± 1.44 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [175]:
torch.masked.MaskedTensor(torch.tensor([1,2,3]), torch.tensor([True, True, False])).max()

If you would like this operator to be supported, please file an issue for a feature request at https://github.com/pytorch/maskedtensor/issues with a minimal reproducible code snippet.
In the case that the semantics for the operator are not trivial, it would be appreciated to also include a proposal for the semantics.


TypeError: no implementation found for 'torch._ops.aten.max.default' on types that implement __torch_dispatch__: [<class 'torch.masked.maskedtensor.core.MaskedTensor'>]

In [79]:
for _ in range(1000):
    t = torch.normal(mean=0, std=1, size=(10,), device='mps')
    out_1 = lib.pt.stats.bin(t.cpu(), bins=3)[0]
    out_2 = lib.pt.stats.bin(t, bins=3)[0].cpu()
    out_3 = lib.sp.stats.bin(t.cpu().numpy(), bins=3)[0]
    assert (out_1 == out_2.cpu()).all() and (out_1.numpy() == out_3).all()

In [74]:
out_1 = lib.pt.stats.bin(a.cpu(), bins=3)
out_2 = lib.pt.stats.bin(a, bins=3)
out_3 = lib.sp.stats.bin(a.cpu(), bins=3)
print(out_1)
print(out_2)
print(out_3)

(tensor([3, 2, 3, 3, 3, 1, 3, 2, 2, 3]), tensor([    nan, -1.2486, -0.3541,  0.5404,     nan]), tensor([-1.6958, -0.8013,  0.0931,  0.9876]))
(tensor([3, 2, 3, 3, 3, 1, 3, 2, 2, 3], device='mps:0'), tensor([    nan, -1.2486, -0.3541,  0.5404,     nan], device='mps:0'), tensor([-1.6958, -0.8013,  0.0931,  0.9876], device='mps:0'))
(array([3, 2, 3, 3, 3, 1, 3, 2, 2, 3]), array([        nan, -1.24855185, -0.3540917 ,  0.54036844,         nan]), array([-1.695782  , -0.8013218 ,  0.09313837,  0.98759854], dtype=float32))


In [7]:
M, N = 100, 100000

# bins = [-2.5,-2,-1,0,1,2,3.0]
bins = 100
arr = np.random.normal(size=N)
indices = np.random.randint(N, size=M)
arr[indices] = np.nan

In [8]:
binnumbers, centers, edges = lib.sp.stats.bin(arr, bins=bins, nan_policy='raise')

t = torch.tensor(arr)
tbinnumbers, tcenters, tedges = lib.pt.stats.bin(t, bins=bins, nan_policy='raise')

pbinnumbers, pedges = pd.cut(arr, bins=bins, retbins=True, labels=False, right=False)

# print(binnumbers)
# print(tbinnumbers.numpy())
# print(pbinnumbers)
# print(centers)
# print(tcenters.numpy())
# print(binnumbers[np.isnan(arr)])
# print(tbinnumbers.numpy()[np.isnan(arr)])
# print(pbinnumbers[np.isnan(arr)])

torch.testing.assert_close(torch.from_numpy(binnumbers), tbinnumbers)
torch.testing.assert_close(torch.from_numpy(centers), tcenters, equal_nan=True)
torch.testing.assert_close(torch.from_numpy(edges), tedges)

isnan = np.isnan(pbinnumbers)
np.testing.assert_allclose(binnumbers[~isnan], pbinnumbers[~isnan] + 1)
if isinstance(bins, numbers.Number):
    r = np.nanmax(arr) - np.nanmin(arr)
    edges[-1] = edges[-1] + r * 0.001
np.testing.assert_allclose(edges, pedges)

ValueError: array([-0.79857488, -0.29282404,  0.29925128, ..., -0.66033549,
       -0.61356288,  0.04627834]) contains non-finite values.