In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from os import sep as sep
import seaborn as sns
import statsmodels.api as sm
from sklearn.linear_model import LinearRegression

In [2]:
exp_path = r'Y:\Lior&Einav\Experiments\experiment23_271020'
enriched_trop_table_filename = r'trophallaxis_table_enriched_temp_with_conf.csv'
ant_filename = r'ants_list.csv'

In [3]:
to_save_transparency = True
transparency_filename = r'transparency_table.csv'

In [4]:
trop_table = pd.read_csv(exp_path + sep + enriched_trop_table_filename)
trop_table.head()

Unnamed: 0,vidnum,id,actual_ant1,actual_ant2,actual_start,actual_end,group,general_start_frame,general_end_frame,general_group_id,...,ant2_got_yellow,ant1_crop_before_red,ant1_crop_before_yellow,ant2_crop_before_red,ant2_crop_before_yellow,ant1_x,ant1_y,ant2_x,ant2_y,estimation_confidence
0,1.0,31.0,521,408.0,142.0,147.0,1.0,101.0,106.0,1.0,...,0.0,0.0,306399.34375,0.0,0.0,2356.364268,1750.350192,,,3.0
1,1.0,,421,408.0,145.0,147.0,1.0,104.0,106.0,1.0,...,0.0,0.0,0.0,0.0,3998.497977,,,,,3.0
2,1.0,,421,521.0,145.0,153.0,1.0,104.0,112.0,1.0,...,0.0,0.0,0.0,0.0,302372.525313,,,2314.53771,1717.31935,3.0
3,1.0,,197,421.0,146.0,147.0,1.0,105.0,106.0,1.0,...,0.0,0.0,0.0,0.0,369.450757,,,,,3.0
4,1.0,,197,521.0,146.0,147.0,1.0,105.0,106.0,1.0,...,0.0,0.0,0.0,0.0,301030.2525,,,2319.521674,1717.0094,3.0


In [5]:
no_groups = trop_table.loc[trop_table['general_group_id'].isna()]
no_groups = no_groups[['actual_ant1','actual_ant2', 'general_start_frame',
                                                                 'general_end_frame', 'ant1_got_red', 'ant1_got_yellow',
                                                                 'ant2_got_red', 'ant2_got_yellow', 'estimation_confidence']]
no_groups.head()

Unnamed: 0,actual_ant1,actual_ant2,general_start_frame,general_end_frame,ant1_got_red,ant1_got_yellow,ant2_got_red,ant2_got_yellow,estimation_confidence
6,137,529.0,109.0,109.0,-61493.957813,0.0,0.0,0.0,3.0
13,161,169.0,126.0,205.0,0.0,-345062.47959,0.0,240684.749219,3.0
17,13,529.0,132.0,134.0,0.0,0.0,0.0,0.0,3.0
18,16,521.0,134.0,201.0,0.0,0.0,0.0,-158349.657031,2.0
22,197,525.0,146.0,195.0,174195.203528,0.0,-183213.479858,0.0,3.0


In [6]:
high_conf = no_groups.loc[no_groups['estimation_confidence']==3]
high_conf.head()

Unnamed: 0,actual_ant1,actual_ant2,general_start_frame,general_end_frame,ant1_got_red,ant1_got_yellow,ant2_got_red,ant2_got_yellow,estimation_confidence
6,137,529.0,109.0,109.0,-61493.957813,0.0,0.0,0.0,3.0
13,161,169.0,126.0,205.0,0.0,-345062.47959,0.0,240684.749219,3.0
17,13,529.0,132.0,134.0,0.0,0.0,0.0,0.0,3.0
22,197,525.0,146.0,195.0,174195.203528,0.0,-183213.479858,0.0,3.0
23,156,484.0,146.0,189.0,111936.617188,0.0,-138819.625,0.0,3.0


In [7]:
len(high_conf)

1411

In [8]:
len(no_groups)

2700

In [9]:
len(trop_table)

2936

In [10]:
ants = pd.read_csv(exp_path + sep + ant_filename)
ants.head()

