In [38]:
import os
import mne
import json
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from mne.time_frequency import psd_array_multitaper
from mne import Epochs

from helper_functions import save_figure

In [ ]:
mne.set_log_level()

In [2]:
with open('../settings.json', "r") as f:
    settings = json.load(f)
    
epoch_folder = settings['epochs_folder']
plot_folder = settings['plots_folder']

Let's load the epoch object

In [3]:
# let's load the epochs file
epochs = mne.read_epochs(os.path.join(epoch_folder, "filtered_epochs_w_movement-epo.fif"), preload=True)

Reading C:\Users\Olle de Jong\Documents\MSc Biology\rp2\rp2_data\resting_state\output\epochs\filtered_epochs_w_movement-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =       0.00 ...    4998.53 ms
        0 CTF compensation matrices available
Reading C:\Users\Olle de Jong\Documents\MSc Biology\rp2\rp2_data\resting_state\output\epochs\filtered_epochs_w_movement-epo-1.fif ...
Isotrak not found
    Found the data of interest:
        t =       0.00 ...    4998.53 ms
        0 CTF compensation matrices available
Reading C:\Users\Olle de Jong\Documents\MSc Biology\rp2\rp2_data\resting_state\output\epochs\filtered_epochs_w_movement-epo-2.fif ...
Isotrak not found
    Found the data of interest:
        t =       0.00 ...    4998.53 ms
        0 CTF compensation matrices available
Adding metadata with 4 columns
21980 matching events found
No baseline correction applied
0 projection items activated


Establish a dictionary holding channels to be omitted from consideration

In [53]:
bad_epochs_per_subject = {
    "80630": ["OFC_R"],
    "80625": ["OFC_L"],
    "81193": ["OFC_R", "OFC_L"]
}

Get the names of the channels we want to plot

In [54]:
wanted_chans = [channel for channel in epochs.info["ch_names"] if not channel in ['EMG_L', 'EMG_R']]

#### DRD2-WT and DRD2-KO average PSD for movement and non-movement epochs

Define the frequency domains

In [55]:
freq_bands = {
    r'$\delta$': (1, 4),  # Delta
    r'$\theta$': (4, 8),  # Theta
    r'$\alpha$': (8, 13),  # Alpha
    r'$\beta$': (13, 30),  # Beta
    r'$\gamma$': (30, 100)  # Gamma
}

For now, we want to separate the movement and non-movement data, so let's do that

In [7]:
movement_epochs = epochs[epochs.metadata["movement"] == 1]
non_movement_epochs = epochs[epochs.metadata["movement"] == 0]

First, let's create a dataframe that holds all PSD means per subject, and is annotated with the genotype of the subject

In [59]:
behavioral_dfs = {}
for behaviour, epochs in {'movement': movement_epochs, 'non_movement': non_movement_epochs}.items():
    print(f"Generating {behaviour} dataframe.")
    # loop through channels, as we want data per channel
    
    df = pd.DataFrame(columns=['freq', 'psd (means)', 'subject_id', 'genotype'])

    for channel in wanted_chans:
        for subject_id in epochs.metadata["animal_id"].unique():
            
            subject_data = epochs[epochs.metadata["animal_id"] == subject_id]
            genotype = subject_data.metadata["genotype"].iloc[0]

            # skip plotting data of bad quality
            if subject_id in bad_epochs_per_subject.keys():
                if channel in bad_epochs_per_subject[subject_id]: 
                    print(f"Omitting channel {channel} for subject {subject_id}.")
                    continue
            
            # get the average PSD for this subject
            psds_sub, freqs = psd_array_multitaper(
                subject_data.get_data(picks=channel),
                fmin=0, fmax=100,
                sfreq=subject_data.info['sfreq'],
                n_jobs=-1
            )
            mean_psd_sub = np.mean(psds_sub[:, 0, :], axis=0)
            
            # save the average of this subject, so we can later plot the mean of the subject averages
            df = pd.concat([df, pd.DataFrame({
                "freq": freqs,
                "psd (means)": mean_psd_sub,
                "subject_id": subject_id,
                "genotype": genotype,
                "channel": channel,
            })])
                
    behavioral_dfs[behaviour] = df

