In [1]:
# Cài đặt các thư viện cần thiết
!pip install mne wandb edfio pandas -q

In [2]:
import mne
import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import numpy as np
import pandas as pd
from mne.io import read_raw_edf, RawArray
from mne.annotations import read_annotations
from torch.utils.data import Dataset, DataLoader
import random
import torch.optim as optim
from tqdm import tqdm
import wandb
from datetime import datetime
import warnings
import shutil

warnings.filterwarnings("ignore", category=RuntimeWarning)

DATA_DIRS = [
    "/kaggle/input/sleep-edf-and-apnea/sleep-edf-database-expanded-1.0.0/sleep-edf-database-expanded-1.0.0/sleep-cassette/",
    "/kaggle/input/sleep-edf-and-apnea/sleep-edf-database-expanded-1.0.0/sleep-edf-database-expanded-1.0.0/sleep-telemetry/"
]
EXPORT_DIR = "/kaggle/working/unseen_test_data"
EPOCH_SEC = 30
SAMPLE_RATE = 100
EPOCH_LENGTH = EPOCH_SEC * SAMPLE_RATE
CHANNELS_TO_TRAIN = ['EEG Fpz-Cz', 'EOG horizontal', 'EMG submental']
CHANNELS_TO_EXPORT = ['EEG Fpz-Cz', 'EOG horizontal', 'EMG submental']
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

stage_map = {
    "Sleep stage W": 0, "Sleep stage 1": 1, "Sleep stage 2": 2,
    "Sleep stage 3": 3, "Sleep stage 4": 3, "Sleep stage R": 4,
    "Sleep stage ?": -1, "Movement time": -1,
}

wandb.login(key="5977d6c3b044eb3d92080d4075b5683327a497ac")

Using device: cpu


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/codespace/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnamthse182380[0m ([33mnamthse182380-fpt-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
class DeepSleepNetBranch(nn.Module):
    def __init__(self):
        super(DeepSleepNetBranch, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv1d(1, 64, kernel_size=50, stride=6, padding=22),
            nn.BatchNorm1d(64), nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4), nn.Dropout(0.5),
            nn.Conv1d(64, 128, kernel_size=8, padding=4),
            nn.BatchNorm1d(128), nn.ReLU(),
            nn.Conv1d(128, 128, kernel_size=8, padding=4),
            nn.BatchNorm1d(128), nn.ReLU(),
            nn.Conv1d(128, 128, kernel_size=8, padding=4),
            nn.BatchNorm1d(128), nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4), nn.Dropout(0.5)
        )
        self.conv2 = nn.Sequential(
            nn.Conv1d(1, 64, kernel_size=400, stride=50, padding=200),
            nn.BatchNorm1d(64), nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4), nn.Dropout(0.5),
            nn.Conv1d(64, 128, kernel_size=6, padding=3),
            nn.BatchNorm1d(128), nn.ReLU(),
            nn.Conv1d(128, 128, kernel_size=6, padding=3),
            nn.BatchNorm1d(128), nn.ReLU(),
            nn.Conv1d(128, 128, kernel_size=6, padding=3),
            nn.BatchNorm1d(128), nn.ReLU(),
            nn.MaxPool1d(kernel_size=2, stride=2), nn.Dropout(0.5)
        )
        self.lstm = nn.LSTM(input_size=128, hidden_size=512, num_layers=2, batch_first=True, bidirectional=True)

    def forward(self, x):
        x1 = self.conv1(x)
        x2 = self.conv2(x)
        min_len = min(x1.shape[2], x2.shape[2])
        x1 = x1[:, :, :min_len]
        x2 = x2[:, :, :min_len]
        x = x1 + x2
        x = x.permute(0, 2, 1)
        lstm_out, _ = self.lstm(x)
        embedding = lstm_out[:, -1, :]
        return embedding

