In [1]:
import numpy as np
from numpy import matlib as ml
import pandas as pd
import os
import warnings
import librosa
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns
from random import shuffle
import scipy.stats as st

from pynwb import NWBHDF5IO

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.gridspec as grid_spec

import pdb
# warnings.filterwarnings('ignore')

In [None]:
# ----- LOAD DATA -----
data = pd.read_csv('all_annotations.csv')

# ----- FIGURES -----
outp = os.path.join('figures','2024September')

# ----- SET UP COLORS -----
acols = [[0.627451,   0.57254905, 0.37254903],
        [0.9607843,  0.7882353,  0.15294118],
        [0.34901962, 0.35686275, 0.49019608],
        [0.24705882, 0.30588236, 0.9607843 ]]

fcols = acols[0:2]
mcols = acols[2:]

fpal = sns.color_palette(fcols)
mpal = sns.color_palette(mcols)
apal = sns.color_palette(acols)

# set hue order
ho = ['WT','Het']

In [None]:
# ----- TRANSITION MATRIX FUNCTION -----
def calculate_transition_matrix(data,normaxis):
    
    # use data to extract transitions
    data['next_behav'] = data.groupby('pair_tag').behavior.shift(periods=-1) # generate match of behavior to next behavior
    transitions = data.groupby(['behavior', 'next_behav']) # organize by unique behavior transitions
    counts = {i[0]:len(i[1]) for i in transitions} # count up instances of transitions
    
    # generate behavior x behavior matrix
    behavs = sorted(data.behavior.unique())
    matrix = pd.DataFrame()

    for x in behavs: # count up transition numbers
        matrix[x] = pd.Series([counts.get((x,y), 0) for y in behavs], index=behavs)        
        
    cols = matrix.columns
    
    # calculate probabilities across row
    matrix[cols] = matrix[cols].div(matrix[cols].sum(axis=normaxis), axis=normaxis)
    
    return matrix

def filter_data_and_calculate_transitions(data,assay,sex,gt,normaxis):
    
    adf = data[data.assay==assay]
    filt = adf[np.logical_and(adf.sex==sex,adf.GT==gt)]
       
    tmatrix = calculate_transition_matrix(filt,normaxis)
    
    return tmatrix

In [None]:
# ----- MAKE TRANSITION MATRICES FOR MALE INTROS -----

assays = ['introduction']
sexes = ['M']
genotypes = ['WT','Het']
normaxis = 0
saveplots = False
        
cc = 0        
        
for i, sex in enumerate(sexes):
    for assay in assays:
        if sex =='M':
            upal = mpal
        wt = filter_data_and_calculate_transitions(data,assay,sex,'WT',normaxis)
        het = filter_data_and_calculate_transitions(data,assay,sex,'Het',normaxis)
        fig = plt.figure()
        hmap = sns.heatmap(diff,cmap=upal,annot=True,linewidth=.5,fmt=".2f")
        ttl = f'WT minus het transition matrix for {sex}s in {assay}, norm axis = {normaxis}'
        hmap.set(title=ttl)
        fname = f'transitionProbs_{assay}_{sex}_WT-het_normaxis{normaxis}.png'
        plt.show()
        if saveplots:
            fig.savefig(os.path.join(outp,fname))

    cc+=2