In [1]:
import numpy as np
import torch

In [57]:
lats = np.load("/home/thaddaus/code/ocood/clevr_latents.npy")
# imgs = np.load("/mnt/qb/work/bethge/jbrady61/clevr_data/clevr_obs.npy")
nums = np.load("/home/thaddaus/code/ocood/clevr_num_objects.npy")

In [58]:
# distribution over number of objects
np.unique(nums, return_counts=True)

(array([1, 2, 3, 4, 5, 6]), array([   14,  1040, 13145, 12875, 13032, 13377]))

In [59]:
# normalize latents
mins = lats.min(axis=(0,1))
maxs = lats.max(axis=(0,1))
lats_norm = (lats - mins) / (maxs - mins)
print(lats_norm.min(axis=(0,1)), lats_norm.max(axis=(0,1)))

[0. 0. 0. 0. 0. 0.] [1. 1. 1. 1. 1. 1.]


In [62]:
np.sqrt(6) / 2

1.224744871391589

In [81]:
def get_ID_OOD_splits(
    latents: np.ndarray, delta: float
) -> np.ndarray:
    """
    All distances from the diagonal are bigger than delta.
    Opposite case of __sample_delta_diagonal_cube.

    Rejection sampling used as the algorithm.
    """
    indcs_ID = np.ndarray((0))
    indcs_OOD = np.ndarray((0))
    indcs_ID_per_n = []
    indcs_OOD_per_n = []

    # we need to separately sample for different numbers of objects
    for n_slots in range(1, latents.shape[1] + 1):
        max_delta = np.sqrt(n_slots) / 2

        # we only need to keep track of the indeces
        at_most_n_slots = np.all(lats[:, n_slots:, :] == 0, axis=(1, 2))
        at_least_n_slots = np.any(lats[:, n_slots - 1, :] != 0, axis=1)
        indcs = np.nonzero(at_most_n_slots & at_least_n_slots)[0]

        latents_n_slots = latents[indcs][:, :n_slots]

        # define the diagonal
        diag_unit = np.ones(n_slots) / np.sqrt(n_slots)
        # calculate the projection onto the diagonal (and from there the distance) along the
        # `slots`-dimension since one diagonal contains the i-th latent of each slot
        diag_scalar_component = np.dot(latents_n_slots.transpose(0, 2, 1), diag_unit)
        diag_component = diag_scalar_component[:, None, :] * diag_unit[None, :, None]
        orth_component = latents_n_slots - diag_component
        orth_component_norm = np.linalg.norm(orth_component, axis=1)

        # diag = np.ones((latents_n_slots.shape[0], n_slots, latents.shape[2]))

        # get distance to line defined by diagonal along each latent dimension
        # ort_vec = latents_n_slots - diag * (latents_n_slots * diag).sum(axis=1, keepdims=True)\
        #         / (diag * diag).sum(axis=1, keepdims=True)
        
        # find ID/OOD points based on distance to diagonal
        mask_ID = np.all(orth_component_norm <= delta * max_delta, axis=1)
        mask_OOD = np.any(orth_component_norm > delta * max_delta, axis=1)

        # print(mask_ID.sum(), mask_OOD.sum(), mask_OOD.sum() + mask_ID.sum() == len(indcs))
        
        indcs_ID = np.append(indcs_ID, indcs[mask_ID])
        indcs_OOD = np.append(indcs_OOD, indcs[mask_OOD])
        
        indcs_ID_per_n.append(mask_ID.sum())
        indcs_OOD_per_n.append(mask_OOD.sum())

    return indcs_ID, indcs_OOD, indcs_ID_per_n, indcs_OOD_per_n

In [84]:
# delta in [0 .. 1]
indcs_ID, indcs_OOD, indcs_ID_per_n, indcs_OOD_per_n = get_ID_OOD_splits(lats_norm, 0)

print([indcs_ID_per_n[i] / indcs_OOD_per_n[i] for i in range(6)])
print(indcs_ID_per_n)
print(sum(indcs_ID_per_n), sum(indcs_OOD_per_n))

[inf, 0.0, 0.0, 0.0, 0.0, 0.0]
[14, 0, 0, 0, 0, 0]
14 53469


  print([indcs_ID_per_n[i] / indcs_OOD_per_n[i] for i in range(6)])


# Sanity checks for ID and OOD sampling