class MultiChannelDeepSleepNet(nn.Module):
    def __init__(self, num_classes=5):
        super(MultiChannelDeepSleepNet, self).__init__()
        self.eeg_branch = DeepSleepNetBranch()
        self.eog_branch = DeepSleepNetBranch()
        self.emg_branch = DeepSleepNetBranch()
        self.fusion_dim = 1024 * 3
        self.classifier = nn.Sequential(
            nn.Linear(self.fusion_dim, 512), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        eeg_embedding = self.eeg_branch(x[:, 0:1, :])
        eog_embedding = self.eog_branch(x[:, 1:2, :])
        emg_embedding = self.emg_branch(x[:, 2:3, :])
        fused_embedding = torch.cat([eeg_embedding, eog_embedding, emg_embedding], dim=1)
        output = self.classifier(fused_embedding)
        return output

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

class TransformerBranch(nn.Module):
    def __init__(self, input_size=3000, patch_size=30, d_model=128, num_heads=8, num_layers=4):
        super(TransformerBranch, self).__init__()
        self.patch_size = patch_size
        self.d_model = d_model
        if input_size % patch_size != 0:
            raise ValueError(f"input_size ({input_size}) must be divisible by patch_size ({patch_size}).")
        self.n_patches = input_size // patch_size
        self.embedding = nn.Linear(patch_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model=d_model, max_len=self.n_patches)
        encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=256, dropout=0.1, activation='relu', batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)

    def forward(self, x):
        x = x.unfold(dimension=2, size=self.patch_size, step=self.patch_size)
        x = x.squeeze(1)
        x = self.embedding(x)
        x = self.pos_encoder(x)
        x = self.transformer_encoder(x)
        x = x.mean(dim=1)
        return x

class MultiChannelSleepTransformer(nn.Module):
    def __init__(self, input_size=3000, patch_size=30, d_model=128, num_heads=8, num_layers=4, num_classes=5):
        super(MultiChannelSleepTransformer, self).__init__()
        self.eeg_branch = TransformerBranch(input_size, patch_size, d_model, num_heads, num_layers)
        self.eog_branch = TransformerBranch(input_size, patch_size, d_model, num_heads, num_layers)
        self.emg_branch = TransformerBranch(input_size, patch_size, d_model, num_heads, num_layers)
        self.fusion_dim = d_model * 3
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.fusion_dim),
            nn.Linear(self.fusion_dim, 256), nn.ReLU(), nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        eeg_embedding = self.eeg_branch(x[:, 0:1, :])
        eog_embedding = self.eog_branch(x[:, 1:2, :])
        emg_embedding = self.emg_branch(x[:, 2:3, :])
        fused_embedding = torch.cat([eeg_embedding, eog_embedding, emg_embedding], dim=1)
        output = self.classifier(fused_embedding)
        return output

In [4]:
def get_hypnogram_path(data_dir, subject_id):
    hyp_path_c = os.path.join(data_dir, f"{subject_id}C-Hypnogram.edf")
    hyp_path_p = os.path.join(data_dir, f"{subject_id}P-Hypnogram.edf")
    if os.path.exists(hyp_path_c):
        return hyp_path_c
    elif os.path.exists(hyp_path_p):
        return hyp_path_p
    return None