Unnamed: 0,ant_id,is_forager
0,1069,False
1,113,True
2,12,False
3,13,True
4,137,True


In [11]:
len(ants)

44

In [12]:
def get_ant_interaction_table(ant, interaction_table):
    
    ant_interaction_rows1 = interaction_table['actual_ant1'] == ant
    ant_interaction_rows2 = interaction_table['actual_ant2'] == ant
    ant_interactions = interaction_table.loc[ant_interaction_rows1 | ant_interaction_rows2]
    
    return ant_interactions

In [13]:
def get_ant_interaction_measurements(ant, ant_interactions):
    
    ant_got_red = []
    ant_got_yellow = []
    other_ant_got_red = []
    other_ant_got_yellow = []
    for idx, trop_row in ant_interactions.iterrows():
        if trop_row['actual_ant1'] == ant:
            ant_got_red.append(trop_row['ant1_got_red'])
            ant_got_yellow.append(trop_row['ant1_got_yellow'])
            other_ant_got_red.append(trop_row['ant2_got_red'])
            other_ant_got_yellow.append(trop_row['ant2_got_yellow'])
        else:
            ant_got_red.append(trop_row['ant2_got_red'])
            ant_got_yellow.append(trop_row['ant2_got_yellow'])
            other_ant_got_red.append(trop_row['ant1_got_red'])
            other_ant_got_yellow.append(trop_row['ant1_got_yellow'])
    
    ant_interaction_measurements = pd.DataFrame({'ant_got_red': ant_got_red, 'ant_got_yellow': ant_got_yellow, 
              'other_ant_got_red': other_ant_got_red, 'other_ant_got_yellow': other_ant_got_yellow})
    return ant_interaction_measurements

In [14]:
def get_ant_fit_data(ant_interaction_measurements, color, fit_intercept=False, remove_zeros=True):
    if color == 'pooled':
        x = np.array([ant_interaction_measurements['ant_got_red'],ant_interaction_measurements['ant_got_yellow']]).flatten()
        y = -np.array([ant_interaction_measurements['other_ant_got_red'],ant_interaction_measurements['other_ant_got_yellow']]).flatten()        
    else:
        x = np.array(ant_interaction_measurements['ant_got_'+color])
        y = -np.array(ant_interaction_measurements['other_ant_got_'+color])
    
    nanx = np.isnan(x)
    nany = np.isnan(y)
    
    x = np.delete(x, np.argwhere(nanx|nany)).reshape(-1, 1)
    y = np.delete(y, np.argwhere(nanx|nany)).reshape(-1, 1)
    
    if remove_zeros:
        x_non0 = x.nonzero()
        y_non0 = y.nonzero()
        tot_non0 = np.intersect1d(x_non0,y_non0)
        x = x[tot_non0]
        y = y[tot_non0]
    
    if len(x)>0:
        reg = LinearRegression(fit_intercept=fit_intercept).fit(x,y)
        slope = reg.coef_[0][0]
        r2 = reg.score(x,y)
        return slope, r2, x, y, reg
    else:
        return np.nan, np.nan, x, y, np.nan

In [15]:
def plot_ant_fit(ax,x,y,reg):
    ax.scatter(x,y)
    xlims = np.array(ax.get_xlim()).reshape(-1, 1)
    ax.plot(xlims, reg.predict(xlims),'r')

# Fit slope for each color separately

In [16]:
%matplotlib qt
n_cols = 10
n_ants = len(ants)
yellow_fig, yellow_axes = plt.subplots(np.int(np.ceil(n_ants/n_cols)),n_cols)
yellow_fig.suptitle('Yellow')
red_fig, red_axes = plt.subplots(np.int(np.ceil(n_ants/n_cols)),n_cols)
red_fig.suptitle('Red')


