In [None]:
!pip install wfdb
!pip install neurokit2

import os
import zipfile
import requests
from collections import namedtuple
from typing import List
import numpy as np
import wfdb
import matplotlib.pyplot as plt
from scipy.fft import fft, fftfreq
from scipy.interpolate import interp1d

# Define the named tuple structure
EMGData = namedtuple('EMGData', [
    'id',           # Unique identifier
    'subject',      # Subject/participant number
    'gesture',      # Gesture number
    'trial',        # Trial number
    'session',      # Session number
    'image'         # FFT image (5 windows x 32 channels x 512 freq bins)
])


In [None]:
# ============================================================================
# SECTION 1: Download and Extract Dataset from Drive
# ============================================================================

def download_grabmyo_dataset(save_dir='grabmyo_data', drive_zip_path='/content/drive/MyDrive/gesture-recognition-and-biometrics-electromyogram-grabmyo-1.1.0.zip'):
    os.makedirs(save_dir, exist_ok=True)

    # Define marker file path
    marker_file = os.path.join(save_dir, '.extraction_complete')

    # Check if extraction was already completed
    if os.path.exists(marker_file):
        print(f"✓ Dataset already extracted at: {save_dir}")
        print("Skipping extraction to save time.")
        return save_dir

    # Check if Google Drive is mounted
    if not os.path.exists('/content/drive'):
        print("Google Drive is not mounted. Please mount Google Drive to access the dataset.")
        print("You can do this by clicking the folder icon on the left, then the Google Drive icon.")
        return None

    print(f"Using dataset from Google Drive: {drive_zip_path}")

    if not os.path.exists(drive_zip_path):
        print(f"Error: Zip file not found at {drive_zip_path}")
        return None

    print("\nExtracting files...")
    try:
        with zipfile.ZipFile(drive_zip_path, 'r') as zip_ref:
            zip_ref.extractall(save_dir)

        print(f"Dataset extracted to: {save_dir}")

        # Create marker file to indicate successful extraction
        with open(marker_file, 'w') as f:
            f.write(f"Extraction completed on: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")

        print("✓ Extraction complete! Marker file created.")

    except Exception as e:
        print(f"Error during extraction: {e}")
        return None

    return save_dir


In [None]:
def normalize_fft_images(data_list, global_min=None, global_max=None):
    # Compute global statistics if not provided
    if global_min is None or global_max is None:
        print("Computing global statistics from all data...")
        all_values = []
        for data in data_list:
            all_values.append(data.image.flatten())

        all_values = np.concatenate(all_values)
        global_min = all_values.min()
        global_max = all_values.max()

        print(f"Global min: {global_min:.4f}")
        print(f"Global max: {global_max:.4f}")

    # Normalize all images using global statistics
    print("Normalizing all images with global statistics...")
    normalized_data = []

    for data in data_list:
        # Normalize to [0, 1]
        if global_max - global_min > 0:
            normalized_image = (data.image - global_min) / (global_max - global_min)
        else:
            normalized_image = np.zeros_like(data.image)

        # Scale to [-1, 1]
        normalized_image = 2 * normalized_image - 1

        # Create new EMGData with normalized image
        normalized_data.append(EMGData(
            id=data.id,
            subject=data.subject,
            gesture=data.gesture,
            trial=data.trial,
            session=data.session,
            image=normalized_image
        ))

    print(f"Normalized {len(normalized_data)} samples")

    return normalized_data, global_min, global_max


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

In [None]:
def prepare_data_for_cnn(data_list):
    images = []
    labels = []
    for data in data_list:
        # The image shape is (5, 32, 512). We need to add a channel dimension for CNN.
        # The input shape for a 2D CNN is typically (batch_size, height, width, channels).
        # Here, each sample is a sequence of 5 images, so we can treat this as a batch of 5 images with 1 channel each,
        # or reshape it to have 5 channels with height 32 and width 512, or flatten the windows.
        # Let's reshape it to (32, 512, 5) to represent 5 channels.
        reshaped_image = np.transpose(data.image, (1, 2, 0)) # Transpose to (32, 512, 5)
        images.append(reshaped_image)
        labels.append(data.subject) # Use subject as the label

    # Convert lists to numpy arrays
    images = np.array(images)
    labels = np.array(labels)

    return images, labels

