In [8]:
from cellworld import *
import numpy as np
import matplotlib.pyplot as plt
import glob
from tqdm.notebook import tqdm
import pickle
import matplotlib.colors as colors
import matplotlib.cm as cmx
import pandas as pd

import warnings
warnings.filterwarnings("ignore", category=FutureWarning) 

from _src.pose import *
from _src.visibility import *
from _src.itor import *

##  Load Data
World data, visibility matrix, pose library

In [9]:
# get logs
logs = glob.glob('./_data/logs/*.json')

# load visibility objects
print('Loading world visibility...')
e = Experiment.load_from_file(logs[3])
vis,w = get_vis(e)

# load visibility matrix
print('Loading visibility matrix...')
[A,V,pts,sparse_arr] = pickle.load(open('./_data/visibility-21_05-1000.pkl','rb'))   
vis_graph = {'V':V,'A':A,'src':pts,'dst':sparse_arr}

# load/build pose library
print('Loading/building pose library...')
poselib = build_pose_library(logs)

Loading world visibility...
Loading visibility matrix...
Loading/building pose library...
./_data/logs/_pose_library.pkl found, loading...


In [None]:
# for each experiment compute the itor null distribution
# test episode: MICE_20220609_1907_DMM1_21_05_SR4_episode020 = logs[-10]
d = Display(w)
#compute_itor_null(logs[-10],poselib,vis_graph,d,start_ep=20)
for l in logs[0:10]:
    compute_itor_null(l,poselib,vis_graph,d)


./_data/logs\MICE_20220512_1353_DMM1_21_05_HT1_experiment.json
./_data/logs\MICE_20220512_1520_DMM2_21_05_HT1_experiment.json
./_data/logs\MICE_20220512_1913_DMM3_21_05_HT1_experiment.json
./_data/logs\MICE_20220512_2015_DMM4_21_05_HT1_experiment.json
 Episode 0


  0%|          | 0/888 [00:00<?, ?it/s]

  saving...
 Episode 1


  0%|          | 0/951 [00:00<?, ?it/s]

  saving...
 Episode 2


  0%|          | 0/515 [00:00<?, ?it/s]

  saving...
 Episode 3


  0%|          | 0/695 [00:00<?, ?it/s]

  saving...
 Episode 4


  0%|          | 0/618 [00:00<?, ?it/s]

  saving...
 Episode 5


  0%|          | 0/1260 [00:00<?, ?it/s]

In [7]:
logs[50:]

