In [1]:
import timm
import torch
import torch.nn as nn


class ResNet_1D_Block(nn.Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, downsampling):
        super(ResNet_1D_Block, self).__init__()
        self.bn1 = nn.BatchNorm1d(num_features=in_channels)
        self.relu = nn.ReLU(inplace=False)
        self.dropout = nn.Dropout(p=0.1, inplace=False)
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)
        self.bn2 = nn.BatchNorm1d(num_features=out_channels)
        self.conv2 = nn.Conv1d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
                               stride=stride, padding=padding, bias=False)
        self.maxpool = nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
        self.downsampling = downsampling

    def forward(self, x):
        identity = x

        out = self.bn1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.conv2(out)

        out = self.maxpool(out)
        identity = self.downsampling(x)

        out += identity
        return out



class HMSSpecEEGPararellModel(nn.Module):
    def __init__(self, cfg, kernels=[3,5,7,9], backbone_name='tf_efficientnet_b0', pretrained=False, in_channels_spec=1, in_channels_eeg=20, fixed_kernel_size=17, num_classes=6, is_training=True):
        super(HMSSpecEEGPararellModel, self).__init__()
        
        self.cfg = cfg
        self.kernels = kernels
        self.planes = 24
        self.parallel_conv = nn.ModuleList()
        self.in_channels = in_channels_eeg


        self.backbone_2d = timm.create_model(
            backbone_name,
            pretrained=pretrained,
            drop_rate=0.1,
            drop_path_rate=0.1,
            in_chans=in_channels_spec
        
        )
        
        self.features_2d = nn.Sequential(*list(self.backbone_2d.children())[:-2] + [nn.AdaptiveAvgPool2d(1),nn.Flatten()])
        
        # nn.Sequential(
            # nn.AdaptiveAvgPool2d(1),

        for i, kernel_size in enumerate(list(self.kernels)):
            sep_conv = nn.Conv1d(in_channels=in_channels_eeg, out_channels=self.planes, kernel_size=(kernel_size),
                               stride=1, padding=0, bias=False,)
            self.parallel_conv.append(sep_conv)

        self.bn1 = nn.BatchNorm1d(num_features=self.planes)
        self.relu = nn.ReLU(inplace=False)
        self.conv1 = nn.Conv1d(in_channels=self.planes, out_channels=self.planes, kernel_size=fixed_kernel_size,
                               stride=2, padding=2, bias=False)
        self.block = self._make_resnet_layer(kernel_size=fixed_kernel_size, stride=1, padding=fixed_kernel_size//2)
        self.bn2 = nn.BatchNorm1d(num_features=self.planes)
        self.avgpool = nn.AvgPool1d(kernel_size=4, stride=4, padding=2)
        self.rnn = nn.GRU(input_size=self.in_channels, hidden_size=128, num_layers=1, bidirectional=True)
        
        self.fc1 = nn.Linear(in_features=1280, out_features=128)
        self.fc2 = nn.Linear(in_features=736, out_features=128)
        self.fc = nn.Linear(in_features=256, out_features=num_classes)

        self.fc1d = nn.Linear(in_features=128, out_features=num_classes)
        self.fc2d = nn.Linear(in_features=128, out_features=num_classes)
        
        
        self.rnn1 = nn.GRU(input_size=156, hidden_size=156, num_layers=1, bidirectional=True)

        self.is_training = is_training

    def _make_resnet_layer(self, kernel_size, stride, blocks=8, padding=0):
        layers = []
        downsample = None
        base_width = self.planes

        for i in range(blocks):
            downsampling = nn.Sequential(
                    nn.MaxPool1d(kernel_size=2, stride=2, padding=0)
                )
            layers.append(ResNet_1D_Block(in_channels=self.planes, out_channels=self.planes, kernel_size=kernel_size,
                                       stride=stride, padding=padding, downsampling=downsampling))

        return nn.Sequential(*layers)

    def forward(self, batch):
        spec = batch['spec_img']
        x = batch['raw_eeg']

        # print(spec.shape) #2, 1280, 16, 8
        out_sep = []

        # x: (8, 10000)
        for i in range(len(self.kernels)):
            sep = self.parallel_conv[i](x)
            out_sep.append(sep)

        out = torch.cat(out_sep, dim=2)
        print(out.shape)
        out = self.bn1(out)
        out = self.relu(out)
        print(out.shape)
        out = self.conv1(out)  
        print(out.shape)

        out = self.block(out)
        print(out.shape)
        out = self.bn2(out)
        out = self.relu(out)
        out = self.avgpool(out)  
        print(out.shape)


        
        out = out.reshape(out.shape[0], -1)  

        # 
        rnn_out, _ = self.rnn(x.permute(0,2, 1))
        new_rnn_h = rnn_out[:, -1, :]  

        new_out = torch.cat([out, new_rnn_h], dim=1)  
        new_out = self.fc2(new_out)  
        eeg_output = self.fc1d(new_out)
        
        #spec = self._reshape_input(spec)
        # spec : (3, 512, 512)
        spec = self.features_2d(spec)
        spec = self.fc1(spec)  
        original_output = self.fc2d(spec)
        
        output = torch.cat([new_out, spec], dim=1)  
        output = self.fc(output)
    
        weighted_output = 0.5*output + 0.25*original_output + 0.25*eeg_output

        if self.is_training:       
            return {'weighted_output': weighted_output,'output': output, 'original_output': original_output, "eeg_output": eeg_output, "original_feat": new_out, "eeg_feat": spec}

        else: return weighted_output

In [2]:
m = HMSSpecEEGPararellModel(cfg=None)
batch = {"spec_img" : torch.randn((1, 1, 512, 512)), "raw_eeg": torch.randn((1, 20, 10000))}
y = m(batch)

torch.Size([1, 24, 39980])
torch.Size([1, 24, 39980])
torch.Size([1, 24, 19984])
torch.Size([1, 24, 78])
torch.Size([1, 24, 20])


In [3]:
batch = {"spec_img" : torch.randn((1, 1, 512, 512)), "raw_eeg": torch.randn((1, 20, 5000))}
y = m(batch)

torch.Size([1, 24, 19980])
torch.Size([1, 24, 19980])
torch.Size([1, 24, 9984])
torch.Size([1, 24, 39])
torch.Size([1, 24, 10])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x496 and 736x128)