class EMGDataset(Dataset):
    def __init__(self, images, labels):
        self.images = torch.tensor(images, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

In [None]:
import numpy as np
import neurokit2 as nk
from scipy.interpolate import interp1d

# ============================================================================
# SECTION 1: Channel Configuration
# ============================================================================

# Define channel indices to keep (all except unused channels)
# Unused channels: U1-U4 at column indices 16, 23, 24, 31 (0-indexed)
UNUSED_CHANNELS = [16, 23, 24, 31]
# ALL_CHANNELS = list(range(32))
# ACTIVE_CHANNELS = [ch for ch in ALL_CHANNELS if ch not in UNUSED_CHANNELS]

# Channel groups for reference
# FOREARM_CHANNELS = list(range(0, 16))  # F1-F16
# WRIST_CHANNELS = list(range(17, 23)) + list(range(25, 31))  # W1-W12 (skipping unused)

# only use these channels (0-indexed): [8,9,10,11,12,13,14,15,17,18,19,20,21,22]
ACTIVE_CHANNELS = [8, 9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20, 21, 22]

# Channel groups for reference
FOREARM_CHANNELS = [ch for ch in ACTIVE_CHANNELS if ch < 16]   # F9-F16 (0-indexed)
WRIST_CHANNELS  = [ch for ch in ACTIVE_CHANNELS if ch >= 17]     # W1-W6  (0-indexed)

# ============================================================================
# SECTION 2: Data Augmentation Functions
# ============================================================================

def augment_emg_signal(signal, fs=2048, augmentation_params=None):
    if augmentation_params is None:
        augmentation_params = {
            'noise_level': 0.02,
            'amplitude_scale': True,
            'time_warp': True,
            'dc_shift': True,
            'baseline_wander': True
        }

    augmented = signal.copy()

    # 1. Add random Gaussian noise
    if augmentation_params.get('noise_level', 0) > 0:
        noise_std = augmentation_params['noise_level'] * np.std(signal, axis=0)
        noise = np.random.normal(0, noise_std, signal.shape)
        augmented += noise

    # 2. Random amplitude scaling (0.8 to 1.2)
    if augmentation_params.get('amplitude_scale', False):
        scale_factor = np.random.uniform(0.8, 1.2)
        augmented *= scale_factor

    # 3. Small DC shift per channel
    if augmentation_params.get('dc_shift', False):
        dc_shift = np.random.uniform(-0.05, 0.05, augmented.shape[1])
        augmented += dc_shift

    # 4. Low frequency baseline wander (simulating motion artifacts)
    if augmentation_params.get('baseline_wander', False):
        t = np.arange(augmented.shape[0]) / fs
        for ch in range(augmented.shape[1]):
            freq = np.random.uniform(0.1, 2.0)  # 0.1-2 Hz
            amplitude = np.random.uniform(0.01, 0.05) * np.std(signal[:, ch])
            phase = np.random.uniform(0, 2*np.pi)
            baseline = amplitude * np.sin(2 * np.pi * freq * t + phase)
            augmented[:, ch] += baseline

    # 5. Time warping (subtle stretching/compression)
    if augmentation_params.get('time_warp', False):
        warp_factor = np.random.uniform(0.95, 1.05)
        original_time = np.arange(augmented.shape[0])
        new_time = np.linspace(0, augmented.shape[0]-1,
                               int(augmented.shape[0] * warp_factor))

        warped = np.zeros_like(augmented)
        for ch in range(augmented.shape[1]):
            interp_func = interp1d(original_time, augmented[:, ch],
                                   kind='cubic', fill_value='extrapolate')
            warped_signal = interp_func(new_time)

            # Resample back to original length
            if len(warped_signal) > augmented.shape[0]:
                warped[:, ch] = warped_signal[:augmented.shape[0]]
            else:
                warped[:len(warped_signal), ch] = warped_signal
                warped[len(warped_signal):, ch] = warped_signal[-1]

        augmented = warped

    return augmented


def generate_augmented_samples(signal, fs=2048, num_augmentations=5):
    augmented_samples = [signal]  # Include original

    # Generate variations with different augmentation strengths
    for i in range(num_augmentations):
        # Vary augmentation parameters
        params = {
            'noise_level': np.random.uniform(0.01, 0.03),
            'amplitude_scale': np.random.choice([True, False]),
            'time_warp': np.random.choice([True, False]),
            'dc_shift': True,
            'baseline_wander': np.random.choice([True, False])
        }

        augmented = augment_emg_signal(signal, fs, params)
        augmented_samples.append(augmented)

    return augmented_samples


# ============================================================================
# SECTION 3: EMG Signal Filtering
# ============================================================================

def filter_emg_signal(signal, fs=2048, method='neurokit'):
    filtered = np.zeros_like(signal)

    for ch in range(signal.shape[1]):
        if method == 'neurokit':
            filtered[:, ch] = nk.emg_clean(signal[:, ch], sampling_rate=fs)
        else:
            # Custom bandpass filter
            filtered[:, ch] = nk.signal_filter(
                signal[:, ch],
                sampling_rate=fs,
                lowcut=100,
                highcut=500,
                method='butterworth',
                order=4
            )

    return filtered


# ============================================================================
# SECTION 4: Updated Data Loading Functions
# ============================================================================

def parse_filename(filename):
    parts = filename.split('_')
    session = int(parts[0].replace('session', ''))
    subject = int(parts[1].replace('participant', ''))
    gesture = int(parts[2].replace('gesture', ''))
    trial = int(parts[3].replace('trial', ''))
    return session, subject, gesture, trial


def create_fft_image(signals, fs, window_duration=1, target_width=512):
    samples_per_window = int(fs * window_duration)
    num_windows = signals.shape[0] // samples_per_window
    num_channels = signals.shape[1]

    fft_images = []

    for window_idx in range(num_windows):
        start_idx = window_idx * samples_per_window
        end_idx = start_idx + samples_per_window

        fft_image = np.zeros((num_channels, target_width))

        for ch in range(num_channels):
            window_data = signals[start_idx:end_idx, ch]

            # Compute FFT
            N = len(window_data)
            fft_result = fft(window_data)
            freqs = fftfreq(N, 1/fs)

            # Keep only positive frequencies
            positive_mask = freqs >= 0
            positive_freqs = freqs[positive_mask]
            fft_magnitude = np.abs(fft_result[positive_mask])

            # Interpolate to target width
            interp_func = interp1d(positive_freqs, fft_magnitude,
                                   kind='linear', fill_value='extrapolate')
            new_freqs = np.linspace(0, positive_freqs[-1], target_width)
            fft_magnitude_resampled = interp_func(new_freqs)

            fft_image[ch, :] = fft_magnitude_resampled

        fft_images.append(fft_image)

    fft_images = np.array(fft_images)

    # Apply log transform ONLY (no normalization)
    epsilon = 1e-10
    fft_images_log = np.log10(fft_images + epsilon)

    return fft_images_log


def load_single_file(file_path, unique_id, apply_filtering=True,
                     apply_augmentation=False, num_augmentations=5):
    # Read WFDB record
    record = wfdb.rdrecord(file_path)

    # Extract metadata from filename
    filename = os.path.basename(file_path)
    session, subject, gesture, trial = parse_filename(filename)

    # Filter channels (remove unused channels)
    signals = record.p_signal[:, ACTIVE_CHANNELS]

    # Apply filtering if requested
    if apply_filtering:
        signals = filter_emg_signal(signals, fs=record.fs)

    data_samples = []

    # Just process original signal
    fft_image = create_fft_image(signals, record.fs)

    data_samples.append(EMGData(
        id=unique_id,
        subject=subject,
        gesture=gesture,
        trial=trial,
        session=session,
        image=fft_image
    ))

    # Generate augmented versions if requested
    if apply_augmentation:
        augmented_signals = generate_augmented_samples(
            signals, fs=record.fs, num_augmentations=num_augmentations
        )

        for aug_idx, aug_signal in enumerate(augmented_signals):
            fft_image = create_fft_image(aug_signal, record.fs)

            data_samples.append(EMGData(
                id=unique_id,
                subject=subject,
                gesture=gesture,
                trial=8 + aug_idx,  # Augmented trials start at 8
                session=session,
                image=fft_image
            ))

    return data_samples



from concurrent.futures import ProcessPoolExecutor, as_completed
import os

def load_all_data(base_dir='grabmyo_data', apply_filtering=True,
                  apply_augmentation=False, num_augmentations=5, gesture=1,
                  load_all_gestures=False, use_multiprocessing=True, max_workers=None):

    if load_all_gestures:
        print(f"Scanning for ALL gesture files...")
    else:
        print(f"Scanning for gesture {gesture} files...")

    # First pass: count total files matching criteria
    all_files = []
    gesture_counts = {}

    for root, dirs, files in os.walk(base_dir):
        for file in files:
            if file.endswith('.hea'):
                filename = file[:-4]  # Remove .hea extension
                try:
                    _, _, file_gesture, _ = parse_filename(filename)

                    # Track gesture counts
                    gesture_counts[file_gesture] = gesture_counts.get(file_gesture, 0) + 1

                    # Filter based on load_all_gestures flag
                    if load_all_gestures or file_gesture == gesture:
                        file_path = os.path.join(root, filename)
                        all_files.append(file_path)
                except:
                    continue

    total_files = len(all_files)

    if load_all_gestures:
        print(f"Found {total_files} files across {len(gesture_counts)} gestures:")
        for gest, count in sorted(gesture_counts.items()):
            print(f"  Gesture {gest}: {count} files")
    else:
        print(f"Found {total_files} files for gesture {gesture}")

    if total_files == 0:
        print("No files found!")
        return []

    # Check available CPU cores
    import multiprocessing as mp
    available_cores = mp.cpu_count()
    if max_workers is None:
        max_workers = available_cores

    print(f"\nAvailable CPU cores: {available_cores}")

    if use_multiprocessing and total_files > 1:
        print(f"Using multiprocessing with {max_workers} workers")
    else:
        print("Using single-threaded processing")

    if apply_augmentation:
        print(f"Augmentation enabled: {num_augmentations} additional samples per file")

    print("\nLoading files...")

    data_list = []
    files_processed = 0

    if use_multiprocessing and total_files > 1:
        # Multiprocessing approach
        with ProcessPoolExecutor(max_workers=max_workers) as executor:
            # Submit all tasks
            future_to_file = {
                executor.submit(
                    load_single_file,
                    file_path,
                    idx + 1,
                    apply_filtering,
                    apply_augmentation,
                    num_augmentations
                ): file_path
                for idx, file_path in enumerate(all_files)
            }

            # Process results as they complete
            for future in as_completed(future_to_file):
                file_path = future_to_file[future]
                try:
                    samples = future.result()
                    data_list.extend(samples)
                    files_processed += 1

                    # Print progress every 20 files
                    if files_processed % 20 == 0:
                        percentage = (files_processed / total_files) * 100
                        total_samples = len(data_list)
                        print(f"Progress: {files_processed}/{total_files} files ({percentage:.1f}%) - {total_samples} samples generated")

                except Exception as e:
                    filename = os.path.basename(file_path)
                    print(f"Error loading {filename}: {e}")
    else:
        # Single-threaded approach (fallback)
        for idx, file_path in enumerate(all_files):
            try:
                samples = load_single_file(
                    file_path,
                    idx + 1,
                    apply_filtering,
                    apply_augmentation,
                    num_augmentations
                )
                data_list.extend(samples)
                files_processed += 1

                # Print progress every 20 files
                if files_processed % 20 == 0:
                    percentage = (files_processed / total_files) * 100
                    total_samples = len(data_list)
                    print(f"Progress: {files_processed}/{total_files} files ({percentage:.1f}%) - {total_samples} samples generated")

            except Exception as e:
                filename = os.path.basename(file_path)
                print(f"Error loading {filename}: {e}")

    total_count = len(data_list)

    print(f"\n=== Loading Complete ===")
    if load_all_gestures:
        print(f"All gestures loaded")
        # Count samples per gesture
        gesture_sample_counts = {}
        for data in data_list:
            gesture_sample_counts[data.gesture] = gesture_sample_counts.get(data.gesture, 0) + 1
        print("Samples per gesture:")
        for gest, count in sorted(gesture_sample_counts.items()):
            print(f"  Gesture {gest}: {count} samples")
    else:
        print(f"Gesture: {gesture}")

    print(f"Files processed: {files_processed}/{total_files} (100%)")

    if apply_augmentation:
        original_samples = files_processed
        augmented_samples = total_count - original_samples
        print(f"Original samples: {original_samples}")
        print(f"Augmented samples: {augmented_samples}")
        print(f"Total samples: {total_count}")
    else:
        print(f"Total samples: {total_count}")

    print(f"Active channels: {len(ACTIVE_CHANNELS)} (excluding {len(UNUSED_CHANNELS)} unused)")

    return data_list


In [None]:
# Download and extract the dataset
dataset_path = download_grabmyo_dataset()

In [None]:
print("Loading data...")
data_vector_raw = load_all_data('grabmyo_data', apply_augmentation=True, gesture=17)
print(f"Loaded {len(data_vector_raw)} samples")

# Step 2: Apply global normalization to ALL data
data_vector, global_min, global_max = normalize_fft_images(data_vector_raw)


# Step 4: Split into train/test with new strategy
data_vector.sort(key=lambda x: (x.subject, x.session, x.trial))

# Get all unique subjects
subjects = sorted(list(set([data.subject for data in data_vector])))

# Define subject groups
main_train_subjects = subjects[:35]  # First 35 subjects
reserved_train_subjects = subjects[35:38]  # Next 3 subjects (36, 37, 38)
test_only_subjects = subjects[38:]  # Remaining subjects (39+)

print(f"Main training subjects: {len(main_train_subjects)} subjects")
print(f"Reserved training subjects: {len(reserved_train_subjects)} subjects")
print(f"Test-only subjects: {len(test_only_subjects)} subjects")

# Initialize data lists
train_data_main = []  # First 35 subjects, all data
train_data_reserved = []  # Subjects 36-38, all sessions except last
test_data = []  # Subjects 36-38 last session + all data from subjects 39+

# Process first 35 subjects - all data goes to main training
for data in data_vector:
    if data.subject in main_train_subjects:
        train_data_main.append(data)

# Process subjects 36-38 - split by last session
for subject in reserved_train_subjects:
    subject_data = [d for d in data_vector if d.subject == subject]

    # Get all unique sessions for this subject
    sessions = sorted(list(set([d.session for d in subject_data])))

    if len(sessions) > 0:
        last_session = sessions[-1]

        # All sessions except last → reserved training
        for data in subject_data:
            if data.session != last_session:
                train_data_reserved.append(data)
            else:
                # Last session → test
                test_data.append(data)

# Process remaining subjects (39+) - all data goes to test
for data in data_vector:
    if data.subject in test_only_subjects:
        test_data.append(data)

# Print statistics
print(f"\n=== Data Split Summary ===")
print(f"Main training samples (subjects 1-35): {len(train_data_main)}")
print(f"Reserved training samples (subjects 36-38, excluding last session): {len(train_data_reserved)}")
print(f"Test samples (subjects 36-38 last session + subjects 39+): {len(test_data)}")
print(f"\nTotal training samples: {len(train_data_main) + len(train_data_reserved)}")
print(f"Total test samples: {len(test_data)}")

# Optional: Verify subject distribution
print(f"\n=== Subject Distribution ===")
train_main_subjects_actual = sorted(list(set([d.subject for d in train_data_main])))
train_reserved_subjects_actual = sorted(list(set([d.subject for d in train_data_reserved])))
test_subjects_actual = sorted(list(set([d.subject for d in test_data])))

print(f"Main train subjects: {train_main_subjects_actual}")
print(f"Reserved train subjects: {train_reserved_subjects_actual}")
print(f"Test subjects: {test_subjects_actual}")

In [None]:
class ImprovedSubjectCNN(nn.Module):
    def __init__(self, num_classes, input_shape, dropout_rate=0.5):
        super(ImprovedSubjectCNN, self).__init__()

        # Input: (batch, 5, 14, 512) after permute
        self.conv1 = nn.Conv2d(5, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU()
        # Only pool width, not height (use (1, 2) instead of (2, 2))
        self.pool1 = nn.MaxPool2d(kernel_size=(1, 2))  # -> (64, 14, 256)

        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=(2, 2))  # -> (128, 7, 128)

        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=(1, 2))  # -> (256, 7, 64)

        self.conv4 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.relu4 = nn.ReLU()
        self.pool4 = nn.MaxPool2d(kernel_size=(1, 2))  # -> (512, 7, 32)

        # Calculate flattened size with actual input shape
        self.flatten = nn.Flatten()
        with torch.no_grad():
            dummy = torch.zeros(1, *input_shape)
            dummy = dummy.permute(0, 3, 1, 2)  # -> (1, 5, 14, 512)
            x = self.pool1(self.relu1(self.bn1(self.conv1(dummy))))
            x = self.pool2(self.relu2(self.bn2(self.conv2(x))))
            x = self.pool3(self.relu3(self.bn3(self.conv3(x))))
            x = self.pool4(self.relu4(self.bn4(self.conv4(x))))
            flattened_size = self.flatten(x).shape[1]
            print(f"Calculated flattened size: {flattened_size}")

        self.fc1 = nn.Linear(flattened_size, 512)
        self.bn5 = nn.BatchNorm1d(512)
        self.relu5 = nn.ReLU()
        self.dropout1 = nn.Dropout(dropout_rate)

        self.fc2 = nn.Linear(512, 256)
        self.bn6 = nn.BatchNorm1d(256)
        self.relu6 = nn.ReLU()
        self.dropout2 = nn.Dropout(dropout_rate)

        self.fc3 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)  # (batch, height, width, channels) -> (batch, channels, height, width)
        x = self.pool1(self.relu1(self.bn1(self.conv1(x))))
        x = self.pool2(self.relu2(self.bn2(self.conv2(x))))
        x = self.pool3(self.relu3(self.bn3(self.conv3(x))))
        x = self.pool4(self.relu4(self.bn4(self.conv4(x))))
        x = self.flatten(x)
        x = self.dropout1(self.relu5(self.bn5(self.fc1(x))))
        x = self.dropout2(self.relu6(self.bn6(self.fc2(x))))
        x = self.fc3(x)
        return x




