In [41]:
import os
import mne
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from mne.time_frequency import psd_array_multitaper

In [42]:
with open('../settings.json', "r") as f:
    settings = json.load(f)
    
epoch_folder = settings['epochs_folder']
plot_folder = settings['plots_folder']

Load the epoch files

In [14]:
epoch_objects = []
for file in os.listdir(epoch_folder):
    if not "filtered_epochs_r" in file: # skipping entire KO/WT filtered epoch objects and raw epoch objects
        continue
    epoch_objects.append(mne.read_epochs(os.path.join(epoch_folder, file), preload=True))

Reading C:\Users\Olle de Jong\Documents\MSc Biology\rp2\rp2_data\resting_state\output\epochs\filtered_epochs_resting_state_39489-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =       0.00 ...    4998.53 ms
        0 CTF compensation matrices available
Adding metadata with 3 columns
1197 matching events found
No baseline correction applied
0 projection items activated
Reading C:\Users\Olle de Jong\Documents\MSc Biology\rp2\rp2_data\resting_state\output\epochs\filtered_epochs_resting_state_39508-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =       0.00 ...    4998.53 ms
        0 CTF compensation matrices available
Adding metadata with 3 columns
2013 matching events found
No baseline correction applied
0 projection items activated
Reading C:\Users\Olle de Jong\Documents\MSc Biology\rp2\rp2_data\resting_state\output\epochs\filtered_epochs_resting_state_78211-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =      

Let's split up the filtered epoch array objects into WT and KO

In [15]:
filtered_epochs_per_subj_WT = []
filtered_epochs_per_subj_KO = []
# split WT and KO epoch objects
[filtered_epochs_per_subj_WT.append(object) if object.metadata["genotype"][0] == "DRD2-WT" else filtered_epochs_per_subj_KO.append(object) for object in epoch_objects]
print(len(filtered_epochs_per_subj_WT), len(filtered_epochs_per_subj_KO))

8 10


Let's report the average number of epochs that passed the filtering (WT vs KO) 

In [10]:
num_epochs_wt = [len(epoch_obj) for epoch_obj in filtered_epochs_per_subj_WT]
num_epochs_ko = [len(epoch_obj) for epoch_obj in filtered_epochs_per_subj_KO]
print("Average number of epochs for DRD2-WT subjects: ", np.mean(num_epochs_wt))
print("Average number of epochs for DRD2-KO subjects: ", np.mean(num_epochs_ko))

Average number of epochs for DRD2-WT subjects:  1757.375
Average number of epochs for DRD2-KO subjects:  1418.4


### Average PSD of all epochs per subject (WT/KO)

We now want to generate a figure per channel that displays the average PSD traces of all distinct subjects with a certain genotype (WT or DRD2-KO), while also plotting again the average of those traces. This way we can see how much each subject influences the total average PSD in a certain genotype.

Let's first create the data that is needed to plot this. This might take a while to run.

In [0]:
# let's create a chan name list that does not contain the EMG chans (reused quite often)
wanted_chans = [channel for channel in filtered_epochs_per_subj_WT[0].info["ch_names"] if channel not in ['EMG_L', 'EMG_R']]

In [None]:
freqs = np.array([])

# we want to plot for every channel
df_av_psds_WT = pd.DataFrame([])
df_av_psds_KO = pd.DataFrame([])

df_subj_psds_WT = pd.DataFrame([])
df_subj_psds_KO = pd.DataFrame([])

for channel in wanted_chans:

    channel_psds_WT = pd.DataFrame([])
    channel_psds_KO = pd.DataFrame([])

    # calculate the average PSD of all epochs of each WT subject  # TODO code can be more efficient when looping through all subjects and using 'subj_epochs.metadata["genotype"][0] == "DRD2-WT"'
    for subject_wt in filtered_epochs_per_subj_WT:
        id_WT = subject_wt.metadata["animal_id"][0]
        psds_WT, freqs = psd_array_multitaper(subject_wt.get_data(picks=channel), fmin=0, fmax=100, sfreq=subject_wt.info['sfreq'], n_jobs=-1)
        channel_psds_WT[f"{id_WT}-{channel}"] = np.mean(psds_WT[:, 0, :], axis=0)  # take the average of the psds of all epochs

    # calculate the average PSD of all epochs of each KO subject
    for subject_ko in filtered_epochs_per_subj_KO:
        id_KO = subject_ko.metadata["animal_id"][0]
        psds_KO, freqs = psd_array_multitaper(subject_ko.get_data(picks=channel), fmin=0, fmax=100, sfreq=subject_ko.info['sfreq'], n_jobs=-1)
        channel_psds_KO[f"{id_KO}-{channel}"] = np.mean(psds_KO[:, 0, :], axis=0)  # take the average of the psds of all epochs

    # add the individual psd means to the larger dataframe
    df_subj_psds_WT = pd.concat([df_subj_psds_WT, channel_psds_WT], axis=1)
    df_subj_psds_KO = pd.concat([df_subj_psds_KO, channel_psds_KO], axis=1)

    # save the averaged subject psds (after this there's one trace for KO and WT)
    df_av_psds_WT['freqs'] = freqs
    df_av_psds_KO['freqs'] = freqs
    df_av_psds_WT[channel] = channel_psds_WT.mean(axis=1)
    df_av_psds_KO[channel] = channel_psds_KO.mean(axis=1)

