In [1]:
import os
import mne
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from mne.time_frequency import psd_array_multitaper
from mne import Epochs

from helper_functions import save_figure

In [ ]:
mne.set_log_level()

In [2]:
with open('../settings.json', "r") as f:
    settings = json.load(f)
    
epoch_folder = settings['epochs_folder']
plot_folder = settings['plots_folder']

Let's load the epoch object

In [3]:
# let's load the epochs file
epochs = mne.read_epochs(os.path.join(epoch_folder, "filtered_epochs_w_movement-epo.fif"), preload=True)

Reading C:\Users\Olle de Jong\Documents\MSc Biology\rp2\rp2_data\resting_state\output\epochs\filtered_epochs_w_movement-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =       0.00 ...    4998.53 ms
        0 CTF compensation matrices available
Reading C:\Users\Olle de Jong\Documents\MSc Biology\rp2\rp2_data\resting_state\output\epochs\filtered_epochs_w_movement-epo-1.fif ...
Isotrak not found
    Found the data of interest:
        t =       0.00 ...    4998.53 ms
        0 CTF compensation matrices available
Reading C:\Users\Olle de Jong\Documents\MSc Biology\rp2\rp2_data\resting_state\output\epochs\filtered_epochs_w_movement-epo-2.fif ...
Isotrak not found
    Found the data of interest:
        t =       0.00 ...    4998.53 ms
        0 CTF compensation matrices available
Adding metadata with 4 columns
21980 matching events found
No baseline correction applied
0 projection items activated


Establish a dictionary holding channels to be omitted from consideration

In [5]:
bad_epochs_per_subject = {
    "80630": ["OFC_R"],
    "81193": ["OFC_R"]
}

Get the names of the channels we want to plot

In [6]:
wanted_chans = [channel for channel in epochs.info["ch_names"] if not channel in ['EMG_L', 'EMG_R']]

['OFC_R', 'OFC_L', 'CG', 'STR_R', 'S1_L', 'S1_R', 'V1_R']


#### DRD2-WT and DRD2-KO average PSD for movement and non-movement epochs

Define the frequency domains

In [25]:
freq_bands = {
    r'$\delta$': (1, 4),  # Delta
    r'$\theta$': (4, 8),  # Theta
    r'$\alpha$': (8, 13),  # Alpha
    r'$\beta$': (13, 30),  # Beta
    r'$\gamma$': (30, 100)  # Gamma
}

For now, we want to separate the movement and non-movement data, so let's do that

In [7]:
movement_epochs = epochs[epochs.metadata["movement"] == 1]
non_movement_epochs = epochs[epochs.metadata["movement"] == 0]

First, let's create a dataframe that holds all PSD means per subject, and is annotated with the genotype of the subject

In [None]:
behavioral_dfs = {}
for behaviour, epochs in {'movement': movement_epochs, 'non_movement': non_movement_epochs}.items():
    # loop through channels, as we want data per channel
    
    df = pd.DataFrame(columns=['freq', 'psd (means)', 'subject_id', 'genotype'])

    for channel in wanted_chans:
        for subject_id in epochs.metadata["animal_id"].unique():
            
            subject_data = epochs[epochs.metadata["animal_id"] == subject_id]
            genotype = subject_data.metadata["genotype"].iloc[0]

            # skip plotting data of bad quality
            if subject_id in bad_epochs_per_subject.keys():
                if channel in bad_epochs_per_subject[subject_id]: 
                    continue
            
            # get the average PSD for this subject
            psds_sub, freqs = psd_array_multitaper(
                subject_data.get_data(picks=channel),
                fmin=0, fmax=100,
                sfreq=subject_data.info['sfreq'],
                n_jobs=-1
            )
            mean_psd_sub = np.mean(psds_sub[:, 0, :], axis=0)
            
            # save the average of this subject, so we can later plot the mean of the subject averages
            df = pd.concat([df, pd.DataFrame({
                "freq": freqs,
                "psd (means)": mean_psd_sub,
                "subject_id": subject_id,
                "genotype": genotype,
                "channel": channel,
            })])
                
    behavioral_dfs[behaviour] = df

Now we have a dataframe for both movement, and non-movement data, let's generate a plot per channel with it.