# Prepare data
train_images, train_labels = prepare_data_for_cnn(train_data_main)

# Get unique labels FIRST
unique_labels = sorted(list(set(train_labels)))  # sorted for consistency

# Create mapping from unique labels only
subject_to_int = {subject: i for i, subject in enumerate(unique_labels)}

print(f"Number of classes: {len(unique_labels)}")
print(f"Unique subjects: {unique_labels}")

# Convert labels to integers
train_labels_int = np.array([subject_to_int[label] for label in train_labels])

print(f"\nTraining samples: {len(train_labels_int)}")
print(f"Train label range: {train_labels_int.min()} to {train_labels_int.max()}")

# Create datasets and loaders
train_dataset = EMGDataset(train_images, train_labels_int)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)



In [None]:
# Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_shape = train_images.shape[1:]
num_classes = len(unique_labels)

model = ImprovedSubjectCNN(num_classes, input_shape, dropout_rate=0.3)
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-4)

# Use StepLR scheduler - reduces LR every 15 epochs
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=15, gamma=0.5)

# Or alternatively, use CosineAnnealingLR for smooth decay
# scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

print(f"Model initialized on {device}")
print(f"Number of classes: {num_classes}")
print(f"Input shape: {input_shape}")

# Training loop
num_epochs = 100
best_loss = float('inf')

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    running_corrects = 0

    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, predicted = torch.max(outputs, 1)
        running_corrects += torch.sum(predicted == labels.data)

    epoch_loss = running_loss / len(train_dataset)
    epoch_acc = running_corrects.double() / len(train_dataset)

    # Step the scheduler
    scheduler.step()

    # Get current learning rate
    current_lr = optimizer.param_groups[0]['lr']

    print(f'Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f}, Acc: {epoch_acc:.4f}, LR: {current_lr:.6f}')

    # Save best model based on training loss
    if epoch_loss < best_loss:
        best_loss = epoch_loss
        torch.save(model.state_dict(), 'best_model.pth')
        print(f'  Saved best model with loss: {best_loss:.4f}\n')

