In [2]:
import mne
import os
import s3fs
from dask.distributed import Client
from mne.time_frequency import psd_multitaper
from dask_cloudprovider.aws import FargateCluster
from dask import delayed
import numpy as np
from dask.distributed import performance_report

# Reduce Verbosity
mne.set_log_level('WARNING')
# Numbers of subjects to be processed in parallel

In [3]:
n_workers = 50
os.environ["AWS_DEFAULT_REGION"] = "us-east-1"

cluster = FargateCluster(
    image="mnetools/mne-python:0.22.1", #Base Docker Image to use
    n_workers=n_workers,
    fargate_use_private_ip=False,
    scheduler_timeout="15 minutes"
)
client = Client(cluster)

  next(self.gen)


In [4]:
client

0,1
Client  Scheduler: tcp://3.239.49.133:8786  Dashboard: http://3.239.49.133:8787/status,Cluster  Workers: 50  Cores: 200  Memory: 745.06 GiB


In [5]:
fs = s3fs.S3FileSystem(anon=False, key='********************', secret='********************')
files = fs.ls("mnedask")
print(files)

['mnedask/rest1.fif', 'mnedask/rest10.fif', 'mnedask/rest100.fif', 'mnedask/rest101.fif', 'mnedask/rest102.fif', 'mnedask/rest103.fif', 'mnedask/rest104.fif', 'mnedask/rest105.fif', 'mnedask/rest106.fif', 'mnedask/rest107.fif', 'mnedask/rest108.fif', 'mnedask/rest109.fif', 'mnedask/rest11.fif', 'mnedask/rest110.fif', 'mnedask/rest111.fif', 'mnedask/rest112.fif', 'mnedask/rest113.fif', 'mnedask/rest114.fif', 'mnedask/rest115.fif', 'mnedask/rest116.fif', 'mnedask/rest117.fif', 'mnedask/rest118.fif', 'mnedask/rest119.fif', 'mnedask/rest12.fif', 'mnedask/rest120.fif', 'mnedask/rest121.fif', 'mnedask/rest122.fif', 'mnedask/rest123.fif', 'mnedask/rest124.fif', 'mnedask/rest125.fif', 'mnedask/rest126.fif', 'mnedask/rest127.fif', 'mnedask/rest128.fif', 'mnedask/rest129.fif', 'mnedask/rest13.fif', 'mnedask/rest130.fif', 'mnedask/rest131.fif', 'mnedask/rest132.fif', 'mnedask/rest133.fif', 'mnedask/rest134.fif', 'mnedask/rest135.fif', 'mnedask/rest136.fif', 'mnedask/rest137.fif', 'mnedask/rest138

In [6]:
def read_raw_s3(fif_file_s3):
    fif_file = fs.open(fif_file_s3, mode='rb')
    raw = mne.io.read_raw_fif(fif_file, preload=True)
    return raw

In [7]:
def crop_raw(raw):
    raw = raw.crop(0, 60)
    return raw

In [8]:
def apply_proj(raw):
    raw = raw.apply_proj()
    return raw

In [9]:
def apply_filter(raw):
    raw = raw.filter(1, None)
    return raw

In [10]:
def compute_psd(raw):
    picks = mne.pick_types(raw.info, meg='mag', eeg=False,
                           eog=False, stim=False)
    psd, _ = psd_multitaper(raw, fmin=2, fmax=55, picks=picks, normalization="full")
    return np.log10(psd)

In [11]:
def compute_mean_psd(psds):
    return np.mean(np.array(psds), axis=0)

In [12]:
psds = []
for file in files:
    raw = delayed(read_raw_s3)(file)
    raw = delayed(crop_raw)(raw)
    raw = delayed(apply_proj)(raw)
    raw = delayed(apply_filter)(raw)
    psd = delayed(compute_psd)(raw)
    psds.append(psd)

mean_psd = delayed(compute_mean_psd)(psds)

In [13]:
%%time
all_psds = client.compute(mean_psd)
all_psds = all_psds.result()

Wall time: 2min 50s


In [14]:
client.close()

In [15]:
client