Let's generate all the desired plots; one per channel for both WT and KO. Each plot contains the average PSD per subject for that channel, as well as the average of those.

In [61]:
# df_av_psds_WT: holds WT total average PSD and freqs (x-axis for our plots)
# df_av_psds_KO: holds KO total average PSD and freqs (x-axis for our plots)
# df_subj_psds_WT: holds WT average PSD traces per channel for all subjects
# df_subj_psds_KO: holds KO average PSD traces per channel for all subjects

# loop through the two genotypes and their accompanying dataframe holding the indivudal subject's psd averages per channel
for genotype, df_subj_psds in zip(["DRD2-WT", "DRD2-KO"], [df_subj_psds_WT, df_subj_psds_KO]):
    freqs = df_av_psds_WT.freqs  # doesn't matter which one we retrieve (are equal)

    # we want to create a plot per genotype and channel
    for channel in wanted_chans:
        fig, ax = plt.subplots(figsize=(12, 8))  # initiate plot
        cols_to_plot = [col for col in df_subj_psds.columns if channel in col]

        # plot the PSD for every subject for this channel
        for col in cols_to_plot:
            subj_id = int(col.split('-')[0])  # subject id from column name
            ax.plot(freqs, 10 * np.log10(df_subj_psds[col]), label=f"id: {subj_id}", linewidth=0.8, alpha=0.6)

        # add average of the psd of all WT subjects for this channel
        av_psd_of_all_subj = df_av_psds_WT[channel] if genotype == "DRD2-WT" else df_av_psds_KO[channel]
        ax.plot(freqs, 10 * np.log10(av_psd_of_all_subj), label="Average", linewidth=3, color="#c23616" if genotype == "DRD2-KO" else "#0097e6", alpha=1.0)

        # add axis labels and title
        ax.set_xlabel('Frequency (Hz)')
        ax.set_ylabel('Power/Frequency (dB/Hz)')
        ax.set_title(f'Average PSD of {genotype} subjects - {channel}')
        plt.legend()
        plt.savefig(os.path.join(plot_folder, f"averaged_per_subject/averaged_per_subj_{genotype}_{channel}.png"), bbox_inches="tight", dpi=300)
        plt.close()

Now let's create one figure where we include all subjects (separate them by color), and the averages of WT and KO.

In [65]:
for channel in wanted_chans:
    fig, ax = plt.subplots(figsize=(12, 8))  # initiate plot

    for genotype, df_subj_psds in zip(["DRD2-WT", "DRD2-KO"], [df_subj_psds_WT, df_subj_psds_KO]):

        cols_to_plot = [col for col in df_subj_psds.columns if channel in col]

        # plot the PSD for every subject for this channel
        for col in cols_to_plot:
            subj_id = int(col.split('-')[0])  # subject id from column name
            ax.plot(freqs, 10 * np.log10(df_subj_psds[col]), linewidth=0.8, alpha=0.4, color="#c23616" if genotype == "DRD2-KO" else "#0097e6")

        # add average of the psd of all WT subjects for this channel
        av_psd_of_all_subj = df_av_psds_WT[channel] if genotype == "DRD2-WT" else df_av_psds_KO[channel]
        ax.plot(freqs, 10 * np.log10(av_psd_of_all_subj), label=f"{'WT' if genotype == 'DRD2-WT' else 'KO'} Average", linewidth=3, color="#c23616" if genotype == "DRD2-KO" else "#0097e6", alpha=1.0)

    # add axis labels and title
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('Power/Frequency (dB/Hz)')
    ax.set_title(f'Average PSD - {channel}')
    plt.legend()
    plt.savefig(os.path.join(plot_folder, f"averaged_per_subj_combined_{channel}.png"), bbox_inches="tight", dpi=300)
    plt.close()

### Contribution of individual epochs to subject average

This plotting chunk can take a while since it generates (n-subjects x n-channels) plots.