print(f"\nFinished Training. Best training loss: {best_loss:.4f}")


In [None]:
class BiometricVerificationModel(nn.Module):
    def __init__(self, pretrained_model_path="best_model.pth", dropout_rate=0.3, device="cpu"):
        super(BiometricVerificationModel, self).__init__()

        checkpoint = torch.load(pretrained_model_path, map_location=device)

        # Extract original parameters
        original_num_classes = checkpoint['fc3.weight'].shape[0]
        original_flattened_size = checkpoint['fc1.weight'].shape[1]

        print(f"Loading pretrained model:")
        print(f"  - Num classes: {original_num_classes}")
        print(f"  - Flattened size: {original_flattened_size}")

        if original_flattened_size == 114688:
            original_input_shape = (14, 512, 5)
        else:
            # Fallback
            original_input_shape = (14, 512, 5)
            print(f"  - WARNING: Unexpected flattened size, using shape: {original_input_shape}")

        # Create and load original model
        original_model = ImprovedSubjectCNN(
            num_classes=original_num_classes,
            input_shape=original_input_shape,
            dropout_rate=0.5
        )
        original_model.load_state_dict(checkpoint)

        # Copy all layers
        self.conv1 = original_model.conv1
        self.bn1 = original_model.bn1
        self.relu1 = original_model.relu1
        self.pool1 = original_model.pool1

        self.conv2 = original_model.conv2
        self.bn2 = original_model.bn2
        self.relu2 = original_model.relu2
        self.pool2 = original_model.pool2

        self.conv3 = original_model.conv3
        self.bn3 = original_model.bn3
        self.relu3 = original_model.relu3
        self.pool3 = original_model.pool3

        self.conv4 = original_model.conv4
        self.bn4 = original_model.bn4
        self.relu4 = original_model.relu4
        self.pool4 = original_model.pool4

        self.flatten = original_model.flatten
        self.fc1 = original_model.fc1
        self.bn5 = original_model.bn5
        self.relu5 = original_model.relu5
        self.dropout1 = original_model.dropout1

        self.fc2 = original_model.fc2
        self.bn6 = original_model.bn6
        self.relu6 = original_model.relu6
        self.dropout2 = original_model.dropout2

        # Freeze pretrained layers
        for p in self.parameters():
            p.requires_grad = False

        # Add new verification head
        self.fc_verify1 = nn.Linear(256, 128)
        self.bn_verify1 = nn.BatchNorm1d(128)
        self.relu_verify1 = nn.ReLU()
        self.dropout_verify1 = nn.Dropout(dropout_rate)

        self.fc_verify2 = nn.Linear(128, 1)
        self.sigmoid = nn.Sigmoid()

        # Make verification head trainable
        for p in [self.fc_verify1, self.bn_verify1, self.fc_verify2]:
            for param in p.parameters():
                param.requires_grad = True

    def forward(self, x):
        x = x.permute(0, 3, 1, 2)
        x = self.pool1(self.relu1(self.bn1(self.conv1(x))))
        x = self.pool2(self.relu2(self.bn2(self.conv2(x))))
        x = self.pool3(self.relu3(self.bn3(self.conv3(x))))
        x = self.pool4(self.relu4(self.bn4(self.conv4(x))))
        x = self.flatten(x)
        x = self.dropout1(self.relu5(self.bn5(self.fc1(x))))
        x = self.dropout2(self.relu6(self.bn6(self.fc2(x))))

        # Verification head
        x = self.dropout_verify1(self.relu_verify1(self.bn_verify1(self.fc_verify1(x))))
        x = self.fc_verify2(x)
        x = self.sigmoid(x)
        return x