In [2]:
def sample_delta_diagonal_cube(
    n_samples: int, n_slots: int, n_latents: int, delta: float, oversampling: int = 10
) -> torch.Tensor:
    """
    Sample near the diagonal in latent space i.e. all distances from the diagonal are less than delta.

    Algorithm:
        1. Draw points on the diagonal of [0, 1)^(n_slots, n_latents) cube.
        2. For every latent draw uniformly noise from n_slots-dimensional ball. For drawing uniformly inside the ball we
            use the following theorem (http://compneuro.uwaterloo.ca/files/publications/voelker.2017.pdf):
            if point uniformly sampled from the (n+1)-sphere, then n-first coordinates are uniformly sampled from the n-ball.
        3. Project sampled inside-ball points to the hyperplane perpendicular to the diagonal and normalize them
            (this gives us points on (n_slots-2)-sphere embedded in n_slots-space).
        4. Get final points by adding the diagonal point to the projected points.
        5. Keep only points inside the [0, 1)^(n_slots, n_latents) cube.
    """
    _n = oversampling * n_samples
    z_out = torch.Tensor(0, n_slots, n_latents)
    while z_out.shape[0] < n_samples:
        # sample randomly on diagonal
        z_sampled = torch.repeat_interleave(
            torch.rand(_n, n_latents), n_slots, dim=0
        ).reshape(_n, n_slots, n_latents)

        # sample noise from n_slots-ball
        noise = torch.randn(_n, n_slots + 2, n_latents)
        noise = noise / torch.norm(noise, dim=1, keepdim=True)  # points on n-sphere
        noise = noise[:, :n_slots, :]  # remove two last points

        # project to hyperplane perpendicular to diagonal
        ort_vec = noise - z_sampled * (noise * z_sampled).sum(axis=1, keepdim=True) / (
            z_sampled * z_sampled
        ).sum(axis=1, keepdim=True)
        ort_vec /= torch.norm(ort_vec, p=2, dim=1, keepdim=True)

        # final step
        # why n - 1 here? because we sample
        # "radius" not in the original space, but in the embedded
        final = z_sampled + (
            ort_vec
            * torch.pow(torch.rand([_n, 1, n_latents]), 1 / (n_slots - 1))
            * delta
        )

        # only keep samples inside [0, 1]^{k×l}
        mask = ((final - 0.5).abs() <= 0.5).flatten(1).all(1)
        idx = mask.nonzero().squeeze(1)

        z_out = torch.cat([z_out, final[idx]])
    z_out = z_out[:n_samples]
    return z_out[:n_samples]

In [76]:
import matplotlib.pyplot as plt

n = 100000
n_slots = 6
n_latents = 6
delta = 0.5  # in [0 .. sqrt(n_slots)/2]  or  [0 .. sqrt(n_slots * n_latents) / 2] for _norm

z = np.random.rand(n, n_slots, n_latents)

# if a_1=[a_1_1, ..., a_1_n_latents] is a single slot with dim=n_slots
# then the entire space is z=[[a_1_1, ..., a_1_n_latents], ..., [a_n_slots_1, ..., a_n_slots_n_latents]] with dim=n_latents
# on the diagonal, a_1 can assume all possible values, and the other slots must be equal to it
#   d = [a_1, a_1, a_1, ...]
# since a_1_i are completely independent of each other, there's essentially n_latents independent diagonals with dim=n_slots
#   d_:_i = [a_1_i, a_1_i, ..., a_1_i]  with  i in [1 .. n_latents]
# for a datapoint to lie "near to the diagonal", it must be near to all component diagonals, i.e.
#   [a_1_i, a_2_i, ..., a_n_slots_i] must be close to d_:_i for all i
# i.e. a_:_i - a^d_:_i < delta
# with a^d_:_i = a_:_i * diag_norm

diag_unit = np.ones(n_slots) / np.sqrt(n_slots)
# calculate the projection onto the diagonal (and from there the distance) along the
# `slots`-dimension since one diagonal contains the i-th latent of each slot
diag_scalar_component = np.dot(z.transpose(0, 2, 1), diag_unit)
diag_component = diag_scalar_component[:, None, :] * diag_unit[None, :, None]
orth_component = z - diag_component
orth_component_norm = np.linalg.norm(orth_component, axis=1)

# ID points lie within `delta` from each diagonal
mask_ID = np.all(orth_component_norm <= delta, axis=1)
# mask_ID_norm = np.linalg.norm(z_orth_component_norm, axis=1) <= delta
# OOD points have at least one component in one latent that lies outside `delta` from
# the diagonal
mask_OOD_any = np.any(orth_component_norm > delta, axis=1)
mask_OOD_all = np.all(orth_component_norm > delta, axis=1)
# mask_OOD_norm = np.linalg.norm(z_orth_component_norm, axis=1) > delta

n_ID = mask_ID.sum()
# n_ID_norm = mask_ID_norm.sum()
n_OOD_any = mask_OOD_any.sum()
n_OOD_all = mask_OOD_all.sum()
# n_OOD_norm = mask_OOD_norm.sum()
# print(n_ID, n_OOD_all, n_OOD_any, n_OOD_all / n_OOD_any)
print(n_ID, n_OOD_any, n_ID + n_OOD_any)
# print(n_ID_norm, n_OOD_norm, n_ID_norm + n_OOD_norm)

