In [6]:
# libraries
import matplotlib.pyplot as plt
%matplotlib widget
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader
import pandas as pd
import itertools
import numpy as np
import torch.nn as nn
import datetime
import tqdm
from itertools import accumulate


# custom
import utilities
import sync_head_utils 
import torch.nn.functional as F

# plotting
WIDTH = 12
HEIGHT = 3

In [7]:
MACHINE = 'BrushlessMotor' # BrushlessMotor, RoboticArm
# Paths to the training and testing HDF5 dataset files
TRAIN_DATASET_PATH = f'data/{MACHINE}/windowed/train_dataset_window_0.100s.h5'
TEST_DATASET_PATH = f'data/{MACHINE}/windowed/test_dataset_window_0.100s.h5'

# List of sensor names to be extracted from the dataset
SENSORS = [
    'imp23absu_mic',
    'ism330dhcx_acc',
    'ism330dhcx_gyro'
]

# List of label names to be extracted from the dataset
LABEL_NAMES = ['segment_id',
               'split_label',
               'anomaly_label',
               'domain_shift_op',
               'domain_shift_env']


PARAMS = {
    'batch_size': 64,
    'num_epochs': 1000,
    'lr': 0.001,
    # TO BE ADAPTED TO YOUR MACHINE: either 'mps or 'cuda' if GPU available,
    # otherwise 'cpu'
    'device': 'mps',
    'patience': 3,
    'normalisation': 'std_window',
    'valid_size': 0.1,
    'seed': 1995,
}


# Load the dataset
X_train_raw, Y_train_raw, X_test, Y_test = utilities.load_dataset(
    TRAIN_DATASET_PATH, TEST_DATASET_PATH, LABEL_NAMES, SENSORS)

# Set the seed for general torch operations
torch.manual_seed(PARAMS['seed'])
# Set the seed for MPS torch operations (ones that happen on the MPS Apple GPU)

if PARAMS['device'] == 'mps':
    torch.mps.manual_seed(PARAMS['seed'])
elif PARAMS['device'] == 'cuda':
    torch.cuda.manual_seed(PARAMS['seed'])
elif PARAMS['device'] == 'cpu':
    torch.manual_seed(PARAMS['seed'])
else:
    raise ValueError(f"Unsupported device type: {PARAMS['device']}")


# Load the dataset
X_train_raw, Y_train_raw, X_test, Y_test = utilities.load_dataset(
    TRAIN_DATASET_PATH, TEST_DATASET_PATH, LABEL_NAMES, SENSORS)

# Combine anomaly labels and domain shift labels to form a combined label
Y_train_raw['combined_label'] = Y_train_raw['anomaly_label'] + \
    Y_train_raw['domain_shift_op'] + Y_train_raw['domain_shift_env']
Y_test['combined_label'] = Y_test['anomaly_label'] + \
    Y_test['domain_shift_op'] + Y_test['domain_shift_env']

# Split training data into training and validation sets, maintaining the
# stratified distribution of the combined label
train_indices, valid_indices, _, _ = train_test_split(
    range(len(Y_train_raw)),
    Y_train_raw,
    stratify=Y_train_raw['combined_label'],
    test_size=PARAMS['valid_size'],
    random_state=PARAMS['seed']
)

# Select the training and validation data based on the indices
X_train = [sensor_data[train_indices] for sensor_data in X_train_raw]
X_valid = [sensor_data[valid_indices] for sensor_data in X_train_raw]
Y_train = Y_train_raw.iloc[train_indices].reset_index(drop=True)
Y_valid = Y_train_raw.iloc[valid_indices].reset_index(drop=True)

# Normalize the training, validation, and test datasets using the
# specified normalization method
X_train, X_valid, X_test = utilities.normalize_data(
    X_train, X_valid, X_test, PARAMS['normalisation'])

# Extract the number of channels and window lengths for each sensor
NUM_CHANNELS = {SENSORS[i]: x.shape[1] for i, x in enumerate(X_train)}
WINDOW_LENGTHS = {SENSORS[i]: x.shape[2] for i, x in enumerate(X_train)}