red_slopes = {}
red_r2 = {}
yellow_slopes = {}
yellow_r2 = {}
num_points = {}
ant_ids = {}
for idx, ant in enumerate(ants['ant_id']):
    coor = np.unravel_index(idx,(np.int(np.ceil(n_ants/n_cols)),n_cols))
    ant_interactions = get_ant_interaction_table(ant, high_conf)
    ant_interaction_measurements = get_ant_interaction_measurements(ant, ant_interactions)
    # num_points[idx] = len(ant_interactions)
    ant_ids[idx] = ant
    for color in ['red','yellow']:
        
        if color == 'red':
            red_slopes[idx], red_r2[idx], x, y, reg = get_ant_fit_data(ant_interaction_measurements, color)
            num_points[idx] = len(x)
            if len(x)>0:
                plot_ant_fit(red_axes[coor],x,y,reg)
            red_axes[coor].set_title(ant)
        else:
            yellow_slopes[idx], yellow_r2[idx], x, y, reg = get_ant_fit_data(ant_interaction_measurements, color)
            num_points[idx] = len(x)
            if len(x)>0:
                plot_ant_fit(yellow_axes[coor],x,y,reg)
            yellow_axes[coor].set_title(ant)
    
#     all_slopes[idx] = slope
#     all_r2[idx] = r2

# red_fig.title('Red')
# yellow_fig.title('Yellow')
    
red_fit_results = pd.DataFrame({'ant': ants['ant_id'], 'slope': red_slopes, 'r2': red_r2, 'num_points': num_points})
yellow_fit_results = pd.DataFrame({'ant': ants['ant_id'], 'slope': yellow_slopes, 'r2': yellow_r2, 'num_points': num_points})

red_fit_results.head()

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  after removing the cwd from sys.path.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations


Unnamed: 0,ant,slope,r2,num_points
0,1069,0.828053,0.74337,7
1,113,0.707877,0.719077,14
2,12,0.703239,0.362399,3
3,13,0.657651,0.619054,25
4,137,0.696487,0.457601,2


In [17]:
fig,ax = plt.subplots(1,2)
ax[0].hist(red_fit_results['r2'],bins=50)
ax[0].set_title('red')
ax[1].hist(yellow_fit_results['r2'],bins=50)
ax[1].set_title('yellow')
plt.show()

In [18]:
plt.figure()
plt.hist(yellow_fit_results['r2'].loc[yellow_fit_results['r2']>0],bins=50)

(array([1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 3., 1.,
        2., 1., 0., 1., 2., 1., 0., 3., 1., 2., 0., 1., 0., 3., 0., 2., 0.,
        0., 1., 1., 1., 0., 2., 1., 0., 2., 0., 1., 1., 0., 0., 0., 3.]),
 array([0.26038291, 0.27517525, 0.2899676 , 0.30475994, 0.31955228,
        0.33434462, 0.34913696, 0.3639293 , 0.37872165, 0.39351399,
        0.40830633, 0.42309867, 0.43789101, 0.45268335, 0.4674757 ,
        0.48226804, 0.49706038, 0.51185272, 0.52664506, 0.54143741,
        0.55622975, 0.57102209, 0.58581443, 0.60060677, 0.61539911,
        0.63019146, 0.6449838 , 0.65977614, 0.67456848, 0.68936082,
        0.70415316, 0.71894551, 0.73373785, 0.74853019, 0.76332253,
        0.77811487, 0.79290722, 0.80769956, 0.8224919 , 0.83728424,
        0.85207658, 0.86686892, 0.88166127, 0.89645361, 0.91124595,
        0.92603829, 0.94083063, 0.95562297, 0.97041532, 0.98520766,
        1.        ]),
 <BarContainer object of 50 artists>)

In [19]:
r2_thres = 0.2
min_num_points = 2
good_red_results = red_fit_results.loc[(red_fit_results['r2']>r2_thres) & (red_fit_results['num_points']>min_num_points)]
good_yellow_results = yellow_fit_results.loc[(yellow_fit_results['r2']>r2_thres) & (yellow_fit_results['num_points']>min_num_points)]

In [20]:
fig,ax = plt.subplots(1,2)
ax[0].hist(good_red_results['r2'],bins=50)
ax[0].set_title('red')
ax[1].hist(good_yellow_results['r2'],bins=50)
ax[1].set_title('yellow')
plt.show()

