In [1]:
import numpy as np
import torch
import torch.nn as nn

In [6]:
class temporal_channel_joint_attention(nn.Module):
    def __init__(self, num_windows, num_channels, num_FBs, num_classes, num_heads=8, 
                 dim_feedforward=2048, num_encoder_layers=6):
        super(temporal_channel_joint_attention, self).__init__()
        self.temporal_attention = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=num_windows*num_FBs, nhead=num_heads, dim_feedforward=dim_feedforward, batch_first=True),
            num_layers=num_encoder_layers
        )
        self.channel_attention = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=num_channels*num_FBs, nhead=num_heads, dim_feedforward=dim_feedforward, batch_first=True),
            num_layers=num_encoder_layers
        ) 
        self.fc = nn.Linear(2*num_windows*num_channels*num_FBs, num_classes)
        self.dropout = nn.Dropout(0.5)
        self.activation = nn.ReLU()

    def forward(self, PSDs):
        # PSDs: (batch_size, num_windows, num_channels, num_FBs)
        # tokenization
        temporal_tokens = PSDs.reshape(-1, num_windows, num_channels*num_FBs)
        channel_tokens = PSDs.reshape(-1, num_channels, num_windows*num_FBs)
        temporal_features = self.temporal_attention(temporal_tokens)
        channel_features = self.channel_attention(channel_tokens)
        # concat 
        temporal_features = temporal_features.reshape(-1, 1, num_windows*num_channels*num_FBs)
        channel_features = channel_features.reshape(-1, 1, num_channels*num_windows*num_FBs)
        joint_features = torch.cat((temporal_features, channel_features), dim=1)
        # classification
        joint_features = nn.Flatten()(joint_features)
        joint_features = self.activation(joint_features)
        joint_features = self.dropout(joint_features)
        res = self.fc(joint_features)
        return res