In [1]:
from typing import Dict, List, Optional

import numpy as np
import torch
from torch import nn, Tensor
from torch.nn import functional as F

from config_handling import prepare_config
from data_loading import prepare_downsteam_dataloaders
from models import NeuroSignalEncoder, NeuroDecoder, PositionalEncoding

In [2]:
class NeuralDecoderCustom(nn.Module):
    def __init__(self,
                 config: dict,
                 encoder: Optional[NeuroSignalEncoder] = None):
        super().__init__()
        self.config = config
        self.model_config = config['model']
        self.downstream_config = config['downstream']

        # Load the encoder model if necessary
        if encoder is None:
            self.encoder = NeuroSignalEncoder(self.model_config)
            
            # Add LSTM layer to model before loading weights if CPC was used
            if self.config['train_method'].lower() == 'cpc':
                self.encoder.add_lstm_head(self.model_config['lstm_embedding_dim'])
                self.encoder.load_state_dict(torch.load(self.model_config['save_path']))
            else:
                # MSM was not trained with the LSTM head, so it should be added after
                # the weights for the rest of the model are loaded
                self.encoder.load_state_dict(torch.load(self.model_config['save_path']))
                self.encoder.add_lstm_head(self.model_config['lstm_embedding_dim'])
        else:
            self.encoder = encoder
            
        self.decoder = nn.Sequential(
            nn.Linear(self.model_config['lstm_embedding_dim'], 128),
            nn.ReLU(),
            nn.Linear(128, self.downstream_config['n_classes']))

    def forward(self,
                primary_input: Tensor,
                calibration_input: Optional[Tensor] = None) -> Tensor:
        """
        Forward pass of the decoder, that takes stimulus inputs and outputs
        a predicted class.

        Args:
            primary_input: Tensor of shape (batch_size, n_timesteps, n_channels)
            calibration_input: Tensor of shape (1, n_timesteps, n_channels)

        Returns:
            Tensor of shape (batch_size, n_classes)
        """
        # Pass the input signals through the encoder
        output_dict = self.encoder(primary_input, calibration_input=calibration_input)

        lstm_embeds = output_dict['lstm_embeddings'] # Sequence of all hidden outputs
        primary_mask = output_dict['primary_embeddings_mask']

        # Select the embeddings corresponding to the primary sequence
        selected_embeds = lstm_embeds.masked_select(primary_mask.unsqueeze(-1).bool())
        primary_embeds = selected_embeds.reshape(lstm_embeds.shape[0], -1, lstm_embeds.shape[2])

        # Select the emebddings at the end of the stimulus response time
        target_output_idx = torch.tensor(
            (self.downstream_config['n_stimulus_samples'] / \
            self.downstream_config['tmax_samples']) \
            * primary_embeds.shape[1]).ceil().type(torch.int64)
        target_output_idx = target_output_idx.to(primary_embeds.device)
        output_embeds = primary_embeds.index_select(dim=1, index=target_output_idx)
        output_embeds = output_embeds.squeeze(1)

        # Pass the embeddings through the decoder
        class_logits = self.decoder(output_embeds)

        return class_logits

In [61]:
class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc_layers = nn.Sequential(
            nn.Linear(700*204, 1000),
            nn.ReLU(),
            nn.Linear(1000, 500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, 10))

    def forward(self, x, y):
        x = torch.cat((y[:, -200:], x[:, :500]), axis=1)
        x = x.view(x.shape[0], -1)
        return self.fc_layers(x)
      
