In [17]:
import mne
import os
import s3fs
from dask.distributed import Client
from mne.time_frequency import psd_multitaper
from dask_cloudprovider.aws import FargateCluster
from dask import delayed
import numpy as np
from dask.distributed import performance_report

# Reduce Verbosity
mne.set_log_level('WARNING')
# Numbers of subjects to be processed in parallel

In [18]:
n_workers = 2
worker_mem = 30720

os.environ["AWS_DEFAULT_REGION"] = "us-east-2"

cluster = FargateCluster(
    image="daskdev/dask:latest", #Base Docker Image to use
    worker_mem=worker_mem,
    n_workers=n_workers,
    fargate_use_private_ip=False,
    scheduler_timeout="15 minutes",
    environment={
        "EXTRA_PIP_PACKAGES": "dask-ml==1.6.0 scikit-learn==0.23.2 s3fs mne bokeh"
    }
)
client = Client(cluster)

In [20]:
client

0,1
Client  Scheduler: tcp://3.138.124.15:8786  Dashboard: http://3.138.124.15:8787/status,Cluster  Workers: 2  Cores: 8  Memory: 60.00 GB


In [21]:
fs = s3fs.S3FileSystem(anon=False, key='**********', secret='****************')
files = fs.ls("mnedask")
print(files)

['mnedask/rest1.fif', 'mnedask/rest10.fif', 'mnedask/rest11.fif', 'mnedask/rest12.fif', 'mnedask/rest13.fif', 'mnedask/rest14.fif', 'mnedask/rest16.fif', 'mnedask/rest17.fif', 'mnedask/rest2.fif', 'mnedask/rest3.fif', 'mnedask/rest4.fif', 'mnedask/rest5.fif', 'mnedask/rest6.fif', 'mnedask/rest7.fif', 'mnedask/rest8.fif', 'mnedask/rest9.fif']


In [22]:
def read_raw_s3(fif_file_s3):
    fif_file = fs.open(fif_file_s3, mode='rb')
    raw = mne.io.read_raw_fif(fif_file, preload=True)
    raw.crop(0, 50)
    return raw

In [23]:
def compute_psd(raw):
    picks = mne.pick_types(raw.info, meg='mag', eeg=False,
                           eog=False, stim=False)
    psd, _ = psd_multitaper(raw, fmin=2, fmax=55, picks=picks, normalization="full")
    return np.log10(psd)

In [24]:
def compute_mean_psd(psds):
    return np.mean(np.array(psds), axis=0)

In [25]:
psds = []
for file in files[0:4]:
    raw = delayed(read_raw_s3)(file)
    psd = delayed(compute_psd)(raw)
    psds.append(psd)

mean_psd = delayed(compute_mean_psd)(psds)

In [26]:
all_psds = client.compute(mean_psd)
all_psds = all_psds.result()

In [27]:
client.close()

In [28]:
client