In [32]:
for behaviour, df in behavioral_dfs.items():
    for channel in df.channel.unique():
        fig, ax = plt.subplots(figsize=(12, 8))  # initiate plot
        
        channel_data = df[df.channel == channel]
        
        # plot individual subject traces traces
        for subject_id in channel_data["subject_id"].unique():
            
            subject_data = channel_data[channel_data["subject_id"] == subject_id]
            genotype = subject_data["genotype"].iloc[0]
            
            # get actual data to plot and plot data for this subject
            freqs, mean_psd_sub = subject_data.freq, subject_data["psd (means)"]
            ax.plot(freqs, 10 * np.log10(mean_psd_sub), linewidth=1, alpha=0.2, color="#c23616" if genotype == "DRD2-KO" else "#0097e6")
            
        # group by genotype and index and take the mean of the PSD traces as well as the STD
        mean_psd_by_genotype = channel_data.groupby(['genotype', channel_data.index], as_index=False).agg({'freq': 'first', 'psd (means)': 'mean', 'genotype': 'first'})
        std_psd_by_genotype = channel_data.groupby(['genotype', channel_data.index], as_index=False).agg({'freq': 'first', 'psd (means)': 'std', 'genotype': 'first'})
        
        # split into ko and wt data
        mean_psd_wt, mean_psd_ko = mean_psd_by_genotype[mean_psd_by_genotype.genotype == "DRD2-WT"], mean_psd_by_genotype[mean_psd_by_genotype.genotype == "DRD2-KO"]
        std_psd_wt, std_psd_ko = std_psd_by_genotype[std_psd_by_genotype.genotype == "DRD2-WT"], std_psd_by_genotype[std_psd_by_genotype.genotype == "DRD2-KO"]
        
        conf_int_wt = 1.96 * std_psd_wt['psd (means)'] / np.sqrt(len(std_psd_wt['psd (means)']))  # 95% confidence interval
        conf_int_ko = 1.96 * std_psd_ko['psd (means)'] / np.sqrt(len(std_psd_ko['psd (means)']))  # 95% confidence interval
    
        ax.plot(mean_psd_wt.freq, 10 * np.log10(mean_psd_wt["psd (means)"]), linewidth=3, alpha=1, color="#0097e6", label="DRD2-WT")
        ax.plot(mean_psd_ko.freq, 10 * np.log10(mean_psd_ko["psd (means)"]), linewidth=3, alpha=1, color="#c23616", label="DRD2-KO")
        ax.fill_between(mean_psd_wt.freq, 10 * np.log10(mean_psd_wt["psd (means)"] - conf_int_wt), 10 * np.log10(mean_psd_wt["psd (means)"] + conf_int_wt), alpha=0.2)
        ax.fill_between(mean_psd_ko.freq, 10 * np.log10(mean_psd_ko["psd (means)"] - conf_int_ko), 10 * np.log10(mean_psd_ko["psd (means)"] + conf_int_ko), alpha=0.2)
        
        # Add vertical lines and labels for frequency bands
        for band, (start, end) in freq_bands.items():
            ax.axvline(x=start, color='gray', linestyle='--', alpha=0.3)
            ax.axvline(x=end, color='gray', linestyle='--', alpha=0.3)
            ax.text((start + end) / 2, ax.get_ylim()[1] * 1.01, band, horizontalalignment='center', verticalalignment='top', fontsize=8, color='black')
        
        plt.legend(loc='upper right', bbox_to_anchor=(1, 1))
    
        ax.set_xlabel('Frequency (Hz)')
        ax.set_ylabel('Power/Frequency (dB/Hz)')
        ax.set_title(f'Average PSD of all {behaviour} epochs ({channel})')
        plt.legend()
        save_figure(os.path.join(plot_folder, f"movement_vs_non_movement/wt_vs_ko/{behaviour}_{channel}"))

  mean_psd_by_genotype = channel_data.groupby(['genotype', channel_data.index], as_index=False).agg({'freq': 'first', 'psd (means)': 'mean', 'genotype': 'first'})
  std_psd_by_genotype = channel_data.groupby(['genotype', channel_data.index], as_index=False).agg({'freq': 'first', 'psd (means)': 'std', 'genotype': 'first'})
  mean_psd_by_genotype = channel_data.groupby(['genotype', channel_data.index], as_index=False).agg({'freq': 'first', 'psd (means)': 'mean', 'genotype': 'first'})
  std_psd_by_genotype = channel_data.groupby(['genotype', channel_data.index], as_index=False).agg({'freq': 'first', 'psd (means)': 'std', 'genotype': 'first'})
  mean_psd_by_genotype = channel_data.groupby(['genotype', channel_data.index], as_index=False).agg({'freq': 'first', 'psd (means)': 'mean', 'genotype': 'first'})
  std_psd_by_genotype = channel_data.groupby(['genotype', channel_data.index], as_index=False).agg({'freq': 'first', 'psd (means)': 'std', 'genotype': 'first'})
  mean_psd_by_genotype = cha