In [None]:
from pathlib import Path

import numpy as np
import pandas as pd

# Change to your data directory
basedir = Path('/Volumes/data-1/behavior/hakan')
imagedir = Path('/mnt/ssd_cache/manual_cache')


### Load data

In [None]:
data_path = basedir/'Kucukdereli_atal_data/data_table_all_mice.h5'

data_df = pd.read_hdf(data_path, key='data')


In [None]:
# Additional imports and setting up the plots
import pickle
from tqdm import tqdm

import matplotlib.pyplot as plt

font = {'family' : 'sans',
        'weight' : 'regular',
        'size'   : 8}
plt.rc('font', **font)
plt.rc('axes', linewidth=0.8)

colors = {'gray':'#D4D4D4', 'darkgray':'#545454', 
          'male':'#FF5E74', 'male_light':'#FFABB6', 
          'female':'#00B7FF', 'female_light':'#9EE3FF'}


## Figure 4E

### Display mean faces

Mean faces are plotted from randomly selected matched number frames from each condition. For details see *Classification of facial expressions* under Methods.

In [None]:
mice_ordered = data_df.sort_values([('test_', 5)])[('mouse','mouse')]

full_df = pickle.load(open(f'{basedir}/oren/eval_agrp_stress_all_mice_table.pkl','rb'))
full_df = full_df.query("experiment=='test_' & day==5 & train==0").copy()
full_df_ordered = pd.concat([full_df.query("mouse==@mouse") for mouse in mice_ordered])


In [None]:
# Orientation of the faces
# Faces facing left will be flipped to face right
face_ori = {'HK125':'R', 'HK89':'R', 'HK129':'L', 'HK90':'L', 'HK94':'R', 
            'HK88':'L', 'HK127':'R', 'HK96':'L', 'HK123':'R', 'HK120':'L', 
            'HK122':'L', 'HK98':'L', 'HK99':'R', 'HK95':'R', 'HK128':'R', 'HK124':'L'}


In [None]:
class_threshold = 0.95 # Set a threshold for the class prediction

n = np.ceil(np.sqrt(len(face_ori))).astype(int)
fig, axs = plt.subplots(n, n, figsize=(n*1.5,n*1.5), dpi=150)
axs = axs.ravel()
for j, ((mouse, date), df) in enumerate(full_df_ordered.groupby(['mouse', 'date'], sort=False)):     
    shap_dir = f'{basedir}/{mouse}/shap'
    if (os.path.isdir(shap_dir)):
        shap_path = f'{shap_dir}/{mouse}_{date}_shap.pkl'
        if os.path.isfile(shap_path):
            outs = pickle.load(open(shap_path, 'rb'))

        images = []
        for n in tqdm(outs['frame_n'], desc=f'{mouse} {date}'):
            img_path = f'{imagedir}/{mouse}/{date}_{mouse}/DLCmask/{mouse}_{date}_{n}.jpg'
            img = np.asarray(Image.open(img_path))
            images.append(img)
        images = np.array([im[:,:,np.newaxis] for im in images])
        
        ax = axs[j]
        
        axs[0,0].text(x=-0.1, y=0.5, s='Stim', rotation=90, ha='center', va='center', transform=axs[0,0].transAxes)
        axs[1,0].text(x=-0.1, y=0.5, s='Neutral', rotation=90, ha='center', va='center', transform=axs[1,0].transAxes)
        [ax.axis('off') for ax in axs.ravel()]
        
        df = full_df.query("mouse==@mouse & date==@date & train==0")

        # Plot images that are classified as stim or neutral with high probability
        indx_stim = df.query(f"pred==1 & p_1>={class_threshold}")['frame_n'].to_numpy()
        indx_stim = np.ravel([np.argwhere(i==outs['frame_n']) for i in indx_stim])
        indx_neutral = df.query(f"pred==0 & p_0>={class_threshold}")['frame_n'].to_numpy()
        indx_neutral = np.ravel([np.argwhere(i==outs['frame_n']) for i in indx_neutral])
        print(f'{mouse} {class_threshold}-> ', indx_stim.shape, indx_neutral.shape)
        if len(indx_stim):
            axs[0,0].imshow(images[indx_stim,:,:,0].mean(axis=0), 'gray')
            axs[0,0].set_title(f"p>={class_threshold}")
        if len(indx_neutral):
            axs[1,0].imshow(images[indx_neutral,:,:,0].mean(axis=0), 'gray')
