<a href="https://colab.research.google.com/github/kotosham/sleep-stages/blob/main/sleep_phase_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import random

class SleepPhaseDataset(Dataset):
    def __init__(self, file_paths, k):
        """
        Инициализация датасета.

        :param file_paths: Список путей к CSV файлам.
        :param k: Длина фрагмента.
        """
        self.data = []
        self.labels = []
        self.k = k

        # Загрузка данных из файлов
        for label, file_path in enumerate(file_paths):
            # Чтение и объединение всех частей файла
            df = pd.concat([pd.read_csv(part) for part in file_path], ignore_index=True)
            df = df.dropna()  # Удаляем строки с NaN
            self.data.append(df.iloc[:, :2].values)  # Первые две колонки
            self.labels.append(df.iloc[:, -1].values)  # Последняя колонка (метки)

    def __len__(self):
        return sum(len(data) for data in self.data)

    def __getitem__(self, idx):
        cumulative_length = 0
        for i, data in enumerate(self.data):
            cumulative_length += len(data)
            if idx < cumulative_length:
                data_idx = idx - (cumulative_length - len(data))
                start_idx = random.randint(0, max(0, len(data) - self.k))  # Случайный стартовый индекс
                end_idx = start_idx + self.k

                fragment = data[start_idx:end_idx]
                label = self.labels[i][data_idx]
                return torch.tensor(fragment, dtype=torch.float32), torch.tensor(label)

file_paths = [
    ['https://raw.githubusercontent.com/kotosham/sleep-stages/refs/heads/main/sleep_phase_0-1.csv',
     'https://raw.githubusercontent.com/kotosham/sleep-stages/refs/heads/main/sleep_phase_0-2.csv'],
    ['https://raw.githubusercontent.com/kotosham/sleep-stages/refs/heads/main/sleep_phase_1.csv'],
    ['https://raw.githubusercontent.com/kotosham/sleep-stages/refs/heads/main/sleep_phase_2.csv'],
    ['https://raw.githubusercontent.com/kotosham/sleep-stages/refs/heads/main/sleep_phase_3.csv']
]

dataset = SleepPhaseDataset(file_paths=file_paths, k=50)

In [20]:
from sklearn.model_selection import train_test_split

# Разделение на обучающий и тестовый наборы
train_files, test_files = train_test_split(file_paths, test_size=0.2, random_state=42)

# Дальнейшее разделение обучающего набора на валидационный и тренировочный
train_files, val_files = train_test_split(train_files, test_size=0.2, random_state=42)

# Создание DataLoader для каждого набора
train_dataset = SleepPhaseDataset(train_files, k=50)
val_dataset = SleepPhaseDataset(val_files, k=50)
test_dataset = SleepPhaseDataset(test_files, k=50)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

In [21]:
train_labels = []
for file_path in train_files:
    for part in file_path:
        df = pd.concat([pd.read_csv(part) for part in file_path], ignore_index=True)
        train_labels.extend(df['Sleep stages'].dropna().unique())
print("Unique labels in training set:", set(train_labels))

Unique labels in training set: {0, 2}


In [22]:
val_labels = []
for file_path in val_files:
    for part in file_path:
        df = pd.concat([pd.read_csv(part) for part in file_path], ignore_index=True)
        val_labels.extend(df['Sleep stages'].dropna().unique())

print("Unique labels in validation set:", set(val_labels))

Unique labels in validation set: {3}


In [23]:
test_labels = []
for file_path in test_files:
    for part in file_path:
        df = pd.concat([pd.read_csv(part) for part in file_path], ignore_index=True)
        test_labels.extend(df['Sleep stages'].dropna().unique())

print("Unique labels in test set:", set(test_labels))

Unique labels in test set: {1}


In [24]:
import torch.nn as nn
import torch.optim as optim

# Параметры
input_size = 2  # Количество признаков
num_classes = 4  # Количество классов (фаз сна)
num_epochs = 10   # Количество эпох
learning_rate = 0.0001

# Определение простой полносвязной нейронной сети
class SimpleNN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size * 50, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Инициализация модели, потерь и оптимизатора
model = SimpleNN(input_size=input_size, num_classes=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Обучение модели
for epoch in range(num_epochs):
    model.train()
    for batch_data, batch_labels in train_loader:
        optimizer.zero_grad()

        outputs = model(batch_data.view(batch_data.size(0), -1))
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}')

    # Оценка модели на валидационном наборе после каждой эпохи
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for val_data, val_labels in val_loader:
            outputs = model(val_data.view(val_data.size(0), -1))
            _, predicted = torch.max(outputs.data, 1)
            total += val_labels.size(0)
            correct += (predicted == val_labels).sum().item()

        print(f'Validation Accuracy: {100 * correct / total:.2f}%')

# Оценка модели на тестовом наборе после завершения обучения
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for test_data, test_labels in test_loader:
        outputs = model(test_data.view(test_data.size(0), -1))
        _, predicted = torch.max(outputs.data, 1)
        total += test_labels.size(0)
        correct += (predicted == test_labels).sum().item()

    print(f'Test Accuracy: {100 * correct / total:.2f}%')

KeyboardInterrupt: 