In [9]:
import os
home = os.path.expanduser("~")

import numpy as np
import matplotlib.pyplot as plt
import glob
import pickle
import gc
import pandas as pd

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)

import datafunctions as dfunc
import headingchange_models as hm
dt=1/60

In [13]:
## data and file location definitions
datadir = '/media/jacob/JD_DATA/zfish/'
resultdir = 'savedresults/'  

## note that 'treatment' refers to different lines.  would be more accurate to label as 'line' instead of 'treatment', e.g. focuslines instead of focustreatments

# This notebook contains things to run once and save output

# (Step 0):  simple quantities

## Run and save some simple treatment quantities (just run once)

In [14]:
# datadir = '/mnt/storage/zfish/'
# resultdir = home+'/Dropbox/zfish/savedresults/'
treatments = [str.split(s,'/')[-2] for s in sorted(glob.glob(datadir+"/*/*"+ os.path.sep))]
treatments.remove('chd8KO_WT')  # remove this because the groups were half mutants and half WT fish
treatments.remove('oxtHO_verified')  # remove because unsure what this mutation was (documentation lost)

pickle.dump([treatments],open(resultdir+'treatmentlist.pkl','wb'))

numtreatments = len(treatments)
print('numtreatments:',numtreatments)

focustreatments=np.ndarray.flatten(np.array([np.where([t==ft for t in treatments]) for ft in 
               ['WT','immp2lHO','ctnnd2bHO','scn1labHT','disc1KO','chrna2aHO']]))                                          
sel = np.tile(True,numtreatments)
sel[focustreatments] = False
notfocus = np.arange(numtreatments)[sel]
notfocus = notfocus[np.argsort(np.array(treatments)[notfocus])]

print(focustreatments)
print(notfocus)

pickle.dump([focustreatments,notfocus],open(resultdir+'focustreatmentlist.pkl','wb'))

numtreatments: 91
[ 0 46 65  4 50  2]
[ 1 48 56 57 58 59 60 49 61 62 63 64 51 66 52 53 54 41 55 67 68 42 69 70
 71 72 73 74 75 43 76 44 77 78 45 79 80  3 81 47 82 83 84 34 85 86 87 88
 89 90 35 13 14 15 16 17 36 18 19  5 20  6 37  7 21  8 26  9 22 38 39 23
 40 24 25 10 27 28 29 30 11 31 32 12 33]


##  Run to get all number of trials (just run this once and save results)

In [12]:
def get_numtrials(datadir,treatment):  # a terrible way to get the number of trials, so just run it once and save results
    [_,_,_,_,
                    _,_,_,
                    _,_,_,
                    datafiles,_] = pickle.load(open(datadir+treatment+'-alltrials.pkl','rb'))

    numtrials = len(datafiles)
    return  numtrials

allnumtrials = np.array([get_numtrials(datadir,t) for t in treatments])
pickle.dump([allnumtrials],open(resultdir+'allnumtrials.pkl','wb'))

# (step 0) Read in treatment quantities, if already ran the above

In [15]:
[treatments] = pickle.load(open(resultdir+'treatmentlist.pkl','rb'))
numtreatments = len(treatments)
[focustreatments,notfocus] = pickle.load(open(resultdir+'focustreatmentlist.pkl','rb'))
[allnumtrials] = pickle.load(open(resultdir+'allnumtrials.pkl','rb'))
focustreatments
pxpercm = 4.02361434 * 10  # from tracker
np.array(treatments)