In [None]:
# Choose one of the reserved training subjects as the target verification subject
# reserved_train_subjects was defined when splitting data_vector
TARGET_SUBJECT = reserved_train_subjects[2]  # or manually: 36 / 37 / 38

# Get all data for the target subject from the reserved+test sets
target_all = [d for d in train_data_reserved if d.subject == TARGET_SUBJECT] + \
             [d for d in test_data if d.subject == TARGET_SUBJECT]

# Find this subject's sessions
target_sessions = sorted(list(set(d.session for d in target_all)))
assert len(target_sessions) >= 2, "Target subject must have at least 2 sessions."

# Define training and test sessions:
# - training: first two sessions
# - test: last session (may be same as third if only 3 sessions)
train_sessions_target = target_sessions[:2]
test_session_target = target_sessions[-1]

genuine_train = [d for d in target_all if d.session in train_sessions_target]
genuine_test = [d for d in target_all if d.session == test_session_target]

print(f"Target subject: {TARGET_SUBJECT}")
print(f"Target sessions: {target_sessions}")
print(f"Train sessions (target): {train_sessions_target}")
print(f"Test session (target): {test_session_target}")
print(f"Genuine train samples: {len(genuine_train)}")
print(f"Genuine test samples: {len(genuine_test)}")

# -----------------------------
# Impostor pools
# -----------------------------

# 1) Training impostors: subjects from main training set ONLY (subjects 1-35),
#    and we can optionally restrict to the same sessions as target's train sessions
impostor_train_pool = [
    d for d in train_data_main
    if d.subject != TARGET_SUBJECT and d.session in train_sessions_target
]

# 2) Test impostors:
#    - other 2 reserved subjects: only their last session
#    - all test_only_subjects (39+) all sessions
#    - optionally: last session from main train subjects as "unseen" sessions, if present in test_data
other_reserved_subjects = [s for s in reserved_train_subjects if s != TARGET_SUBJECT]

impostor_test_pool = []

# Other reserved subjects: get their last session and put into impostor pool
for subj in other_reserved_subjects:
    subj_all = [d for d in train_data_reserved if d.subject == subj] + \
               [d for d in test_data if d.subject == subj]
    if len(subj_all) == 0:
        continue
    subj_sessions = sorted(list(set(d.session for d in subj_all)))
    last_sess = subj_sessions[-1]
    impostor_test_pool.extend([d for d in subj_all if d.session == last_sess])

# Test-only subjects (39+): everything is impostor test data
test_only_subjects = [s for s in sorted(list(set(d.subject for d in test_data)))
                      if s not in reserved_train_subjects]
impostor_test_pool.extend([d for d in test_data if d.subject in test_only_subjects])

print(f"Impostor train pool: {len(impostor_train_pool)}")
print(f"Impostor test pool: {len(impostor_test_pool)}")


In [None]:
import random

def generate_training_pairs(ratio_impostor_to_genuine=1.0):
    pairs = []

    # Genuine pairs (target subject vs itself)
    for gs in genuine_train:
        pairs.append((gs, 1.0))  # label=1.0 for genuine

    n_impostor_needed = int(ratio_impostor_to_genuine * len(genuine_train))
    for _ in range(n_impostor_needed):
        isample = random.choice(impostor_train_pool)
        pairs.append((isample, 0.0))  # label=0.0 for impostor

    random.shuffle(pairs)
    return pairs


def generate_test_pairs(ratio_impostor_to_genuine=4.0):
    pairs = []

    for gs in genuine_test:
        pairs.append((gs, 1.0))

    n_impostor_needed = int(ratio_impostor_to_genuine * len(genuine_test))
    for _ in range(n_impostor_needed):
        isample = random.choice(impostor_test_pool)
        pairs.append((isample, 0.0))

    random.shuffle(pairs)
    return pairs


In [None]:
class VerificationDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        sample, label = self.pairs[idx]
        # sample.image: (5, 32, 512) -> (32, 512, 5)
        img = np.transpose(sample.image, (1, 2, 0)).astype(np.float32)
        img = torch.from_numpy(img)
        label = torch.tensor(label, dtype=torch.float32)
        return img, label


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = BiometricVerificationModel(pretrained_model_path="best_model.pth", device=device).to(device)

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(
    [p for p in model.parameters() if p.requires_grad],
    lr=1e-3
)

batch_size = 64
num_epochs = 100

for epoch in range(num_epochs):
    model.train()

    # Regenerate training pairs each epoch for fresh impostors
    train_pairs = generate_training_pairs(ratio_impostor_to_genuine=1.0)
    train_dataset = VerificationDataset(train_pairs)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    running_loss = 0.0
    for x, y in train_loader:
        x = x.to(device)
        y = y.to(device).view(-1, 1)

        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * x.size(0)

    epoch_loss = running_loss / len(train_dataset)
    print(f"Epoch {epoch+1}/{num_epochs} - train loss: {epoch_loss:.4f}")

# Build test loader once (or regenerate if you want random impostors)
test_pairs = generate_test_pairs(ratio_impostor_to_genuine=4.0)
test_dataset = VerificationDataset(test_pairs)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

model.eval()
all_preds = []
all_labels = []
with torch.no_grad():
    for x, y in test_loader:
        x = x.to(device)
        y = y.to(device).view(-1, 1)

        y_hat = model(x)
        all_preds.append(y_hat.cpu())
        all_labels.append(y.cpu())

all_preds = torch.cat(all_preds).numpy().ravel()
all_labels = torch.cat(all_labels).numpy().ravel()


In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix
import torch
from torch.utils.data import DataLoader