In [39]:
# we want to generate a plot for every channel (wt / ko)
for subject in epoch_objects:
    genotype_short = 'KO' if subject.metadata["genotype"][0] == 'DRD2-KO' else 'WT'
    
    for channel in wanted_chans:
        fig, ax = plt.subplots(figsize=(12, 8))  # initiate plot

        subj_id = subject.metadata["animal_id"][0]
        psds, freqs = psd_array_multitaper(subject.get_data(picks=channel), fmin=0, fmax=100, sfreq=subject.info['sfreq'], n_jobs=-1)
        subj_mean_psd = np.mean(psds[:, 0, :], axis=0)  # take the average of the psds of all epochs
        
        for i in range(psds.shape[0]):
            ax.plot(freqs, 10 * np.log10(psds[i, 0, :]), linewidth=0.1, alpha=0.1, color="grey")
        
        ax.plot(freqs, 10 * np.log10(subj_mean_psd), label="Average", linewidth=2, color="#c23616" if genotype_short == "KO" else "#0097e6", alpha=1.0)
        
        handles, labels = ax.get_legend_handles_labels()
        patch = mpatches.Patch(color='grey', linestyle='--', label='Individual Epochs')
        handles.append(patch) 

        ax.set_xlabel('Frequency (Hz)')
        ax.set_ylabel('Power/Frequency (dB/Hz)')
        ax.set_title(f'Epoch (n={psds.shape[0]}) and Average PSD for subject {subj_id} ({channel})')
        plt.legend(handles=handles)
        plt.savefig(os.path.join(plot_folder, f"averaged_per_subject/individual_epochs/subject_{subj_id}_{genotype_short}_{channel}.png"), bbox_inches="tight", dpi=300)
        plt.close()

    Using multitaper spectrum estimation with 7 DPSS windows


  ax.plot(freqs, 10 * np.log10(psds[i, 0, :]), linewidth=0.1, alpha=0.1, color="grey")


    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spe

### Average PSD of all epochs of a genotype (WT/KO)

Now, we instead take the average PSD of all epochs of one channel for both WT and KO.

In [3]:
# let's load the two files that hold either all WT or KO combined epochs
wt_epochs = mne.read_epochs(os.path.join(epoch_folder, "filtered_epochs_WT-epo.fif"), preload=True)
ko_epochs = mne.read_epochs(os.path.join(epoch_folder, "filtered_epochs_KO-epo.fif"), preload=True)

Reading C:\Users\Olle de Jong\Documents\MSc Biology\rp2\rp2_data\resting_state\output\epochs\filtered_epochs_WT-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =       0.00 ...    4998.53 ms
        0 CTF compensation matrices available
Reading C:\Users\Olle de Jong\Documents\MSc Biology\rp2\rp2_data\resting_state\output\epochs\filtered_epochs_WT-epo-1.fif ...
Isotrak not found
    Found the data of interest:
        t =       0.00 ...    4998.53 ms
        0 CTF compensation matrices available
Adding metadata with 3 columns
14059 matching events found
No baseline correction applied
0 projection items activated
Reading C:\Users\Olle de Jong\Documents\MSc Biology\rp2\rp2_data\resting_state\output\epochs\filtered_epochs_KO-epo.fif ...
Isotrak not found
    Found the data of interest:
        t =       0.00 ...    4998.53 ms
        0 CTF compensation matrices available
Reading C:\Users\Olle de Jong\Documents\MSc Biology\rp2\rp2_data\resting_state\output\epochs\fil

In [4]:
print(f"Number of DRD2-WT and DRD2-KO epochs, respectively: {len(wt_epochs)}, {len(ko_epochs)}")

Number of DRD2-WT and DRD2-KO epochs, respectively: 14059, 14184


In [6]:
wanted_chans = [channel for channel in wt_epochs.info["ch_names"] if not channel in ['EMG_L', 'EMG_R']]

In [13]:
# loop through desired channels
for chan in wanted_chans:
    # calculate the psd of all epochs in both WT and KO for this channel
    psds_wt, freqs = psd_array_multitaper(wt_epochs.get_data(picks=chan), fmin=0, fmax=100, sfreq=wt_epochs.info['sfreq'])
    psds_ko, freqs = psd_array_multitaper(ko_epochs.get_data(picks=chan), fmin=0, fmax=100, sfreq=ko_epochs.info['sfreq'])

    # take the average of all epochs in this channel
    mean_psd_wt = np.mean(psds_wt[:, 0, :], axis=0)
    mean_psd_ko = np.mean(psds_ko[:, 0, :], axis=0)

    # conf_int_wt = 1.96 * np.std(psds_wt[:, 0, :], axis=0) / np.sqrt(psds_wt.shape[0])  # 95% confidence interval
    # conf_int_ko = 1.96 * np.std(psds_ko[:, 0, :], axis=0) / np.sqrt(psds_ko.shape[0])  # 95% confidence interval

    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(freqs, 10 * np.log10(mean_psd_wt), label='DRD2-WT', color="#0097e6")
    ax.plot(freqs, 10 * np.log10(mean_psd_ko), label='DRD2-KO', color="#c23616")
    # ax.fill_between(freqs, 10 * np.log10(mean_psd_wt - conf_int_wt), 10 * np.log10(mean_psd_wt + conf_int_wt), alpha=0.2)
    # ax.fill_between(freqs, 10 * np.log10(mean_psd_ko - conf_int_ko), 10 * np.log10(mean_psd_ko + conf_int_ko), alpha=0.2)
    ax.set_xlabel('Frequency (Hz)')
    ax.set_ylabel('Power/Frequency (dB/Hz)')
    ax.set_title(f'Average PSD - Channel {chan}')
    plt.legend()
    plt.savefig(os.path.join(plot_folder, f"averaged_all_epochs/psd_wt_vs_ko_{chan}.png"), dpi=300, bbox_inches='tight')
    plt.close()

    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
    Using multitaper spectrum estimation with 7 DPSS windows