In [21]:
fig = plt.figure()
plt.scatter(red_fit_results['slope'], yellow_fit_results['slope'], c=yellow_fit_results['num_points'])
plt.plot([0,4],[0,4],'k')
#plt.xlim([0,1.5])
#plt.ylim([0,1.5])
plt.xlabel('red slope')
plt.ylabel('yellow slope')
plt.colorbar(label='yellow n')
#plt.clim([0,1])

<matplotlib.colorbar.Colorbar at 0x244d8f79888>

In [22]:
def calc_ant_transparency_by_mean(ant, red_fit, yellow_fit, min_r2, min_n):
    red_slope = red_fit['slope']
    yellow_slope = yellow_fit['slope']
    red_r2 = red_fit['r2']
    yellow_r2 = yellow_fit_results['r2']
    n = red_fit_results['num_points']
    
    if n < minN:
        ant_transparency = np.nan
    if (red_slope <= 0) or (red_r2 < min_r2):
        if (yellow_slope <= 0) or (yellow_r2 < min_r2):
            ant_transparency = np.nan
        else:
            ant_transparency = 1/yellow_slope
    elif (yellow_slope <= 0) or (yellow_r2 < min_r2):
        ant_transparency = 1/red_slope
    else:
        ant_transparency = 1/(np.mean([red_slope, yellow_slope]))
        
    return ant_transparency
            

In [23]:
for (idx_r, red_row), (idx_y, yellow_row) in zip(red_fit_results.iterrows(), yellow_fit_results.iterrows()):
    print(idx_r, red_row, yellow_row)

0 ant           1069.000000
slope            0.828053
r2               0.743370
num_points       7.000000
Name: 0, dtype: float64 ant           1069.000000
slope            1.363597
r2               0.260383
num_points       7.000000
Name: 0, dtype: float64
1 ant           113.000000
slope           0.707877
r2              0.719077
num_points     14.000000
Name: 1, dtype: float64 ant           113.000000
slope           0.710169
r2              0.716695
num_points     14.000000
Name: 1, dtype: float64
2 ant           12.000000
slope          0.703239
r2             0.362399
num_points     3.000000
Name: 2, dtype: float64 ant           12.000000
slope          0.213080
r2            -0.090293
num_points     3.000000
Name: 2, dtype: float64
3 ant           13.000000
slope          0.657651
r2             0.619054
num_points    25.000000
Name: 3, dtype: float64 ant           13.000000
slope          0.885719
r2             0.561150
num_points    25.000000
Name: 3, dtype: float64
4 ant   

# Calculate pooled transparency

In [24]:
%matplotlib qt
n_cols = 10
n_ants = len(ants)
fig, axes = plt.subplots(np.int(np.ceil(n_ants/n_cols)),n_cols)
fig.suptitle('Pooled')

pooled_slopes = {}
pooled_r2 = {}
num_points = {}
ant_ids = {}
for idx, ant in enumerate(ants['ant_id']):
    coor = np.unravel_index(idx,(np.int(np.ceil(n_ants/n_cols)),n_cols))
    ant_interactions = get_ant_interaction_table(ant, high_conf)
    ant_interaction_measurements = get_ant_interaction_measurements(ant, ant_interactions)
    ant_ids[idx] = ant
   
    pooled_slopes[idx], pooled_r2[idx], x, y, reg = get_ant_fit_data(ant_interaction_measurements, 'pooled')
    num_points[idx] = len(x)
    if len(x)>0:
        plot_ant_fit(axes[coor],x,y,reg)
    axes[coor].set_title(ant)


    
pooled_fit_results = pd.DataFrame({'ant': ants['ant_id'], 'slope': pooled_slopes, 'r2': pooled_r2, 'num_points': num_points})

pooled_fit_results.head()

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  after removing the cwd from sys.path.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if sys.path[0] == '':
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if sys.path[0] == '':
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if sys.path[0] == '':
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if sys.path[0] == '':
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if sys.path[0] == '':
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if sys.path[0] == '':
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if sys.path[0] == '':
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if sys.path[0] == '':
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if sys.path[0] == '':
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if sys.path[0] == '':
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if sys.path[0] == '':
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  

