In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DEVICE

device(type='cuda', index=0)

In [3]:
class GG:
    def __init__(self) -> None:
        pass

    def torch(self, X: np.ndarray, p: float = 2):
        X = torch.Tensor(X).to(DEVICE)
        n = X.shape[0]
        F = torch.cdist(X,X, p)**p
        F.fill_diagonal_(float('inf'))

        adj = torch.zeros((n,n), dtype=torch.bool).to(DEVICE)
        for i in tqdm(range(n-1)):
            A = F[i]+F[i+1:]
            idx_min = torch.argmin(A, axis=1)
            a = A[torch.arange(A.shape[0]), idx_min] - F[i, i+1:]
            adj[i, i+1:] = torch.where(a > 0, 1, 0)
        adj = adj + adj.T
        return adj.cpu()
    
    def batch(self, X: np.ndarray, btsz: int, p: float = 2):
        X = torch.Tensor(X).to(DEVICE)
        n = X.shape[0]
        adj = torch.zeros((n, n), dtype = torch.bool).to(DEVICE)
        for i in tqdm(range(n)):
            delta = torch.cdist(X[i:i+1, :], X, p)**p
            delta[0, i] = float('inf')
            val_min = torch.ones(n).to(DEVICE) * float('inf')
            for b in range(0, n, btsz):
                X_batch = X[b:b+btsz, :]
                F_batch = torch.cdist(X_batch, X, p)**p
                diag_idx = np.diag_indices(btsz)
                diag_idx = (diag_idx[0], diag_idx[1] + b)
                F_batch[diag_idx] = float('inf')
                A_batch = delta[0, :btsz] + F_batch.T
                val_min_batch, _ = torch.min(A_batch, axis = 1)
                val_min, _ = torch.min(torch.stack((val_min, val_min_batch), dim = 1), dim = 1)
                del X_batch, F_batch, A_batch, val_min_batch
            a = val_min - delta[0, :]
            adj[i, :] = torch.where(a > 0, 1, 0)
            del val_min, a
        adj = adj + adj.T
        return adj.cpu()
    
    def separate(self, X_train: np.ndarray, X_test: np.ndarray, btsz: int, p: float = 2):
        X_train = torch.Tensor(X_train).to(DEVICE)
        X_test = torch.Tensor(X_test).to(DEVICE)
        n = X_train.shape[0]
        N = X_test.shape[0]
        adj = torch.zeros((N, n), dtype = torch.bool).to(DEVICE)
        delta = torch.cdist(X_test, X_train, p)**p
        for i in tqdm(range(N)):
            # delta = torch.cdist(X_test[i:i+1, :], X_train, p)**p
            val_min = torch.ones(n).to(DEVICE) * float('inf')
            for b in range(0, n, btsz):
                X_batch = X_train[b:b+btsz, :]
                F_batch = torch.cdist(X_batch, X_train, p)**p
                A_batch = delta[i, :btsz] + F_batch.T
                val_min_batch, _ = torch.min(A_batch, axis = 1)
                val_min, _ = torch.min(torch.stack((val_min, val_min_batch), dim = 1), dim = 1)
                del X_batch, F_batch, A_batch, val_min_batch
            a = val_min - delta[i, :]
            adj[i, :] = torch.where(a > 0, 1, 0)
            del val_min, a
        return adj.cpu()

In [4]:
H_train = torch.load('data/H_train.pt')
H_test = torch.load('data/H_test.pt')

In [5]:
# ggclass = GG()
# # adjb = ggclass.batch(H_train, len(H_train) // 4, p = 64)
# adjt = ggclass.torch(H_train[25000:, :], p = 64)

In [6]:
# vizinhos = torch.sum(adjt, axis = 0)
# torch.mean(vizinhos.float())

In [7]:
# torch.save(adjt, 'data/gg_train_1.pt')

In [8]:
ggclass = GG()
adjs = ggclass.separate(H_train, H_test, len(H_train) // 4, p = 64)

  2%|▏         | 248/10000 [2:16:34<89:30:32, 33.04s/it]


KeyboardInterrupt: 

In [None]:
torch.save(adjs, 'data/gg_test.pt')