### Notebook for plotting epochs based on behaviour (movement/non-movement)

In [2]:
import mne
import numpy as np
import pandas as pd
import seaborn as sns
from mne.time_frequency import psd_array_multitaper

from shared.helper_functions import *

In [3]:
epoch_folder = select_folder("Select the folder that holds epoch files starting with 'filtered_epochs_w_movement'")
plot_folder = select_or_create_folder("Create or select a folder the plots will be saved to")

Let's load the epoch object

In [None]:
# let's load the epochs file
epochs = mne.read_epochs(os.path.join(epoch_folder, "filtered_epochs_w_movement-epo.fif"), preload=True)

Establish a dictionary holding channels to be omitted from consideration

In [4]:
bad_epochs_per_subject = {
    "80630": ["OFC_R"],
    "39489": ["OFC_R"],
    "80625": ["OFC_L"],
    "81193": ["OFC_R", "OFC_L"]
}

Get the names of the channels we want to plot

In [4]:
wanted_chans = [channel for channel in epochs.info["ch_names"] if not channel in ['EMG_L', 'EMG_R']]
wanted_chans

['OFC_R', 'OFC_L', 'CG', 'STR_R', 'S1_L', 'S1_R', 'V1_R']

#### DRD2-WT and DRD2-KO average PSD for movement and non-movement epochs

Define the frequency domains

In [5]:
freq_bands = {
    r'$\delta$': (1, 4),  # Delta
    r'$\theta$': (4, 8),  # Theta
    r'$\alpha$': (8, 13),  # Alpha
    r'$\beta$': (13, 30),  # Beta
    r'$\gamma$': (30, 100)  # Gamma
}

For now, we want to separate the movement and non-movement data, so let's do that

In [9]:
movement_epochs = epochs[epochs.metadata["movement"] == 1]
non_movement_epochs = epochs[epochs.metadata["movement"] == 0]

First, let's create a dataframe that holds all PSD means per subject, and is annotated with the genotype of the subject

In [82]:
behavioral_dfs = {}
for behaviour, epochs in {'movement': movement_epochs, 'non_movement': non_movement_epochs}.items():
    print(f"Generating {behaviour} dataframe.")
    # loop through channels, as we want data per channel
    
    df = pd.DataFrame()

    for channel in wanted_chans:
        for subject_id in epochs.metadata["animal_id"].unique():
            
            subject_data = epochs[epochs.metadata["animal_id"] == subject_id]
            subject_data = subject_data[:-1]  # somehow the last epoch has only 0.0 values, so we remove this
            genotype = subject_data.metadata["genotype"].iloc[0]

            # skip plotting data of bad quality
            if subject_id in bad_epochs_per_subject.keys():
                if channel in bad_epochs_per_subject[subject_id]: 
                    print(f"Omitting channel {channel} for subject {subject_id}.")
                    continue
            
            # get the average PSD for this subject
            psds_sub, freqs = psd_array_multitaper(
                subject_data.get_data(picks=channel),
                fmin=0, fmax=100,
                sfreq=subject_data.info['sfreq'],
                n_jobs=-1
            )
            total_power = np.sum(psds_sub, axis=-1)
            psds_sub_norm = psds_sub / total_power[:, np.newaxis]
            
            mean_psd_sub = np.mean(psds_sub[:, 0, :], axis=0)
            mean_psd_sub_norm = np.mean(psds_sub_norm[:, 0, :], axis=0)
            
            # save the average of this subject, so we can later plot the mean of the subject averages
            df = pd.concat([df, pd.DataFrame({
                "freq": freqs,
                "psd (means)": mean_psd_sub,
                "psd (norm)": mean_psd_sub_norm,
                "subject_id": subject_id,
                "genotype": genotype,
                "channel": channel,
            })])
                
    behavioral_dfs[behaviour] = df

Generating movement dataframe.
Omitting channel OFC_R for subject 81193.
Omitting channel OFC_R for subject 39489.
Omitting channel OFC_R for subject 80630.
Omitting channel OFC_L for subject 81193.
Omitting channel OFC_L for subject 80625.
Generating non_movement dataframe.
Omitting channel OFC_R for subject 81193.


  psds_sub_norm = psds_sub / total_power[:, np.newaxis]


Omitting channel OFC_R for subject 39489.
Omitting channel OFC_R for subject 80630.
Omitting channel OFC_L for subject 81193.
Omitting channel OFC_L for subject 80625.


Now we have a dataframe for both movement, and non-movement data, let's generate a plot per channel where we average the PSDs of all WT and all KO subjects