Unnamed: 0,ant,slope,r2,num_points
0,1069,0.85579,0.682866,16
1,113,0.71108,0.727257,41
2,12,0.467843,0.150649,9
3,13,0.694231,0.596366,51
4,137,0.696146,0.457642,27


In [25]:
fig2,ax2 = plt.subplots(1,2)
ax2[0].hist(pooled_fit_results['slope'],bins=50)
ax2[0].set_title('slope')
ax2[1].hist(pooled_fit_results['r2'],bins=50)
ax2[1].set_title('r2')
plt.show()

# Test transparency

### correct transparency

In [26]:
# ant = 139
# pooled_fit_results['slope'].loc[pooled_fit_results['ant']==ant]

In [27]:
# plt.figure()
# plt.hist(1/pooled_fit_results['slope'].loc[pooled_fit_results['r2']>0.2],bins=50)
# np.mean(1/pooled_fit_results['slope'].loc[pooled_fit_results['r2']>0.2])

In [28]:
# plt.figure()
# plt.scatter(pooled_fit_results['r2'], pooled_fit_results['slope'])
# plt.xlabel('r2')
# plt.ylabel('slope')
# plt.grid()

In [29]:
# good_fits = pooled_fit_results.loc[pooled_fit_results['r2']>0.2]
# len(good_fits)

In [30]:
# len(pooled_fit_results)

In [31]:
transparency_table = pooled_fit_results
mean_slope = np.mean(pooled_fit_results['slope'].loc[pooled_fit_results['r2']>=0.2])
transparency_table['transparency'] = 1/transparency_table['slope'] - 1/mean_slope + 1
transparency_table['transparency'].loc[transparency_table['r2']<0.2]=1
transparency_table['transparency'].loc[transparency_table['slope'].isna() | transparency_table['r2'].isna()]=1
transparency_table

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  self._setitem_with_indexer(indexer, value)


Unnamed: 0,ant,slope,r2,num_points,transparency
0,1069,0.85579,0.682866,16,0.954128
1,113,0.71108,0.727257,41,1.191929
2,12,0.467843,0.150649,9,1.0
3,13,0.694231,0.596366,51,1.22606
4,137,0.696146,0.457642,27,1.222097
5,156,0.539741,0.572413,23,1.638358
6,16,0.661161,0.506185,20,1.298109
7,160,0.674114,0.827472,18,1.269046
8,161,1.063563,0.610881,28,0.725853
9,164,1.09163,0.590331,26,0.701679


In [32]:
plt.figure()
plt.hist(transparency_table['transparency'],bins=50)
np.mean(transparency_table['transparency'])

1.0609075067945168

### transform measurements with ant transparency

In [33]:
# high_conf

In [34]:
transparency_table = transparency_table.set_index('ant')
transparency_table

Unnamed: 0_level_0,slope,r2,num_points,transparency
ant,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
1069,0.85579,0.682866,16,0.954128
113,0.71108,0.727257,41,1.191929
12,0.467843,0.150649,9,1.0
13,0.694231,0.596366,51,1.22606
137,0.696146,0.457642,27,1.222097
156,0.539741,0.572413,23,1.638358
16,0.661161,0.506185,20,1.298109
160,0.674114,0.827472,18,1.269046
161,1.063563,0.610881,28,0.725853
164,1.09163,0.590331,26,0.701679


In [35]:
# transparency_table.loc[392]

In [36]:
if to_save_transparency:
    transparency_table.to_csv(exp_path+sep+transparency_filename)


In [37]:
def correct_measurements_by_transparency(measurement, ant):
    return measurement/transparency_table['transparency'][ant]
    

In [38]:
high_conf = high_conf.drop(high_conf.loc[(high_conf['actual_ant1']==-1) | (high_conf['actual_ant2']==-1) | 
                                        (high_conf['actual_ant1']==-5) | (high_conf['actual_ant2']==-5)].index)
corrected_measurements = high_conf