class SleepDataset(Dataset):
    def __init__(self, subject_infos, augment=False):
        self.subject_infos = subject_infos
        self.augment = augment
        all_epochs_data = []
        all_epochs_labels = []

        for sid, data_dir in tqdm(self.subject_infos, desc="Loading subjects"):
            psg_path = os.path.join(data_dir, f"{sid}0-PSG.edf")
            hyp_path = get_hypnogram_path(data_dir, sid)
            
            if not os.path.exists(psg_path) or not hyp_path:
                print(f"Warning: File pair not found for subject {sid} in {data_dir}. Skipping.")
                continue
            try:
                epochs_data, epochs_labels = self.extract_epochs(psg_path, hyp_path, channels=CHANNELS_TO_TRAIN)
                if epochs_data is not None and epochs_data.nelement() > 0:
                    all_epochs_data.append(epochs_data)
                    all_epochs_labels.append(epochs_labels)
            except Exception as e:
                print(f"[ERROR] Failed for subject {sid}: {e}")
                continue
        
        if not all_epochs_data:
            self.data = torch.tensor([])
            self.labels = torch.tensor([])
        else:
            self.data = torch.cat(all_epochs_data, dim=0)
            self.labels = torch.cat(all_epochs_labels, dim=0)

    def extract_epochs(self, psg_path, hyp_path, channels):
        raw = read_raw_edf(psg_path, preload=True, verbose=False)
        annot = read_annotations(hyp_path)
        
        available_channels = [ch for ch in channels if ch in raw.ch_names]
        if len(available_channels) != len(channels): return None, None
            
        raw.pick(available_channels)
        raw.reorder_channels(channels)
        raw.resample(sfreq=SAMPLE_RATE, verbose=False)
        raw.set_annotations(annot, emit_warning=False)
        
        present_annotations = set(annot.description)
        file_event_id = {key: val for key, val in stage_map.items() if key in present_annotations}
        if not file_event_id: return None, None
        
        try:
            events, _ = mne.events_from_annotations(raw, event_id=file_event_id, chunk_duration=EPOCH_SEC, verbose=False)
        except ValueError: return None, None
        
        valid_events_mask = events[:, -1] >= 0
        events = events[valid_events_mask]
        if len(events) == 0: return None, None
        
        clean_event_id = {key: val for key, val in file_event_id.items() if val >= 0}
        if not clean_event_id: return None, None
        
        tmax = EPOCH_SEC - 1. / raw.info['sfreq']
        epochs = mne.Epochs(raw=raw, events=events, event_id=clean_event_id, tmin=0., tmax=tmax, proj=False, baseline=None, preload=True, verbose=False)
        data = epochs.get_data(copy=False)
        data = (data - data.mean(axis=2, keepdims=True)) / (data.std(axis=2, keepdims=True) + 1e-8)
        labels = epochs.events[:, -1]
        return torch.from_numpy(data).float(), torch.from_numpy(labels).long()

    def __len__(self):
        return self.labels.shape[0] if self.labels is not None else 0

    def __getitem__(self, idx):
        x = self.data[idx]
        y = self.labels[idx]
        if self.augment: x = self.apply_augmentation(x)
        return x, y

    def apply_augmentation(self, x):
        if random.random() > 0.5:
            noise_std = random.uniform(0.01, 0.1)
            noise = torch.randn_like(x) * noise_std
            x = x + noise
        if random.random() > 0.5:
            shift = random.randint(-150, 150)
            x = torch.roll(x, shifts=shift, dims=-1)
        return x

def split_subjects(data_dirs, seed=SEED):
    valid_subjects_info = []
    
    for data_dir in data_dirs:
        if not os.path.exists(data_dir):
            print(f"Warning: Directory not found, skipping: {data_dir}")
            continue
        psg_files = sorted([f for f in os.listdir(data_dir) if f.endswith("PSG.edf")])
        for psg_file in psg_files:
            subject_id = psg_file[:7]
            if get_hypnogram_path(data_dir, subject_id):
                valid_subjects_info.append((subject_id, data_dir))

    valid_subjects_info = sorted(list(set(valid_subjects_info)))
    print(f"Found {len(valid_subjects_info)} valid subjects across all directories.")
    
    random.seed(seed)
    random.shuffle(valid_subjects_info)
    
    n = len(valid_subjects_info)
    n_unseen = int(n * 0.2)
    n_val = int(n * 0.1)
    n_test = int(n * 0.1)
    
    unseen_info = valid_subjects_info[:n_unseen]
    val_info = valid_subjects_info[n_unseen : n_unseen + n_val]
    test_info = valid_subjects_info[n_unseen + n_val : n_unseen + n_val + n_test]
    train_info = valid_subjects_info[n_unseen + n_val + n_test :]
    
    print(f"Split: Train={len(train_info)}, Val={len(val_info)}, Test={len(test_info)}, Unseen for App={len(unseen_info)}")
    return train_info, val_info, test_info, unseen_info
    
def get_dataloaders(train_infos, val_infos, batch_size=64, augment_train=True):
    print("\n--- Loading Training Data ---")
    train_set = SleepDataset(train_infos, augment=augment_train)
    print("\n--- Loading Validation Data ---")
    val_set = SleepDataset(val_infos, augment=False)
    if len(train_set) == 0 or len(val_set) == 0:
        raise ValueError("Training or validation dataset is empty.")
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, val_loader

