In [1]:
# %%
import numpy as np
import torch
import numba
from numba import cuda, prange
#import cupy as cp

from tqdm import tqdm

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(DEVICE)

class GG:
    def __init__(self) -> None:
        pass

    def torch(self, X: np.ndarray):
        X = torch.Tensor(X).to(DEVICE)
        n = X.shape[0]
        F = torch.cdist(X,X)**2
        F.fill_diagonal_(float('inf'))

        adj = torch.zeros((n,n), dtype=torch.bool).to(DEVICE)
        for i in tqdm(range(n-1)):
            A = F[i]+F[i+1:]
            idx_min = torch.argmin(A, axis=1)
            a = A[torch.arange(A.shape[0]), idx_min] - F[i, i+1:]
            adj[i, i+1:] = torch.where(a > 0, 1, 0)
        adj = adj + adj.T
        return adj
    
    def torch_batch_ii(self, X: np.ndarray, btsz: int, tol: int):
        X = torch.Tensor(X).to(DEVICE)
        n = X.shape[0]
        tol = 1e-6
        # F = torch.cdist(X,X)**2
        # F.fill_diagonal_(float('inf'))

        adj = torch.zeros((n,n), dtype=torch.bool).to(DEVICE)
        print(len(range(0, n-1, btsz)))
        for i in (range(0, n-1, btsz)):
            ii = slice(i, min(i + btsz, n))
            Fi = torch.cdist(X[ii, :], X)**2
            Fi[Fi < tol] = float('inf')
            for j in tqdm(range(i, n)):
                Fj = torch.cdist(X[j:j+1, :], X)**2
                Fj[:, j] = float('inf')
                A = Fi + Fj
                idx_min = torch.argmin(A, axis = 1)
                a = A[torch.arange(A.shape[0]), idx_min] - Fj[:, ii]
                adj[ii, j] = torch.where(a > 0, 1, 0)
            del Fi
        adj = adj + adj.T
        return adj
    
    def torch_batch_jj(self, X: np.ndarray, btsz: int, tol: int):
        X = torch.Tensor(X).to(DEVICE)
        n = X.shape[0]
        # F = torch.cdist(X,X)**2
        # F.fill_diagonal_(float('inf'))

        adj = torch.zeros((n,n), dtype=torch.bool).to(DEVICE)
        # adj = torch.tensor([], dtype = torch.int).to(DEVICE)
        for i in tqdm(range(n-1)):
            # A = F[i]+F[i+1:]
            # idx_min = torch.argmin(A, axis=1)
            # a = A[torch.arange(A.shape[0]), idx_min] - F[i, i+1:]
            # adj[i, i+1:] = torch.where(a > 0, 1, 0)
            Fi = torch.cdist(X[i:i+1, :], X)**2
            Fi[:, i] = float('inf')
            for j in (range(i+1, n, btsz)):
                jj = slice(j, min(j + btsz, n))
                Fjj = torch.cdist(X[jj, :], X)**2
                Fjj[Fjj < tol] = float('inf')
                A = Fi + Fjj
                idx_min = torch.argmin(A, axis = 1)
                a = A[torch.arange(A.shape[0]), idx_min] - Fjj[:, i]
                adj[i, jj] = torch.where(a > 0, 1, 0)
                # mask = torch.where(a > 0, True, False)
                # idx = torch.arange(jj.start, jj.stop).to(DEVICE)[mask]
                # idx = torch.stack((torch.full_like(idx, i), idx), dim = 1)
                # adj = torch.cat((adj, idx))
        adj = adj + adj.T
        return adj

  from .autonotebook import tqdm as notebook_tqdm


cuda:0


In [2]:
# H = np.load('data/H_train.npy')
X = np.load('data/H_train.npy')

In [3]:
batch_size = 10000
# sz = 10000
# print(sz, batch_size)
# idx = np.random.choice(len(H), size = sz)
# X = H[idx, :]

In [4]:
ggclass = GG()

In [5]:
# adj = ggclass.torch(X)

In [6]:
# adj_batch = ggclass.torch_batch_jj(X, btsz = batch_size, tol = 1e-6)

In [7]:
# label, counts = np.unique(adj.cpu() == adj_batch.cpu(), return_counts = True)
# print(label, counts / sum(counts))

In [8]:
adj_batch = ggclass.torch_batch_ii(X, btsz = batch_size, tol = 1e-6)

6


100%|██████████| 60000/60000 [17:51<00:00, 55.97it/s]
100%|██████████| 50000/50000 [14:53<00:00, 55.97it/s]
100%|██████████| 40000/40000 [11:54<00:00, 55.97it/s]
100%|██████████| 30000/30000 [08:56<00:00, 55.96it/s]
100%|██████████| 20000/20000 [05:57<00:00, 55.95it/s]
100%|██████████| 10000/10000 [02:58<00:00, 55.90it/s]


In [9]:
# label, counts = np.unique(adj.cpu() == adj_batch.cpu(), return_counts = True)
# print(label, counts / sum(counts))

In [10]:
torch.save(adj_batch, 'data/H_gg_train_full.pt')