In [39]:
corrected_measurements['ant1_got_red_corrected'] = corrected_measurements.apply(lambda x: correct_measurements_by_transparency(x['ant1_got_red'], x['actual_ant1']),axis=1)
corrected_measurements['ant1_got_yellow_corrected'] = corrected_measurements.apply(lambda x: correct_measurements_by_transparency(x['ant1_got_yellow'], x['actual_ant1']),axis=1)
corrected_measurements['ant2_got_red_corrected'] = corrected_measurements.apply(lambda x: correct_measurements_by_transparency(x['ant2_got_red'], x['actual_ant2']),axis=1)
corrected_measurements['ant2_got_yellow_corrected'] = corrected_measurements.apply(lambda x: correct_measurements_by_transparency(x['ant2_got_yellow'], x['actual_ant2']),axis=1)
corrected_measurements

Unnamed: 0,actual_ant1,actual_ant2,general_start_frame,general_end_frame,ant1_got_red,ant1_got_yellow,ant2_got_red,ant2_got_yellow,estimation_confidence,ant1_got_red_corrected,ant1_got_yellow_corrected,ant2_got_red_corrected,ant2_got_yellow_corrected
6,137,529.0,109.0,109.0,-61493.957813,0.000000,0.000000,0.000000,3.0,-50318.377550,0.000000,0.000000,0.000000
13,161,169.0,126.0,205.0,0.000000,-345062.479590,0.000000,240684.749219,3.0,0.000000,-475388.995510,0.000000,418141.460079
17,13,529.0,132.0,134.0,0.000000,0.000000,0.000000,0.000000,3.0,0.000000,0.000000,0.000000,0.000000
22,197,525.0,146.0,195.0,174195.203528,0.000000,-183213.479858,0.000000,3.0,257111.990668,0.000000,-230223.493018,0.000000
23,156,484.0,146.0,189.0,111936.617188,0.000000,-138819.625000,0.000000,3.0,68322.442885,0.000000,-170573.877803,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...
2903,42,329.0,10010.0,10046.0,-127200.868750,0.000000,0.000000,0.000000,3.0,-135612.721841,0.000000,0.000000,0.000000
2914,325,365.0,10055.0,10150.0,0.000000,55838.238281,0.000000,0.000000,3.0,0.000000,36314.809462,0.000000,0.000000
2916,418,329.0,10070.0,10138.0,59329.453711,61329.028418,0.000000,-72079.934064,3.0,65510.141625,67718.023445,0.000000,-70080.338749
2918,45,484.0,10081.0,10097.0,0.000000,66122.825000,0.000000,0.000000,3.0,0.000000,80904.779650,0.000000,0.000000


### test correlation again

In [40]:
def get_ant_corrected_measurements(ant, ant_interactions):
    
    ant_got_red = []
    ant_got_yellow = []
    other_ant_got_red = []
    other_ant_got_yellow = []
    for idx, trop_row in ant_interactions.iterrows():
        if trop_row['actual_ant1'] == ant:
            ant_got_red.append(trop_row['ant1_got_red_corrected'])
            ant_got_yellow.append(trop_row['ant1_got_yellow_corrected'])
            other_ant_got_red.append(trop_row['ant2_got_red_corrected'])
            other_ant_got_yellow.append(trop_row['ant2_got_yellow_corrected'])
        else:
            ant_got_red.append(trop_row['ant2_got_red_corrected'])
            ant_got_yellow.append(trop_row['ant2_got_yellow_corrected'])
            other_ant_got_red.append(trop_row['ant1_got_red_corrected'])
            other_ant_got_yellow.append(trop_row['ant1_got_yellow_corrected'])
    
    ant_interaction_measurements = pd.DataFrame({'ant_got_red': ant_got_red, 'ant_got_yellow': ant_got_yellow, 
              'other_ant_got_red': other_ant_got_red, 'other_ant_got_yellow': other_ant_got_yellow})
    return ant_interaction_measurements

In [41]:
n_cols = 10
n_ants = len(ants)
fig, axes = plt.subplots(np.int(np.ceil(n_ants/n_cols)),n_cols)
fig.suptitle('corrected')