In [5]:
def export_unseen_data(subject_infos, export_dir, channels_to_export, epochs_per_stage=20, max_total_epochs=100):
    if os.path.exists(export_dir):
        shutil.rmtree(export_dir)
    os.makedirs(export_dir, exist_ok=True)
    print(f"\n--- Exporting a BALANCED unseen dataset to '{export_dir}' ---")

    label_to_stage_map = {0: "W", 1: "N1", 2: "N2", 3: "N3", 4: "REM"}
    all_possible_epochs = []
    
    print("Step 1: Scanning all unseen subjects to gather epoch information...")
    for sid, data_dir in tqdm(subject_infos, desc="Scanning Subjects"):
        psg_path = os.path.join(data_dir, f"{sid}0-PSG.edf")
        hyp_path = get_hypnogram_path(data_dir, sid)
        if not os.path.exists(psg_path) or not hyp_path: continue

        try:
            raw = mne.io.read_raw_edf(psg_path, preload=False, verbose=False)
            annot = mne.read_annotations(hyp_path) 
            raw.set_annotations(annot, emit_warning=False, verbose=False)

            present_annotations = set(annot.description)
            file_event_id = {key: val for key, val in stage_map.items() if key in present_annotations and val >= 0}
            if not file_event_id: continue

            events, _ = mne.events_from_annotations(raw, event_id=file_event_id, chunk_duration=EPOCH_SEC, verbose=False)
            
            for i in range(len(events)):
                all_possible_epochs.append({
                    'sid': sid, 'data_dir': data_dir, 'event_idx': i, 'label': events[i, -1]
                })
        except Exception as e:
            print(f"Warning: Could not process {sid}. Error: {e}")
            
    if not all_possible_epochs:
        print("Error: No valid epochs found in the unseen set. Halting export.")
        return

    print(f"\nStep 2: Found {len(all_possible_epochs)} total epochs. Creating a balanced sample...")
    df = pd.DataFrame(all_possible_epochs)
    
    balanced_sample_df = df.groupby('label', group_keys=False).apply(lambda x: x.sample(min(len(x), epochs_per_stage), random_state=SEED))
    
    if len(balanced_sample_df) > max_total_epochs:
        balanced_sample_df = balanced_sample_df.sample(max_total_epochs, random_state=SEED)

    print(f"Sampled {len(balanced_sample_df)} epochs to export. Stage distribution:\n{balanced_sample_df['label'].map(label_to_stage_map).value_counts()}")

    print("\nStep 3: Exporting the balanced sample...")
    exported_count = 0
    
    for sid_group, group_df in tqdm(balanced_sample_df.groupby('sid'), desc="Exporting Subjects"):
        sid = sid_group
        data_dir = group_df['data_dir'].iloc[0]
        psg_path = os.path.join(data_dir, f"{sid}0-PSG.edf")
        hyp_path = get_hypnogram_path(data_dir, sid)
        
        try:
            raw = mne.io.read_raw_edf(psg_path, preload=True, verbose=False)
            annot = mne.read_annotations(hyp_path)
            
            raw.pick(channels_to_export)
            raw.reorder_channels(channels_to_export)
            raw.resample(sfreq=SAMPLE_RATE, verbose=False)
            raw.set_annotations(annot, emit_warning=False, verbose=False)
            
            file_event_id = {key: val for key, val in stage_map.items() if key in set(annot.description) and val >= 0}
            events, _ = mne.events_from_annotations(raw, event_id=file_event_id, chunk_duration=EPOCH_SEC, verbose=False)
            
            tmax = EPOCH_SEC - 1. / raw.info['sfreq']
            epochs = mne.Epochs(raw=raw, events=events, event_id=file_event_id, tmin=0., tmax=tmax, baseline=None, preload=True, verbose=False)
            
            for _, row in group_df.iterrows():
                event_sample = events[row['event_idx']][0]
                epoch_index_in_mne_object = np.where(epochs.events[:, 0] == event_sample)[0][0]
                
                epoch_data = epochs[epoch_index_in_mne_object].get_data(copy=True)[0]
                
                stage_label_str = label_to_stage_map.get(row['label'], "UNKNOWN")
                info = mne.create_info(ch_names=channels_to_export, sfreq=SAMPLE_RATE, ch_types='eeg')
                raw_epoch = RawArray(epoch_data.reshape(len(channels_to_export), -1), info, verbose=False)

                export_filename = f"{sid}_epoch_{exported_count:03d}_label_{stage_label_str}.edf"
                export_filepath = os.path.join(export_dir, export_filename)
                raw_epoch.export(export_filepath, fmt='edf', overwrite=True, verbose=False)
                exported_count += 1
        except Exception as e:
            print(f"Warning: Skipped exporting subject {sid}. Error: {e}")

    print(f"✅ Exported {exported_count} balanced epoch files to '{export_dir}'")

