### Notebook for generating grid PSD plots

In [1]:
import os
import mne
import json
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from mne.time_frequency import psd_array_multitaper
from mne import Epochs

from helper_functions import save_figure

In [2]:
with open('../settings.json', "r") as f:
    settings = json.load(f)
    
epoch_folder = settings['epochs_folder']
plot_folder = settings['plots_folder']

In [3]:
freq_bands = {
    r'$\delta$': (1, 4),  # Delta
    r'$\theta$': (4, 8),  # Theta
    r'$\alpha$': (8, 13),  # Alpha
    r'$\beta$': (13, 30),  # Beta
    r'$\gamma$': (30, 100)  # Gamma
}

Let's load the epoch objects

In [5]:
epoch_objects = []
for file in os.listdir(epoch_folder):
    if not "filtered_epochs_r" in file: # skipping entire KO/WT filtered epoch objects and raw epoch objects
        continue
    epoch_objects.append(mne.read_epochs(os.path.join(epoch_folder, file), preload=True))

Establish a dictionary holding channels to be omitted from consideration because of quality issues (or extreme outliers)

In [107]:
bad_epochs_per_subject = {
    "80630": ["OFC_R"],
    "39489": ["OFC_R"],
    "80625": ["OFC_L"],
    "81193": ["OFC_R", "OFC_L"]
}

Get the names of the channels we want to plot

In [108]:
wanted_chans = [channel for channel in epoch_objects[0].info["ch_names"] if not channel in ['EMG_L', 'EMG_R']]
wanted_chans

['OFC_R', 'OFC_L', 'CG', 'STR_R', 'S1_L', 'S1_R', 'V1_R']

Let's generate a dataframe that holds the PSD averages for each channel for all subjects

In [109]:
df = pd.DataFrame()

for channel in wanted_chans:
    for subject_epochs in epoch_objects:
        subject_id = subject_epochs.metadata["animal_id"][0]
        genotype = subject_epochs.metadata["genotype"].iloc[0]

        # skip plotting data of bad quality
        if subject_id in bad_epochs_per_subject.keys():
            if channel in bad_epochs_per_subject[subject_id]: 
                print(f"Omitting channel {channel} for subject {subject_id}.")
                continue
        
        # get the average PSD for this subject
        psds_sub, freqs = psd_array_multitaper(
            subject_epochs.get_data(picks=channel),
            fmin=0, fmax=100,
            sfreq=subject_epochs.info['sfreq'],
            n_jobs=-1
        )
        mean_psd_sub = np.mean(psds_sub[:, 0, :], axis=0)
        
        # save the average of this subject, so we can later plot the mean of the subject averages
        df = pd.concat([df, pd.DataFrame({
            "freq": freqs,
            "psd (means)": mean_psd_sub,
            "subject_id": subject_id,
            "genotype": genotype,
            "channel": channel,
        })])

Omitting channel OFC_R for subject 81193.
Omitting channel OFC_R for subject 39489.
Omitting channel OFC_R for subject 80630.
Omitting channel OFC_L for subject 81193.
Omitting channel OFC_L for subject 80625.


And now let's generate a grid plot of the PSD averages per genotype (PDF)

In [110]:
# Create a figure and axes for subplots
fig, axs = plt.subplots(2, 4, figsize=(35, 15), sharex=True, sharey=True)
axs = axs.ravel()

subjects = df["subject_id"].unique()
for i, channel in enumerate(wanted_chans):

    channel_data = df[df.channel == channel]
    channel_data = channel_data[(channel_data.freq > 52) | (channel_data.freq < 48)]  # remove the 50Hz peak
    
    palette = {'DRD2-WT': '#2D6D8E', 'DRD2-KO': '#9B2B11'}
    sns.lineplot(data=channel_data, x='freq', y='psd (means)', palette=palette, hue_order=['DRD2-WT', 'DRD2-KO'], hue='genotype', legend=True, ax=axs[i], errorbar='se')
    
    for band, (start, end) in freq_bands.items():
        axs[i].axvline(x=start, color='gray', linestyle='--', alpha=0.3)
        axs[i].axvline(x=end, color='gray', linestyle='--', alpha=0.3)
        axs[i].text((start + end) / 2, axs[i].get_ylim()[1] * 1.01, band, horizontalalignment='center', verticalalignment='top', fontsize=8, color='black')
    
    axs[i].set_yscale('log')
    axs[i].set_xlabel('Frequency (Hz)')
    axs[i].set_ylabel('Power/Frequency (dB/Hz)')
    axs[i].set_title(f'Average PSD per genotype ({channel})')