def comprehensive_validation(verification_model, device, target_subject,
                           train_data_main, train_data_reserved, test_data):

    verification_model.eval()

    # Get target subject's data
    target_all = [d for d in train_data_reserved if d.subject == target_subject] + \
                 [d for d in test_data if d.subject == target_subject]

    # Find target subject's sessions
    target_sessions = sorted(list(set(d.session for d in target_all)))
    train_sessions_target = target_sessions[:2]  # Sessions 1 and 2
    test_session_target = target_sessions[-1]    # Last session (session 3)

    # Get test-only subjects (the "untrained upper ones")
    test_only_subjects = sorted(list(set(d.subject for d in test_data if d.subject >= 39)))

    print(f"Target subject: {target_subject}")
    print(f"Target sessions: {target_sessions}")
    print(f"Train sessions: {train_sessions_target}")
    print(f"Test session: {test_session_target}")
    print(f"Test-only subjects (39+): {test_only_subjects}")

    # Define thresholds (50% to 100% in 5% increments)
    thresholds = np.arange(0.00, 1.05, 0.05)

    results = {}

    # ========================================================================
    # SCENARIO A: All data (Used to see overtrain and dataaugmentation effects)
    # ========================================================================
    print("\n=== SCENARIO A: All data ===")

    all_data = train_data_main + train_data_reserved + test_data
    pairs_a = generate_all_data_pairs(all_data, target_subject)
    predictions_a, labels_a = get_predictions(verification_model, pairs_a, device)
    results['A'] = create_confusion_matrices(predictions_a, labels_a, thresholds, "A: All data")

    # ========================================================================
    # SCENARIO B: All data except sessions 1&2 of target subject
    # ========================================================================
    print("\n=== SCENARIO B: All data except sessions 1&2 of target subject ===")

    # Remove sessions 1&2 of target subject
    filtered_data_b = []
    for d in all_data:
        if d.subject == target_subject and d.session in train_sessions_target:
            continue  # Skip sessions 1&2 of target subject
        filtered_data_b.append(d)

    pairs_b = generate_all_data_pairs(filtered_data_b, target_subject)
    predictions_b, labels_b = get_predictions(verification_model, pairs_b, device)
    results['B'] = create_confusion_matrices(predictions_b, labels_b, thresholds,
                                           "B: All data except target sessions 1&2")

    # ========================================================================
    # SCENARIO C: Non-augmented data only, excluding sessions 1&2 of target
    # ========================================================================
    print("\n=== SCENARIO C: Non-augmented data only, excluding target sessions 1&2 ===")

    # Filter to non-augmented data only (no "_aug" in ID)
    non_aug_data = [d for d in filtered_data_b if "_aug" not in str(d.id)]

    pairs_c = generate_all_data_pairs(non_aug_data, target_subject)
    predictions_c, labels_c = get_predictions(verification_model, pairs_c, device)
    results['C'] = create_confusion_matrices(predictions_c, labels_c, thresholds,
                                           "C: Non-augmented only, excluding target sessions 1&2")

    # ========================================================================
    # SCENARIO D: Only session 3 of target + test-only subjects (39+)
    # ========================================================================
    print("\n=== SCENARIO D: Target session 3 + test-only subjects (39+) ===")

    # Target subject session 3
    target_session3 = [d for d in target_all if d.session == test_session_target]

    # Test-only subjects (39+)
    test_only_data = [d for d in test_data if d.subject in test_only_subjects]

    scenario_d_data = target_session3 + test_only_data

    pairs_d = generate_all_data_pairs(scenario_d_data, target_subject)
    predictions_d, labels_d = get_predictions(verification_model, pairs_d, device)
    results['D'] = create_confusion_matrices(predictions_d, labels_d, thresholds,
                                           "D: Target session 3 + test-only subjects")

    # ========================================================================
    # SCENARIO E: Non-augmented data only - Target session 3 + test-only subjects (Results from this)
    # ========================================================================
    print("\n=== SCENARIO E: Non-augmented only - Target session 3 + test-only subjects ===")

    # Get target subject's data
    target_all = [d for d in train_data_reserved if d.subject == target_subject] + \
                 [d for d in test_data if d.subject == target_subject]

    target_sessions = sorted(list(set(d.session for d in target_all)))
    test_session_target = target_sessions[-1]    # Last session (session 3)

    # Get test-only subjects (39+)
    test_only_subjects = sorted(list(set(d.subject for d in test_data if d.subject >= 39)))

    # Target subject session 3 - NON-AUGMENTED ONLY
    target_session3_orig = [d for d in target_all
                           if d.session == test_session_target and d.trial <= 7]

    # Test-only subjects - NON-AUGMENTED ONLY
    test_only_data_orig = [d for d in test_data
                          if d.subject in test_only_subjects and d.trial <= 7]

    scenario_e_data = target_session3_orig + test_only_data_orig

    print(f"Target session 3 (original only): {len(target_session3_orig)} samples")
    print(f"Test-only subjects (original only): {len(test_only_data_orig)} samples")
    print(f"Total scenario E samples: {len(scenario_e_data)}")

    pairs_e = generate_all_data_pairs(scenario_e_data, target_subject)
    predictions_e, labels_e = get_predictions(verification_model, pairs_e, device)
    results['E'] = create_confusion_matrices(predictions_e, labels_e, thresholds,
                                           "E: Non-augmented only - Target session 3 + test-only subjects")

    return results


def generate_all_data_pairs(data_pool, target_subject):
    pairs = []

    for sample in data_pool:
        if sample.subject == target_subject:
            pairs.append((sample, 1.0))  # Genuine
        else:
            pairs.append((sample, 0.0))  # Impostor

    return pairs


def get_predictions(model, pairs, device, batch_size=64):
    dataset = VerificationDataset(pairs)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for x, y in dataloader:
            x = x.to(device)
            y_pred = model(x)

            all_predictions.append(y_pred.cpu().numpy())
            all_labels.append(y.numpy())

    predictions = np.concatenate(all_predictions).ravel()
    labels = np.concatenate(all_labels).ravel()

    return predictions, labels


def create_confusion_matrices(predictions, labels, thresholds, scenario_name):
    print(f"\n{scenario_name}")
    print(f"Total samples: {len(predictions)}")
    print(f"Genuine samples: {np.sum(labels == 1)}")
    print(f"Impostor samples: {np.sum(labels == 0)}")

    results = []

    for threshold in thresholds:
        # Convert predictions to binary decisions
        predicted_labels = (predictions >= threshold).astype(int)

        # Create confusion matrix
        cm = confusion_matrix(labels, predicted_labels, labels=[0, 1])

        plot_confusion_matrix(cm, threshold, scenario_name)

        # Extract values
        tn, fp, fn, tp = cm.ravel()

        # Calculate metrics
        accuracy = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) > 0 else 0
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

        # False Accept Rate (FAR) and False Reject Rate (FRR)
        far = fp / (fp + tn) if (fp + tn) > 0 else 0
        frr = fn / (fn + tp) if (fn + tp) > 0 else 0

        results.append({
            'Threshold': f"{threshold:.2f}",
            'TP': tp,
            'TN': tn,
            'FP': fp,
            'FN': fn,
            'Accuracy': f"{accuracy:.3f}",
            'Precision': f"{precision:.3f}",
            'Recall': f"{recall:.3f}",
            'Specificity': f"{specificity:.3f}",
            'F1': f"{f1:.3f}",
            'FAR': f"{far:.3f}",
            'FRR': f"{frr:.3f}"
        })

    # Create DataFrame and print table
    df = pd.DataFrame(results)
    print(f"\nConfusion Matrix Results for {scenario_name}:")
    print("=" * 120)
    print(df.to_string(index=False))

    return df


# Usage example:
def run_comprehensive_validation():
    results = comprehensive_validation(
        verification_model=model,
        device=device,
        target_subject=TARGET_SUBJECT,
        train_data_main=train_data_main,
        train_data_reserved=train_data_reserved,
        test_data=test_data
    )

    # Save results to files if needed
    for scenario, df in results.items():
        df.to_csv(f"validation_results_scenario_{scenario}.csv", index=False)
        print(f"\nSaved results for scenario {scenario} to validation_results_scenario_{scenario}.csv")

    return results

import matplotlib.pyplot as plt
import seaborn as sns

