In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import pickle
import sys
sys.path.append('/DATA/publish/mocet/analysis/scripts')
from utils.base import get_project_directory, get_configs

project_dir = get_project_directory()
configs = get_configs()

x_direction_range = np.linspace(-np.deg2rad(5), np.deg2rad(5), 9)
z_direction_range = np.linspace(-np.deg2rad(5), np.deg2rad(5), 9)
orientation_range = np.arange(-90, 91, 5)
depth_range = [0.14,0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25]
total_count = len(x_direction_range)*len(z_direction_range)*len(depth_range)*len(orientation_range)*12
print(total_count, total_count/12)

basis_log_fname = f'../../data/simulation/basis_log.csv'
basis_log = pd.read_csv(basis_log_fname)
basis_pupill_data = basis_log[["center_x","center_y","confidence"]]
basis_pupill_data = basis_pupill_data.dropna()
basis_pupill_data = basis_pupill_data.to_numpy()

basis = []
basis_params = []
idx = 0
for x, x_d in enumerate(x_direction_range):
    for z, z_d in enumerate(z_direction_range):
        for depth in depth_range:
            for d, deg in enumerate(orientation_range):
                tmp_basis = np.zeros((12,2))
                for t in range(12):
                    tmp_basis[t,:] = basis_pupill_data[idx, :2]
                    idx += 1
                if np.any(tmp_basis == -1):
                    pass
                else:
                    basis.append(tmp_basis)
                    basis_params.append([x_d, z_d, depth, deg])

basis = np.array(basis)
basis_params = np.array(basis_params)

valid_data = pickle.load(open('../../data/valid_data_list.pkl', 'rb'))

subjects = []
subjects_runs = {}
for key in list(valid_data.keys()):
    subjects.append(key[0])
    if not subjects_runs.keys().__contains__(key[0]):
        subjects_runs[key[0]] = 0
    subjects_runs[key[0]] += 1
subjects = list(set(subjects))
subjects.sort()

431568 35964.0


In [2]:
import os
import mocet
from tqdm import tqdm

basis_data = pickle.load(open('../../data/simulation/basis_data.pkl', 'rb'))
n_iterations = 100

def random_phase_shuffle(data):
    fft_data = np.fft.fft(data, axis=0)
    amplitude = np.abs(fft_data)
    phase = np.angle(fft_data)
    random_phase = np.random.uniform(0, 2 * np.pi, phase.shape)
    new_fft_data = amplitude * np.exp(1j * random_phase)
    return np.real(np.fft.ifft(new_fft_data, axis=0))


def generate_null_eyetracking_simulation(subject, session, task, run, k):
    motion_param_labels = ['trans_x', 'trans_y', 'trans_z', 'rot_x', 'rot_y', 'rot_z']
    root = f'{project_dir}/data/eyetracking/{subject}/{session}'
    confounds_fname = f'{root}/{subject}_{session}_{task}_{run}_desc-confounds_timeseries.tsv'

    basis_idx = basis_data[(subject, session, task, run)]
    fmriprep_confounds = pd.read_csv(confounds_fname, delimiter='\t')
    output_dir = f'{project_dir}/data/simulation_null/{subject}/{session}'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir, exist_ok=True)

    for i in tqdm(range(n_iterations)):
        motion_params = fmriprep_confounds[motion_param_labels]
        motion_params = np.nan_to_num(motion_params)

        np.random.seed(i*k) # for reproducibility
        random_motion_params = random_phase_shuffle(motion_params)
        random_motion_params = random_motion_params - random_motion_params[0,:]

        _, pupil_coordinates = mocet.simulation.generate(random_motion_params,
                                                         basis_params[basis_idx],
                                                         render = True,
                                                         render_resolution = (128, 96),
                                                         detect_pupil = True)

        pupil_coordinates.to_csv(f'{output_dir}/{subject}_{session}_{task}_{run}_simulation-eyetracking_null-{i:04d}.csv', index=False)

In [7]:
import time
from multiprocessing import Pool

time_sta = time.time()
n_processes = 60
with Pool(n_processes) as pool:
    pool.starmap(generate_null_eyetracking_simulation, [key + (k,) for k, key in enumerate(valid_data.keys())])

100%|██████████| 100/100 [24:35:06<00:00, 885.07s/it]  
100%|██████████| 100/100 [24:39:06<00:00, 887.47s/it] 
100%|██████████| 100/100 [24:44:16<00:00, 890.56s/it] 
100%|██████████| 100/100 [24:47:59<00:00, 892.80s/it] 
100%|██████████| 100/100 [24:48:55<00:00, 893.35s/it]
100%|██████████| 100/100 [24:48:57<00:00, 893.38s/it]
100%|██████████| 100/100 [24:49:27<00:00, 893.67s/it] 
100%|██████████| 100/100 [24:49:31<00:00, 893.71s/it]
100%|██████████| 100/100 [24:49:53<00:00, 893.94s/it]]
100%|██████████| 100/100 [24:50:12<00:00, 894.12s/it]
100%|██████████| 100/100 [24:50:38<00:00, 894.39s/it]
100%|██████████| 100/100 [24:51:39<00:00, 895.00s/it]]
100%|██████████| 100/100 [24:52:32<00:00, 895.52s/it]
100%|██████████| 100/100 [24:54:36<00:00, 896.76s/it]
100%|██████████| 100/100 [24:55:25<00:00, 897.26s/it]
100%|██████████| 100/100 [24:56:37<00:00, 897.97s/it]
100%|██████████| 100/100 [24:58:10<00:00, 898.90s/it] 
100%|██████████| 100/100 [24:58:57<00:00, 899.37s/it]
100%|██████████| 10