In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib qt 

from typing import List, Tuple
from itertools import zip_longest
import sys
import numpy as np
import matplotlib.pyplot as plt
import csv
import seaborn as sns
from scipy.stats import skew, zscore, linregress,mannwhitneyu
import statsmodels.api as sm
import statsmodels.formula.api as smf

from consts import HERE

sys.path.append(str(HERE.parent))
import pandas as pd

from ctko.utils import moving_average
sns.set_theme(context="talk", style="ticks")

In [2]:
def normalize(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

In [None]:
type = 'all-spike-rate.npy'
# type = 'moving-rates.npy'
# type = 'resting-rates.npy'

suite2p_type = "cellpose"

correlation_paths = list(
    (HERE.parent / "cache").glob("*-correlation_coeff.npy")
)

n_transients_paths = list((HERE.parent / "cache").glob("*-n-transients.npy"))
if suite2p_type != "activity_suite2p_rates":
    rate_paths = list((HERE.parent / "cache").glob(f"*-{type}"))
else:
    rate_paths = list((HERE.parent /  "cache" / suite2p_type).glob(f"*-{type}"))


wt_color = sns.color_palette("muted")[0]
ctko_color = sns.color_palette("muted")[1]

data = {}
palette = {}

correlations_wt = []
correlations_ctko = []

n_transients_wt = []
n_transients_ctko = []

for path in rate_paths:
    mouse = path.name.split("-")[0]
    assert mouse[:3] == "J02"
    correlation_path = [path for path in correlation_paths if mouse in str(path)]
    n_transients_path = [path for path in n_transients_paths if mouse in str(path)]
    n_frames = 50912 if mouse == "J025" else 54000 
    assert len(correlation_path) == 1

    # Even numbers are WT
    genotype = "WT" if int(mouse[3]) % 2 == 0 else "CTKO"
    print(genotype)
    rate_vector = np.load(path)
    print(n_transients_path[0])
    data[mouse] = {"genotype": genotype, "data": rate_vector, "n_transients": np.load(n_transients_path[0]) / (n_frames / 30)}
    palette[mouse] = wt_color if genotype == "WT" else ctko_color

    if genotype == "WT":
        correlations_wt.append(np.load(correlation_path[0]).item())
        n_transients_wt.extend(np.load(n_transients_path[0]) / (n_frames / 30))

    elif genotype == "CTKO":
        correlations_ctko.append(np.load(correlation_path[0]).item())
        n_transients_ctko.extend(np.load(n_transients_path[0]) / (n_frames / 30))

df_all_rates = pd.concat(
    [
        pd.DataFrame({"Subject": subject, "Genotype": info["genotype"], "Firing rate": info["data"], "Number of transients per second": info['n_transients']})
        for subject, info in data.items()
    ],
).reset_index(drop=True)

all_rates_wt = df_all_rates[df_all_rates['Genotype'] == "WT"]["Firing rate"]
all_rates_ctko = df_all_rates[df_all_rates['Genotype'] == "CTKO"]["Firing rate"]

CTKO
/Users/jamesrowland/Code/ctko/cache/J023-2024-09-27-n-transients.npy


WT
/Users/jamesrowland/Code/ctko/cache/J028-2024-10-24-n-transients.npy


CTKO
/Users/jamesrowland/Code/ctko/cache/J027-2024-10-09-n-transients.npy


WT
/Users/jamesrowland/Code/ctko/cache/J022-2024-09-27-n-transients.npy


WT
/Users/jamesrowland/Code/ctko/cache/J026-2024-10-24-n-transients.npy


CTKO
/Users/jamesrowland/Code/ctko/cache/J025-2024-09-27-n-transients.npy


CTKO
/Users/jamesrowland/Code/ctko/cache/J029-2024-10-25-n-transients.npy


WT
/Users/jamesrowland/Code/ctko/cache/J024-2024-10-09-n-transients.npy




In [16]:
df_all_rates.to_csv(HERE.parent / "data_for_sam" / "Number_of_transients_and_rates_by_subject.csv")

In [8]:
sns.kdeplot(df_all_rates[df_all_rates['Genotype'] == "WT"]['Number of transients per second'],fill=True, label='WT')
sns.kdeplot(df_all_rates[df_all_rates['Genotype'] == "CTKO"]['Number of transients per second'],fill=True, label='CTKO')
plt.legend()

plt.xlim(0, None)
sns.despine()
plt.tight_layout()
plt.xlabel("Number of calcium transients / second")


Text(0.5, 9.883333333333328, 'Number of calcium transients / second')

In [6]:
df_all_rates
df_all_rates.to_csv(HERE.parent / "data_for_sam"/ 'subject_firing_rates.csv', sep='\t')


In [7]:
with open(HERE.parent / "data_for_sam"/ 'firing_rates.csv', mode="w", newline="") as file:
    writer = csv.writer(file)
    # Write headers
    writer.writerow(["Controls", "cTKO"])
    # Write rows
    for control, ctko_value in zip_longest(all_rates_wt, all_rates_ctko, fillvalue=""):
        writer.writerow([control, ctko_value])

In [8]:
sns.boxplot({"WT": correlations_wt, "CTKO": correlations_ctko}, showfliers= False)
sns.stripplot({"WT": correlations_wt, "CTKO": correlations_ctko}, linewidth=1)
plt.ylabel("Average Correlation Coffecient")


Text(0, 0.5, 'Average Correlation Coffecient')

In [17]:
sns.kdeplot(n_transients_wt,fill=True, label='WT')
sns.kdeplot(n_transients_ctko,fill=True, label="CTKO")
plt.legend()
sns.despine()
plt.tight_layout()
plt.xlabel("Number of calcium transients / second")


plt.xlim(0, None)

with open(HERE.parent / "data_for_sam"/ 'number_of_transients.csv', mode="w", newline="") as file:
    writer = csv.writer(file)
    # Write headers
    writer.writerow(["Controls", "cTKO"])
    # Write rows
    for control, ctko_value in zip_longest(n_transients_wt, n_transients_ctko, fillvalue=""):
        writer.writerow([control, ctko_value])

data_loaded = pd.read_csv(HERE.parent / "data_for_sam"/ 'Number_of_transients_and_rates_by_subject.csv')


In [10]:
print(f"Mean Control: {np.mean(n_transients_wt)}")
print(f"Median Control: {np.median(n_transients_wt)}")
print(f"Mean ctko: {np.mean(n_transients_ctko)}")
print(f"Median ctko: {np.median(n_transients_ctko)}")

Mean Control: 0.21746565113500593
Median Control: 0.2088888888888889
Mean ctko: 0.09790577131929311
Median ctko: 0.05527777777777777


In [19]:
data_loaded

Unnamed: 0.1,Unnamed: 0,Subject,Genotype,Firing rate,Number of transients per second
0,0,J023,CTKO,0.989750,0.197222
1,1,J023,CTKO,1.376770,0.152222
2,2,J023,CTKO,0.123205,0.013889
3,3,J023,CTKO,1.823752,0.231111
4,4,J023,CTKO,1.043952,0.183333
...,...,...,...,...,...
1407,1407,J024,WT,0.513040,0.368889
1408,1408,J024,WT,0.275604,0.069444
1409,1409,J024,WT,0.392947,0.436111
1410,1410,J024,WT,0.324955,0.006111


In [22]:
sns.kdeplot(data_loaded[data_loaded['Genotype'] == 'WT']['Firing rate'],fill=True, label='WT')
sns.kdeplot(data_loaded[data_loaded['Genotype'] == 'CTKO']['Firing rate'],fill=True, label='CTKO')
plt.legend()
sns.despine()
plt.tight_layout()
plt.xlabel("Number of calcium transients / second")
plt.xlim(0, None)


(0.0, 12.152914204550504)

In [12]:
plt.figure()
genotype_colors = {'WT': 'green', 'CTKO': 'orange'}
ax = sns.boxplot(x='Genotype', y='Firing rate', hue='Subject', data=df_all_rates, dodge=True, showfliers=False, legend=True)
sns.stripplot(x='Genotype', y='Firing rate', hue='Subject', data=df_all_rates, dodge=True, legend=False, alpha=0.5)
sns.despine()
plt.tight_layout()
plt.ylabel(f"Firing rate {type.strip(".npy").replace("-" , " ")}")
plt.savefig(HERE.parent / "figures"/ f"rates {suite2p_type} {type.strip(".npy")}")

In [13]:
n = 100
plt.figure()

sns.kdeplot(all_rates_wt,fill=True, label='WT')
sns.kdeplot(all_rates_ctko,fill=True, label="CTKO")
plt.legend()
# plt.xlim(0,4)
sns.despine()
plt.tight_layout()


plt.xlim(0, None)
# plt.hist(df_all_rates[df_all_rates['Genotype'] == "CTKO"]["Firing rate"], n, color='red', alpha=0.5)
mannwhitneyu(all_rates_wt, all_rates_ctko)

MannwhitneyuResult(statistic=np.float64(340038.0), pvalue=np.float64(5.3302554321106566e-33))

In [14]:
print(f"Mean Control: {np.mean(all_rates_wt)}")
print(f"Median Control: {np.median(all_rates_wt)}")

print(f"Mean ctko: {np.mean(all_rates_ctko)}")
print(f"Median ctko: {np.median(all_rates_ctko)}")

Mean Control: 1.15641679371668
Median Control: 0.6667706258526307
Mean ctko: 0.6399564340532031
Median ctko: 0.3111814249755955


In [15]:
movement_paths = list((HERE.parent / "cache" ).glob("*-diffed.npy"))

movement_data = {}

for path in movement_paths:

    mouse = path.name.split("-")[0]
    assert mouse[:3] == "J02"

    movement_data[mouse] = np.load(path)

In [16]:

plt.close('all')
# plt.figure()
wt_legend_done = False
ctko_lend_done = False
x = []
y = []


for mouse in movement_data.keys():
    genotype = "WT" if int(mouse[3]) % 2 == 0 else "CTKO"
    rates = data[mouse]["data"] 
    # rates = sum(rates < 100) / len(rates)
    movement  = movement_data[mouse]
    movement = moving_average(movement, 20)
    movement = normalize(movement)

    percent_moving = np.sum(movement > 0.1) / len(movement)
    print(mouse)
    print(percent_moving)
    print('\n')
    # percent_moving = np.sum(movement)

    if genotype == "WT" and not wt_legend_done:
        label = "WT"
    elif genotype == "CTKO" and not ctko_lend_done:
        label = "CTKO"
    else:
        label = None

    label = mouse
    x.append(percent_moving)
    y.append(np.mean(rates))


    # plt.plot(percent_moving, np.mean(rates), ".", color="blue" if genotype == "WT" else "orange", label=label)
    # plt.plot(percent_moving, np.mean(rates), ".", label=label)

    if genotype == "WT":
        wt_legend_done = True
    if genotype == "CTKO":
        ctko_lend_done = True


slope, intercept, r_value, p_value, std_err = linregress(x,y)

x = np.linspace(0, 0.5)

y = x * slope + intercept


# plt.plot(x, y)

# plt.savefig(f"{type.strip(".npy")}")
# plt.xlabel("Percent moving")
# plt.ylabel("Mean firing rate")
# plt.legend()

J025
0.24641040246705034


J026
0.038371080945943443


J028
0.061056686234930276


J027
0.04418600344450823


J029
0.29837589584992313


J023
0.271856886238634


J024
0.012518750347228652


J022
0.016426230115372507




In [17]:
rows = []

for mouse in movement_data.keys():
    genotype = "WT" if int(mouse[3]) % 2 == 0 else "CTKO"

    rates = data[mouse]["data"] 
    n_transients = data[mouse]["n_transients"] 
    # rates = sum(rates < 100) / len(rates)

    movement  = movement_data[mouse]
    movement = moving_average(movement, 100)
    movement = normalize(movement)
    percent_moving = sum(movement > 0.1) / len(movement)



    # plt.figure()
    # plt.plot(movement)
    # plt.title(mouse)
    # plt.axhline(0.1)

    # rows.append({
    #         "subject": mouse,
    #         "group": genotype,
    #         "value": np.median(rates),
    #         "percent_moving": percent_moving,
    #     })

    rows.extend(
        {
            "subject": mouse,
            "group": genotype,
            "value": rate,
            "percent_moving": percent_moving,
        }
        for rate in rates
    )

    # rows.extend(
    #     {
    #         "subject": mouse,
    #         "group": genotype,
    #         "value": n,
    #         "percent_moving": percent_moving,
    #     }
    #     for n in n_transients
    # )



In [18]:
df = pd.DataFrame(rows)
# df["movement_centered"] = df["percent_moving"] / max(df['percent_moving'])
df['movement_centered'] = (df['percent_moving'] - df['percent_moving'].mean()) / df['percent_moving'].std()

df['movement_centered_group'] = df.groupby('group')['percent_moving'].transform(lambda x:  (x - x.mean()) / x.std())

In [19]:
plt.figure()
df_movement_plot = df.drop_duplicates('subject')
wt_movement = df_movement_plot[df_movement_plot['group'] == "WT"]['percent_moving']
ctko_movement = df_movement_plot[df_movement_plot['group'] == "CTKO"]['percent_moving']

movement_plot_data = {"WT":  wt_movement, 
                      "CTKO" : ctko_movement}

sns.boxplot(movement_plot_data, showfliers=False)
sns.stripplot(movement_plot_data, linewidth=1)
plt.ylabel("Percent time moving")
sns.despine()
plt.tight_layout()



with open(HERE.parent / "data_for_sam"/ 'percent_moving.csv', mode="w", newline="") as file:
    writer = csv.writer(file)
    # Write headers
    writer.writerow(["Controls", "cTKO"])
    # Write rows
    for control, ctko_value in zip_longest(wt_movement, ctko_movement, fillvalue=""):
        writer.writerow([control, ctko_value])

In [20]:
mannwhitneyu(movement_plot_data['WT'], movement_plot_data['CTKO'])

MannwhitneyuResult(statistic=np.float64(2.0), pvalue=np.float64(0.11428571428571428))

In [22]:
# Mixed-effects model: group as fixed effect, movement as covariate, and subject as random effect
model = smf.mixedlm(
    # formula="value ~ group",  # Fixed effects
    formula="value ~ group",  # Fixed effects
    # formula="value ~ group",  # Fixed effects
    data=df,
    groups=df["subject"], # Random effects by subject
    use_sqrt=True

)
result = model.fit(reml=True)

In [23]:
df.to_csv(HERE.parent / "data_for_sam"/ 'data_for_model.csv')

In [24]:
print(result.summary())

         Mixed Linear Model Regression Results
Model:            MixedLM Dependent Variable: value     
No. Observations: 1412    Method:             REML      
No. Groups:       8       Scale:              1.2354    
Min. group size:  75      Log-Likelihood:     -2159.4416
Max. group size:  295     Converged:          Yes       
Mean group size:  176.5                                 
--------------------------------------------------------
                Coef. Std.Err.   z   P>|z| [0.025 0.975]
--------------------------------------------------------
Intercept       0.663    0.075 8.785 0.000  0.515  0.811
group[T.WT]     0.489    0.105 4.651 0.000  0.283  0.695
Group Var       0.014    0.012                          



In [68]:
beta_group = result.params["group[T.WT]"]  # The effect of being in the WT group (relative to the baseline)
beta_movement = result.params["movement_centered_group"]  # The effect of movement_centered on value
movement_effect = df["movement_centered"] * beta_movement

df["corrected_value"] = df['value'] - movement_effect
df['fitted_value'] = result.fittedvalues - movement_effect

KeyError: 'movement_centered_group'

In [62]:
plt.clf()
sns.boxplot(x='group', y='corrected_value', hue='subject', data=df, dodge=True, legend=True)
sns.stripplot(x='group', y='corrected_value', hue='subject', data=df, dodge=True, legend=False)

plt.show()

In [None]:
fixed_effects = model.fit().params
movement_effect = fixed_effects["movement_centered_group"] * df["movement_centered_group"]
