In [None]:
import mne
import argparse
import scipy.io as sio
import numpy as np
import nibabel as nib

In [None]:
def freesurfer_to_mri(image_nii):
    '''
    The transformation to go from freesurfer space to mri space

    Parameters
    ----------
    image_nii : str
        path to the nifti file
    
    Returns
    -------
    translation : numpy.ndarray
        The translation matrix
    Note: maybe we want the inverse of this transformation!!!
    '''

    translation = np.eye(4)

    # load image
    image_nii = nib.load('/media/8.1/scripts/laurap/franscescas_data/meg_headcast/mri/T1/sMQ03532-0009-00001-000192-01.nii')

    shape = np.array(image_nii.shape)
    center = shape / 2
    center_homogeneous = np.hstack((center, [1]))
    transform = image_nii.affine
    
    cras = (transform @ center_homogeneous)[:3]

    translation[:3, -1] = cras

    return np.linalg.inv(translation)

def get_hpi_meg(epochs):
    ''' 
    Gets the position of the hpi coils in MEG space and 
    reorders them to match the order of the hpi coils in MRI space

    Parameters
    ----------
    epochs : mne.Epochs
        The epochs object

    Returns
    -------
    hpi_coil_pos : numpy.ndarray
    '''
    hpi_coil_pos = np.array([dig['r'] for dig in epochs.info['hpi_results'][0]['dig_points']]) # not 100 percent sure these are the right ones  
    
    # order of hpi coils is different in MEG and mri space, so we need to reorder them
    hpi_coil_pos[[0, 1, 2, 3]] = hpi_coil_pos[[2, 3, 1, 0]]
    
    return hpi_coil_pos

def rot3dfit(A, B):
    """
    Permforms a least-square fit for the linear form 
    Y = X*R + T

    where R is a 3 x 3 orthogonal rotation matrix, t is a 1 x 3
    translation vector, and A and B are sets of 3D points defined as
    3 x N matrices, where N is the number of points.

    Implementation of the rigid 3D transform algorithm from:
    Least-Squares Fitting of Two 3-D Point Sets,
    Arun, K. S. and Huang, T. S. and Blostein, S. D (1987)
    """
    assert A.shape == B.shape

    if A.shape[0] != 3 or B.shape[0] != 3:
        raise ValueError('A and B must be 3 x N matrices')

    # compute centroids (average points over each dimension (x, y, z))
    centroid_A = np.mean(A, axis=1) 
    centroid_B = np.mean(B, axis=1)
    
    centroid_A = centroid_A.reshape(-1, 1)
    centroid_B = centroid_B.reshape(-1, 1)

    # to find the optimal rotation we first re-centre both dataset 
    # so that both centroids are at the origin (subtract mean)
    Ac = A - centroid_A
    Bc = B - centroid_B

    # rotation matrix
    H = Ac @ Bc.T
    U, S, V = np.linalg.svd(H)
    R = V.T @ U.T
    
    if np.linalg.det(R) < 0:
        print("det(R) < R, reflection detected!, correcting for it ...")
        V[2,:] *= -1
        R = V.T @ U.T

    # translation vector
    t = -R @ centroid_A + centroid_B 

    # best fit 
    Yf = R @ A + t

    dY = B - Yf
    errors = []
    for point in range(dY.shape[1]):
        err = np.linalg.norm(dY[:, point])
        errors.append(err)

    print(errors)
        
    return R, t, Yf


def transform_geometry(epochs, hpi_mri, image_nii):
    '''
    Changes the sensor positions and dev_head_t from device to mri

    Parameters
    ----------
    epochs : mne.Epochs
        The epochs object

    Returns
    -------
    epochs : mne.Epochs
        The epochs object with changed sensor positions and dev_head_t
    '''

    hpi_meg = get_hpi_meg(epochs)
    hpi_mri = hpi_mri/1000 # convert to meters

    
    # find rotation matrix and translation vector to move from MEG to MRI space
    R, T, yf = rot3dfit(hpi_meg.T, hpi_mri.T) # function needs 3 x N matrices

    meg_mri_t = np.zeros((4, 4))
    meg_mri_t[:3, :3] = R.T
    meg_mri_t[:3, 3] = T.T 
    meg_mri_t[3, 3] = 1


    # This transformation is used to go from MRI to freesurfer space
    trans = freesurfer_to_mri(image_nii=image_nii)
    trans[:3, -1] = trans[:3, -1]/1000
    epochs.info['dev_head_t']['trans'] = trans

    for i in range(len(epochs.info['chs'])):
        # change sensor positions
        location = epochs.info['chs'][i]['loc']
        loc = np.append(location[:3], 1)
        loc = loc @ meg_mri_t
        location[:3] = loc[:3]
        epochs.info['chs'][i]['loc'] = location
        # change sensor orientations
        rot_coils = np.array([location[3:6], location[6:9], location[9:12]])
        rot_coils = rot_coils @ R.T
        
        location[3:12] = rot_coils.flatten() # check if this is correct

        if i ==305:
            break

    return epochs