X_train_tensor = [torch.from_numpy(x) for x in X_train]
X_valid_tensor = [torch.from_numpy(x) for x in X_valid]
X_test_tensor = [torch.from_numpy(x) for x in X_test]

train_dataset = utilities.CustomDataset(X_train_tensor)
valid_dataset = utilities.CustomDataset(X_valid_tensor)
test_dataset = utilities.CustomDataset(X_test_tensor)

train_data_loader = DataLoader(
    train_dataset, batch_size=PARAMS['batch_size'], shuffle=True)
valid_data_loader = DataLoader(
    valid_dataset, batch_size=PARAMS['batch_size'], shuffle=False)
test_data_loader = DataLoader(
    test_dataset, batch_size=PARAMS['batch_size'], shuffle=False)

In [None]:
# Define the lambda parameter
lambda_ = 0.8
sync_head_conv_parameters = sync_head_utils.initialize_parameters(
    SENSORS, WINDOW_LENGTHS, NUM_CHANNELS, lambda_=lambda_)

In [86]:
# --------------------------------------------------------------------------------
# A UNIFIED FC HEAD FOR (N, C, L_in) --> (N, C, L_out)
# --------------------------------------------------------------------------------
class create_fc_head(nn.Module):
    """
    A fully-connected (FC) head that operates channel by channel.

    This can be used for:
      - Synchronization: map from a sensor's raw length L_sensor to a common length L_common.
      - Projection: map from the common length L_common back to L_sensor.
    
    Shape:
        - Input:  (N, C, L_in)
        - Output: (N, C, L_out)
    """

    def __init__(self, input_size, output_size, num_channels, num_layers):
        """
        Args:
            input_size (int): L_in
            output_size (int): L_out
            num_channels (int): C
            num_layers (int): number of FC layers per channel
        """
        super(create_fc_head, self).__init__()
        self.fc_stacks = nn.ModuleList()

        # Create channel-specific FC stacks
        for _ in range(num_channels):
            layers = []
            current_size = input_size
            for layer_idx in range(num_layers):
                # FC layer
                fc = nn.Linear(current_size, output_size)
                layers.append(fc)

                # Add ReLU + BatchNorm except after the last layer
                if layer_idx < num_layers - 1:
                    layers.append(nn.ReLU())
                    layers.append(nn.BatchNorm1d(output_size))

                current_size = output_size

            # Wrap in a Sequential
            self.fc_stacks.append(nn.Sequential(*layers))

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): (N, C, L_in)
        
        Returns:
            torch.Tensor: (N, C, L_out)
        """
        outputs = []
        for c in range(x.size(1)):
            # Channel slice => (N, L_in)
            channel_input = x[:, c, :]
            channel_output = self.fc_stacks[c](channel_input)  # (N, L_out)
            outputs.append(channel_output.unsqueeze(1))        # (N, 1, L_out)

        # Concat the per-channel outputs => (N, C, L_out)
        return torch.cat(outputs, dim=1)


# --------------------------------------------------------------------------------
# UTILITY CLASSES / FUNCTIONS
# --------------------------------------------------------------------------------

class MaskOutOneChannel(nn.Module):
    """
    For an input of shape (B, C_total, L), this module produces a 
    concatenation of masked versions. For each block of width c_sync,
    we omit that block and keep the rest, repeating for all channel blocks.
    """

    def __init__(self, sensors, num_channels, c_sync):
        super(MaskOutOneChannel, self).__init__()
        self.sensors = sensors
        self.num_channels_dict = num_channels
        self.c_sync = c_sync

        # total_channels = sum of sensor channels
        self.total_channels = sum(num_channels.values())
        # So total input channels = total_channels * c_sync
        self.total_c = self.total_channels * self.c_sync

        self.register_buffer('final_indices', self._create_indices())

    def _create_indices(self):
        """
        For each channel block of width c_sync, omit that block
        and keep the others. Then concatenate for all blocks.
        """
        all_indices = []
        for ch in range(self.total_channels):
            remove_start = ch * self.c_sync
            remove_end = remove_start + self.c_sync
            keep_indices = list(range(0, remove_start)) + \
                           list(range(remove_end, self.total_c))
            all_indices.extend(keep_indices)
        return torch.tensor(all_indices, dtype=torch.long)

    def forward(self, x):
        """
        x: shape (B, C_total, L)
        returns: shape (B, C_total*(C_total - c_sync), L)
        """
        return x.index_select(dim=1, index=self.final_indices)


# --------------------------------------------------------------------------------
# SYNCHRONIZATION BLOCK
# --------------------------------------------------------------------------------

class SynchronizationBlock(nn.Module):
    """
    Resamples/synchronizes each sensor’s data to a common length L_common.
    Multiple methods: 'sync_head_conv', 'sync_head_fc', 'resample_interp', etc.
    """

    def __init__(
        self,
        sensors,
        num_channels,
        c_sync,
        sync_head_conv_parameters,
        params,
        sync_method,
        window_lengths,
        fc_num_layers
    ):
        super(SynchronizationBlock, self).__init__()
        self.default_device = params['device']

        # L_common for conv-based might come from sync_head_conv_parameters[...]['input_2']
        # for FC-based, you can choose an L_common or store in sync_head_conv_parameters as well
        self.sync_window_length = sync_head_conv_parameters[
            list(sync_head_conv_parameters.keys())[0]]['input_2']

        self.total_channels = sum(num_channels.values())
        self.c_sync = c_sync
        self.sync_method = sync_method
        self.window_lengths = window_lengths

        if sync_method == 'sync_head_conv':
            # Example conv-based approach (via some sync_head_utils)
            self.sync_heads = nn.ModuleList([
                sync_head_utils.create_synchronization_head(
                    input_sensor_channels=num_channels[sensor],
                    output_sensor_channels=c_sync * num_channels[sensor],
                    groups=num_channels[sensor],
                    parameters=sync_head_conv_parameters[sensor],
                    type='input'
                )
                for sensor in sensors
            ])
            self._sync_fn = self._resample_sync_head_conv

        elif sync_method == 'sync_head_fc':
            # FC-based approach (unified create_fc_head)
            self.sync_heads = nn.ModuleList([
                create_fc_head(
                    input_size=window_lengths[sensor],      # L_in (raw)
                    output_size=self.sync_window_length,    # L_out (common)
                    num_channels=num_channels[sensor],
                    num_layers=fc_num_layers  # or more, user choice
                )
                for sensor in sensors
            ])
            self._sync_fn = self._resample_sync_head_fc

        elif sync_method == 'resample_interp':
            self.sync_heads = None
            self._sync_fn = self._resample_interp

        elif sync_method == 'resample_fft':
            self.sync_heads = None
            self._sync_fn = self._resample_fft

        elif sync_method == 'zeropad':
            self.sync_heads = None
            self._sync_fn = self._resample_zeropad

        else:
            raise ValueError(f"Unknown sync_method: {sync_method}")

    def forward(self, input_data_list):
        """
        input_data_list: list of (B, C_sensor, L_sensor)
        returns: (B, C_total*c_sync, L_common)
        """
        return self._sync_fn(input_data_list)

    # --------------------------------------------------------------------------
    # Different synchronization methods
    # --------------------------------------------------------------------------

    def _resample_sync_head_conv(self, input_data_list):
        # Just cat outputs from each conv-based head
        return torch.cat([
            head(inp) for head, inp in zip(self.sync_heads, input_data_list)
        ], dim=1)

    def _resample_sync_head_fc(self, input_data_list):
        # Cat outputs from each fc-based head
        return torch.cat([
            head(inp) for head, inp in zip(self.sync_heads, input_data_list)
        ], dim=1)

    def _resample_interp(self, input_data_list):
        input_data_list = [inp.cpu() for inp in input_data_list]
        resampled = [
            F.interpolate(
                inp, size=self.sync_window_length, mode='linear'
            ).to(self.default_device)
            for inp in input_data_list
        ]
        return torch.cat(resampled, dim=1)

    def _fft_resample_single(self, input_data):
        fft_vals = torch.fft.fft(input_data, dim=-1)
        L_new = self.sync_window_length
        L = input_data.size(-1)

        if L_new > L:
            pad_size = L_new - L
            fft_vals = F.pad(fft_vals, (0, pad_size), "constant", 0)
        else:
            fft_vals = fft_vals[..., :L_new]

        return torch.fft.ifft(fft_vals, dim=-1).real

    def _resample_fft(self, input_data_list):
        input_data_list = [inp.cpu() for inp in input_data_list]
        resampled = [
            self._fft_resample_single(inp).to(self.default_device)
            for inp in input_data_list
        ]
        return torch.cat(resampled, dim=1)

    def _resample_zeropad(self, input_data_list):
        padded_data = []
        for inp in input_data_list:
            B, C, L = inp.shape
            if L < self.sync_window_length:
                pad_size = self.sync_window_length - L
                inp_padded = F.pad(inp, (0, pad_size), "constant", 0)
            else:
                inp_padded = inp[..., :self.sync_window_length]
            padded_data.append(inp_padded)
        return torch.cat(padded_data, dim=1)


# --------------------------------------------------------------------------------
# FUSING BLOCK (PMCE)
# --------------------------------------------------------------------------------

class PMCE(nn.Module):
    """
    Applies two sequential convolutional blocks (with dilation & skip connections)
    to the masked input, effectively fusing the masked representations.
    """

    def __init__(self, total_channels, c_sync, c_fuse, kernel_size, sensors, num_channels):
        super(PMCE, self).__init__()
        self.total_channels = total_channels
        self.c_sync = c_sync

        # After sync, shape => (B, C_total*c_sync, L)
        # MaskOutOneChannel => (B, (C_total-c_sync)*C_total, L)
        self.total_input_fuse_channels = (
            self.total_channels - 1
        ) * self.c_sync * self.total_channels

        self.total_middle_fuse_channels = self.total_input_fuse_channels * c_fuse

        padding = kernel_size - 1
        self.mask_module = MaskOutOneChannel(sensors, num_channels, c_sync)

        # First fuse part
        self.fusing_part1 = nn.Sequential(
            nn.Conv1d(
                in_channels=self.total_input_fuse_channels,
                out_channels=self.total_middle_fuse_channels,
                kernel_size=kernel_size,
                stride=1,
                padding=padding,
                dilation=2,
                groups=self.total_channels
            ),
            nn.ReLU(),
            nn.BatchNorm1d(self.total_middle_fuse_channels),
            nn.Conv1d(
                in_channels=self.total_middle_fuse_channels,
                out_channels=self.total_middle_fuse_channels,
                kernel_size=kernel_size,
                stride=1,
                padding=padding * 2,
                dilation=4,
                groups=self.total_channels
            ),
            nn.ReLU(),
            nn.BatchNorm1d(self.total_middle_fuse_channels),
        )

        # Second fuse part
        self.fusing_part2 = nn.Sequential(
            nn.Conv1d(
                in_channels=self.total_middle_fuse_channels,
                out_channels=self.total_middle_fuse_channels,
                kernel_size=kernel_size,
                stride=1,
                padding=padding * 4,
                dilation=8,
                groups=self.total_channels
            ),
            nn.ReLU(),
            nn.BatchNorm1d(self.total_middle_fuse_channels),
            nn.Conv1d(
                in_channels=self.total_middle_fuse_channels,
                out_channels=self.total_channels * c_sync,
                kernel_size=kernel_size,
                stride=1,
                padding=padding * 8,
                dilation=16,
                groups=self.total_channels
            ),
            nn.ReLU(),
            nn.BatchNorm1d(self.total_channels * c_sync),
        )

        # Residual layers
        self.residual_conv1 = nn.Conv1d(
            in_channels=self.total_input_fuse_channels,
            out_channels=self.total_middle_fuse_channels,
            kernel_size=1,
            stride=1,
            groups=self.total_channels
        )
        self.residual_conv2 = nn.Conv1d(
            in_channels=self.total_middle_fuse_channels,
            out_channels=self.total_channels * c_sync,
            kernel_size=1,
            stride=1,
            groups=self.total_channels
        )

    def forward(self, x):
        """
        x: (B, C_total*c_sync, L)

        returns: (B, C_total*c_sync, L) fused output
        """
        masked_groups = self.mask_module(x)

        out_part1 = self.fusing_part1(masked_groups)
        residual1 = self.residual_conv1(masked_groups)
        out_part1 = out_part1 + residual1

        out_part2 = self.fusing_part2(out_part1)
        residual2 = self.residual_conv2(out_part1)
        out_part2 = out_part2 + residual2

        return out_part2

# RMCE

class RMCE(nn.Module):
    """
    Applies two sequential convolutional blocks (with dilation & skip connections)
    to the masked input, effectively fusing the masked representations.
    """

    def __init__(self, total_channels, c_sync, c_fuse, kernel_size, sensors, num_channels):
        super(RMCE, self).__init__()
        self.total_channels = total_channels
        self.c_sync = c_sync

        padding = kernel_size - 1
        self.total_middle_fuse_channels=total_channels*c_fuse
        # First fuse part
        self.fusing_part1 = nn.Sequential(
            nn.Conv1d(
                in_channels=self.total_channels * c_sync,
                out_channels=self.total_middle_fuse_channels,
                kernel_size=kernel_size,
                stride=1,
                padding=padding,
                dilation=2
            ),
            nn.ReLU(),
            nn.BatchNorm1d(self.total_middle_fuse_channels),
            nn.Conv1d(
                in_channels=self.total_middle_fuse_channels,
                out_channels=self.total_middle_fuse_channels,
                kernel_size=kernel_size,
                stride=1,
                padding=padding * 2,
                dilation=4
            ),
            nn.ReLU(),
            nn.BatchNorm1d(self.total_middle_fuse_channels),
        )

        # Second fuse part
        self.fusing_part2 = nn.Sequential(
            nn.Conv1d(
                in_channels=self.total_middle_fuse_channels,
                out_channels=self.total_middle_fuse_channels,
                kernel_size=kernel_size,
                stride=1,
                padding=padding * 4,
                dilation=8
            ),
            nn.ReLU(),
            nn.BatchNorm1d(self.total_middle_fuse_channels),
            nn.Conv1d(
                in_channels=self.total_middle_fuse_channels,
                out_channels=self.total_channels * c_sync,
                kernel_size=kernel_size,
                stride=1,
                padding=padding * 8,
                dilation=16
            ),
            nn.ReLU(),
            nn.BatchNorm1d(self.total_channels * c_sync),
        )

        # Residual layers
        self.residual_conv1 = nn.Conv1d(
            in_channels=self.total_channels * c_sync,
            out_channels=self.total_middle_fuse_channels,
            kernel_size=1,
            stride=1
        )
        self.residual_conv2 = nn.Conv1d(
            in_channels=self.total_middle_fuse_channels,
            out_channels=self.total_channels * c_sync,
            kernel_size=1,
            stride=1
        )

    def forward(self, x):
        """
        x: (B, C_total*c_sync, L)

        returns: (B, C_total*c_sync, L) fused output
        """
        # generate random mask randomly masking out some channels of some sensors B, C, L
        mask = torch.rand(size=(x.shape[0], x.shape[1], x.shape[2]), dtype=torch.float32).to(x.device) > 0.95
        x = x * mask

        out_part1 = self.fusing_part1(x)
        residual1 = self.residual_conv1(x)
        out_part1 = out_part1 + residual1

        out_part2 = self.fusing_part2(out_part1)
        residual2 = self.residual_conv2(out_part1)
        out_part2 = out_part2 + residual2

        return out_part2





# --------------------------------------------------------------------------------
# PROJECTION BLOCK
# --------------------------------------------------------------------------------

class ProjectionBlock(nn.Module):
    """
    Projects the fused representation (common length L_common)
    back to each sensor's original length L_sensor (or to any desired L_out).
    Multiple methods: 'conv' or 'fc'.
    """

    def __init__(
        self,
        sensors,
        num_channels,
        c_sync,
        sync_head_conv_parameters,
        projection_method='conv',
        fc_num_layers=1,
        window_lengths=None
    ):
        super(ProjectionBlock, self).__init__()
        self.sensors = sensors
        self.num_channels = num_channels
        self.c_sync = c_sync
        self.total_channels = sum(num_channels.values())

        # For slicing out each sensor's portion from the fused feature map
        channel_sizes = [num_channels[sensor] * self.c_sync for sensor in sensors]
        cumulative_offsets = [0] + list(accumulate(channel_sizes))[:-1]
        self.proj_slices = [
            slice(start, start + size)
            for start, size in zip(cumulative_offsets, channel_sizes)
        ]

        self.proj_heads = nn.ModuleList()
        if projection_method == 'conv':
            # Use existing sync_head_utils in "output" mode (conv-based)
            for sensor in sensors:
                out_params = sync_head_utils.invert_synchronization_head_parameters(
                    sync_head_conv_parameters[sensor]
                )
                proj_head = sync_head_utils.create_synchronization_head(
                    input_sensor_channels=c_sync * num_channels[sensor],
                    output_sensor_channels=num_channels[sensor],
                    groups=num_channels[sensor],
                    parameters=out_params,
                    type='output'
                )
                self.proj_heads.append(proj_head)
            self._proj_fn = self._proj_fn_conv

        elif projection_method == 'fc':
            # Use the unified FC-based approach
            if window_lengths is None:
                raise ValueError("For 'fc' projection, you need `window_lengths`.")
            for sensor in sensors:
                L_common = sync_head_conv_parameters[sensor]['input_2']
                L_sensor = window_lengths[sensor]
                proj_head = create_fc_head(
                    input_size=L_common,
                    output_size=L_sensor,
                    num_channels=num_channels[sensor],
                    num_layers=fc_num_layers
                )
                self.proj_heads.append(proj_head)
            self._proj_fn = self._proj_fn_fc

        else:
            raise ValueError(f"Unknown projection_method: {projection_method}")

    def forward(self, fused_output):
        """
        fused_output: (B, C_total*c_sync, L_common)
        returns: list of (B, C_sensor, L_sensor), one for each sensor
        """
        return self._proj_fn(fused_output)

    def _proj_fn_conv(self, fused_output):
        sensor_projections = []
        for proj_head, sl in zip(self.proj_heads, self.proj_slices):
            sensor_out = proj_head(fused_output[:, sl, :])
            sensor_projections.append(sensor_out)
        return sensor_projections

    def _proj_fn_fc(self, fused_output):
        sensor_projections = []
        for proj_head, sl in zip(self.proj_heads, self.proj_slices):
            sensor_out = proj_head(fused_output[:, sl, :])
            sensor_projections.append(sensor_out)
        return sensor_projections


# --------------------------------------------------------------------------------
# MAIN MODEL
# --------------------------------------------------------------------------------

class SynchronMaskEstimator(nn.Module):
    """
    A neural network module that:
      1) Synchronizes input sensor data to a common window length (L_common).
      2) Passes these masked versions through a fusing block (PMCE).
      3) Desyncronises the fused representation back to each sensor's space (L_sensor).
    """

    def __init__(
        self,
        sensors,
        num_channels,
        window_lengths,
        c_sync,
        c_fuse,
        kernel_size,
        params,
        sync_method,
        sync_head_conv_parameters,
        projection_method='conv',
        fc_num_layers=1
    ):
        super(SynchronMaskEstimator, self).__init__()
        self.sensors = sensors
        self.num_channels_dict = num_channels
        self.c_sync = c_sync
        self.c_fuse = c_fuse
        self.kernel_size = kernel_size
        self.params = params
        self.default_device = params['device']
        self.total_channels = sum(num_channels.values())

        # 1) Synchronization block
        self.synchronizer = SynchronizationBlock(
            sensors=sensors,
            num_channels=num_channels,
            c_sync=c_sync,
            sync_head_conv_parameters=sync_head_conv_parameters,
            params=params,
            sync_method=sync_method,
            window_lengths=window_lengths,
            fc_num_layers=fc_num_layers
        )

        # 2) Fusing block
        self.fusing_block = RMCE(
            total_channels=self.total_channels,
            c_sync=self.c_sync,
            c_fuse=self.c_fuse,
            kernel_size=self.kernel_size,
            sensors=self.sensors,
            num_channels=self.num_channels_dict
        )

        # 3) Projection block
        self.projection_block = ProjectionBlock(
            sensors=sensors,
            num_channels=num_channels,
            c_sync=c_sync,
            sync_head_conv_parameters=sync_head_conv_parameters,
            projection_method=projection_method,
            fc_num_layers=fc_num_layers,
            window_lengths=window_lengths
        )

    def forward(self, input_data_list):
        """
        input_data_list: list of (B, C_sensor, L_sensor)
        returns: list of (B, C_sensor, L_sensor)
        """
        # Step 1: Synchronize & concat => (B, C_total*c_sync, L_common)
        synced_data = self.synchronizer(input_data_list)

        # Step 2 & 3: Fuse => (B, C_total*c_sync, L_common)
        fused_output = self.fusing_block(synced_data)

        # Step 4: Project back => list of (B, C_sensor, L_sensor)
        sensor_outputs = self.projection_block(fused_output)

        return sensor_outputs


In [None]:
C_sync = 1
C_fuse = 1
kernel_size = 3
model = SynchronMaskEstimator(sensors=SENSORS,
                              num_channels=NUM_CHANNELS,
                              window_lengths=WINDOW_LENGTHS,
                              c_sync=C_sync,
                              c_fuse=C_fuse,
                              kernel_size=kernel_size,
                              params=PARAMS,
                              sync_method='sync_head_conv',# 'sync_head_conv', 'sync_head_fc' , 'resample_interp', 'resample_fft', 'zeropad'
                              sync_head_conv_parameters=sync_head_conv_parameters,
                              projection_method='conv', # 'fc' or 'conv',
                              fc_num_layers=1
                              ).to(PARAMS['device'])

optimizer = torch.optim.Adam(model.parameters(), lr=PARAMS['lr'])
total_params = sum(p.numel() for p in model.parameters())
print(f"Total number of parameters: {total_params}")
# Calculate number of trainable parameters
trainable_params = sum(p.numel()
                       for p in model.parameters() if p.requires_grad)
print(f"Total number of trainable parameters: {trainable_params}")
model

In [None]:
best_loss = float('inf')
non_improving_count = 0
best_model = None
train_losses_epoch = []
valid_losses_epoch = []

for epoch in range(100):
    print('----------- Epoch', epoch + 1, '-----------')
    # Training phase
    model.train()
    total_train_loss = 0.0
    total_valid_loss = 0.0
    loader = tqdm.tqdm(train_data_loader)
    for batch_idx, x_batch in enumerate(loader):
        optimizer.zero_grad()
        x_batch = [x.to(PARAMS['device']) for x in x_batch]
        x_batch_output = model(x_batch)
        sensor_losses = torch.cat([((x_sensor - x_sensor_out)**2).mean(0).mean(1)
                                  for x_sensor, x_sensor_out in zip(x_batch, x_batch_output)])
        loss = sum(sensor_losses)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()
        avg_train_loss = total_train_loss / (batch_idx+1)
        loader.set_postfix({'Train_loss': avg_train_loss})
    train_losses_epoch.append(avg_train_loss)

    # Validation phase
    model.eval()
    with torch.no_grad():
        loader = tqdm.tqdm(valid_data_loader)
        for batch_idx, x_batch in enumerate(loader):
            x_batch = [x.to(PARAMS['device']) for x in x_batch]
            x_batch_output = model(x_batch)
            sensor_losses = sensor_losses = torch.cat([((x_sensor - x_sensor_out)**2).mean(
                0).mean(1) for x_sensor, x_sensor_out in zip(x_batch, x_batch_output)])
            loss = sum(sensor_losses)
            total_valid_loss += loss.item()
            avg_val_loss = total_valid_loss / (batch_idx+1)
            loader.set_postfix({'Valid_loss': avg_val_loss})
        valid_losses_epoch.append(avg_val_loss)

        # Check if current validation loss is less than the best found so far
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            best_model = model.state_dict()  # Save the best model
            non_improving_count = 0  # Reset the counter
        else:
            non_improving_count += 1  # Increment the counter since there is no improvement

        # Stop training if validation loss has not improved for more than 4 epochs
        if non_improving_count > PARAMS['patience']:
            print("Stopping early due to no improvement in validation loss.")
            break

In [89]:
x_batch = next(valid_data_loader.__iter__())
x_batch_t = [x.to(PARAMS['device']) for x in x_batch.copy()]
x_batch_output = model(x_batch_t)

sensor_zero_out = 0
x_batch_t_zeros = [x.clone() for x in x_batch_t]
x_batch_t_zeros[sensor_zero_out][:, :, :] = torch.randn_like(
    x_batch_t_zeros[sensor_zero_out][:, :, :])
# zero out one sensor
# x_batch_t_zeros[sensor_zero_out][:,:,:] = torch.zeros_like(x_batch_t_zeros[sensor_zero_out][:,:,:])

x_batch_output_zeros = model(x_batch_t_zeros)

In [None]:
# compute MSE of normal and 1 masked sensor version
sensor_losses = torch.cat([((x_sensor - x_sensor_out)**2).mean(0).mean(1)
                          for x_sensor, x_sensor_out in zip(x_batch_t, x_batch_output)]).cpu().detach().numpy()
sensor_losses_zeros = torch.cat([((x_sensor - x_sensor_out)**2).mean(0).mean(1)
                                for x_sensor, x_sensor_out in zip(x_batch_t, x_batch_output_zeros)]).cpu().detach().numpy()
# print the MSE of the normal and masked sensor version
print(sensor_losses)
print(sensor_losses_zeros)

In [None]:

sample = np.random.randint(0, PARAMS['batch_size'])
total_plot = sum([NUM_CHANNELS[sensor] for sensor in SENSORS])
plt.close('all')
fig, axs = plt.subplots(total_plot, 1, figsize=(
    WIDTH, HEIGHT * total_plot * 0.7), sharex=True)

i = 0
overall_MSE = 0
for sensor_idx, sensor in enumerate(SENSORS):
    for channel_idx in range(NUM_CHANNELS[sensor]):
        t = np.linspace(0, 1, WINDOW_LENGTHS[sensor])
        input = x_batch_t[sensor_idx][sample, channel_idx].cpu().numpy()
        output = x_batch_output[sensor_idx][sample,
                                            channel_idx].detach().cpu().numpy()
        output_zeros = x_batch_output_zeros[sensor_idx][sample, channel_idx].detach(
        ).cpu().numpy()
        # print loss on output and output_zeros
        loss_output = ((input - output)**2).mean(0)
        loss_output_zeros = ((input - output_zeros)**2).mean(0)

        print(f"Sensor:{sensor}, Channel:{channel_idx}, Loss on output: {
              loss_output}, Loss on output_zeros: {loss_output_zeros}")

        axs[i].plot(t, input, label='Input')
        axs[i].plot(t, output, label='Output')
        axs[i].plot(t, output_zeros, label='Output (Random)')
        i += 1
        overall_MSE += loss_output
print(f"Overall MSE: {overall_MSE}")
plt.tight_layout()
plt.legend()