In [143]:
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

from utils.utils import *
from utils.dataloader import distillDataset, distill
from helper_train import distill_train

SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)
rng = np.random.RandomState(SEED)

from tqdm import tqdm

EXPERIMENT_NAME = 'supervised_hmc'
MODALITY = 'ecg'
DO_KFOLD = False
BATCH_SIZE = 256
EPOCH_LEN = 7

DATASET_PATH = '/scratch/hmc'
DATASET_SUBJECTS = os.listdir(os.path.join(DATASET_PATH, 'subjects_data'))
SAVE_PATH = './saved_weights'

if not os.path.exists(SAVE_PATH):
    os.makedirs(SAVE_PATH, exist_ok=True)



DATASET_SUBJECTS.sort(key=natural_keys)
DATASET_SUBJECTS = [os.path.join(DATASET_PATH, 'subjects_data', f) for f in DATASET_SUBJECTS]
dataset_subjects_data = [np.load(f) for f in DATASET_SUBJECTS]


sub_train = rng.choice(DATASET_SUBJECTS, int(len(DATASET_SUBJECTS)*0.8), replace=False)
sub_test = sorted(list(set(DATASET_SUBJECTS) - set(sub_train)))

print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
print(f"Train: {len(sub_train)} \n")
print(f"Test: {len(sub_test)} \n")
print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

TRAIN_FILES = [np.load(f) for f in sub_train]
TEST_FILES = [np.load(f) for f in sub_test]

# load files
TRAIN_PATH = os.path.join(DATASET_PATH, 'train_ecg') 
TEST_PATH = os.path.join(DATASET_PATH, 'test_ecg')

if not os.path.exists(TRAIN_PATH):
    os.makedirs(TRAIN_PATH, exist_ok=True)

if not os.path.exists(TEST_PATH):
    os.makedirs(TEST_PATH, exist_ok=True)

half_window = 3

cnt = 0
for file in tqdm(TRAIN_FILES):
    x_dat = file[MODALITY]
    y_dat = file["y"].astype('int')

    for i in range(half_window,x_dat.shape[0]-half_window):
        dct = {}
        temp_path = os.path.join(TRAIN_PATH, str(cnt)+".npz")
        dct['X'] = x_dat[i-half_window:i+half_window+1]
        dct['y'] = y_dat[i-half_window:i+half_window+1]
        np.savez(temp_path,**dct)
        cnt+=1

cnt = 0
for file in tqdm(TEST_FILES):
    x_dat = file[MODALITY]
    y_dat = file["y"].astype('int')

    for i in range(half_window,x_dat.shape[0]-half_window):
        dct = {}
        temp_path = os.path.join(TEST_PATH, str(cnt)+".npz")
        dct['X'] = x_dat[i-half_window:i+half_window+1]
        dct['y'] = y_dat[i-half_window:i+half_window+1]
        np.savez(temp_path,**dct)
        cnt+=1


Extracting EDF parameters from /scratch/hmc/recordings/SN001.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6566399  =      0.000 ... 25649.996 secs...
Extracting EDF parameters from /scratch/hmc/recordings/SN002.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 6577407  =      0.000 ... 25692.996 secs...
Extracting EDF parameters from /scratch/hmc/recordings/SN003.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 7330815  =      0.000 ... 28635.996 secs...
Extracting EDF parameters from /scratch/hmc/recordings/SN004.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 7804927  =      0.000 ... 30487.996 secs...
Extracting EDF parameters from /scratch/hmc/recordings/SN005.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ...

In [146]:
for windows_subject in windows_dataset.datasets:
    print(windows_subject)
    break

(array([[ 1.7087381e-05,  1.4344381e-05,  1.1525641e-06, ...,
         1.2500958e-05,  2.5288666e-05,  2.6783413e-05],
       [ 1.4664362e-05,  7.3353440e-06,  2.4618416e-06, ...,
         4.4581006e-07,  1.2948366e-05,  1.7326725e-05],
       [ 3.8281894e-05,  2.0248757e-05,  2.4458244e-05, ...,
         7.9606389e-06,  1.8057051e-05,  2.1296686e-05],
       [ 5.8028559e-06,  6.9939942e-06, -2.9917583e-06, ...,
        -1.0429584e-05, -8.3001996e-06, -1.2006130e-06]], dtype=float32), 0, [0, 0, 3000])


In [174]:
a = windows_dataset.datasets[0]


for i in a.windows:
    print(i.shape)
    break

(4, 3000)


In [17]:
import os, shutil, re
import torch
import numpy as np
import argparse
from tqdm import tqdm


SEED = 1234
np.random.seed(SEED)
rng = np.random.RandomState(SEED)

# ARGS
HALF_WINDOW = 3

args = {}
args['dir'] = r'C:\Users\likit\Downloads\shhs'

def atoi(text):
    return int(text) if text.isdigit() else text

def natural_keys(text):
    return [ atoi(c) for c in re.split(r'(\d+)', text) ]

DATASET_SUBJECTS = sorted(os.listdir(os.path.join(args['dir'], 'subjects_data')))
DATASET_SUBJECTS = [os.path.join(args['dir'], 'subjects_data', f) for f in DATASET_SUBJECTS]
TRAIN_PATH = os.path.join(args['dir'], f'train_{HALF_WINDOW}') 
TEST_PATH = os.path.join(args['dir'], f'test_{HALF_WINDOW}')

dataset_subjects_data = [np.load(f) for f in DATASET_SUBJECTS]
sub_train = rng.choice(DATASET_SUBJECTS, int(len(DATASET_SUBJECTS)*0.8), replace=False)
sub_test = list(set(DATASET_SUBJECTS) - set(sub_train))





