In [None]:
%load_ext autoreload
%autoreload 2

from rae.modules import attention
from rae.modules.relative_classifier import ReprPooling
from rae.modules.enumerations import *
import torch
from torch import nn
from typing import *
import torch.nn.functional as F

In [None]:
IN_FEATURES = 512
N_ANCHORS = 500
N_CLASSES = 100

In [None]:
batch_latents = torch.randn(8, IN_FEATURES, dtype=torch.double)
batch_latents

In [None]:
anchors_latents = torch.randn(N_ANCHORS, IN_FEATURES, dtype=torch.double)
anchors_latents

In [None]:
class MultiHeadRelativeAttention(nn.Module):
    def __init__(
        self,
        attentions: Sequence[attention.RelativeAttention],
        hidden_features,
        repr_pooling: ReprPooling = None,
    ):
        super().__init__()
        self.num_subspaces = len(attentions)
        self.relative_attentions = attentions

        self.subspace_features = hidden_features // self.num_subspaces
        assert (hidden_features / self.num_subspaces) == (hidden_features // self.num_subspaces)

        self.repr_pooling: ReprPooling = repr_pooling if repr_pooling is not None else ReprPooling.NONE

        if self.repr_pooling not in set(ReprPooling):
            raise ValueError(f"Representation Pooling method not supported: {repr_pooling}")

        repr_dim: int = list(self.relative_attentions)[0].output_dim

        if self.repr_pooling != ReprPooling.NONE:

            self.classification_layer = nn.Linear(repr_dim, N_CLASSES)

        else:
            self.classification_layer = nn.Linear(sum(x.output_dim for x in self.relative_attentions), N_CLASSES)

        if self.repr_pooling == ReprPooling.LINEAR:
            self.head_pooling = nn.Linear(
                in_features=sum(x.output_dim for x in self.relative_attentions),
                out_features=repr_dim,
            )

    def forward(
        self,
        batch_latents: torch.Tensor,
        anchors_latents: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:

        subspace_outputs = []
        for i, relative_attention in enumerate(self.relative_attentions):
            x_i_subspace = batch_latents[:, i * self.subspace_features : (i + 1) * self.subspace_features]
            anchors_i_subspace = anchors_latents[:, i * self.subspace_features : (i + 1) * self.subspace_features]
            subspace_output = relative_attention(
                x=x_i_subspace,
                anchors=anchors_i_subspace,
            )
            subspace_outputs.append(subspace_output)

        attention_output = {key: [subspace[key] for subspace in subspace_outputs] for key in subspace_outputs[0].keys()}
        for to_merge in (AttentionOutput.OUTPUT, AttentionOutput.SIMILARITIES):
            attention_output[to_merge] = torch.stack(attention_output[to_merge], dim=1)

        if self.repr_pooling == ReprPooling.LINEAR:
            attention_output[AttentionOutput.OUTPUT] = torch.flatten(attention_output[AttentionOutput.OUTPUT], 1, 2)
            attention_output[AttentionOutput.OUTPUT] = self.head_pooling(attention_output[AttentionOutput.OUTPUT])
        elif self.repr_pooling == ReprPooling.MAX:
            attention_output[AttentionOutput.OUTPUT] = attention_output[AttentionOutput.OUTPUT].max(dim=1)[0]
        elif self.repr_pooling == ReprPooling.SUM:
            attention_output[AttentionOutput.OUTPUT] = attention_output[AttentionOutput.OUTPUT].sum(dim=1)
        elif self.repr_pooling == ReprPooling.MEAN:
            attention_output[AttentionOutput.OUTPUT] = attention_output[AttentionOutput.OUTPUT].mean(dim=1)
        elif self.repr_pooling == ReprPooling.NONE:
            attention_output[AttentionOutput.OUTPUT] = torch.flatten(attention_output[AttentionOutput.OUTPUT], 1, 2)
        else:
            raise NotImplementedError

        return attention_output[AttentionOutput.OUTPUT]

In [None]:
params = {
    "hidden_features": 64,
    "transform_elements": None,
    #'dropout_p': 0.1,
    "normalization_mode": "off",
    "similarity_mode": "inner",
    #'num_subspaces': 4,
    "values_mode": "similarities",
    "values_self_attention_nhead": 8,
    "similarities_quantization_mode": None,
    "similarities_bin_size": None,
    "similarities_aggregation_mode": None,
    "similarities_aggregation_n_groups": None,
    "anchors_sampling_mode": None,
    "n_anchors_sampling_per_class": None,
    #'repr_pooling': None
}

In [None]:
%%capture
relative_attention = attention.RelativeAttention(
    in_features=IN_FEATURES, n_anchors=N_ANCHORS, n_classes=N_CLASSES, **params
)
relative_attention

In [None]:
%%capture

N_SUBSPACES = 8

attentions = [
    attention.RelativeAttention(in_features=IN_FEATURES, n_anchors=N_ANCHORS, n_classes=N_CLASSES, **params)
    for _ in range(N_SUBSPACES)
]
multihead = MultiHeadRelativeAttention(attentions, hidden_features=IN_FEATURES, repr_pooling="sum")
multihead

In [None]:
rel_out = relative_attention(batch_latents, anchors_latents)[AttentionOutput.OUTPUT]
mout = multihead(batch_latents, anchors_latents)
torch.allclose(mout, rel_out)

In [None]:
rel_out

In [None]:
mout