# double-check that all points in .all are in .any
print((mask_OOD_any & mask_OOD_all).sum() == mask_OOD_all.sum())

# double-check that OOD_any and ID together contain all points
print((mask_ID | mask_OOD_any).sum() == n)

# fig = plt.figure(figsize=(12, 12))
# ax = fig.add_subplot(projection='3d')
# ax.scatter(orth_component_norm[mask_OOD_any][:, 0], orth_component_norm[mask_OOD_any][:, 1], orth_component_norm[mask_OOD_any][:, 2])
# ax.scatter(z_orth_component_norm[mask_OOD_all][:, 0], z_orth_component_norm[mask_OOD_all][:, 1], z_orth_component_norm[mask_OOD_all][:, 2])

# check that ID samples don't contain any OOD samples
z_ID = sample_delta_diagonal_cube(1000, n_slots, n_latents, delta+0.00001).numpy()

diag_unit = np.ones(n_slots) / np.sqrt(n_slots)
# calculate the projection onto the diagonal (and from there the distance) along the
# `slots`-dimension since one diagonal contains the i-th latent of each slot
diag_scalar_component = np.dot(z_ID.transpose(0, 2, 1), diag_unit)
diag_component = diag_scalar_component[:, None, :] * diag_unit[None, :, None]
orth_component = z_ID - diag_component
orth_component_norm = np.linalg.norm(orth_component, axis=1)

mask_ID = np.all(orth_component_norm <= delta, axis=1)
# mask_ID_norm = np.linalg.norm(z_orth_component_norm, axis=1) <= delta
mask_OOD_any = np.any(orth_component_norm > delta, axis=1)
# mask_OOD_norm = np.linalg.norm(z_orth_component_norm, axis=1) > delta

n_ID = mask_ID.sum()
# n_ID_norm = mask_ID_norm.sum()
n_OOD_any = mask_OOD_any.sum()
# n_OOD_norm = mask_OOD_norm.sum()

print(n_ID, n_OOD_any, n_ID + n_OOD_any)
# print(n_ID_norm, n_OOD_norm, n_ID_norm + n_OOD_norm)

ax.scatter(orth_component_norm[:, 0], orth_component_norm[:, 1], orth_component_norm[:, 2])

5 99995 100000
True
True
1000 0 1000


<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x7fcbf06271f0>

In [None]:
def sample_delta_diagonal_cube(
    n_samples: int, n_slots: int, n_latents: int, delta: float, oversampling: int = 100
) -> torch.Tensor:
    """
    Sample near the diagonal in latent space i.e. all distances from the diagonal are less than delta.

    Algorithm:
        1. Draw points on the diagonal of [0, 1)^(n_slots, n_latents) cube.
        2. For every latent draw uniformly noise from n_slots-dimensional ball. For drawing uniformly inside the ball we
            use the following theorem (http://compneuro.uwaterloo.ca/files/publications/voelker.2017.pdf):
            if point uniformly sampled from the (n+1)-sphere, then n-first coordinates are uniformly sampled from the n-ball.
        3. Project sampled inside-ball points to the hyperplane perpendicular to the diagonal and normalize them
            (this gives us points on (n_slots-2)-sphere embedded in n_slots-space).
        4. Get final points by adding the diagonal point to the projected points.
        5. Keep only points inside the [0, 1)^(n_slots, n_latents) cube.
    """
    _n = oversampling * n_samples
    z_out = torch.Tensor(0, n_slots, n_latents)
    while z_out.shape[0] < n_samples:
        # sample randomly on diagonal
        z_sampled = torch.repeat_interleave(
            torch.rand(_n, n_latents), n_slots, dim=0
        ).reshape(_n, n_slots, n_latents)

        # sample noise from n_slots-ball
        noise = torch.randn(_n, n_slots + 2, n_latents)
        noise = noise / torch.norm(noise, dim=1, keepdim=True)  # points on n-sphere
        noise = noise[:, :n_slots, :]  # remove two last points

        # project to hyperplane perpendicular to diagonal
        ort_vec = noise - z_sampled * (noise * z_sampled).sum(axis=1, keepdim=True) / (
            z_sampled * z_sampled
        ).sum(axis=1, keepdim=True)
        ort_vec /= torch.norm(ort_vec, p=2, dim=1, keepdim=True)

        # final step
        # why n - 1 here? because we sample
        # "radius" not in the original space, but in the embedded
        final = z_sampled + (
            ort_vec
            * torch.pow(torch.rand([_n, 1, n_latents]), 1 / (n_slots - 1))
            * delta
        )

        # only keep samples inside [0, 1]^{k×l}
        mask = ((final - 0.5).abs() <= 0.5).flatten(1).all(1)
        idx = mask.nonzero().squeeze(1)

        z_out = torch.cat([z_out, final[idx]])
    z_out = z_out[:n_samples]
    return z_out[:n_samples]
