In [1]:
import os
import glob
from pathlib import Path
from typing import Tuple

import h5py
import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

In [26]:
def normalize_data(data: np.ndarray, scale_range: Tuple=None):
    """Normalize data to range [a, b]"""
    # scale to range [0, 1] first
    new_data = (data - data.min())/(data.max() - data.min())
    # scale to range [a, b]
    if scale_range is not None:
        a, b = scale_range
        assert a<=b, f'Invalid range: {scale_range}'
        new_data = (b-a)*new_data + a
        
    return new_data

In [35]:
data_dir = Path('/data2/quan/Datasets/EEG2fMRI/Kris/')

In [45]:
# read data
data_name = '02_eeg_fmri_data.h5'

with h5py.File(data_dir/data_name, 'r') as f:
    eeg_train = np.array(f['eeg_train'][:])
    fmri_train = np.array(f['fmri_train'][:])
    eeg_test = np.array(f['eeg_test'][:])
    fmri_test = np.array(f['fmri_test'][:])

In [46]:
eeg_train.shape

(2041, 43, 269, 10, 1)

In [47]:
eeg_train[:, :, :, :, 0].shape

(2041, 43, 269, 10)

In [48]:
fmri_train.shape

(2041, 64, 64, 32, 1)

In [49]:
eeg_test.shape

(628, 43, 269, 10, 1)

In [50]:
fmri_test.shape

(628, 64, 64, 32, 1)

### Save new h5 file

In [51]:
hf = h5py.File(data_dir/'02_eeg_fmri_data_new.h5', 'w')

In [52]:
hf.create_dataset('eeg_train', data=eeg_train[:, :, :, :, 0], compression="gzip", compression_opts=9)
hf.create_dataset('fmri_train', data=fmri_train[:, :, :, :, 0], compression="gzip", compression_opts=9)

hf.create_dataset('eeg_test', data=eeg_test[:, :, :, :, 0], compression="gzip", compression_opts=9)
hf.create_dataset('fmri_test', data=fmri_test[:, :, :, :, 0], compression="gzip", compression_opts=9)

hf.close()

### Test read new data

In [40]:
# read data
data_name = '01_eeg_fmri_data_new.h5'

with h5py.File(data_dir/data_name, 'r') as f:
    eeg_train = np.array(f['eeg_train'][:])
    fmri_train = np.array(f['fmri_train'][:])
    eeg_test = np.array(f['eeg_test'][:])
    fmri_test = np.array(f['fmri_test'][:])

In [44]:
fmri_test.shape

(861, 64, 64, 30)