class TestModel2(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv1d(204, 256, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv1d(256, 256, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv1d(256, 256, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv1d(256, 256, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv1d(256, 256, kernel_size=3, stride=2),
            nn.ReLU())

        self.fc_layers = nn.Sequential(
            nn.Linear(20*256, 1000),
            nn.ReLU(),
            nn.Linear(1000, 500),
            nn.ReLU(),
            nn.Linear(500, 500),
            nn.ReLU(),
            nn.Linear(500, 10))

    def forward(self, x, y):
        x = torch.cat((y[:, -200:], x[:, :500]), axis=1)
        x = x.transpose(1, 2)
        x = self.conv_layers(x)
        x = x.view(x.shape[0], -1)
        return self.fc_layers(x)
     
     
class TestModel3(nn.Module):
    def __init__(self):
        super().__init__()
        self.seq_len = 24

        self.conv_layers = nn.Sequential(
            nn.Conv1d(204, 128, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv1d(128, 128, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv1d(128, 128, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv1d(128, 128, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Conv1d(128, 128, kernel_size=3, stride=2),
            nn.ReLU())

        self.tlayer = nn.TransformerEncoderLayer(
                d_model = 128,
                nhead = 2,
                dim_feedforward = 64,
                # dropout = dropout,
                activation = 'relu',
                batch_first=True
            )
        
        self.transformer = nn.TransformerEncoder(self.tlayer, 1)

        self.fc_layers = nn.Sequential(
            nn.Linear(self.seq_len*128, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 10))
            
        self.pos_enc = PositionalEncoding(
            128, 0.25 * 0.1, self.seq_len)

    def forward(self, x, y):
        x = torch.cat((y[:, -200:], x[:, :600]), axis=1)
        x = x.transpose(1, 2)
        x = self.conv_layers(x)
        x = x.transpose(1, 2)
        # x = self.pos_enc(x)
        x = self.transformer(x)
        x = x.view(x.shape[0], -1)
        return self.fc_layers(x)

In [62]:
config = prepare_config('configs/w2v_cpc.yaml')
model_config = config['model']
# encoder = NeuroSignalEncoder(model_config)
# encoder.add_lstm_head(model_config['lstm_embedding_dim'])
# encoder.load_state_dict(torch.load(model_config['save_path']))

In [63]:
def format_inputs(inputs: List[Tensor], config: Dict) -> List[Tensor]:
    """
    Reshapes the inputs and moves them to the correct device.

    Args:
        inputs: List of tensors of shape (batch_size, n_channels, n_timesteps)
            or (n_channels, n_timesteps)
        config: Dictionary containing the config

    Returns:
        List of tensors of shape (batch_size, n_timesteps, n_channels)
    """
    new_inputs = []
    for i in range(len(inputs)):
        new_inputs.append(inputs[i])

        # If the inputs are in the shape (n_channels, n_timesteps)
        if len(new_inputs[i].shape) == 2:
            new_inputs[i] = new_inputs[i].unsqueeze(0)

        # Reshape the inputs to (batch_size, n_timesteps, n_channels)
        new_inputs[i] = new_inputs[i].permute(0, 2, 1)

        # Move the inputs to the correct device
        new_inputs[i] = new_inputs[i].to(config['device'])

    return new_inputs

# TODO: This is hardcoded for the MEG colors dataset, fix that
def format_labels(labels: Tensor) -> Tensor:
    """
    Reshapes the labels, moves them to the correct device,
    and converts the range to the range [0, n_classes - 1].

    Args:
        labels: Tensor of shape (batch_size, 1)

    Returns:
        Tensor of shape (batch_size,)
    """
    labels = labels.squeeze(1)
    labels = labels - 1
    labels = labels.to(config['device'])
    return labels


In [64]:
# print(NeuroDecoder(config))

In [65]:
# print(TestModel3())

In [66]:
# for param in encoder.parameters():
#     param.requires_grad = False

In [67]:
ds_config = config['downstream']

#Prepare the data
batch_size = 16
dataloaders = prepare_downsteam_dataloaders(config, batch_size)
train_loader = dataloaders['train']
val_loader = dataloaders['val']
test_loader = dataloaders['test']

model = TestModel3() # NeuroDecoder(config) # 
model = model.to(config['device'])
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002) # ds_config['learning_rate'])
criterion = nn.CrossEntropyLoss()

for i in range(2): # ds_config['train_epochs']):
    print('Starting epoch', i+1)
    epoch_preds = []
    epoch_labels = []
    epoch_losses = []

    model.train()
    batch_losses = []
    for sample_idx, data in enumerate(train_loader):
        # Get and format the data for the batch
        primary_input = data['primary_input']
        calibration_input = data['calibration_input']
        labels = data['label']

        primary_input, calibration_input = format_inputs(
            (primary_input, calibration_input), config)
        labels = format_labels(labels)

        # Run the data through the model
        logits = model(primary_input, calibration_input)

        # Update the result buffers
        preds = logits.argmax(dim=1)
        epoch_preds.extend(preds.detach().cpu().numpy())
        epoch_labels.extend(labels.detach().cpu().numpy())

        # Calculate the loss and update the model weights
        probs = F.softmax(logits, dim=1)
        loss = criterion(probs, labels)
        batch_losses.append(loss)
        epoch_losses.append(loss.item())
        
        if sample_idx % 16 == 0:
            batch_loss = torch.mean(torch.stack(batch_losses))

            optimizer.zero_grad()
            batch_loss.backward()
            optimizer.step()

            batch_losses = []

        # Print updates
        if (sample_idx + 1) % 32 == 0:
            lookback = batch_size * 32
            print(sample_idx, np.mean(epoch_losses[-32:]))
            print('Accuracy:',
                (np.array(epoch_preds)[-lookback:] == \
                 np.array(epoch_labels)[-lookback:]).sum() \
                    / len(epoch_preds[-lookback:]))

Starting epoch 1
