In [1]:
import timm
import torch.nn as nn
import torch

In [2]:
class HMSSpecMultiModalModel(nn.Module):
    def __init__(
            self,
            in_channels_original_spec: int = 1,
            in_channels_eeg_spec: int = 4,
            num_classes: int = 6,
            backbone_name: str = "efficientnet_b0",
            pretrained: bool = True,
        ):
        super().__init__()

        self.original_spec_backbone = timm.create_model(
            model_name=backbone_name, pretrained=pretrained,
            num_classes=0, in_chans=in_channels_original_spec)

        self.eeg_spec_backbone = timm.create_model(
            model_name=backbone_name, pretrained=pretrained,
            num_classes=0, in_chans=in_channels_eeg_spec)
        
        #if 'eca' in backbone_name:
        #    self.in_features = self.original_spec_backbone.head.fc.in_features
        #else:
        #    self.in_features = self.original_spec_backbone.classifier.in_features

        self.in_features = 1280

        self.fc_original_spec_1 = nn.Linear(self.in_features, 128)
        self.fc_original_spec_2 = nn.Linear(128, num_classes)

        self.fc_eeg_spec_1 = nn.Linear(self.in_features, 128)
        self.fc_eeg_spec_2 = nn.Linear(128, num_classes)

        self.fc_final = nn.Linear(128*2, num_classes)


    def forward(self, data):
        original_spec = data["spec_img"]
        eeg_spec = data["eeg_img"]

        original_feat = self.original_spec_backbone(original_spec)
        original_feat = self.fc_original_spec_1(original_feat)
        original_output = self.fc_original_spec_2(original_feat)    

        eeg_feat = self.eeg_spec_backbone(eeg_spec)
        eeg_feat = self.fc_eeg_spec_1(eeg_feat)
        eeg_output = self.fc_eeg_spec_2(eeg_feat)  

        output = self.fc_final(torch.cat([original_feat, eeg_feat], -1))
            
        return output, original_output, eeg_output, original_feat, eeg_feat

In [3]:
m = HMSSpecMultiModalModel()

data = {"spec_img": torch.randn((1, 1, 512, 512)), "eeg_img": torch.randn((1, 4, 512, 512))}

output, original_output, eeg_output, original_feat, eeg_feat = m(data)
output.shape, original_output.shape, eeg_output.shape, original_feat.shape, eeg_feat.shape

(torch.Size([1, 6]),
 torch.Size([1, 6]),
 torch.Size([1, 6]),
 torch.Size([1, 128]),
 torch.Size([1, 128]))

In [4]:
m.original_spec_backbone

EfficientNet(
  (conv_stem): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (bn1): BatchNormAct2d(
    32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
    (drop): Identity()
    (act): SiLU(inplace=True)
  )
  (blocks): Sequential(
    (0): Sequential(
      (0): DepthwiseSeparableConv(
        (conv_dw): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn1): BatchNormAct2d(
          32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
          (drop): Identity()
          (act): SiLU(inplace=True)
        )
        (se): SqueezeExcite(
          (conv_reduce): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
          (act1): SiLU(inplace=True)
          (conv_expand): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
          (gate): Sigmoid()
        )
        (conv_pw): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn2): BatchNormAct2d(
      