In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50

class DilationModel(nn.Module):
    def __init__(
        self,
        time_window=None,
        eeg_input_dimension=64,
        env_input_dimension=1,
        spatial_filters=8,
        num_mismatched_segments=2
    ):
        super(DilationModel, self).__init__()

        self.eeg_conv = nn.Conv1d(1, spatial_filters, kernel_size=1)
        
        # Load ResNet50 model without top classification layer
        resnet_model = resnet50(pretrained=True)
        self.resnet_layers = nn.Sequential(*list(resnet_model.children())[:-1])  # Remove last classification layer

        # Set the ResNet layers to non-trainable
        for param in self.resnet_layers.parameters():
            param.requires_grad = False

        # Linear projection similarity
        self.linear_proj_sim = nn.Linear(spatial_filters, 1)

        # Number of stimuli input
        self.stimuli_projs = nn.ModuleList([
            nn.Sequential(
                nn.Conv1d(1, spatial_filters, kernel_size=1),
                self.resnet_layers
            ) for _ in range(num_mismatched_segments + 1)
        ])

    def forward(self, eeg, stimuli_inputs):
        # Spatial convolution for EEG
        eeg_proj_1 = self.eeg_conv(eeg)

        # Apply ResNet50 to EEG input
        eeg_proj_1 = self.resnet_layers(eeg_proj_1)

        # Apply ResNet50 to each stimulus input
        stimuli_projs = [stimuli_proj(stimulus_input) for stimuli_proj, stimulus_input in zip(self.stimuli_projs, stimuli_inputs)]

        # Comparison
        cos = [F.cosine_similarity(eeg_proj_1, stimulus_proj, dim=2) for stimulus_proj in stimuli_projs]

        # Linear projection similarity
        cos_proj = [self.linear_proj_sim(cos_i.flatten()) for cos_i in cos]

        # Classification
        out = F.softmax(torch.cat(cos_proj, dim=1), dim=1)

        return out

# Example usage
time_window = 100
eeg_input_dimension = 64
env_input_dimension = 1
spatial_filters = 8
num_mismatched_segments = 2

model = DilationModel(time_window, eeg_input_dimension, env_input_dimension, spatial_filters, num_mismatched_segments)
print(model)


DilationModel(
  (eeg_conv): Conv1d(1, 8, kernel_size=(1,), stride=(1,))
  (resnet_layers): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inpl