Generating movement dataframe.
Omitting channel OFC_R for subject 81193.


  df = pd.concat([df, pd.DataFrame({


Omitting channel OFC_R for subject 80630.
Omitting channel OFC_L for subject 81193.
Omitting channel OFC_L for subject 80625.
Generating non_movement dataframe.
Omitting channel OFC_R for subject 81193.


  df = pd.concat([df, pd.DataFrame({


Omitting channel OFC_R for subject 80630.
Omitting channel OFC_L for subject 81193.
Omitting channel OFC_L for subject 80625.


Now we have a dataframe for both movement, and non-movement data, let's generate a plot per channel where we average the PSDs of all WT and all KO subjects

In [78]:
for behaviour, df in behavioral_dfs.items():
    for channel in df.channel.unique():
        fig, ax = plt.subplots(figsize=(12, 8))  # initiate plot
        
        channel_data = df[df.channel == channel]
        
        palette = {'DRD2-WT': '#2D6D8E', 'DRD2-KO': '#9B2B11'}

        sns.lineplot(data=channel_data, x='freq', y='psd (means)', palette=palette, hue_order=['DRD2-WT', 'DRD2-KO'], hue='genotype', legend=True, ax=ax,errorbar=('ci', 95))
        
        # add frequency bands
        for band, (start, end) in freq_bands.items():
            ax.axvline(x=start, color='gray', linestyle='--', alpha=0.3)
            ax.axvline(x=end, color='gray', linestyle='--', alpha=0.3)
            ax.text((start + end) / 2, ax.get_ylim()[1] * 1.01, band, horizontalalignment='center', verticalalignment='top', fontsize=8, color='black')
        
        ax.set_yscale('log')
        ax.set_xlabel('Frequency (Hz)')
        ax.set_ylabel('Power/Frequency (dB/Hz)')
        ax.set_title(f'Average PSD of all {behaviour} epochs ({channel})')
        plt.legend(loc='best')
        save_figure(os.path.join(plot_folder, f"movement_vs_non_movement/wt_vs_ko/{behaviour}_{channel}"))

Let's also instead hue by subject_id to see what subject averages stand out to us.

In [79]:
for behaviour, df in behavioral_dfs.items():
    for channel in df.channel.unique():
        fig, ax = plt.subplots(figsize=(12, 8))  # initiate plot
        
        channel_data = df[df.channel == channel]
        sns.lineplot(data=channel_data, x='freq', y='psd (means)', hue='subject_id', legend=True, ax=ax)
        
        # add frequency bands
        for band, (start, end) in freq_bands.items():
            ax.axvline(x=start, color='gray', linestyle='--', alpha=0.3)
            ax.axvline(x=end, color='gray', linestyle='--', alpha=0.3)
            ax.text((start + end) / 2, ax.get_ylim()[1] * 1.01, band, horizontalalignment='center', verticalalignment='top', fontsize=8, color='black')
        
        ax.set_yscale('log')
        ax.set_xlabel('Frequency (Hz)')
        ax.set_ylabel('Power/Frequency (dB/Hz)')
        ax.set_title(f'Average PSD of {behaviour} epochs per subject ({channel})')
        plt.legend(loc='best')
        save_figure(os.path.join(plot_folder, f"movement_vs_non_movement/all_subjects_{behaviour}_{channel}"))

Now let's plot both behaviours in the same plot and hue on genotype

In [80]:
for channel in wanted_chans:
    fig, ax = plt.subplots(figsize=(12, 8))  # initiate plot
    
    mov_df = behavioral_dfs['movement']
    non_mov_df = behavioral_dfs['non_movement']
    channel_data_mov = mov_df[mov_df.channel == channel]
    channel_data_non = non_mov_df[non_mov_df.channel == channel]
    
    palette_mov = {'DRD2-WT': '#2D6D8E', 'DRD2-KO': '#9B2B11'}
    palette_non = {'DRD2-WT': '#4CB6ED', 'DRD2-KO': '#B96A58'}
    
    sns.lineplot(data=channel_data_mov, x='freq', y='psd (means)', palette=palette_mov, hue_order=['DRD2-WT', 'DRD2-KO'], hue='genotype', legend=True, ax=ax, errorbar=None)
    sns.lineplot(data=channel_data_non, x='freq', y='psd (means)', palette=palette_non, hue_order=['DRD2-WT', 'DRD2-KO'], hue='genotype', legend=True, ax=ax, errorbar=None)
    
    # add frequency bands
    for band, (start, end) in freq_bands.items():
        ax.axvline(x=start, color='gray', linestyle='--', alpha=0.3)
        ax.axvline(x=end, color='gray', linestyle='--', alpha=0.3)
        ax.text((start + end) / 2, ax.get_ylim()[1] * 1.01, band, horizontalalignment='center', verticalalignment='top', fontsize=8, color='black')
    
    ax.set_yscale('log')
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('Power/Frequency (dB/Hz)')
    ax.set_title(f'Average PSD - ({channel})')
    
    handles = []
    patch1 = mpatches.Patch(color='#2D6D8E', linestyle='-', label='DRD2-WT - Movement')
    patch2 = mpatches.Patch(color='#9B2B11', linestyle='-', label='DRD2-KO - Movement')
    patch3 = mpatches.Patch(color='#4CB6ED', linestyle='-', label='DRD2-WT - Idle')
    patch4 = mpatches.Patch(color='#B96A58', linestyle='-', label='DRD2-KO - Idle')
    [handles.append(patch) for patch in [patch1, patch2, patch3, patch4]]
    plt.legend(loc='upper right', handles=handles)
    
    save_figure(os.path.join(plot_folder, f"movement_vs_non_movement/wt_vs_ko/combined/{channel}"))