In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiLabelClassificationHead(nn.Module):
    def __init__(self, input_dim, num_descriptors=18, options_per_descriptor=4):
        super(MultiLabelClassificationHead, self).__init__()
        self.num_descriptors = num_descriptors
        self.options_per_descriptor = options_per_descriptor
        
        # Create separate linear layers for each descriptor
        self.descriptor_layers = nn.ModuleList([
            nn.Linear(input_dim, options_per_descriptor) for _ in range(num_descriptors)
        ])
        
    def forward(self, x):
        # x shape: (batch_size, input_dim)
        outputs = []
        for layer in self.descriptor_layers:
            descriptor_output = layer(x)  # shape: (batch_size, options_per_descriptor)
            descriptor_output = F.softmax(descriptor_output, dim=1)  # Apply softmax for each descriptor
            outputs.append(descriptor_output)
        
        # Stack outputs along a new dimension
        return torch.stack(outputs, dim=1)  # shape: (batch_size, num_descriptors, options_per_descriptor)

# Example usage
input_dim = 512  # This should match the output dimension of your backbone network
classification_head = MultiLabelClassificationHead(input_dim)

# Assuming 'features' is the output from your backbone network
batch_size = 32
features = torch.randn(batch_size, input_dim)
outputs = classification_head(features)

print(f"Output shape: {outputs.shape}")  # Should be (batch_size, 18, 4)

# Loss function
class MultiLabelSoftmaxLoss(nn.Module):
    def __init__(self):
        super(MultiLabelSoftmaxLoss, self).__init__()
        self.ce_loss = nn.CrossEntropyLoss(reduction='none')
    
    def forward(self, predictions, targets):
        # predictions shape: (batch_size, num_descriptors, options_per_descriptor)
        # targets shape: (batch_size, num_descriptors)
        loss = 0
        for i in range(predictions.size(1)):
            loss += self.ce_loss(predictions[:, i, :], targets[:, i])
        return loss.mean()

# Example of computing loss
criterion = MultiLabelSoftmaxLoss()
targets = torch.randint(0, 4, (batch_size, 18))  # Random targets for illustration
print(f"Targets shape: {targets.shape}")
print(f"outputs shape: {outputs.shape}")
loss = criterion(outputs, targets)
print(f"Loss: {loss.item()}")

Output shape: torch.Size([32, 18, 4])
Loss: 24.95261001586914
