In [10]:
import torch
import torch.nn as nn


features = {
    "layer4": torch.rand(2, 64, 192),
    "layer3": torch.rand(2, 64, 192),
    "layer2": torch.rand(2, 64, 192),
    "layer1": torch.rand(2, 64, 192),
}

class TransformerDecoder(nn.Module):

    def __init__(self, num_layers, num_queries, num_heads, hidden_dim, output_classes) -> None:
        super().__init__()
        self.queries = nn.Embedding(num_queries, hidden_dim)
        self.decoder = nn.ModuleList(
            [
                nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=num_heads, batch_first=True)
            for _ in range(num_layers)]
        )
        self.classification_layer = nn.Linear(hidden_dim, output_classes)

    def forward(self, features):
        queries = self.queries.weight.unsqueeze(0).repeat(features["layer4"].shape[0], 1, 1)
        for decoder, key in zip(self.decoder, features.keys()):
            queries = decoder(queries, features[key])
        return self.classification_layer(queries.mean(dim=1))
        

decoder = TransformerDecoder(num_queries=100, hidden_dim=192)

print(decoder(features).shape)


torch.Size([2, 1])