def plot_confusion_matrix(cm: np.ndarray, threshold: float, scenario_name: str,
                         save_dir: str = "cm_plots"):
    import os, pathlib
    pathlib.Path(save_dir).mkdir(parents=True, exist_ok=True)

    labels = ['Impostor (0)', 'Genuine (1)']
    plt.figure(figsize=(4, 3))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
               xticklabels=labels, yticklabels=labels, cbar=False)
    plt.title(f'{scenario_name}\nThreshold = {threshold:.2f}')
    plt.ylabel('True'); plt.xlabel('Predicted')
    plt.tight_layout()

    safe_name = scenario_name.replace(' ', '_').replace(':', '')
    fpath = os.path.join(save_dir, f'{safe_name}_th{threshold:.2f}.png')
    plt.savefig(fpath, dpi=150)
    plt.close()

# Run the validation
validation_results = run_comprehensive_validation()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Filter data for subject 36, trial 1, gesture 1, session 1
target_samples = [d for d in data_vector
                  if d.subject == 35
                  and d.trial == 1
                  and d.gesture == 1
                  and d.session == 1]

print(f"Found {len(target_samples)} matching samples")

if len(target_samples) > 0:
    # Take the first matching sample
    sample = target_samples[0]

    print(f"\nSample details:")
    print(f"  ID: {sample.id}")
    print(f"  Subject: {sample.subject}")
    print(f"  Gesture: {sample.gesture}")
    print(f"  Trial: {sample.trial}")
    print(f"  Session: {sample.session}")
    print(f"  Image shape: {sample.image.shape}")  # (5 windows, 14 channels, 512 freq bins)

    # Plot all 5 windows
    fig, axes = plt.subplots(5, 1, figsize=(15, 20))
    fig.suptitle(f'Subject {sample.subject}, Gesture {sample.gesture}, Trial {sample.trial}, Session {sample.session}',
                 fontsize=16, fontweight='bold')

    for window_idx in range(5):
        ax = axes[window_idx]

        # Get the FFT image for this window (14 channels x 512 freq bins)
        window_data = sample.image[window_idx, :, :]

        # Plot as heatmap
        im = ax.imshow(window_data, aspect='auto', cmap='viridis',
                      extent=[0, 512, len(ACTIVE_CHANNELS), 0])

        ax.set_title(f'Window {window_idx + 1}')
        ax.set_xlabel('Frequency Bin')
        ax.set_ylabel('Channel Index')
        ax.set_yticks(np.arange(len(ACTIVE_CHANNELS)) + 0.5)
        ax.set_yticklabels([f'Ch {ch}' for ch in ACTIVE_CHANNELS])

        # Add colorbar
        plt.colorbar(im, ax=ax, label='Log Magnitude (Normalized)')

    plt.tight_layout()
    plt.savefig('subject36_trial1_gesture1_session1.png', dpi=150, bbox_inches='tight')
    plt.show()

    # Also create a summary plot showing all windows side-by-side
    fig, axes = plt.subplots(1, 5, figsize=(25, 6))
    fig.suptitle(f'All Windows - Subject {sample.subject}, Gesture {sample.gesture}, Trial {sample.trial}, Session {sample.session}',
                 fontsize=16, fontweight='bold')

    for window_idx in range(5):
        ax = axes[window_idx]
        window_data = sample.image[window_idx, :, :]

        im = ax.imshow(window_data, aspect='auto', cmap='viridis',
                      extent=[0, 512, len(ACTIVE_CHANNELS), 0])
        ax.set_title(f'Window {window_idx + 1}')
        ax.set_xlabel('Freq Bin')
        ax.set_ylabel('Channel')

    plt.tight_layout()
    plt.savefig('subject36_trial1_gesture1_session1_summary.png', dpi=150, bbox_inches='tight')
    plt.show()

else:
    print("No matching samples found!")
    print("\nAvailable data for subject 36:")
    subj36_data = [d for d in data_vector if d.subject == 36]
    if subj36_data:
        for d in subj36_data[:10]:  # Show first 10
            print(f"  Subject {d.subject}, Gesture {d.gesture}, Trial {d.trial}, Session {d.session}")
    else:
        print("  No data found for subject 36")


In [None]:
# ----------------------------------------------------------
# Grad-CAM – FULLY DETERMINISTIC (seed everything)
# ----------------------------------------------------------
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---- FORCE DETERMINISM ----
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ---- 1. Load model ONCE ----
model = BiometricVerificationModel(
    pretrained_model_path='best_model.pth',
    dropout_rate=0.3,
    device=device
).to(device)
model.eval()

# Explicitly set all dropout/batchnorm to eval
for m in model.modules():
    if isinstance(m, (torch.nn.Dropout, torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)):
        m.eval()
        for p in m.parameters():
            p.requires_grad = False

# Unfreeze only conv4 + bn4
for p in model.conv4.parameters():
    p.requires_grad = True
for p in model.bn4.parameters():
    p.requires_grad = True

# ---- 2. GradCAM Class ----
class GradCAM:
    def __init__(self, model, layer):
        self.model = model
        self.layer = layer
        self.acts = None
        self.grads = None
        self._f = layer.register_forward_hook(self._save_act)
        self._b = layer.register_full_backward_hook(self._save_grad)

    def _save_act(self, m, inp, out):
        self.acts = out.detach().clone()

    def _save_grad(self, m, gin, gout):
        self.grads = gout[0].detach().clone()

    def __call__(self, x):
        self.model.zero_grad()
        self.acts = None
        self.grads = None

        with torch.set_grad_enabled(True):
            out = self.model(x)

        out.backward()

        g = self.grads[0]
        a = self.acts[0]
        w = g.mean(dim=(1, 2))
        cam = torch.relu(torch.einsum('c,chw->hw', w, a))

        return cam.cpu().numpy()

    def cleanup(self):
        self._f.remove()
        self._b.remove()

# ---- 3. Helper function ----
def pick_sample(s, sess, t, g):
    h = [d for d in data_vector
         if d.subject == s and d.session == sess and d.trial == t and d.gesture == g]
    return h[0] if h else None

# ---- 4. Main visualization loop ----
for subj in (35, 36, 37, 38):
    samp = pick_sample(subj, 1, 1, 1)
    if samp is None:
        print(f'⚠ No data for subject {subj}')
        continue

    # samp.image shape: (5, 14, 512) = (C, H, W)
    print(f"Subject {subj} - Original shape: {samp.image.shape}")

    # Generate Grad-CAM ONCE for the full image
    # Match your original format: (H, W, C)
    full_img_transposed = samp.image.transpose(1, 2, 0)  # (14, 512, 5)
    full_img_tensor = torch.from_numpy(full_img_transposed.copy()).unsqueeze(0).float().to(device)
    full_img_tensor.requires_grad_(True)

    print(f"Full image tensor shape: {full_img_tensor.shape}")

    # Generate heatmap for full image
    gradcam = GradCAM(model, model.conv4)
    heatmap_full = gradcam(full_img_tensor)
    gradcam.cleanup()

    # Resize heatmap to match full image dimensions (14, 512)
    heatmap_full_resized = cv2.resize(heatmap_full, (512, 14))

    # Now split into windows
    num_windows = 5
    window_width = 512 // num_windows  # 102 pixels per window

    # Create figure
    fig, axes = plt.subplots(1, 5, figsize=(20, 4))
    fig.suptitle(f'Subject {subj}', fontsize=14)

    for w, ax in enumerate(axes):
        # Extract window region from original image
        start_col = w * window_width
        end_col = start_col + window_width

        # Extract spectrogram window (C, H, W)
        spec_window = samp.image[:, :, start_col:end_col]  # (5, 14, 102)

        # Extract corresponding heatmap region
        heatmap_window = heatmap_full_resized[:, start_col:end_col]  # (14, 102)

        # Visualize
        spec_avg = spec_window.mean(axis=0)  # Average across channels: (14, 102)
        ax.imshow(spec_avg, cmap='gray', aspect='auto')
        ax.imshow(heatmap_window, cmap='Reds', aspect='auto', alpha=0.45)
        ax.set_title(f'Second {w + 1}')
        ax.axis('off')

    plt.tight_layout()
    plt.show()
    print(f'✓ Subject {subj} completed\n')


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.metrics import confusion_matrix