In [21]:
sub_train

array(['C:\\Users\\likit\\Downloads\\shhs\\subjects_data\\201566.npz',
       'C:\\Users\\likit\\Downloads\\shhs\\subjects_data\\201329.npz',
       'C:\\Users\\likit\\Downloads\\shhs\\subjects_data\\201581.npz',
       'C:\\Users\\likit\\Downloads\\shhs\\subjects_data\\201586.npz',
       'C:\\Users\\likit\\Downloads\\shhs\\subjects_data\\200017.npz',
       'C:\\Users\\likit\\Downloads\\shhs\\subjects_data\\200010.npz',
       'C:\\Users\\likit\\Downloads\\shhs\\subjects_data\\201560.npz',
       'C:\\Users\\likit\\Downloads\\shhs\\subjects_data\\201371.npz'],
      dtype='<U54')

In [20]:
import wandb, os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

from utils.utils import *
from utils.dataloader import distillDataset, distill
from helper_train import distill_train

SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(SEED)
rng = np.random.RandomState(SEED)

from tqdm import tqdm

EXPERIMENT_NAME = 'supervised_hmc'
MODALITY = 'eeg'
DO_KFOLD = False
BATCH_SIZE = 256
EPOCH_LEN = 7

DATASET_PATH = '/scratch/hmc'
DATASET_SUBJECTS = os.listdir(os.path.join(DATASET_PATH, 'subjects_data'))
SAVE_PATH = './saved_weights'


wandb = wandb.init(
    project="distillECG",
    name=EXPERIMENT_NAME,
    save_code=False,
    entity="sleep-staging",
)

if not os.path.exists(SAVE_PATH):
    os.makedirs(SAVE_PATH, exist_ok=True)

wandb.save("./supervised/utils/*")
wandb.save("./supervised/models/*")
wandb.save("./supervised/helper_train.py")
wandb.save("./supervised/train.py")

DATASET_SUBJECTS.sort(key=natural_keys)
DATASET_SUBJECTS = [os.path.join(DATASET_PATH, 'subjects_data', f) for f in DATASET_SUBJECTS]
dataset_subjects_data = [np.load(f) for f in DATASET_SUBJECTS]


sub_train = rng.choice(DATASET_SUBJECTS, int(len(DATASET_SUBJECTS)*0.8), replace=False)
sub_test = sorted(list(set(DATASET_SUBJECTS) - set(sub_train)))

print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
print(f"Train: {len(sub_train)} \n")
print(f"Test: {len(sub_test)} \n")
print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

TRAIN_FILES = [np.load(f) for f in sub_train]
TEST_FILES = [np.load(f) for f in sub_test]

# load files
TRAIN_PATH = os.path.join(DATASET_PATH, 'train_eeg') 
TEST_PATH = os.path.join(DATASET_PATH, 'test_eeg')



TRAIN_EPOCH_FILES = [os.path.join(TRAIN_PATH, f) for f in os.listdir(TRAIN_PATH)]
TEST_EPOCH_FILES = [os.path.join(TEST_PATH, f) for f in os.listdir(TEST_PATH)]

train_dataset = distill(TRAIN_EPOCH_FILES)
test_dataset = distill(TEST_EPOCH_FILES)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=8,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=8,
)



['C:\\Users\\likit\\Downloads\\shhs\\subjects_data\\201359.npz',
 'C:\\Users\\likit\\Downloads\\shhs\\subjects_data\\201515.npz',
 'C:\\Users\\likit\\Downloads\\shhs\\subjects_data\\201552.npz']

In [116]:
from torch.utils.data import DataLoader, Dataset
from braindecode.datautil.windowers import create_windows_from_events, create_fixed_length_windows

windows_dataset = create_fixed_length_windows(
    dataset,
    window_size_samples=window_size_samples * 7,
    window_stride_samples=window_size_samples,
    drop_last_window=True,
    # mapping=mapping,
)


dl = DataLoader(
    windows_dataset,
    batch_size=4,
    shuffle=False,
)

Adding metadata with 4 columns
849 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 849 events and 21000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
850 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 850 events and 21000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
948 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 948 events and 21000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
1010 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 1010 events and 21000 original time points ...
0 bad epochs dropped
Adding metadata with 4 columns
953 matching events found
No baseline correction applied
0 projection items activated
Using dat

In [118]:
windows_dataset

Using data from preloaded Raw for 1 events and 21000 original time points ...


(4, 21000)

In [141]:
class test(Dataset):
    def __init__(self):
        self.x =0
    def __getitem__(self, index):
        return index
    def __len__(self):
        return 16*10

dll = DataLoader(test(), batch_size=16, shuffle=True)

a = []
for i in dll:
    a.extend(list(i.tolist()))
    
a.sort()
a

[0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159]

In [102]:
annots.description

array(['Sleep stage W', 'Sleep stage W', 'Lights off@@EEG F4-A1',
       'Sleep stage W', 'Sleep stage W', 'Sleep stage W', 'Sleep stage W',
       'Sleep stage W', 'Sleep stage W', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N2', 'Sleep stage N1',
       'Sleep stage N2', 'Sleep stage N2', 'Sleep stage N2',
       'Sleep stage N2', 'Sleep stage N2', 'Sleep stage N2',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage W',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N2', 'Sleep stage N1', 'Sleep stage W',
       'Sleep stage W', 'Sleep stage W', 'Sleep stage W',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Sleep stage N1', 'Sleep stage N1', 'Sleep stage N1',
       'Slee