In [None]:
import torch
import torch.nn as nn
import numpy as np
from scipy.signal import find_peaks

class SignalGating(nn.Module):
    def __init__(self, threshold=0.5, window_size=120, sampling_rate=416):
        super(SignalGating, self).__init__()
        self.threshold = threshold  # Threshold for peak detection
        self.window_size = window_size  # Window size in milliseconds
        self.sampling_rate = sampling_rate  # Sensor sampling rate
        self.samples_per_window = 50  # Fixed number of samples per channel

    def forward(self, imu_data):
        """
        Input: imu_data - Tensor of shape (6, N), where 6 channels: [accel_x, accel_y, accel_z, gyro_x, gyro_y, gyro_z]
        Output: Flattened Tensor of shape (300,)
        """
        if isinstance(imu_data, np.ndarray):
            imu_data = torch.tensor(imu_data, dtype=torch.float32)

        # Use the Z-axis accelerometer signal for peak detection
        accel_z = imu_data[2]  # Z-axis of accelerometer

        # Detect peaks in the Z-axis signal
        peaks, _ = find_peaks(accel_z.detach().cpu().numpy(), height=self.threshold)
        windows = []

        for peak in peaks:
            start = max(peak - self.samples_per_window // 2, 0)
            end = min(peak + self.samples_per_window // 2, accel_z.shape[0])

            # Extract 50 samples from each of the 6 channels
            window = imu_data[:, start:end]
            if window.shape[1] == self.samples_per_window:
                # Flatten the window to shape (300,)
                flattened_window = window.flatten()
                windows.append(flattened_window)

        return windows

# Test sample
if __name__ == "__main__":
    # Generate a test signal with random peaks for 6 channels
    np.random.seed(0)
    test_signal = np.random.normal(0, 0.1, (6, 1000))
    test_signal[2, 200] = 1.0  # Simulated tap peak on accel_z
    test_signal[2, 600] = 1.2  # Simulated tap peak on accel_z

    gating = SignalGating(threshold=0.5)
    detected_windows = gating(torch.tensor(test_signal))

    print(f"Number of detected taps: {len(detected_windows)}")
    for i, window in enumerate(detected_windows):
        print(f"Window {i+1} shape: {window.shape}")


Number of detected taps: 2
Window 1 shape: torch.Size([300])
Window 2 shape: torch.Size([300])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchviz import make_dot

class TapNet(nn.Module):
    def __init__(self, device_info_dim=5):
        super(TapNet, self).__init__()
        self.device_info_dim = device_info_dim  # Device vector size

        # Shared Convolutional Layers
        self.conv1 = nn.Conv1d(1, 32, kernel_size=3)  # Output: (298, 32)
        self.pool1 = nn.MaxPool1d(kernel_size=2)      # Output: (148, 32)

        self.conv2 = nn.Conv1d(32, 64, kernel_size=3) # Output: (146, 64)
        self.pool2 = nn.MaxPool1d(kernel_size=2)      # Output: (72, 64)

        self.conv3 = nn.Conv1d(64, 64, kernel_size=3) # Output: (70, 64)
        self.pool3 = nn.MaxPool1d(kernel_size=2)      # Output: (34, 64)

        self.conv4 = nn.Conv1d(64, 64, kernel_size=3, padding=1)  # Output: (34, 64)

        self.flatten = nn.Flatten()  # Output: 2176
        self.fc_shared = nn.Linear(2176, 64)  # Fully connected layer before branching

        self.concat_dim = 64 + device_info_dim

        # Branch 1: Tap Event Detection (Binary)
        self.branch1_fc1 = nn.Linear(self.concat_dim, 8)
        self.branch1_fc2 = nn.Linear(8, 4)
        self.branch1_fc3 = nn.Linear(4, 2)

        # Branch 2: Tap Direction (6 Classes)
        self.branch2_fc1 = nn.Linear(self.concat_dim, 8)
        self.branch2_fc2 = nn.Linear(8, 7)
        self.branch2_fc3 = nn.Linear(7, 6)

        # Branch 3: Finger Part (2 Classes)
        self.branch3_fc1 = nn.Linear(self.concat_dim, 32)
        self.branch3_fc2 = nn.Linear(32, 8)
        self.branch3_fc3 = nn.Linear(8, 2)

        # Branch 4: Location Classification (35 Classes)
        self.branch4_fc1 = nn.Linear(self.concat_dim, 32)
        self.branch4_fc2 = nn.Linear(32, 35)

        # Branch 5: Location Regression (2 values for X and Y)
        self.branch5_fc1 = nn.Linear(self.concat_dim, 8)
        self.branch5_fc2 = nn.Linear(8, 4)
        self.branch5_fc3 = nn.Linear(4, 2)

    def forward(self, x, device_info):
        x = x.unsqueeze(1)
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        x = F.relu(self.conv4(x))
        x = self.flatten(x)
        x = F.relu(self.fc_shared(x))
        combined = torch.cat((x, device_info.float()), dim=1)

        # Branch 1: Tap Event Detection
        b1 = F.relu(self.branch1_fc1(combined))
        b1 = F.relu(self.branch1_fc2(b1))
        b1 = self.branch1_fc3(b1)

        # Branch 2: Tap Direction
        b2 = F.relu(self.branch2_fc1(combined))
        b2 = F.relu(self.branch2_fc2(b2))
        b2 = self.branch2_fc3(b2)

        # Branch 3: Finger Part
        b3 = F.relu(self.branch3_fc1(combined))
        b3 = F.relu(self.branch3_fc2(b3))
        b3 = self.branch3_fc3(b3)

        # Branch 4: Location Classification
        b4 = F.relu(self.branch4_fc1(combined))
        b4 = self.branch4_fc2(b4)

        # Branch 5: Location Regression
        b5 = F.relu(self.branch5_fc1(combined))
        b5 = F.relu(self.branch5_fc2(b5))
        b5 = self.branch5_fc3(b5)

        return b1, b2, b3, b4, b5

# Training and Testing Functions
def train_model(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for signals, device_info, labels in dataloader:
        signals, device_info = signals.to(device), device_info.to(device)
        labels = [label.to(device) for label in labels]
        optimizer.zero_grad()
        outputs = model(signals, device_info)
        loss = sum(criterion(out, label) for out, label in zip(outputs, labels))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

def test_model(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for signals, device_info, labels in dataloader:
            signals, device_info = signals.to(device), device_info.to(device)
            labels = [label.to(device) for label in labels]
            outputs = model(signals, device_info)
            loss = sum(criterion(out, label) for out, label in zip(outputs, labels))
            total_loss += loss.item()
    return total_loss / len(dataloader)

# Example Training Loop
if __name__ == "__main__":
    model = TapNet().to('cuda')
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(10):
        train_loss = train_model(model, train_loader, optimizer, criterion, 'cuda')
        test_loss = test_model(model, test_loader, criterion, 'cuda')
        print(f"Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
