In [1]:
import os
import mne
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from mne.time_frequency import psd_array_multitaper
from mne import Epochs

from helper_functions import save_figure

In [3]:
with open('../settings.json', "r") as f:
    settings = json.load(f)
    
epoch_folder = settings['epochs_folder']
plot_folder = settings['plots_folder']

Let's load the epoch object

In [33]:
# let's load the epochs file
epochs = mne.read_epochs(os.path.join(epoch_folder, "filtered_epochs_w_movement-epo.fif"), preload=True)

Establish a dictionary holding channels to be omitted from consideration

In [34]:
bad_epochs_per_subject = {
    "80630": ["OFC_R"],
    "81193": ["OFC_R"]
}

Get the names of the channels we want to plot

In [35]:
wanted_chans = [channel for channel in epochs.info["ch_names"] if not channel in ['EMG_L', 'EMG_R']]
print(wanted_chans)

['OFC_R', 'OFC_L', 'CG', 'STR_R', 'S1_L', 'S1_R', 'V1_R']


Now, let's plot the average PSD for each channel that holds either movement or non-movement epochs, annotated by genotype

In [36]:
movement_epochs = epochs[epochs.metadata["movement"] == 1]
non_movement_epochs = epochs[epochs.metadata["movement"] == 0]

In [38]:
for i, behaviour_epochs in enumerate([movement_epochs, non_movement_epochs]):
    s_freq = behaviour_epochs.info['sfreq']
    
    # loop through channels, as we want a plot for every channel
    for channel in wanted_chans:
        fig, ax = plt.subplots(figsize=(12, 8))  # initiate plot
        
        wt_epochs = behaviour_epochs[behaviour_epochs.metadata["genotype"] == "DRD2-WT"]
        ko_epochs = behaviour_epochs[behaviour_epochs.metadata["genotype"] == "DRD2-KO"]
        
        df = pd.DataFrame(columns=['freq', 'psd (means)', 'subject_id', 'genotype', 'channel'])

        print("Plotting all subject traces")
        # for each genotype, plot the subject average PSD
        for genotype_data in [wt_epochs, ko_epochs]:
            genotype = genotype_data.metadata["genotype"].iloc[0]
            for subject_id in genotype_data.metadata["animal_id"].unique():
                if subject_id in bad_epochs_per_subject.keys():
                    if channel in bad_epochs_per_subject[subject_id]:  # skip bad channel data
                        continue
                subject_data = genotype_data[genotype_data.metadata["animal_id"] == subject_id]
                psds_sub, freqs = psd_array_multitaper(
                    subject_data.get_data(picks=channel),
                    fmin=0, fmax=100,
                    sfreq=s_freq,
                    n_jobs=-1
                )
                mean_psd_sub = np.mean(psds_sub[:, 0, :], axis=0)
                ax.plot(freqs, 10 * np.log10(mean_psd_sub), linewidth=1, alpha=0.2, color="#c23616" if genotype == "DRD2-KO" else "#0097e6")
                
                # save the average of this subject
                df = pd.concat([df, pd.DataFrame({
                    "freq": freqs,
                    "psd (means)": mean_psd_sub,
                    "subject_id": subject_id,
                    "genotype": genotype,
                })])
        
        mean_psd_by_genotype = df.groupby('genotype')['psd (means)'].mean()
        print(mean_psd_by_genotype)
        
        print("Done plotting all subject traces, proceeding to the averages..")

        # ax.plot(freqs, 10 * np.log10(mean_psd_wt), linewidth=3, alpha=1, color="#0097e6", label="DRD2-WT")
        # ax.plot(freqs, 10 * np.log10(mean_psd_ko), linewidth=3, alpha=1, color="#c23616", label="DRD2-KO")
        # 
        ax.set_xlabel('Frequency (Hz)')
        ax.set_ylabel('Power/Frequency (dB/Hz)')
        ax.set_title(f'Average PSD of all {"movement" if i == 0 else "non_movement"} epochs ({channel})')
        plt.legend()
        save_figure(os.path.join(plot_folder, f"movement_vs_non_movement/wt_vs_ko/{'movement' if i == 0 else 'non_movement'}_{channel}"))
        
        break
    break

Plotting all subject traces


  df = pd.concat([df, pd.DataFrame({
No artists with labels found to put in legend.  Note that artists whose label start with an underscore are ignored when legend() is called with no argument.


genotype
DRD2-KO    2.027366e-07
DRD2-WT    2.951520e-07
Name: psd (means), dtype: float64
Done plotting all subject traces, proceeding to the averages..