array(['WT', 'adra1aaHO', 'chrna2aHO', 'kctd13KO', 'scn1labHT',
       'shank3bKO', 'slc18a2HT', 'slc22a15HO', 'slc25a27HO', 'slc39a11HO',
       'srrHO', 'tph2HO', 'ube3aKO', 'pard3baHO', 'pcdh10bHO', 'pdfHO',
       'pomcbHO', 'prkg1aHO', 'setd8aHO', 'shank3aHO', 'slc16a3HO',
       'slc25a14HO', 'slc4a10bHO', 'slc6a7HO', 'slc9a6aHO', 'slc9a6bHO',
       'slc30a5HO', 'sst1,1HO', 'sst3HO', 'sstr2bHO', 'stat6HO', 'trhHO',
       'trhraHO', 'uts2aHO', 'nfkb1HO', 'oxtKO', 'sapap2KO', 'slc1a1HO',
       'slc6a3KO', 'slc6a4aHO', 'slc6a8HO', 'drd4-rsHO', 'esr2aHO',
       'gnrhr4HO', 'grm5aHO', 'homer1bHO', 'immp2lHO', 'lrrn3KO',
       'adra1abHO', 'chrm4aHO', 'disc1KO', 'dlg4aHO', 'drd1bHO',
       'drd2bKO', 'drd3HO', 'drd4aHO', 'adrb3aHO', 'avpHO', 'avpr1abHO',
       'ca9HO', 'cdnfHO', 'cnn1bHO', 'cpdbHO', 'csmd1aHO', 'ctnnd2aHO',
       'ctnnd2bHO', 'drd1aHO', 'drd6bHO', 'ercc6HO', 'esrraHO',
       'esrrgaHO', 'fgf12aHO', 'fgf12bHO', 'gabrpHO', 'galnHO', 'gnrh3HO',
       'gpc6aHO', 

# (step 1) Calculate distance and rotated neighbor coordinates and save for each line

In [None]:
for treatment in treatments:
    print(treatment)

    [trial_speeds,trial_trajectories,trial_headings,trial_theta,
                    trial_smoothspeeds,trial_smoothtrajectories,trial_smoothheadings,
                    trial_ellipses,trial_arena,trial_sex,
                    datafiles,trial_trackingerrors] = pickle.load(open(datadir+treatment+'-alltrials.pkl','rb'))
    
    numtrials = len(datafiles)
    numfish = trial_speeds[0].shape[1]

    # transpose theta, because the other quantities are stored that way, oops:
    trial_theta = [t.T for t in trial_theta]

    print(treatment,',', numtrials, 'trials')    
    
    ddfilename = datadir+treatment+'-dcoords+dist-heading.pkl'        
    trial_dcoords = []
    trial_dist = []
    for trial in range(numtrials):
        print(trial,numtrials)
        # Rotate coordinates, and make a Katz density plot
        #  NOTE THAT THIS IS 'TRAJECTORIES', NOT SMOOTHED
        trajectories=trial_trajectories[trial]
        theta=trial_smoothheadings[trial]
        numsteps=trajectories.shape[0]
        alldcoords_rotated = np.zeros((numsteps,numfish,numfish,2))
        alldist = np.zeros((numsteps,numfish,numfish)) 

        for i in range(numfish):
            dcoords = np.zeros(trajectories.shape)    
            dcoords_rotated = np.zeros(trajectories.shape)    
            for step in range(numsteps):
                xrot = np.cos(-theta[step,i])
                yrot = np.sin(-theta[step,i])
                dcoords[step] = trajectories[step] - trajectories[step,i]
                dcoords_rotated[step] = np.dot([[xrot,-yrot],[yrot,xrot]],dcoords[step].T).T
            dist =np.sqrt(dcoords_rotated[:,:,0]**2+dcoords_rotated[:,:,1]**2)
            alldist[:,i] = dist
            alldcoords_rotated[:,i] = dcoords_rotated

        trial_dcoords.append(alldcoords_rotated)
        trial_dist.append(alldist)

    # output coordinate transformed neighbors to file
    pickle.dump([trial_dcoords,trial_dist],open(ddfilename,'wb'))
    print('wrote dcoords to:', ddfilename)


# (step 2) Quantiles, boundary, group

In [None]:
rnumbins = 20  # number of distance bins

alldskip=50 # for saving these, which are used all together to get 'combined quantiles'


for tnum in range(numtreatments):
    gc.collect()
    treatment = treatments[tnum]

    ### TREATMENT
    print(tnum, treatment)
    numtrials=allnumtrials[tnum]

    trialsel='all'
    alldist, alldcoords_rotated, smoothtrajectories, smoothspeeds, theta, heading, trajectories, trialids, trackingerrors = dfunc.getcat(treatment,datadir,trialsel)  # do this, because it saves a lot on memory    
    
    ## Calculate arena border, using function
    trial_arena_mid, trial_arena_r = dfunc.getboundaries(trajectories,trialids,numtrials)    
    
    ## Identify 'tracking errors', using the thresholds:
    neighbordist_threshold = 40  # this is about 1 cm
    numfish=6
    neighborRH_threshold = 0.2 # angular threshold in radians
    neighborsel = lambda x: np.reshape(x[:,np.logical_not(np.diag(np.ones(numfish).astype(int)))],(-1,numfish,numfish-1))    
    relativeheading = np.array([heading - np.tile(heading[:,i],(numfish,1)).T for i in range(numfish)]).swapaxes(0,1)
    overlap_individuals = np.sum(neighborsel((np.cos(relativeheading)>np.cos(neighborRH_threshold)) & (alldist<neighbordist_threshold)),axis=-1)
    # now count error time steps as ones where either individuals overlap, or there is a tracking error
    errortimesteps = (np.sum(overlap_individuals,axis=-1)>0) | (np.sum(trackingerrors,axis=-1)>0)
    
    # fractions not used
    numsteps = overlap_individuals.shape[0]
    removed_te = np.sum(np.sum(trackingerrors,axis=-1)>0)/numsteps
    removed_overlap = np.sum( (np.sum(overlap_individuals,axis=-1)>0) & np.logical_not(np.sum(trackingerrors,axis=-1)>0))/numsteps    
    indiv_te = np.sum(trackingerrors)/(numsteps*numfish)
    indiv_overlap = np.sum(overlap_individuals & np.logical_not(trackingerrors))/(numsteps*numfish)

    ## Calculate group heading, rotation and polarization, then the rotated coordinates
    groupcentroid = np.mean(trajectories,axis=1)
    gcdiff=np.gradient(groupcentroid,axis=0)/dt
    groupspeed=np.sqrt(gcdiff[:,0]**2+gcdiff[:,1]**2)

    groupheadingxy=gcdiff/(groupspeed[:, np.newaxis])
    groupheading=np.arctan2(groupheadingxy[:,1],groupheadingxy[:,0])
    positions_relative=np.swapaxes(np.array([trajectories[:,k]-groupcentroid for k in range(numfish)]),0,1)
    # calculate the polarization and rotation coefficients, using the equations from Tunstrom 2013
    grouppolarization=np.sqrt(np.nanmean(np.cos(heading),axis=1)**2 + np.nanmean(np.sin(heading),axis=1)**2)
    xall = positions_relative/np.expand_dims(np.linalg.norm(positions_relative,axis=2),2)
    yall = np.swapaxes(([np.cos(heading),np.sin(heading)]),1,2).T
    grouprotation = np.mean(np.cross(xall,yall),axis=1)

    mask = np.ones((numfish,numfish), dtype=bool)
    np.fill_diagonal(mask,0)
    groupiid = np.mean(alldist[:,mask],axis=1)          
            
    ## Calculate boundary distances
    fn = (lambda x: np.sqrt(x[:,:,0]**2+x[:,:,1]**2))
    boundarydist = np.concatenate([r-fn(trajectories[trialids==n]-mid) 
                         for n, r, mid in zip(range(numtrials),trial_arena_r,trial_arena_mid)])

    #### QUANTILE CALCULATION
    # speed quantiles
    quantilefilter = np.tile(errortimesteps,(numfish,1)).T  # don't include error time steps (from overlap or tracking errors), or boundary id error cases
    scat = smoothspeeds[quantilefilter]
    speedquantiles10 = np.quantile(scat,np.linspace(0,1,11))
    speedquantiles20 = np.quantile(scat,np.linspace(0,1,21))   

    # neighbor distance quantiles
    sel_offdiag = np.logical_not(np.diag(np.ones(numfish).astype(int)))
    ncat = np.reshape(alldist[:,sel_offdiag],(-1,6,5))
    # for simplicity, only filter probabilities by the focal fish
    ncat = np.reshape(ncat[quantilefilter],(-1))
    rmin = 0  # for neighbor, use a min
    ndistquantiles20 = np.quantile(ncat[ncat>rmin],np.linspace(0,1,rnumbins+1))

    # boundary distance quantiles
    brmin = 0  # this can be negative if the boundary is identified wrong, so use a threshold
    bcat = np.reshape(boundarydist[quantilefilter],(-1))
    bdistquantiles20 = np.quantile(bcat[bcat>brmin],np.linspace(0,1,rnumbins+1))

    # save these results
    outfile = datadir+treatment+'-quantile+group+boundary.pkl'
    pickle.dump([speedquantiles10, speedquantiles20, ndistquantiles20, bdistquantiles20,
        groupcentroid, groupspeed, grouppolarization, grouprotation, groupiid, boundarydist,
        trial_arena_mid, trial_arena_r,errortimesteps]
                ,open(outfile,'wb') )
    # save things for calculating the combined distributions
    outfile = datadir+treatment+'quantile-short.pkl'
    pickle.dump([scat[::alldskip],ncat[::alldskip],bcat[::alldskip], removed_te, removed_overlap, indiv_te, indiv_overlap],open(outfile,'wb'))

In [7]:
###  get the combined distributions for speed, ndist, bdist
grid_scat = dfunc.initarray([numtreatments])
grid_ncat = dfunc.initarray([numtreatments])
grid_bcat = dfunc.initarray([numtreatments])

grid_removed_because_of_errors = np.zeros((numtreatments,2))
grid_individual_errors = np.zeros((numtreatments,2))

for tnum in range(numtreatments):
    treatment = treatments[tnum]
    outfile = datadir+treatment+'quantile-short.pkl'
    if os.path.isfile(outfile):
        scat, ncat, bcat, removed_te, removed_overlap, indiv_te, indiv_overlap = pickle.load(open(outfile,'rb'))
        grid_scat[tnum] = scat
        grid_ncat[tnum] = ncat
        grid_bcat[tnum] = bcat
        grid_removed_because_of_errors[tnum] = (removed_te,removed_overlap)
        grid_individual_errors[tnum] = (indiv_te,indiv_overlap)    
    else:
        print(tnum,':  no file')

rnumbins = 20  # number of distance bins
all_speedquantiles10 = np.quantile(np.concatenate(grid_scat),np.linspace(0,1,11))
all_speedquantiles20 = np.quantile(np.concatenate(grid_scat),np.linspace(0,1,21))
temp = np.concatenate(grid_ncat)
all_ndistquantiles20 = np.quantile(temp[temp>0],np.linspace(0,1,rnumbins+1))
temp = np.concatenate(grid_bcat)
all_bdistquantiles20 = np.quantile(temp[temp>0],np.linspace(0,1,rnumbins+1))


outfile = resultdir+'combinedquantiles.pkl'
pickle.dump([all_speedquantiles10, all_speedquantiles20, all_ndistquantiles20, all_bdistquantiles20]
            ,open(outfile,'wb') )
print('wrote to:',outfile)


wrote to: /media/jacob/JD_DATA/zfish-code+save/savedresults/combinedquantiles.pkl


# (step 3) Loop through to get all medians and save to file

In [None]:
datachoice = 1  # 0 = all, 1= no overlap, 2 = far, 3=no tracking errors


def varcoeff(data):
    return np.median(np.std(data,axis=1)/np.std(data))

def varcoeff_IQR(data):
    ingroup = np.median( np.std(data,axis=1)     )
    total = np.quantile(data,0.75)-np.quantile(data,0.25)
    return ingroup

[trial_speedmedians, trial_speedIQR, trial_nnmedians, trial_cosmedians, trial_groupnums, trial_groupiidmedians, trial_grouppolmedians, 
 trial_groupiid_together_medians, trial_grouppol_together_medians,
 trial_freezefrac, trial_freezefrac1sec, trial_speedsync, trial_centroiddist, trial_boundarydist, trial_nnrelspeed
] = np.empty((15,len(treatments)),dtype=list)

[treatment_speedmedians, treatment_speedIQR, treatment_nnmedians, treatment_cosmedians, treatment_groupnums, treatment_groupiidmedians, treatment_grouppolmedians, 
 treatment_groupiid_together_medians, treatment_grouppol_together_medians,
 treatment_freezefrac, treatment_freezefrac1sec, treatment_speedsync, treatment_centroiddist, treatment_boundarydist, treatment_nnrelspeed
] = np.zeros((15,len(treatments)))

for tnum in range(numtreatments):
    print(tnum)
    treatment = treatments[tnum]
    numtrials = allnumtrials[tnum]
    alldist, alldcoords_rotated, smoothtrajectories, smoothspeeds, theta, heading, trajectories, trialids, trackingerrors = dfunc.getcat(treatment,datadir,'all')  # do this, because it saves a lot on memory  
    numsteps = alldist.shape[0]
    numfish = alldist.shape[1]

    outfile = datadir+treatment+'-quantile+group+boundary.pkl'
    [speedquantiles10, speedquantiles20, ndistquantiles, bdistquantiles,
        groupcentroid, groupspeed, grouppolarization, grouprotation, groupiid, boundarydist,
        trial_arena_mid, trial_arena_r, errortimesteps    ] = pickle.load(open(outfile,'rb'))


    # nearest neighbor
    neighborsel = lambda x: np.reshape(x[:,np.logical_not(np.diag(np.ones(numfish).astype(int)))],(-1,6,5))
    nnums = np.argsort(np.argsort(neighborsel(alldist),axis=2),axis=2)
    n=0  # nearest neighbor is n=0
    nndata = np.reshape(neighborsel(alldist)[nnums==n],(numsteps,-1))  
    nnheading =  np.reshape(neighborsel(np.tile(heading,(numfish,1,1)).swapaxes(0,1))[nnums==n],(numsteps,-1))  

    relativespeeds = smoothspeeds[:,np.newaxis,:] - smoothspeeds[:,:,np.newaxis]
    nnrelspd = np.reshape(neighborsel(relativespeeds)[nnums==n],(numsteps,-1))

    cosalign = np.abs(np.cos((heading[:,:,np.newaxis] - heading[:,np.newaxis,:]))    )
    meancosalign = np.mean(neighborsel(cosalign),axis=2)

    # group membership
    outfile = datadir+treatment+'-groupmembership.pkl'
    groupmembership,groupnumber,groupnumcomponents,cutoff = pickle.load(open(outfile,'rb'))
    grouptogether = np.all(groupmembership==0,axis=1)

    # freezing
    def getfreezesel(freezemedian):
        rs, frac = dfunc.mmed_all(smoothspeeds,freezemedian,skipcalc=1)
        freezesel = (rs<all_speedquantiles20[1]) & np.logical_not(np.isnan(rs))       
        return freezesel
    freezesel600 = getfreezesel(600)
    freezesel60 = getfreezesel(60)

    #filters
    far10 = boundarydist > 0.1*np.median(trial_arena_r)

    # use this line to select all data or not
    # make sure to keep the variable definition of 'far', because it is used below
    if (datachoice==0):
        datasel = np.tile(True,boundarydist.shape)
        speedsel = datasel
    elif datachoice==1:
        datasel = np.logical_not(np.tile(errortimesteps,(numfish,1)).T)
        speedsel = datasel            
    elif datachoice==2:
        datasel = far10
        speedsel = datasel            
    elif datachoice==3:
        datasel = np.logical_not(np.tile(errortimesteps,(numfish,1)).T) 
        speedsel = datasel & np.logical_not(freezesel600)
    else:
        print('error in datachoice number')

    dataselcount = np.sum( datasel,axis=1)  
    dataseltimesN = np.array([dataselcount>=i+1 for i in range(numfish)])    
    nsel_datasel = 3  # if make this 5, then actually have no data for some trials with the boundary threshold.  With the tracking errors, it doesn't matter - these are tiled per frame

    datasel_group = dataseltimesN[nsel_datasel]

    # initialize
    [ speedmedians, speedIQR, speedstd,  nnmedians,  cosmedians,  groupnums,  groupiidmedians, grouppolmedians, 
        groupiid_together_medians,    grouppol_together_medians,  freezefrac, freezefrac1sec, group_centroiddist, group_speedsync, group_speedsyncvar, 
        group_boundarydist,  nnrelspeed ]  = np.zeros((17,numtrials+1))

    for j in range(numtrials+1):
        if j==0:  # then select all
            trial=-1
            trialsel = np.tile(True,trialids.shape)
        else:
            trial=j-1
            trialsel = trialids==trial
        speedmedians[j] = np.median(smoothspeeds[trialsel[:,np.newaxis] & speedsel])
        speedIQR[j] = np.diff(np.quantile(smoothspeeds[trialsel[:,np.newaxis] & speedsel],[0.25,0.75]))        
        speedstd[j] = np.std(smoothspeeds[trialsel[:,np.newaxis] & speedsel])
        nnrelspeed[j] = np.diff(np.quantile(nnrelspd[trialsel[:,np.newaxis] & speedsel],[0.25,0.75]))                    
        nnmedians[j] = np.median(nndata[trialsel[:,np.newaxis] & datasel])
        cosmedians[j] = np.median(meancosalign[trialsel[:,np.newaxis] & datasel])
        groupnums[j] = np.mean(groupnumber[trialsel & datasel_group])     

        groupiidmedians[j] = np.nanmedian(groupiid[trialsel & datasel_group])
        grouppolmedians[j] = np.nanmedian(grouppolarization[trialsel & datasel_group])        
        groupiid_together_medians[j] = np.nanmedian(groupiid[trialsel&grouptogether&datasel_group])
        grouppol_together_medians[j] = np.nanmedian(grouppolarization[trialsel&grouptogether&datasel_group])   

        totalcounts = np.sum(trialsel[:,np.newaxis] & datasel)
        freezefrac[j] = np.sum(freezesel600[trialsel[:,np.newaxis] & datasel]) / totalcounts
        freezefrac1sec[j] = np.sum(freezesel60[trialsel[:,np.newaxis] & datasel]) / totalcounts        

        dc = np.diff(groupcentroid[trialsel],axis=0)
        dcdist = np.sqrt(dc[:,0]**2+dc[:,1]**2)        
        group_centroiddist[j] = np.nanmedian(dcdist[datasel_group[trialsel][0:-1] & datasel_group[trialsel][1:]])
        group_speedsync[j] = varcoeff_IQR(smoothspeeds[trialsel & datasel_group])
        group_speedsyncvar[j] = varcoeff(smoothspeeds[trialsel & datasel_group])
        group_boundarydist[j] = np.nanmedian(boundarydist[trialsel])


    treatment_speedmedians[tnum], trial_speedmedians[tnum] = speedmedians[0], speedmedians[1:]
    treatment_speedIQR[tnum], trial_speedIQR[tnum] = speedIQR[0], speedIQR[1:]
    treatment_nnmedians[tnum], trial_nnmedians[tnum] = nnmedians[0], nnmedians[1:]
    treatment_cosmedians[tnum], trial_cosmedians[tnum] = cosmedians[0], cosmedians[1:]

    treatment_groupnums[tnum], trial_groupnums[tnum] = groupnums[0], groupnums[1:]
    treatment_groupiidmedians[tnum], trial_groupiidmedians[tnum] = groupiidmedians[0], groupiidmedians[1:]
    treatment_grouppolmedians[tnum], trial_grouppolmedians[tnum] = grouppolmedians[0], grouppolmedians[1:]
    treatment_groupiid_together_medians[tnum], trial_groupiid_together_medians[tnum] = groupiid_together_medians[0], groupiid_together_medians[1:]
    treatment_grouppol_together_medians[tnum], trial_grouppol_together_medians[tnum] = grouppol_together_medians[0], grouppol_together_medians[1:]

    treatment_freezefrac[tnum], trial_freezefrac[tnum] = freezefrac[0], freezefrac[1:]
    treatment_freezefrac1sec[tnum], trial_freezefrac1sec[tnum] = freezefrac1sec[0], freezefrac1sec[1:]    
    treatment_centroiddist[tnum], trial_centroiddist[tnum] = group_centroiddist[0], group_centroiddist[1:]
    treatment_speedsync[tnum], trial_speedsync[tnum] = group_speedsync[0], group_speedsync[1:]
    treatment_boundarydist[tnum], trial_boundarydist[tnum] = group_boundarydist[0], group_boundarydist[1:]

    treatment_nnrelspeed[tnum], trial_nnrelspeed[tnum] = nnrelspeed[0], nnrelspeed[1:]




### TRIAL save results to pickle file

# use this line to select all data or not
if (datachoice==0):
    postfix=''
elif datachoice==1:
    postfix='-nooverlap'
elif datachoice==2:
    postfix='-far'
elif datachoice==3:
    postfix='nooverlap-nofreezeforspeed'
else:
    print('error in datachoice number')    
outfile = resultdir + 'Fig2-TrialQuantities'+postfix+'.pkl'
pickle.dump([
trial_speedmedians, trial_speedIQR, trial_nnmedians, trial_cosmedians, trial_groupnums, trial_groupiidmedians, trial_grouppolmedians, 
 trial_groupiid_together_medians, trial_grouppol_together_medians,
 trial_freezefrac, trial_freezefrac1sec, trial_speedsync, trial_centroiddist, trial_boundarydist, trial_nnrelspeed
],open(outfile,'wb'))

### TREATMENT save results to pickle file
outfile = resultdir + 'Fig2-TreatmentQuantities'+postfix+'.pkl'
pickle.dump([
treatment_speedmedians, treatment_speedIQR, treatment_nnmedians, treatment_cosmedians, treatment_groupnums, treatment_groupiidmedians, treatment_grouppolmedians, 
 treatment_groupiid_together_medians, treatment_grouppol_together_medians,
 treatment_freezefrac, treatment_freezefrac1sec, treatment_speedsync, treatment_centroiddist, treatment_boundarydist, treatment_nnrelspeed
],open(outfile,'wb'))
print('wrote results to file')

#  (step 4) Make input-outputs for model fitting

In [14]:
# immport combined quantiles
outfile = resultdir+'combinedquantiles.pkl'
[all_speedquantiles10, all_speedquantiles20, all_ndistquantiles20, all_bdistquantiles20] = pickle.load(open(outfile,'rb'))

In [15]:
# Functions related to defining 'freezing', so that can filter results
def movingaverage(data,N):
    # Pandas syntax is ridiculous to me, but this indeed works.  see https://stackoverflow.com/questions/13728392/moving-average-or-running-mean/30141358#30141358
    return pd.Series(data).rolling(window=N).mean().iloc[N-1:].values
def movingmedian(data,N):
    # Pandas syntax is ridiculous to me, but this indeed works.  see https://stackoverflow.com/questions/13728392/moving-average-or-running-mean/30141358#30141358
#     return pd.Series(data).rolling(window=N).median().iloc[N-1:].values
    return pd.Series(data).rolling(window=N,center=True).median().values

def mmed_all(smoothspeeds,N,skipcalc=5,threshold=all_speedquantiles20[1]):
    skipcalc = 1 if N==1 else skipcalc        
    allmed = np.array([movingmedian(smoothspeeds[::skipcalc,i],int(N/skipcalc))for i in range(6)]).T
    return allmed, np.mean(allmed<threshold)

import networkx as nx

def getcomponents(Aij_single):
    G = nx.from_numpy_matrix(Aij_single)
    connected = list(nx.connected_components(G))
    componentsizes =[len(c) for c in connected]
    components = np.zeros(numfish, dtype=int) - 1
    for i in range(len(connected)):
        components[list(connected[i])]=i    
    largestgroupsize = np.max(componentsizes)
    numcomponents = len(componentsizes)
    return components, largestgroupsize, numcomponents

In [None]:
# Load trial quantities
postfix_s = '-nooverlap'
outfile = resultdir + 'Fig2-TrialQuantities'+postfix_s+'.pkl'
trial_speedmedians_all = pickle.load(open(outfile,'rb'))[0]

medmedspeed = np.array([np.median(t) for t in trial_speedmedians_all] )  # use this to adjust the delay time

numfish=6

# make model io (done here only for focustreatments)
for tnum in focustreatments:
    treatment = treatments[tnum]

    ### TREATMENT
    print(tnum, treatment)
    numtrials=allnumtrials[tnum]
    trialsel='all'
    alldist, alldcoords_rotated, smoothtrajectories, smoothspeeds, theta, heading, trajectories, trialids,trackingerrors = dfunc.getcat(treatment,datadir,trialsel)  # do this, because it saves a lot on memory    
    numsteps = alldist.shape[0]
    
    # Group membership - get which are in the largest group.  (not using this metric anymore)
#     groupmembership = np.zeros(smoothspeeds.shape)
#     groupnumber = np.zeros(smoothspeeds.shape[0])
#     groupnumcomponents = np.zeros(smoothspeeds.shape[0])
#     cutoff = 286.86
#     for step in range(numsteps):
#         Aij = np.heaviside(cutoff-alldist[step],0) 
#         Aij[range(numfish),range(numfish)] = 0  
#         groupmembership[step], groupnumber[step], groupnumcomponents[step]  = getcomponents(Aij)
    outfile = datadir+treatment+'-groupmembership.pkl'
    groupmembership,groupnumber,groupnumcomponents,cutoff = pickle.load(open(outfile,'rb'))
    print('read in data from:',outfile)    
#     pickle.dump([groupmembership,groupnumber,groupnumcomponents,cutoff],open(outfile,'wb'))
#     print('wrote to:',outfile)
    
    outfile = datadir+treatment+'-quantile+group+boundary.pkl'
    [speedquantiles10, speedquantiles20, ndistquantiles20, bdistquantiles20,    _, _, _, _, _, boundarydist,
            trial_arena_mid, trial_arena_r,errortimesteps] = pickle.load(open(outfile,'rb'))
    

    # Calculate boundary coordinates, as if the boundary was a "neighbor"
    allbcoords = np.zeros(trajectories.shape)
    allbcoords_rotated = np.zeros(trajectories.shape)
    for i in range(numfish):
        thetafish = np.arctan2(trajectories[:,i,1]-trial_arena_mid[trialids,1],trajectories[:,i,0]-trial_arena_mid[trialids,0])  
        eboundary = np.tile(trial_arena_r[trialids],(2,1)).T * np.array([np.cos(thetafish),np.sin(thetafish)]).T # vector to the boundary
        allbcoords[:,i] = eboundary - (trajectories[:,i]-trial_arena_mid[trialids])    
        rel_orientation = thetafish-heading[:,i]
        x = boundarydist[:,i] * np.cos(rel_orientation)
        y = boundarydist[:,i] * np.sin(rel_orientation)    
        allbcoords_rotated[:,i] = np.array([x,y]).T    

    del allbcoords
    gc.collect()

    # freezing
    medianavgspeeds600frames, frac = dfunc.mmed_all(smoothspeeds,600,skipcalc=1)     
    medianavgspeeds600frames[np.isnan(medianavgspeeds600frames)] = 0 # these will be in the 'negative' bin, so will ignore in getinputoutputs fn.  This is only at the start and end of the array    
    medianavgspeeds60frames, frac = dfunc.mmed_all(smoothspeeds,60,skipcalc=1)     
    medianavgspeeds60frames[np.isnan(medianavgspeeds60frames)] = 0 # these will be in the 'negative' bin, so will ignore in getinputoutputs fn.  This is only at the start and end of the array        
    
    # Save the input data
    inputdata = [smoothspeeds,heading,boundarydist,groupmembership,alldist,alldcoords_rotated,allbcoords_rotated,medianavgspeeds600frames,medianavgspeeds60frames]    
    pickle.dump(inputdata,open(datadir+treatment+'-inputdata.pkl','wb'))
    
    ######################################### BINNING AND SAVING IO #########################################
    # define bins to use for encoding
    thetanumbins = 32
    offsettheta = False  # if true, then will rotate bins to get "front", "back", "side" centers
    dtheta = 2*np.pi/(thetanumbins)
    theta_edge = np.linspace(-np.pi,np.pi,thetanumbins+1) - offsettheta*dtheta/2
    

    binscheme='A10'
    if binscheme=='A': # 20 bins for all except neighbor speed
        sq, sjq, nq, bq = all_speedquantiles20, all_speedquantiles10, all_ndistquantiles20, all_bdistquantiles20
    elif binscheme=='A10':  # 10 bins for all
        sq, sjq, nq, bq = all_speedquantiles10, all_speedquantiles10, all_ndistquantiles20[::2], all_bdistquantiles20[::2]
    elif binscheme=='T':  # use 'treatment' level bins
        sq, sjq, nq, bq = speedquantiles20, speedquantiles10, ndistquantiles20, bdistquantiles20
    elif binscheme=='T10':  # use 'treatment' level bins
        sq, sjq, nq, bq = speedquantiles10, speedquantiles10, ndistquantiles20[::2], bdistquantiles20[::2]

    bdist_threshold=0.05
    speed_threshold = all_speedquantiles20[1]/2
    freeze_threshold = all_speedquantiles20[1]

    # input data bins
    binnedinputdata = hm.getbins(inputdata,[sq,sjq,nq,bq,theta_edge],
                         speed_threshold = speed_threshold, freeze_threshold = freeze_threshold)

    # output data (calculate here, not in a function)
    shifts_const = np.array([30,60])  # for heading change shifts, to make model output   
    shifts_adjusted =  np.round(medmedspeed[0]/medmedspeed[tnum] * shifts_const).astype(int)
    shifts = np.concatenate((shifts_const,shifts_adjusted))
    
    
    headingchange_byshift = [dfunc.get_headingchange(data=heading,shift=s) for s in shifts]
    vectorheadingchange_byshift = [dfunc.get_vectorheadingchange(data=trajectories,data2=heading,shift=s) for s in shifts]
    
    outputdata = headingchange_byshift + vectorheadingchange_byshift
    numheadingchangesteps = np.min([h.shape[0] for h in outputdata])
    binnedinputdata = [bd[0:numheadingchangesteps] for bd in binnedinputdata] 
    outputdata = [bd[0:numheadingchangesteps] for bd in outputdata] 
    
    # some filters    
    # for boundary distance, using the median trial arena_r size to set a threshold, for simplicity
    WT_median_trial_arena_r = 920.6505385557605
    far = boundarydist > bdist_threshold*WT_median_trial_arena_r
    
    # for when there are errors in the future timestep
    futureerrors = np.tile(False,(numsteps,numfish))
    for sh in shifts:
        # use individual tracking errors for this, not 'group'.
        # because its easier, don't use ones where any of the future cases have errors
        futureerrors[0:numsteps-sh] = futureerrors[0:numsteps-sh] | trackingerrors[sh:] 
    
    # this removes cases where the individual is far from the boundary, none are overlapped with other fish 
    for case in [0]:
        if case==0:
            farlabel='far05'
            sel = far & np.logical_not(errortimesteps[:,np.newaxis])  & np.logical_not(futureerrors)
        else:
            farlabel= 'all'
            sel = np.logical_not(errortimesteps[:,np.newaxis]) & np.logical_not(futureerrors)

        inputs, alloutputs, alloutputsraw = hm.getinputoutputs([binnedinputdata,outputdata],sel,tnum,trialids)

        gc.collect()
        for skip in [1,10,50]:
            outfile = datadir+treatment+'-io'+str(skip)+'-'+binscheme+'-'+farlabel+'.pkl'
            pickle.dump([inputs[::skip],[o[::skip] for o in alloutputs],[o[::skip] for o in alloutputsraw],sq,nq,bq],open(outfile,'wb'), protocol=4)
            print('wrote to',outfile)