['./_data/logs\\MICE_20220606_2009_DMM3_21_05_SR4_experiment.json',
 './_data/logs\\MICE_20220607_1801_DMM3_21_05_JR4_experiment.json',
 './_data/logs\\MICE_20220607_1910_DMM4_21_05_SR3_experiment.json',
 './_data/logs\\MICE_20220608_1458_DMM2_21_05_SR4_experiment.json',
 './_data/logs\\MICE_20220608_1557_DMM3_21_05_SR5_experiment.json',
 './_data/logs\\MICE_20220608_1949_DMM4_21_05_JR3_experiment.json',
 './_data/logs\\MICE_20220608_2044_DMM1_21_05_JR3_experiment.json',
 './_data/logs\\MICE_20220609_1542_DMM4_21_05_SR4_experiment.json',
 './_data/logs\\MICE_20220609_1634_DMM3_21_05_JR5_experiment.json',
 './_data/logs\\MICE_20220609_1907_DMM1_21_05_SR4_experiment.json',
 './_data/logs\\MICE_20220609_2041_DMM2_21_05_JR4_experiment.json',
 './_data/logs\\MICE_20220610_1546_DMM1_21_05_JR4_experiment.json',
 './_data/logs\\MICE_20220610_1645_DMM4_21_05_JR4_experiment.json',
 './_data/logs\\MICE_20220613_1537_DMM2_21_05_SR4_experiment.json',
 './_data/logs\\MICE_20220613_1627_DMM4_21_05_SR

In [None]:
ppoints = []
npoint = []
hpoint = []
for p in pose_null:
    print(p.part)
    if 'nose' in p.part:
        npoint = [p.location.x,p.location.y]
    elif 'head' in p.part:
        hpoint = [p.location.x,p.location.y]
    else:
        ppoints.append([p.location.x,p.location.y])
ppoints = np.vstack(ppoints)
plt.scatter(ppoints[:,0],ppoints[:,1])

## Build pose library
The library will have the following columns:

1. experiment name
2. episode
3. POSEx
4. POXEy
5. score

In [None]:
e = Experiment.load_from_file(logs[4])
EXP = []
EPI = []
POSEx = []
POSEy = []
SCORE = []
POSE = []
for j,ep in enumerate(e.episodes):
    pt = ep.trajectories.where('agent_name','prey').get_unique_steps()
    for step in pt:
        posex = []
        posey = []
        score = []
        if step.data:
            pose = PoseList.parse(step.data)
            for i in range(len(pose)):
                posex.append(pose[i].location.x)
                posey.append(pose[i].location.y)
                score.append(pose[i].score)
            EXP.append(e.name)
            EPI.append(j)
            POSEx.append(posex)
            POSEy.append(posey)
            SCORE.append(score)
            POSE.append(pose)
experiment = np.vstack(EXP)
episodes = np.vstack(EPI)
poseX = np.vstack(POSEx)
poseY = np.vstack(POSEy)
score = np.vstack(SCORE)

## Compute ITOR
True value

In [None]:
# select episode
episode = 10 #14

for ep in e.episodes:
    # for each step in the trajectory, compute true ITOR val
    pt = e.episodes[episode].trajectories.where('agent_name','prey').get_unique_steps()
    body_parts=['body_mid','tail_base','tail_post_base','tail_pre_tip','tail_tip']
    #body_parts=['body_mid']
    ITOR = []
    ITOR_body = []
    for i in tqdm(range(len(pt))):
        step = pt[i]
        if step.data:
            pose = PoseList.parse(step.data)

            # compute true ITOR
            I = compute_itor_pose(pose,
                      step.rotation,
                      vis_graph,
                      head_parts=['head_base'],
                      body_parts=body_parts)
            ITOR.append(I['ITOR'])

            # compute ITOR with just the body
            I = compute_itor_pose(pose,
                      step.rotation,
                      vis_graph,
                      head_parts=['head_base'],
                      body_parts=['body_mid'])
            ITOR_body.append(I['ITOR'])

In [None]:
%matplotlib inline

_ = plt.hist(ITOR,bins=np.linspace(0,1,30),alpha=0.5,label='full body')
_ = plt.hist(ITOR_body,bins=np.linspace(0,1,30),alpha=0.5,label='COM only')
plt.xlabel('ITOR')
plt.ylabel('count')
plt.legend()

Null distribution

In [None]:
k = 500
body_parts=['body_mid','tail_base','tail_post_base','tail_pre_tip','tail_tip']

for episode,ep in enumerate(e.episodes):
    filename = f'./_data/episode{episode}_{k}null.pkl'
    pt = ep.trajectories.where('agent_name','prey').get_unique_steps()
    print(f'episode{episode}/{len(e.episodes)}')
    if not glob.glob(filename):
        ITOR_null = []
        poseI = []
        for i in tqdm(range(len(pt))):
            step = pt[i]
            if step.data:
                # get the real pose
                pose0 = PoseList.parse(step.data).copy()

                # get null samples
                rand_sample = choices(np.where(episodes != episode)[0],k=k)
                itor = []
                pose_ind = []
                for kk in rand_sample:

                    # offset and rotate null pose
                    pose1 = POSE[kk].copy()
                    pose_null,src_angle,src_loc,ref_angle,ref_loc = match_pose(pose0,pose1)

                    # compute ITOR and map to color value
                    I = compute_itor_pose(pose_null,
                                      ref_angle,
                                      vis_graph,
                                      head_parts=['head_base'],
                                      body_parts=body_parts)
                    itor.append(I['ITOR'])
                    pose_ind.append(kk)
                ITOR_null.append(itor)
                poseI.append(pose_ind)

        pickle.dump([ITOR_null,poseI],open(filename,'wb'))   
    else:
        [ITOR_null,poseI] = pickle.load(open(filename,'rb'))



In [None]:
def pose2array(pose,score=0):
    '''Covert pose to array'''
    return np.vstack([[i.location.x,i.location.y] for i in pose if i.score > score])

def get_null_pose_dist(frame_list,pose_lib,itor_lib,pose0=[]):
    pose_null = []
    itor_null = []
    for i,kk in enumerate(frame_list):
        if pose0:
            pose_norm,_,_,_,_ = match_pose(
                pose0,
                pose_lib[kk])
        else:
            pose_norm = pose_lib[kk]
        pa = pose2array(pose_norm)
        pose_null.append(pa)
        itor_null.append(np.repeat(itor_lib[i],pa.shape[0]))
    return pose_null,itor_null

In [None]:
# pretransform the null pose for plotting
pt = e.episodes[episode].trajectories.where('agent_name','prey').get_unique_steps()
pnull = []
inull = []
for ind in tqdm(range(len(pt)-1)):
    pose_null,itor_null = get_null_pose_dist(poseI[ind],POSE,ITOR_null[ind],PoseList.parse(pt[ind].data))
    pnull.append(pose_null)
    inull.append(itor_null)

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
import matplotlib.animation
import numpy as np

# colormap
cm = plt.get_cmap('jet')
cn = colors.Normalize(vmin=0,vmax=1)
sm = cmx.ScalarMappable(norm=cn, cmap=cm)

fig, ax = plt.subplots()
x, y = [],[]
d = Display(w, fig_size=(7,9), padding=0, cell_edge_color="lightgrey",fig=fig,ax=ax)
sc = ax.scatter(x,y,s=10,vmin=0,vmax=1)
plt.colorbar(sc,cmap=sm)
plt.xlim(0,1)
plt.ylim(0,1)

def animate(i):
    pose_null = np.vstack(pnull[i])
    itor_null = np.hstack(inull[i])[:,np.newaxis]
    si = np.argsort(itor_null,axis=0)
    x = pose_null[si,0]
    y = pose_null[si,1]
    sc.set_offsets(np.c_[x,y])
    sc.set_array(np.squeeze(itor_null[si]))
    
ani = matplotlib.animation.FuncAnimation(fig, animate, 
                frames=len(pnull), interval=30, repeat=True) 
plt.show()

In [None]:
np.squeeze(itor_null[si])

In [None]:
# set episode, samples etc
k = 500
episode = 14
pt = e.episodes[episode].trajectories.where('agent_name','prey').get_unique_steps()
rt = e.episodes[episode].trajectories.where('agent_name','predator')
duration = len(pt) / 30 # convert to time

def make_frame(ind):
    
    print(f'frame {ind} / {len(pt)}')
    fig,ax = plt.subplots(2,2,figsize=(10,10))
    
    # copy original pose
    step = pt[ind]
    pose0 = PoseList.parse(step.data).copy()
    
    # plot each display
    d0 = Display(w, fig_size=(7,9), padding=0, cell_edge_color="lightgrey",fig=fig,ax=ax[0,0])
    d1 = Display(w, fig_size=(7,9), padding=0, cell_edge_color="lightgrey",fig=fig,ax=ax[0,1])
    
    # set the colormap
    cm = plt.get_cmap('jet')
    cNorm = colors.Normalize(vmin=0,vmax=1)
    scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cm)
    
    # sample and plot random poses
    itor = []
    h = []
    for i,kk in enumerate(poseI[ind]):

        # get a pose draw
        pose1 = POSE[kk].copy()

        # offset and rotate each point
        pose_norm,src_angle,src_loc,ref_angle,ref_loc = match_pose(pose0,pose1)

        # cmap and plot
        itor.append(ITOR_null[ind][i])
        cval = scalarMap.to_rgba(ITOR_null[ind][i])
        h = plot_pose(pose_norm,ax=ax[0,0],color=cval,alpha=0.1)
        h = plot_pose(pose_norm,ax=ax[0,1],color=cval,alpha=0.1)
        
    h = plot_pose(pose0,ax=ax[0,0],color='w')
    ax[0,0].axis('scaled')
    ax[0,0].set_xlabel('x')
    ax[0,0].set_ylabel('y')
    ax[0,0].set_xlim([-0.1,1.1])
    ax[0,0].set_ylim([-0.1,1.1])
    plt.setp(h[0],edgecolor='k',sizes=[15])
    plt.setp(h[1],markeredgecolor='k',ms=5)
    plt.setp(h[2],markeredgecolor='k',ms=3)
    
    # get robot for this frame and plot
    rind = np.where(np.array(rt.get('frame'))==step.frame)[0]
    if len(rind) > 0:
        # update location
        rloc = rt[rind[0]].location
        ax[0,0].plot(rloc.x,rloc.y,'rD',markersize=10)
        ax[0,1].plot(rloc.x,rloc.y,'rD',markersize=10)

    h = plot_pose(pose0,ax=ax[0,1],color='w')
    ax[0,1].axis('scaled')
    ax[0,1].set_xlim([ref_loc.x-0.1,ref_loc.x+0.1])
    ax[0,1].set_ylim([ref_loc.y-0.1,ref_loc.y+0.1])
    ax[0,1].set_xlabel('x')
    ax[0,1].set_ylabel('y')
    plt.setp(h[0],edgecolor='k',sizes=[50])
    plt.setp(h[1],markeredgecolor='k',ms=15)
    plt.setp(h[2],markeredgecolor='k',ms=12)
    plt.colorbar(scalarMap,ax=ax[0,1])

    ax[1,0].plot(ITOR,label='ITOR(t)')
    ax[1,0].axvline(ind,color='r',label='true pose')
    ax[1,0].set_xlabel('frame')
    ax[1,0].set_ylabel('ITOR')
    ax[1,0].set_ylim((0,1))
    ax[1,0].legend()
    ax[1,0].set_title(f'frame {ind}')

    ax[1,1].hist(itor,bins=np.linspace(0,1,30),label='random pose')
    ax[1,1].axvline(ITOR[ind],color='r',label='frame')
    ax[1,1].set_ylabel('count')
    ax[1,1].set_xlabel('ITOR')
    ax[1,1].set_xlim([0,1])
    ax[1,1].set_ylim([0,500])
    ax[1,1].legend()
    ax[1,1].set_title(f'percentile = {np.sum(np.array(itor)<=ITOR[ind])/np.sum(~np.isnan(np.array(itor)))}')

    fig.savefig(f'./_plots/episode{episode}/frame_{ind:03d}.jpeg')
    plt.close()
    
start = 0
for i in tqdm(range(start,len(pt)-1)):
    make_frame(i)

In [None]:
fig,ax = plt.subplots(2,2,figsize=(10,10))

def plot_itor_null(pt,frame,w,ax=[],fig=[]):
      
    if not ax or not fig:
        fig,ax = plt.subplots(1,1)
        
    # plot the display
    d = Display(w, fig_size=(7,9), padding=0, cell_edge_color="lightgrey",fig=fig,ax=ax)
    
    
  



In [None]:
for ind in range(len(pt)):

    # copy original pose
    step = pt[ind]
    if step.data:
        pose0 = PoseList.parse(step.data).copy()
        pose1 = POSE[poseI[ind][0]].copy()
        pose_norm,src_angle,src_loc,ref_angle,ref_loc = match_pose(pose0,pose1)
        break

In [None]:
def pose2array(pose,score=0.8):
    return np.vstack([[i.location.x,i.location.y] for i in pose if i.score > score])

In [None]:
pa = pose2array(pose_norm,score=0.8)
h.append(plt.scatter(pa[:,0],pa[:,1],10))


In [None]:
pose_norm