In [6]:
for behaviour, df in behavioral_dfs.items():
    for y_axis in ['psd (means)', 'psd (norm)']:
        # create palette for WT and KO subjects (background lines)
        palettes, subjects = {}, df["subject_id"].unique()
        [palettes.update({subject: '#427C99'}) if (df[df["subject_id"] == subject].genotype.iloc[0] == "DRD2-WT") else palettes.update({subject: '#AF5541'}) for subject in subjects]
    
        print(f"Generating plot grid for behaviour: {behaviour} ({y_axis})")
        fig, axs = plt.subplots(2, 4, figsize=(35, 15), sharex=True, sharey=True)
        axs = axs.ravel()
        
        for i, channel in enumerate(wanted_chans):
        
            channel_data = df[df.channel == channel]
            channel_data = channel_data[(channel_data.freq > 52) | (channel_data.freq < 48)]  # remove the 50Hz peak
            
            palette = {'DRD2-WT': '#427C99', 'DRD2-KO': '#AF5541'}
            sns.lineplot(data=channel_data, x='freq', y=y_axis, palette=palette, hue_order=['DRD2-WT', 'DRD2-KO'], hue='genotype', legend=True, ax=axs[i], errorbar='se')
            sns.lineplot(data=channel_data, x='freq', y=y_axis, palette=palettes, hue='subject_id', linewidth=.4, legend=False, ax=axs[i], alpha=.5)
            
            for band, (start, end) in freq_bands.items():
                axs[i].axvline(x=start, color='gray', linestyle='--', alpha=0.3)
                axs[i].axvline(x=end, color='gray', linestyle='--', alpha=0.3)
                axs[i].text((start + end) / 2, axs[i].get_ylim()[1] * 1.01, band, horizontalalignment='center', verticalalignment='top', fontsize=8, color='black')
            
            axs[i].set_yscale('log')
            axs[i].set_xlabel('Frequency (Hz)')
            axs[i].set_ylabel('Power/Frequency (dB/Hz)')
            axs[i].set_title(f'Average PSD per genotype ({channel}) - Behaviour: {behaviour}')
        
        plt.subplots_adjust(wspace=0.08, hspace=0.08)
        axs[-1].remove()
        save_figure(os.path.join(plot_folder, f"mov_vs_non_mov/WT_vs_KO_{'norm' if 'norm' in y_axis else 'abs'}_PSD_averages_{behaviour}.pdf"))

Generating plot grid for behaviour: movement (psd (means))
Generating plot grid for behaviour: movement (psd (norm))
Generating plot grid for behaviour: non_movement (psd (means))
Generating plot grid for behaviour: non_movement (psd (norm))


Let's concatenate the two behaviour dataframes

In [7]:
behavioral_dfs["movement"]["behaviour"] = "Movement"
behavioral_dfs["non_movement"]["behaviour"] = "Non-movement"
total_df = pd.concat([behavioral_dfs["movement"], behavioral_dfs["non_movement"]])

Now we plot the movement and non-movement epochs separately, so hue'ing on behaviour instead of genotype

In [10]:
for y_axis in ['psd (means)', 'psd (norm)']:
    fig, axs = plt.subplots(2, 4, figsize=(35, 15), sharex=True, sharey=True)
    axs = axs.ravel()
    
    for i, channel in enumerate(wanted_chans):
        
        channel_data = total_df[total_df.channel == channel]
        channel_data = channel_data[(channel_data.freq > 52) | (channel_data.freq < 48)]  # remove the 50Hz peak
        
        sns.lineplot(data=channel_data, x='freq', y=y_axis, hue="behaviour", legend=True, ax=axs[i], errorbar='se')
        
        for band, (start, end) in freq_bands.items():
            axs[i].axvline(x=start, color='gray', linestyle='--', alpha=0.3)
            axs[i].axvline(x=end, color='gray', linestyle='--', alpha=0.3)
            axs[i].text((start + end) / 2, axs[i].get_ylim()[1] * 1.01, band, horizontalalignment='center', verticalalignment='top', fontsize=8, color='black')
        
        axs[i].set_yscale('log')
        axs[i].set_xlabel('Frequency (Hz)')
        axs[i].set_ylabel('Power/Frequency (dB/Hz)')
        axs[i].set_title(f'Average PSD, movement vs non-movement ({channel})')
    
    plt.subplots_adjust(wspace=0.08, hspace=0.08)
    axs[-1].remove()
    save_figure(os.path.join(plot_folder, f"mov_vs_non_mov/mov_vs_nonmov_{'norm' if 'norm' in y_axis else 'abs'}_PSD_averages.pdf"))

Let's alter the data some more, so we can create a plot that includes traces for movement-wt, movement-ko, non-movement-wt, non-movement-ko.

In [11]:
total_df["mov_geno"] = total_df['genotype'] + " - " + total_df['behaviour']

And plot it

In [12]:
for y_axis in ['psd (means)', 'psd (norm)']:
    fig, axs = plt.subplots(2, 4, figsize=(35, 15), sharex=True, sharey=True)
    axs = axs.ravel()
    
    for i, channel in enumerate(wanted_chans):
        
        channel_data = total_df[total_df.channel == channel]
        channel_data = channel_data[(channel_data.freq > 52) | (channel_data.freq < 48)]  # remove the 50Hz peak
        
        hue_order = ['DRD2-WT - Movement', 'DRD2-KO - Movement', 'DRD2-WT - Non-movement', 'DRD2-KO - Non-movement']
        sns.lineplot(data=channel_data, x='freq', y=y_axis, hue="mov_geno", hue_order=hue_order, ax=axs[i], errorbar='se')
    
        for band, (start, end) in freq_bands.items():
            axs[i].axvline(x=start, color='gray', linestyle='--', alpha=0.3)
            axs[i].axvline(x=end, color='gray', linestyle='--', alpha=0.3)
            axs[i].text((start + end) / 2, axs[i].get_ylim()[1] * 1.01, band, horizontalalignment='center', verticalalignment='top', fontsize=8, color='black')
        
        axs[i].set_yscale('log')
        axs[i].set_xlabel('Frequency (Hz)')
        axs[i].set_ylabel('Power/Frequency (dB/Hz)')
        axs[i].set_title(f'Average PSD of different genotype-behaviour combinations ({channel})')
        axs[i].legend(title="Genotype - Behaviour")
    
    plt.subplots_adjust(wspace=0.08, hspace=0.08)
    axs[-1].remove()
    save_figure(os.path.join(plot_folder, f"mov_vs_non_mov/mov_vs_nonmov_WT_vs_KO_{'norm' if 'norm' in y_axis else 'abs'}_PSD_averages.pdf"))