def plot_3d_points(A, B, title):
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
    for i in range(A.shape[1]):
        ax.scatter(A[0, i], A[1, i], A[2, i], color = colors[i], marker='o', label=f'A{i}', alpha=1)
        ax.scatter(B[0, i], B[1, i], B[2, i], color = colors[i], marker='^', label=f'B{i}', alpha=1)
    plt.title(title)
    plt.legend()
    plt.show()


In [None]:
session = "visual_18"
src = mne.read_source_spaces('/media/8.1/raw_data/franscescas_data/mri/sub1-oct6-src.fif')
bem_sol = '/media/8.1/raw_data/franscescas_data/mri/subj1-bem_solution.fif'
subject = 'subj1'
subject_dir = '/media/8.1/raw_data/franscescas_data/mri'
epoch_path = f'/media/8.1/final_data/laurap/epochs/{session}-epo.fif'
epochs = mne.read_epochs(epoch_path)


In [3]:
dat = epochs.get_data(picks = 'meg')

In [4]:
dat.shape

(215, 306, 250)

In [6]:
dat.transpose(2,0,1).shape

(250, 215, 306)

In [None]:
session = "visual_18"
src = mne.read_source_spaces('/media/8.1/raw_data/franscescas_data/mri/sub1-oct6-src.fif')
bem_sol = '/media/8.1/raw_data/franscescas_data/mri/subj1-bem_solution.fif'
subject = 'subj1'
subject_dir = '/media/8.1/raw_data/franscescas_data/mri'
epoch_path = f'/media/8.1/final_data/laurap/epochs/{session}-epo.fif'
hpi_mri = sio.loadmat(f'/media/8.1/scripts/laurap/franscescas_data/meg_headcast/hpi_mri.mat').get('hpi_mri')
path_nii = '/media/8.1/scripts/laurap/franscescas_data/meg_headcast/mri/T1/sMQ03532-0009-00001-000192-01.nii'

epochs = mne.read_epochs(epoch_path)
epochs = transform_geometry(epochs, hpi_mri, path_nii)

# plot alignment
trans = mne.transforms.Transform('head', 'mri')
mne.viz.plot_alignment(epochs.info, trans = trans, dig = False, surfaces = ['pial', 'head'], subject = subject, subjects_dir = subject_dir, meg = ['sensors'], show_axes = True)

fwd = mne.make_forward_solution(epochs.info, src = src, trans = None, bem = bem_sol)
cov = mne.compute_covariance(epochs, method='empirical') ## sample covariance is calculated
inv = mne.minimum_norm.make_inverse_operator(epochs.info, fwd, cov, loose='auto')


evoked = epochs.average()
stc = mne.minimum_norm.apply_inverse(evoked, inv, lambda2=1.0 / 3.0 ** 2, method="dSPM", pick_ori="normal")

# plot brain with activation
brain = stc.plot(subject=subject, subjects_dir=subject_dir, views='lat', size=(800, 800), smoothing_steps=5, time_viewer=True)

In [None]:
sessions = ['visual_23', 'visual_24', 'visual_25', 'visual_26', 'visual_27', 'visual_28', 'visual_29', 'visual_30', 'visual_31', 'visual_32', 'visual_33', 'visual_34', 'visual_35', 'visual_36', 'visual_37', 'visual_38']
hpi_mri = sio.loadmat(f'/media/8.1/scripts/laurap/franscescas_data/meg_headcast/hpi_mri.mat').get('hpi_mri')

for session in sessions:
    epoch_path = f'/media/8.1/final_data/laurap/epochs/{session}-epo.fif'
    epochs = mne.read_epochs(epoch_path)
    hpi_meg = get_hpi_meg(epochs)

    plot_3d_points(hpi_mri.T, hpi_meg.T*1000, f'before {session}')