corrected_slopes = {}
corrected_r2 = {}
num_points = {}
ant_ids = {}
for idx, ant in enumerate(ants['ant_id']):
    coor = np.unravel_index(idx,(np.int(np.ceil(n_ants/n_cols)),n_cols))
    ant_interactions = get_ant_interaction_table(ant, corrected_measurements)
    ant_interaction_measurements = get_ant_corrected_measurements(ant, ant_interactions)
    ant_ids[idx] = ant
   
    corrected_slopes[idx], corrected_r2[idx], x, y, reg = get_ant_fit_data(ant_interaction_measurements, 'pooled')
    num_points[idx] = len(x)
    if len(x)>0:
        plot_ant_fit(axes[coor],x,y,reg)
    axes[coor].set_title(ant)


    
corrected_fit_results = pd.DataFrame({'ant': ants['ant_id'], 'slope': corrected_slopes, 'r2': corrected_r2, 'num_points': num_points})

corrected_fit_results.head()

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  This is separate from the ipykernel package so we can avoid doing imports until
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/d

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-no

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()


Unnamed: 0,ant,slope,r2,num_points
0,1069,0.666393,0.67415,16
1,113,0.815558,0.781974,41
2,12,0.442843,0.413043,9
3,13,1.008519,0.629765,51
4,137,0.837637,0.490185,27


In [42]:
n_cols = 10
n_ants = len(ants)
fig, axes = plt.subplots(np.int(np.ceil(n_ants/n_cols)),n_cols)
fig.suptitle('raw')

raw_slopes = {}
raw_r2 = {}
num_points = {}
ant_ids = {}
for idx, ant in enumerate(ants['ant_id']):
    coor = np.unravel_index(idx,(np.int(np.ceil(n_ants/n_cols)),n_cols))
    ant_interactions = get_ant_interaction_table(ant, high_conf)
    ant_interaction_measurements = get_ant_interaction_measurements(ant, ant_interactions)
    ant_ids[idx] = ant
   
    raw_slopes[idx], raw_r2[idx], x, y, reg = get_ant_fit_data(ant_interaction_measurements, 'pooled')
    num_points[idx] = len(x)
    if len(x)>0:
        plot_ant_fit(axes[coor],x,y,reg)
    axes[coor].set_title(ant)


    
raw_fit_results = pd.DataFrame({'ant': ants['ant_id'], 'slope': raw_slopes, 'r2': raw_r2, 'num_points': num_points})

raw_fit_results.head()

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  This is separate from the ipykernel package so we can avoid doing imports until
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/d

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  # This is added back by InteractiveShellApp.init_path()
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-no

Unnamed: 0,ant,slope,r2,num_points
0,1069,0.85579,0.682866,16
1,113,0.71108,0.727257,41
2,12,0.467843,0.150649,9
3,13,0.694231,0.596366,51
4,137,0.696146,0.457642,27


In [43]:
raw_fit_results

Unnamed: 0,ant,slope,r2,num_points
0,1069,0.85579,0.682866,16
1,113,0.71108,0.727257,41
2,12,0.467843,0.150649,9
3,13,0.694231,0.596366,51
4,137,0.696146,0.457642,27
5,156,0.539741,0.572413,23
6,16,0.661161,0.506185,20
7,160,0.674114,0.827472,18
8,161,1.063563,0.610881,28
9,164,1.09163,0.590331,26


In [44]:
corrected_fit_results

Unnamed: 0,ant,slope,r2,num_points
0,1069,0.666393,0.67415,16
1,113,0.815558,0.781974,41
2,12,0.442843,0.413043,9
3,13,1.008519,0.629765,51
4,137,0.837637,0.490185,27
5,156,0.870832,0.551984,23
6,16,0.770819,0.503553,20
7,160,0.756596,0.810868,18
8,161,0.852804,0.794985,28
9,164,0.810525,0.639934,26


In [45]:
np.mean(raw_fit_results['slope'])

0.978857985816961

In [46]:
np.mean(raw_fit_results['r2'])

0.5561142970597877

In [47]:
np.mean(corrected_fit_results['slope'])

0.9543209499255858

In [48]:
np.mean(corrected_fit_results['r2'])

0.5963732196506127