In [2]:
import numpy as np
import torch
from scipy.stats import multivariate_normal
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda', index=0)

In [5]:
class GG:
    def __init__(self) -> None:
        pass

    def torch(self, X: np.ndarray, p: float = 2):
        X = torch.Tensor(X).to(DEVICE)
        n = X.shape[0]
        F = torch.cdist(X,X, p)**p
        F.fill_diagonal_(float('inf'))

        adj = torch.zeros((n,n), dtype=torch.bool).to(DEVICE)
        for i in tqdm(range(n-1)):
            A = F[i]+F[i+1:]
            idx_min = torch.argmin(A, axis=1)
            a = A[torch.arange(A.shape[0]), idx_min] - F[i, i+1:]
            adj[i, i+1:] = torch.where(a > 0, 1, 0)
        adj = adj + adj.T
        return adj.cpu()
    
    def batch(self, X: np.ndarray, btsz: int, p: float = 2):
        X = torch.Tensor(X).to(DEVICE)
        n = X.shape[0]
        adj = torch.zeros((n, n), dtype = torch.bool).to(DEVICE)
        for i in tqdm(range(n)):
            delta = torch.cdist(X[i:i+1, :], X, p)**p
            delta[0, i] = float('inf')
            val_min = torch.ones(n).to(DEVICE) * float('inf')
            for b in range(0, n, btsz):
                X_batch = X[b:b+btsz, :]
                F_batch = torch.cdist(X_batch, X, p)**p
                diag_idx = np.diag_indices(btsz)
                diag_idx = (diag_idx[0], diag_idx[1] + b)
                F_batch[diag_idx] = float('inf')
                A_batch = delta[0, :btsz] + F_batch.T
                val_min_batch, _ = torch.min(A_batch, axis = 1)
                val_min, _ = torch.min(torch.stack((val_min, val_min_batch), dim = 1), dim = 1)
                del X_batch, F_batch, A_batch, val_min_batch
            a = val_min - delta[0, :]
            adj[i, :] = torch.where(a > 0, 1, 0)
            del val_min, a
        adj = adj + adj.T
        return adj.cpu()

In [None]:
def separate(X_train: np.ndarray, X_test: np.ndarray, btsz: int, p: float = 2)
    X_train = torch.Tensor(X_train).to(DEVICE)
    X_test = torch.Tensor(X_test).to(DEVICE)
    n = X_train.shape[0]
    N = X_test.shape[0]
    adj = torch.zeros((N, n), dtype = torch.bool).to(DEVICE)
    for i in tqdm(range(N)):
        delta = torch.cdist(X[i:i+1, :], X, p)**p
        delta[0, i] = float('inf')
        val_min = torch.ones(n).to(DEVICE) * float('inf')
        for b in range(0, n, btsz):
            X_batch = X[b:b+btsz, :]
            F_batch = torch.cdist(X_batch, X, p)**p
            diag_idx = np.diag_indices(btsz)
            diag_idx = (diag_idx[0], diag_idx[1] + b)
            F_batch[diag_idx] = float('inf')
            A_batch = delta[0, :btsz] + F_batch.T
            val_min_batch, _ = torch.min(A_batch, axis = 1)
            val_min, _ = torch.min(torch.stack((val_min, val_min_batch), dim = 1), dim = 1)
            del X_batch, F_batch, A_batch, val_min_batch
        a = val_min - delta[0, :]
        adj[i, :] = torch.where(a > 0, 1, 0)
        del val_min, a
    adj = adj + adj.T
    return adj

In [11]:
N = 5000
btsz = 1000
X = multivariate_normal.rvs(cov = np.eye(16), size = N)
X.shape

(5000, 16)

In [12]:
ggclass = GG()
adjt = ggclass.torch(X)

100%|██████████| 4999/4999 [00:02<00:00, 2477.63it/s]


In [13]:
adjb = ggclass.batch(X, btsz)

100%|██████████| 5000/5000 [00:14<00:00, 344.27it/s]


In [14]:
np.unique(adjt.cpu().numpy() == adjb.cpu().numpy(), return_counts = True)

(array([False,  True]), array([ 2310978, 22689022]))

In [15]:
ggclass = GG()
adjt = ggclass.torch(X, p = 16)
adjb = ggclass.batch(X, btsz, p = 16)
np.unique(adjt.cpu().numpy() == adjb.cpu().numpy(), return_counts = True)

100%|██████████| 4999/4999 [00:02<00:00, 2458.93it/s]
  2%|▏         | 93/5000 [00:16<14:53,  5.49it/s]


KeyboardInterrupt: 

In [27]:
X = torch.Tensor(X).to(DEVICE)
n = X.shape[0]
adj = torch.zeros((n, n), dtype = torch.bool).to(DEVICE)

In [28]:
for i in tqdm(range(n)):
    break

  0%|          | 0/500 [00:00<?, ?it/s]


In [29]:
delta = torch.cdist(X[i:i+1, :], X)**2
val_min = torch.ones(n).to(DEVICE) * 1e9
delta.shape

torch.Size([1, 500])

In [30]:
for b in range(0, n, btsz):
    break

In [36]:
X_batch = X[b:b+btsz, :]
F_batch = torch.cdist(X_batch, X)**2
F_batch.shape

torch.Size([100, 500])

In [37]:
diag_idx = np.diag_indices(btsz)
diag_idx = (diag_idx[0], diag_idx[1] + b)
diag_idx

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
        34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
        51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
        68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
        85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]),
 array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
        34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50,
        51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67,
        68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84,
        85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]))

In [40]:
F_batch[diag_idx] = float('inf')
F_batch

tensor([[    inf, 44.4403, 35.6596,  ..., 30.7452, 53.4618, 51.7180],
        [44.4403,     inf, 23.8903,  ..., 21.2678, 51.6006, 48.1585],
        [35.6596, 23.8903,     inf,  ..., 21.4548, 49.9805, 24.7000],
        ...,
        [59.0344, 16.4907, 22.4236,  ..., 38.8746, 69.6844, 51.4538],
        [72.1047, 39.4661, 55.4588,  ..., 40.0942, 49.2782, 70.8818],
        [46.5412, 37.9725, 35.3928,  ..., 47.0189, 61.4506, 35.0358]],
       device='cuda:0')

In [44]:
A_batch = delta[0, :btsz].T + F_batch.T
A_batch.shape

torch.Size([500, 100])

In [None]:


    X_batch = X[b:b+btsz, :]
    F_batch = torch.cdist(X_batch, X)**2
    diag_idx = np.diag_indices(btsz)
    diag_idx = (diag_idx[0], diag_idx[1] + b)
    F_batch[diag_idx] = np.inf
    A_batch = delta[0, :btsz] + F_batch.T
    val_min_batch, _ = torch.min(A_batch, axis = 1)
    val_min, _ = torch.min(torch.stack((val_min, val_min_batch), dim = 1), dim = 1)
    del X_batch, F_batch, A_batch, val_min_batch
a = val_min - delta[0, :]
adj[i, :] = torch.where(a > 0, 1, 0)
del val_min, a
adj = adj + adj.T