In [None]:
import torch
import torch.nn as nn
import numpy as np
from scipy.signal import find_peaks

class SignalGating(nn.Module):
    def __init__(self, threshold=0.5, window_size=120, sampling_rate=416):
        super(SignalGating, self).__init__()
        self.threshold = threshold  # Threshold for peak detection
        self.window_size = window_size  # Window size in milliseconds
        self.sampling_rate = sampling_rate  # Sensor sampling rate
        self.samples_per_window = 50  # Fixed number of samples per channel

    def forward(self, imu_data):
        """
        Input: imu_data - Tensor of shape (6, N), where 6 channels: [accel_x, accel_y, accel_z, gyro_x, gyro_y, gyro_z]
        Output: Flattened Tensor of shape (300,)
        """
        if isinstance(imu_data, np.ndarray):
            imu_data = torch.tensor(imu_data, dtype=torch.float32)

        # Use the Z-axis accelerometer signal for peak detection
        accel_z = imu_data[2]  # Z-axis of accelerometer

        # Detect peaks in the Z-axis signal
        peaks, _ = find_peaks(accel_z.detach().cpu().numpy(), height=self.threshold)
        windows = []

        for peak in peaks:
            start = max(peak - self.samples_per_window // 2, 0)
            end = min(peak + self.samples_per_window // 2, accel_z.shape[0])

            # Extract 50 samples from each of the 6 channels
            window = imu_data[:, start:end]
            if window.shape[1] == self.samples_per_window:
                # Flatten the window to shape (300,)
                flattened_window = window.flatten()
                windows.append(flattened_window)

        return windows

# Test sample
if __name__ == "__main__":
    # Generate a test signal with random peaks for 6 channels
    np.random.seed(0)
    test_signal = np.random.normal(0, 0.1, (6, 1000))
    test_signal[2, 200] = 1.0  # Simulated tap peak on accel_z
    test_signal[2, 600] = 1.2  # Simulated tap peak on accel_z

    gating = SignalGating(threshold=0.5)
    detected_windows = gating(torch.tensor(test_signal))

    print(f"Number of detected taps: {len(detected_windows)}")
    for i, window in enumerate(detected_windows):
        print(f"Window {i+1} shape: {window.shape}")


Number of detected taps: 2
Window 1 shape: torch.Size([300])
Window 2 shape: torch.Size([300])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchsummary import summary
import matplotlib.pyplot as plt
from torchviz import make_dot

class TapNet(nn.Module):
    def __init__(self):
        super(TapNet, self).__init__()

        # Shared convolutional layers
        self.conv1 = nn.Conv1d(1, 32, kernel_size=3)  # 300 -> 298 -> (298, 32)
        self.pool1 = nn.MaxPool1d(2)  # (298, 32) -> (148, 32)

        self.conv2 = nn.Conv1d(32, 64, kernel_size=3)  # (148, 32) -> (146, 64)
        self.pool2 = nn.MaxPool1d(2)  # (146, 64) -> (72, 64)

        self.conv3 = nn.Conv1d(64, 64, kernel_size=3)  # (72, 64) -> (70, 64)
        self.pool3 = nn.MaxPool1d(2)  # (70, 64) -> (34, 64)

        self.conv4 = nn.Conv1d(64, 64, kernel_size=3)  # (34, 64) -> (34, 64)

        self.flatten = nn.Flatten()  # (34, 64) -> (2176)
        self.fc_shared = nn.Linear(2176, 64)  # (2176) -> (64)

        # Task-specific layers
        # Append device info (5 features) to shared features (64)
        self.device_info_dim = 5
        self.fc_combined_dim = 64 + self.device_info_dim

        # Branch 1: Tap Event Detection (Binary)
        self.fc_tap_1 = nn.Linear(self.fc_combined_dim, 8)
        self.fc_tap_2 = nn.Linear(8, 4)
        self.fc_tap_out = nn.Linear(4, 2)

        # Branch 2: Tap Direction (6 Classes)
        self.fc_dir_1 = nn.Linear(self.fc_combined_dim, 8)
        self.fc_dir_2 = nn.Linear(8, 7)
        self.fc_dir_out = nn.Linear(7, 6)

        # Branch 3: Finger Part (2 Classes)
        self.fc_finger_1 = nn.Linear(self.fc_combined_dim, 32)
        self.fc_finger_2 = nn.Linear(32, 8)
        self.fc_finger_out = nn.Linear(8, 2)

        # Branch 4: Location Classification (35 Classes)
        self.fc_loc_1 = nn.Linear(self.fc_combined_dim, 32)
        self.fc_loc_out = nn.Linear(32, 35)

        # Branch 5: Location Regression (2 values for X and Y)
        self.fc_reg_1 = nn.Linear(self.fc_combined_dim, 8)
        self.fc_reg_2 = nn.Linear(8, 4)
        self.fc_reg_out = nn.Linear(4, 2)

    def forward(self, x, device_info):
        # Shared convolutional layers
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        x = F.relu(self.conv4(x))

        # Flatten and shared fully connected layer
        x = self.flatten(x)
        x = F.relu(self.fc_shared(x))

        # Append device information
        x_combined = torch.cat([x, device_info], dim=1)

        # Branch 1: Tap Event Detection (Binary)
        tap = F.relu(self.fc_tap_1(x_combined))
        tap = F.relu(self.fc_tap_2(tap))
        tap = self.fc_tap_out(tap)

        # Branch 2: Tap Direction (6 Classes)
        direction = F.relu(self.fc_dir_1(x_combined))
        direction = F.relu(self.fc_dir_2(direction))
        direction = self.fc_dir_out(direction)

        # Branch 3: Finger Part (2 Classes)
        finger = F.relu(self.fc_finger_1(x_combined))
        finger = F.relu(self.fc_finger_2(finger))
        finger = self.fc_finger_out(finger)

        # Branch 4: Location Classification (35 Classes)
        location = F.relu(self.fc_loc_1(x_combined))
        location = self.fc_loc_out(location)

        # Branch 5: Location Regression (2 values for X and Y)
        regression = F.relu(self.fc_reg_1(x_combined))
        regression = F.relu(self.fc_reg_2(regression))
        regression = self.fc_reg_out(regression)

        return tap, direction, finger, location, regression

# Model instantiation
model = TapNet()

# Input shapes
input_signal = torch.randn(1, 1, 300)  # Single sample, 1 channel, 300 length
device_info = torch.randn(1, 5)  # Single sample, 5 device info features

# Forward pass
outputs = model(input_signal, device_info)

# Visualization of the model architecture
dot = make_dot((outputs), params=dict(list(model.named_parameters())))
dot.render("TapNet", format="png")

# Print model summary
summary(model, [(1, 300), (5,)])

# Outputs for visualization
for i, output in enumerate(outputs):
    print(f"Output {i+1} shape: {output.shape}")