plt.subplots_adjust(wspace=0.08, hspace=0.08)
axs[-1].remove()
save_figure(os.path.join(plot_folder, "WT_vs_KO_PSD_averages.pdf"))

Now, let's also add the subject averages to that plot.

In [111]:
# Create a figure and axes for subplots
fig, axs = plt.subplots(2, 4, figsize=(35, 15), sharex=True, sharey=True)
axs = axs.ravel()

subjects = df["subject_id"].unique()
palettes = {}
for subject in subjects:
    if df[df["subject_id"] == subject].genotype.iloc[0] == "DRD2-WT":
        palettes[subject] = '#2D6D8E'
    else:
        palettes[subject] = '#9B2B11'

for i, channel in enumerate(wanted_chans):

    channel_data = df[df.channel == channel]
    channel_data = channel_data[(channel_data.freq > 52) | (channel_data.freq < 48)]  # remove the 50Hz peak
    
    palette = {'DRD2-WT': '#2D6D8E', 'DRD2-KO': '#9B2B11'}
    sns.lineplot(data=channel_data, x='freq', y='psd (means)', palette=palette, hue_order=['DRD2-WT', 'DRD2-KO'], hue='genotype', legend=True, ax=axs[i], errorbar='se')
    sns.lineplot(data=channel_data, x='freq', y='psd (means)', palette=palettes, hue='subject_id', linewidth=.3, legend=False, ax=axs[i], alpha=.4)
    
    for band, (start, end) in freq_bands.items():
        axs[i].axvline(x=start, color='gray', linestyle='--', alpha=0.3)
        axs[i].axvline(x=end, color='gray', linestyle='--', alpha=0.3)
        axs[i].text((start + end) / 2, axs[i].get_ylim()[1] * 1.01, band, horizontalalignment='center', verticalalignment='top', fontsize=8, color='black')
    
    axs[i].set_yscale('log')
    axs[i].set_xlabel('Frequency (Hz)')
    axs[i].set_ylabel('Power/Frequency (dB/Hz)')
    axs[i].set_title(f'Average PSD per genotype ({channel})')

plt.subplots_adjust(wspace=0.08, hspace=0.08)
axs[-1].remove()
save_figure(os.path.join(plot_folder, "WT_vs_KO_PSD_averages_w_subjects.pdf"))

And make one with only the individual PSD averages, so that we can identify possible outliers

In [112]:
# Create a figure and axes for subplots
fig, axs = plt.subplots(2, 4, figsize=(35, 15), sharex=True, sharey=True)
axs = axs.ravel()

for i, channel in enumerate(wanted_chans):

    channel_data = df[df.channel == channel]
    channel_data = channel_data[(channel_data.freq > 52) | (channel_data.freq < 48)]  # remove the 50Hz peak
    
    sns.lineplot(data=channel_data, x='freq', y='psd (means)', hue='subject_id', legend=True, linewidth=1, ax=axs[i])
    
    for band, (start, end) in freq_bands.items():
        axs[i].axvline(x=start, color='gray', linestyle='--', alpha=0.3)
        axs[i].axvline(x=end, color='gray', linestyle='--', alpha=0.3)
        axs[i].text((start + end) / 2, axs[i].get_ylim()[1] * 1.01, band, horizontalalignment='center', verticalalignment='top', fontsize=8, color='black')
    
    axs[i].set_yscale('log')
    axs[i].set_xlabel('Frequency (Hz)')
    axs[i].set_ylabel('Power/Frequency (dB/Hz)')
    axs[i].set_title(f'Average PSD per genotype ({channel})')

plt.subplots_adjust(wspace=0.08, hspace=0.08)
axs[-1].remove()
save_figure(os.path.join(plot_folder, "PSD_average_per_subject.pdf"))