In [1]:
from typing import *
import random

import torch
import torch.nn.functional as F
from pytorch_lightning import seed_everything


DEVICE: str = "cpu"
NUM_ANCHORS = 300

In [2]:
def relative_projection(x: torch.Tensor, anchors: torch.Tensor) -> torch.Tensor:
    """Compute the relative representation of x with the cosine similarity

    Args:
        x: the samples absolute latents [batch, hidden_dim]
        anchors: the anchors absolute latents [anchors, hidden_dim]

    Returns:
        the relative representation of x. The relative representation is *not* normalized,
        when training on relative representation it is useful to normalize it
    """
    x = F.normalize(x, p=2, dim=-1)
    anchors = F.normalize(anchors, p=2, dim=-1)
    return torch.einsum("bm, am -> ba", x, anchors)


class LatentSpace:
    def __init__(
        self,
        encoding_type: str,
        encoder_name: str,
        vectors: torch.Tensor,
        ids: Sequence[int],
    ):
        """Utility class to represent a generic latent space

        Args:
            encoding_type: the type of latent space, i.e. "absolute" or "relative" usually
            encoder_name: the name of the encoder used to obtain the vectors
            vectors: the latents that compose the latent space
            ids: the ids associated ot the vectors
        """
        assert vectors.shape[0] == len(ids)

        self.encoding_type: str = encoding_type
        self.vectors: torch.Tensor = vectors
        self.ids: Sequence[int] = ids
        self.encoder_name: str = encoder_name

    def get_anchors(self, anchor_choice: str, num_anchors: int, seed: int) -> Sequence[int]:
        """Adopt some strategy to select the anchors.

        Args:
            anchor_choice: the selection strategy for the anchors
            seed: the random seed to use

        Returns:
            the ids of the chosen anchors
        """
        # Select anchors
        seed_everything(seed)

        if anchor_choice == "uniform":
            limit: int = len(self.ids) if anchor_choice == "uniform" else int(anchor_choice[4:])
            anchor_set: Sequence[int] = random.sample(self.ids[:limit], num_anchors)
        else:
            assert NotImplementedError

        result = sorted(anchor_set)

        return result

    def to_relative(
        self, anchor_choice: str = None, seed: int = None, anchors: Optional[Sequence[int]] = None
    ) -> "RelativeSpace":
        """Compute the relative transformation on the current space returning a new one.

        Args:
            anchor_choice: the anchors selection strategy to use, if no anchors are provided
            seed: the random seed to use
            anchors: the ids of the anchors to use

        Returns:
            the RelativeSpace associated to the current LatentSpace
        """
        assert self.encoding_type != "relative"  # TODO: for now
        anchors = self.get_anchors(anchor_choice=anchor_choice, seed=seed) if anchors is None else anchors

        anchor_latents: torch.Tensor = self.vectors[anchors]

        relative_vectors = relative_projection(x=self.vectors, anchors=anchor_latents.cpu())

        return RelativeSpace(vectors=relative_vectors, encoder_name=self.encoder_name, anchors=anchors, ids=self.ids)


class RelativeSpace(LatentSpace):
    def __init__(
        self,
        vectors: torch.Tensor,
        ids: Sequence[int],
        anchors: Sequence[int],
        encoder_name: str = None,
    ):
        """Utility class to represent a relative latent space

        Args:
            vectors: the latents that compose the latent space
            ids: the ids associated ot the vectors
            encoder_name: the name of the encoder_name used to obtain the vectors
            anchors: the ids associated to the anchors to use
        """
        super().__init__(encoding_type="relative", vectors=vectors, encoder_name=encoder_name, ids=ids)
        self.anchors: Sequence[int] = anchors



In [3]:
from torch import cosine_similarity
from scipy.stats import ortho_group

NUM_SAMPLES = 1000
HIDDEN_DIM = 128

# Some fake absolute latents
absolute_latents = torch.randn((NUM_SAMPLES, HIDDEN_DIM))

# Apply a perfect isometry to the fake absolute latents
isometric_transformation = torch.tensor(ortho_group.rvs(HIDDEN_DIM), dtype=torch.float)
isometric_absolute_latents = absolute_latents @ isometric_transformation


latent_space = LatentSpace(
    encoding_type="absolute",
    encoder_name="random_vectors",
    vectors=absolute_latents,
    ids=list(range(NUM_SAMPLES)),
)
iso_latent_space = LatentSpace(
    encoding_type="absolute",
    encoder_name="iso_random_vectors",
    vectors=isometric_absolute_latents,
    ids=list(range(NUM_SAMPLES)),
)

# The shape is [num_samples, hidden_dim]
print(latent_space.vectors.shape, iso_latent_space.vectors.shape)

# Compare the absolute latents --> low similarity since there is an isometry
cosine_similarity(latent_space.vectors, iso_latent_space.vectors).mean()

torch.Size([1000, 128]) torch.Size([1000, 128])


In [4]:
# Get some anchors
anchors_ids = latent_space.get_anchors(anchor_choice="uniform", num_anchors=NUM_ANCHORS, seed=0)

# Transform both spaces w.r.t. the same anchors
rel_latent_space = latent_space.to_relative(anchors=anchors_ids)
rel_iso_latent_space = iso_latent_space.to_relative(anchors=anchors_ids)

# The shape is [num_samples, num_anchors]
print(rel_latent_space.vectors.shape, rel_iso_latent_space.vectors.shape)

# Compare the relative spaces --> perfect similarity, since we are invariant to isometries
cosine_similarity(rel_latent_space.vectors, rel_iso_latent_space.vectors).mean()

Global seed set to 0


torch.Size([1000, 300]) torch.Size([1000, 300])