# -------------------------------------------------
# 1.  Re-create the raw counts from the description
# -------------------------------------------------
TP = 3
TN = 105
FP = 0
FN = 4

cm = np.array([[TN, FP],
               [FN, TP]])

# -------------------------------------------------
# 2.  Row-normalise → per-class accuracy
# -------------------------------------------------
cm_norm = cm.astype('float') / (cm.sum(axis=1)[:, None] + 1e-9)

# -------------------------------------------------
# 3.  Build percentage + count labels
# -------------------------------------------------
labels = [f"{v:.1%}\n({c})" for v, c in zip(cm_norm.ravel(), cm.ravel())]
labels = np.asarray(labels).reshape(2, 2)

# -------------------------------------------------
# 4.  Plot
# -------------------------------------------------
plt.figure(figsize=(5.5, 4))
sns.heatmap(cm_norm,
            annot=labels,
            fmt='',
            cmap='Blues',
            vmin=0,
            vmax=1,
            xticklabels=['Predicted Impostor', 'Predicted Genuine'],
            yticklabels=['Actual Impostor', 'Actual Genuine'])
plt.title('Weighted Confusion Matrix')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()

# -------------------------------------------------
# 5.  Quick summary
# -------------------------------------------------
print(f"Impostor Accuracy : {TN/(TN+FP):.1%}")
print(f"Genuine Accuracy  : {TP/(TP+FN):.1%}")
print(f"FAR               : {FP/(TN+FP):.4f}")
print(f"FRR               : {FN/(TP+FN):.4f}")


In [None]:
# ----------------------------------------------------------
# Grad-CAM – Different heatmap regions for each 1-second segment (Take 2)
# ----------------------------------------------------------
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ---- FORCE DETERMINISM ----
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# ---- 1. Load model ONCE ----
model = BiometricVerificationModel(
    pretrained_model_path='best_model.pth',
    dropout_rate=0.3,
    device=device
).to(device)
model.eval()

# Explicitly set all dropout/batchnorm to eval
for m in model.modules():
    if isinstance(m, (torch.nn.Dropout, torch.nn.BatchNorm2d, torch.nn.BatchNorm1d)):
        m.eval()
        for p in m.parameters():
            p.requires_grad = False

# Unfreeze only conv4 + bn4
for p in model.conv4.parameters():
    p.requires_grad = True
for p in model.bn4.parameters():
    p.requires_grad = True

# ---- 2. GradCAM Class ----
class GradCAM:
    def __init__(self, model, layer):
        self.model = model
        self.layer = layer
        self.acts = None
        self.grads = None
        self._f = layer.register_forward_hook(self._save_act)
        self._b = layer.register_full_backward_hook(self._save_grad)

    def _save_act(self, m, inp, out):
        self.acts = out.detach().clone()

    def _save_grad(self, m, gin, gout):
        self.grads = gout[0].detach().clone()

    def __call__(self, x):
        self.model.zero_grad()
        self.acts = None
        self.grads = None

        with torch.set_grad_enabled(True):
            out = self.model(x)

        out.backward()

        g = self.grads[0]
        a = self.acts[0]
        w = g.mean(dim=(1, 2))
        cam = torch.relu(torch.einsum('c,chw->hw', w, a))

        return cam.cpu().numpy()

    def cleanup(self):
        self._f.remove()
        self._b.remove()

# ---- 3. Helper function ----
def pick_sample(s, sess, t, g):
    h = [d for d in data_vector
         if d.subject == s and d.session == sess and d.trial == t and d.gesture == g]
    return h[0] if h else None

# ---- 4. Main visualization loop ----
for subj in (35, 36, 37, 38):
    samp = pick_sample(subj, 1, 1, 1)
    if samp is None:
        print(f'⚠ No data for subject {subj}')
        continue

    # samp.image shape: (5, 14, 512)
    print(f"Subject {subj} - samp.image shape: {samp.image.shape}")

    # Generate Grad-CAM ONCE for the full image
    full_img_transposed = samp.image.transpose(1, 2, 0)  # (14, 512, 5)
    full_img_tensor = torch.from_numpy(full_img_transposed.copy()).unsqueeze(0).float().to(device)
    full_img_tensor.requires_grad_(True)

    # Generate heatmap for full image
    gradcam = GradCAM(model, model.conv4)
    heatmap_full = gradcam(full_img_tensor)
    gradcam.cleanup()

    # Resize heatmap to match full image dimensions
    heatmap_full_resized = cv2.resize(heatmap_full, (512, 14))

    # Divide heatmap horizontally into 5 sections (one per second)
    num_windows = 5
    window_width = 512 // num_windows  # 102 pixels per window

    # Create figure: 1 row × 5 columns (one per second)
    fig, axes = plt.subplots(1, 5, figsize=(20, 4))
    fig.suptitle(f'Subject {subj}', fontsize=14)

    for sec in range(5):
        ax = axes[sec]

        # Extract channel (which IS the 1-second segment)
        spec_second = samp.image[sec, :, :]  # (14, 512)

        # Extract corresponding section of the heatmap
        start_col = sec * window_width
        end_col = start_col + window_width
        heatmap_window = heatmap_full_resized[:, start_col:end_col]  # (14, 102)

        # Visualize this second's spectrogram with its heatmap overlay
        ax.imshow(spec_second, cmap='gray', aspect='auto')
        ax.imshow(heatmap_window, cmap='Reds', aspect='auto', alpha=0.45)
        ax.set_title(f'Second {sec + 1}')
        ax.axis('off')

    plt.tight_layout()
    plt.show()
    print(f'✓ Subject {subj} completed\n')


In [None]:
import matplotlib.pyplot as plt

models = ['XGBoost', 'Random Forest', 'Autoencoder\n+ MLP', 'CNN']
f1 = [55.16, 55.01, 59.7, 82.1]

plt.figure(figsize=(6, 4))
bars = plt.bar(models, f1, color=['#4c72b0', '#55a868', '#c44e52', '#8172b2'])

# label each bar with its value
for b, v in zip(bars, f1):
    plt.text(b.get_x() + b.get_width()/2, b.get_height() + 1,
             f'{v:.1f}%', ha='center', va='bottom')

plt.ylabel('F1 score (%)')
plt.title('Model comparison')
plt.ylim(0, 100)
plt.tight_layout()
plt.show()
# plt.savefig('f1_bars.png', dpi=300)
