In [12]:
import torch
from tqdm import tqdm
from sklearn.neighbors import KDTree
import h5py

In [2]:
n_embs = 100_000

In [3]:
Z, D = torch.rand(n_embs, 2048), torch.randint(high=4, size=(n_embs, 1))

In [4]:
doms = {}
for i, (z, d) in enumerate(tqdm(zip(Z, D), total=n_embs)):
    d_ = d.item()
    if d_ not in doms.keys():
        doms[d_] = ([z], [i])
    else:
        doms[d_][0].append(z)
        doms[d_][1].append(i)
doms = {k:(torch.vstack(z).contiguous().detach(), torch.tensor(i).contiguous().detach()) for k, (z, i) in doms.items()}

100%|██████████| 100000/100000 [00:00<00:00, 1193499.68it/s]


In [5]:
trees = {k: KDTree(z, leaf_size=1_000) for k, (z, _) in tqdm(doms.items())}

100%|██████████| 4/4 [00:17<00:00,  4.45s/it]


In [7]:
def nearest_ood_neighbors(z: torch.Tensor, k: int, d: int):
    """Returns global indices of k n-neighbors per domain in all domains 
    that are not d (reurns (D-1)*n neighbors in total).

    Parameters
    ----------
    z : torch.Tensor
        Embedding for which to find the nearest neighbors.
    d : int
        ID of domain from which z originates.
    k : int
        Number of neighbors to be returned 
    """   

    idcs = [] 
    (g := list(doms.keys())).remove(d)
    for d_ in g:
        local_idcs = trees[d_].query(z, k=k, return_distance=False)
        global_idcs = [doms[d_][1][i] for i in local_idcs]
        idcs.append(global_idcs[0])

    return torch.hstack(idcs)

### Make File 

In [21]:
f = h5py.File('neighborhood.hdf5', 'w')
dset = f.create_dataset("neighborhood", (1000, 9), dtype='i')

In [22]:
f['neighborhood'][:]

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int32)

In [23]:
for i, (z, d) in enumerate(tqdm(list(zip(Z, D))[:1000])):
    f['neighborhood'][i] = nearest_ood_neighbors(z[None,:], 3, d)

100%|██████████| 1000/1000 [02:04<00:00,  8.06it/s]


In [24]:
f.close()