In [6]:
def train(model, train_loader, val_loader, device, num_epochs=50, lr=1e-4):
    model_name = model.__class__.__name__
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    run_name = f"{model_name}_{timestamp}"
    wandb.init(
        project="sleep-stage-classification-multichannel", name=run_name,
        config={
            "model": model_name, "epochs": num_epochs, "learning_rate": lr,
            "batch_size": train_loader.batch_size, "optimizer": "AdamW",
            "scheduler": "ReduceLROnPlateau", "architecture": str(model),
        }
    )
    model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-2)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)
    best_val_acc = 0
    for epoch in range(num_epochs):
        model.train()
        train_loss, train_correct = 0.0, 0
        loop = tqdm(train_loader, desc=f"Train {model_name} Epoch {epoch+1}/{num_epochs}", leave=True)
        for x, y in loop:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * x.size(0)
            preds = outputs.argmax(dim=1)
            train_correct += (preds == y).sum().item()
        
        train_samples = len(train_loader.dataset)
        train_loss /= train_samples
        train_acc = train_correct / train_samples

        val_loss, val_acc = evaluate(model, val_loader, criterion, device)
        scheduler.step(val_acc)
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Acc={train_acc:.4f} | Val Loss={val_loss:.4f}, Acc={val_acc:.4f}")
        wandb.log({
            "epoch": epoch + 1, "train_loss": train_loss, "train_acc": train_acc,
            "val_loss": val_loss, "val_acc": val_acc, "lr": optimizer.param_groups[0]['lr']
        })
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            model_save_path = f"{run_name}_best.pt"
            torch.save(model.state_dict(), model_save_path)
            print(f"Saved best model to {model_save_path} with Val Acc: {best_val_acc:.4f}")
            wandb.run.summary["best_val_acc"] = best_val_acc
            wandb.save(model_save_path)
    wandb.finish()
    print(f"Training for {model_name} finished.")
    
def evaluate(model, loader, criterion, device):
    model.eval()
    loss_total, correct = 0.0, 0
    total_samples = len(loader.dataset)
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            outputs = model(x)
            loss = criterion(outputs, y)
            loss_total += loss.item() * x.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == y).sum().item()
    loss_avg = loss_total / total_samples
    acc = correct / total_samples
    return loss_avg, acc

In [None]:
try:
    print("--- 1. Splitting Subjects from All Available Directories ---")
    train_info, val_info, test_info, unseen_info = split_subjects(data_dirs=DATA_DIRS)
    
    print("\n--- 2. Exporting Unseen Data for App ---")
    export_unseen_data(
        subject_infos=unseen_info,
        export_dir=EXPORT_DIR,
        channels_to_export=CHANNELS_TO_EXPORT
    )

    print(f"\n--- 2b. Zipping the unseen data directory ---")
    if os.path.exists(EXPORT_DIR) and os.listdir(EXPORT_DIR):
        shutil.make_archive(
            base_name=EXPORT_DIR,
            format='zip',
            root_dir=EXPORT_DIR
        )
        print(f"Successfully created zip file: {EXPORT_DIR}.zip")
    else:
        print("Export directory is empty or does not exist. Skipping zip creation.")

    print("\n--- 3. Preparing Dataloaders ---")
    BATCH_SIZE = 128
    train_loader, val_loader = get_dataloaders(
        train_infos=train_info,
        val_infos=val_info,
        batch_size=BATCH_SIZE
    )

    print("\n--- 4a. Training MultiChannelDeepSleepNet ---")
    model_deepsleepnet = MultiChannelDeepSleepNet(num_classes=5)
    train(model_deepsleepnet, train_loader, val_loader, device=device, num_epochs=50)

    print("\n--- 4b. Training MultiChannelSleepTransformer ---")
    model_sleeptransformer = MultiChannelSleepTransformer(
        input_size=EPOCH_LENGTH, patch_size=30, d_model=128, num_heads=8, num_layers=6, num_classes=5
    )
    train(model_sleeptransformer, train_loader, val_loader, device=device, num_epochs=50)
    
    print("\n--- All training and exporting processes finished successfully! ---")
    print(f"Remember to 'Save Version' -> 'Save & Run All (Commit)' to persist the exported data.")
    print(f"The zipped data file can be found at: '{EXPORT_DIR}.zip'")


except ValueError as e:
    print(f"\nExecution stopped due to a data error: {e}")
except Exception as e:
    print(f"\nAn unexpected error occurred: {e}")