In [69]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import dill as pickle
import bisect

from matplotlib import animation, rc
rc('animation', html='html5')

import sys
sys.path.append('..')
from util import add_angles, angle_between, angled_vector, clip_angle, unit_vector, x_axis, get_rotation_matrix
import binning

l = [1,2,4,5]
l[bisect.bisect(l, 3)]

4

In [196]:
class History(object):
    def __init__(self, store_minimum=0.35):
        self.store_minimum = store_minimum
        
        self.data = [[],[]]
        self.time = [[], []]
        self.last_kick = [0, 0]
        
    def add(self, fish_id, time, data, is_kick=False):
        #assert(self.time[fish_id][-1] < time)           
        self.data[fish_id].append(data)
        self.time[fish_id].append(np.around(time, decimals=6))
        
        if is_kick:
            self.record_kick(fish_id, time)
            
        assert(len(self.data[fish_id]) == len(self.time[fish_id]))
                
    def record_kick(self, fish_id, time):
        """Sets last kick for fish id and resets buffer for other fish"""
        #time_kick_before = self.last_kick[fish_id]
        self.last_kick[fish_id] = self.time
        
        other_id = 1 - fish_id
        # Clean buffer for other fish if needed.
        if self.time[1-fish_id]:
            end = self.time[other_id][-1]
            start = self.time[other_id][0]
            elapsed = end - start
            
            if elapsed > 2 * self.store_minimum:
                print(f'Cleaning buffer for {other_id}')
                # Find idx where we can cutoff
                idx_time = bisect.bisect(self.time[other_id], end - self.store_minimum)
                self.data[other_id] = self.data[other_id][idx_time:]
                self.time[other_id] = self.time[other_id][idx_time:]
            
    def get(self, fish_id, time, truncate_at_kick=False):
        if not self.time[fish_id]:
            return []
        
        end = self.time[fish_id][-1]
        start = self.time[fish_id][0]
        elapsed = end - start
        
        idx = bisect.bisect(self.time[fish_id], end-time) - 1
        idx = max(0,idx) # Good idea?
        
        if truncate_at_kick:
            # TODO: Is this the correct fish for truncation?
            idx = max(idx, self.last_kick[1-fish_id])
        
        assert(self.data[fish_id])        
        return[{'time': t, **m} for t,m in zip(self.time[fish_id][idx:], self.data[fish_id][idx:])]


hist = History()
hist.data = d
hist.time = t

hist.get(fish_id=0, time=0.1), t[0][6:9]

([{'angle_f0': 0,
   'heading_change_f0': 0.21065706266065629,
   'length_f0': 4.457980772074707,
   'time': 4.9,
   'x_f0': 2,
   'y_f0': 2},
  {'angle_f0': 0,
   'heading_change_f0': 0.21065706266065629,
   'length_f0': 4.457980772074707,
   'time': 4.95,
   'x_f0': 2,
   'y_f0': 2},
  {'angle_f0': 0,
   'heading_change_f0': 0.21065706266065629,
   'length_f0': 4.457980772074707,
   'time': 5.0,
   'x_f0': 2,
   'y_f0': 2}],
 [4.9, 4.95, 5.0])

In [197]:
class Fish(object):
    def __init__(self, fish_id, kick_model, boundary_condition):
        self.fish_id = fish_id
        self.kick_model = kick_model
        self.boundary_condition = boundary_condition
        
        self.kick_position_start = np.array([5,5])
        self.kick_position_end = np.array([5,5])
        
        self.kick_time_start = 0
        self.kick_time_end = 0
        self.time = 0
        
    def kick(self, history):
        #print(f"Kick of at {self.time}")
        # TODO: Replace with proper kick model
        print(f'{self.fish_id}\t{self.time}\tKicked')
               
        kick_trajectory, kick_duration = self.kick_model(fish_id=self.fish_id,
                                               history=history.get(fish_id=self.fish_id,
                                                                  time=self.time))
        
        self.kick_time_start = self.time
        self.kick_time_end = self.time + kick_duration
        # Apply bc to kick start and NOT to kick end.
        # This way we can simply interpolate between start and end!
        self.kick_position_start = self.boundary_condition(self.kick_position_end)
        self.kick_position_end = self.kick_position_start + kick_trajectory

    def step(self, 
             dt,
             history,
             boundary_condition=lambda x:x):
        eps = 1e-6 # to compare floats
        is_kick = False
        
        end_time = self.time + dt
        while dt > eps:
            # Move until length of current kick
            cur_time = min(self.kick_time_end, self.time + dt)
            dt -= cur_time - self.time
            self.time = cur_time

            if dt > eps:
                # We need to kick off here.
                self.kick(history=history)
                
                # Make sure that we didn't kick off before:
                assert(not is_kick)
                is_kick = True                
                
        # Avoid float inprecision!
        # This makes sure that both fish have same time at all history points.
        self.time = end_time
        
        history.add(fish_id=self.fish_id,
                   time=self.time,
                   data=self.get_state(),
                   is_kick=is_kick)
        
        
    def get_pos(self):
        # TODO: Use exponential interpolation here?
        pos_dt = self.kick_position_end - self.kick_position_start
        kick_time = self.kick_time_end - self.kick_time_start
        elapsed_time = self.time - self.kick_time_start
        
        if kick_time == 0:
            weight = 0
        else:
            weight = elapsed_time/kick_time
        return self.boundary_condition((1-weight) * self.kick_position_start
                            + weight *self.kick_position_end)
    
    def get_state(self):
        position = self.get_pos()
        
        kick_trajectory = self.kick_position_end - self.kick_position_start
        # TODO: Double check headings, is wrong!
        heading_change = angle_between(x_axis, kick_trajectory)
        heading = angle_between(x_axis, self.kick_position_end)
        #print(heading_change,heading)
        kick_length = np.linalg.norm(kick_trajectory)
        return {f'heading_change_f{self.fish_id}': heading_change,
                f'length_f{self.fish_id}': kick_length,
                #'dt': 0,
                f'angle_f{self.fish_id}': 0,
                f'x_f{self.fish_id}': 2,
                f'y_f{self.fish_id}': 2}

class WorldState(object):
    def __init__(self, size, kick_model):
        self.size = size
        # Use periodic boundary conditions
        boundary_condition = lambda pos: np.mod(pos, self.size)
            
        self.time = 0
        self.time_dt = 0.01 # time step from experiment
        self.history_dt = 5 # frames between history snapshots
        self.history_size = 11 # keep n snapshots
        self.fish = [Fish(fish_id=0,
                        boundary_condition=boundary_condition,
                        kick_model=kick_model),
                     Fish(fish_id=1,
                        boundary_condition=boundary_condition,
                        kick_model=kick_model)]
        self.history = History()
    
    def step(self, dt=None):
        if dt is None:
            dt = self.time_dt
        
        # Pass the history of the other fish to each fish
        for fish in self.fish:
            has_kicked = fish.step(dt, history=self.history)
            
        # Round time, not used for simulation anyway
        self.time = np.around(self.time + dt, decimals=4)
        print(f"Time: {self.time}")
    
    def get_pos(self):
        return np.array([self.fish[0].get_pos(), 
                        self.fish[1].get_pos()])
    
    
class SocialModel(object):
    def __init__(self, wall_model=None):
        self.history_params = {'history_size': 11,
                              'history_dt': 0.05,
                              'truncate_at_kick': False}
    
    def __call__(self, fish_id, history):
        # Wrong!
        if history:
            state = history[-1]
            current_angle = state[f'angle_f{fish_id}']
        else:
            current_angle = 0.0
        
        # TODO: Predict new trajectory in local coordinate system
        # (where fish is at (0,0) and has angle 0        
        kick_duration = np.random.random(1) + 0.16
        kick_length = np.random.random() * 5 #cm
        kick_trajectory = (angled_vector(np.deg2rad(np.random.random()*20)) *
                           kick_length)
        
        # Rotate vector
        rotation_matrix = get_rotation_matrix(current_angle)
        kick_trajectory = rotation_matrix @ kick_trajectory
        
        return kick_trajectory, kick_duration
    
world = WorldState(size=np.array([10,10]), kick_model=SocialModel(wall_model=None))

for i in range(100):
    world.step(0.05)

0	0	Kicked
1	0	Kicked
Time: 0.05
Time: 0.1
Time: 0.15
Time: 0.2
Time: 0.25
Time: 0.3
Time: 0.35
Time: 0.4
Time: 0.45
Time: 0.5
Time: 0.55
Time: 0.6
Time: 0.65
1	[0.67945408]	Kicked
Time: 0.7
Time: 0.75
Time: 0.8
Time: 0.85
0	[0.89736832]	Kicked
Cleaning buffer for 1
Time: 0.9
Time: 0.95
Time: 1.0
Time: 1.05
Time: 1.1
Time: 1.15
Time: 1.2
Time: 1.25
Time: 1.3
Time: 1.35
Time: 1.4
Time: 1.45
1	[1.47647615]	Kicked
Cleaning buffer for 0
Time: 1.5
Time: 1.55
Time: 1.6
Time: 1.65
Time: 1.7
Time: 1.75
Time: 1.8
Time: 1.85
Time: 1.9
Time: 1.95
Time: 2.0
0	[2.00249402]	Kicked
Cleaning buffer for 1
Time: 2.05
1	[2.05450969]	Kicked
Cleaning buffer for 0
Time: 2.1
Time: 2.15
Time: 2.2
Time: 2.25
Time: 2.3
0	[2.33530561]	Kicked
Time: 2.35
Time: 2.4
Time: 2.45
Time: 2.5
Time: 2.55
Time: 2.6
Time: 2.65
Time: 2.7
Time: 2.75
1	[2.75650932]	Kicked
Cleaning buffer for 0
Time: 2.8
Time: 2.85
Time: 2.9
Time: 2.95
Time: 3.0
Time: 3.05
Time: 3.1
0	[3.12998934]	Kicked
Cleaning buffer for 1
Time: 3.15
Time: 3.

In [222]:
history = world.history.data
time = [{'time': t} for t in world.history.time[0]]
history = [{**t, **a, **b} for a,b,t in zip(history[:][0], history[:][1], time)]
df = pd.DataFrame(history, index=None)

def prepare_history(df, backward, step=0.05):
    def interpolate_col(data, time, new_time):
        # Piece-wise linear interpolation.
        interpolated = np.interp(x=new_time, fp=data, xp=time)

        return pd.Series(interpolated)

    # Save columns. Otherwise we loose information about dropped frames!
    time = np.copy(df['time'].values)
    end = time.max()
    begin = max(time.min(), end - backward)
    print(end, begin)
   # 
    # Make sure that we do not have to extrapolate.
   # print(begin, end, time[0], time[-1])
    
    print(np.ceil(end-begin)/step)
    new_time = np.linspace(start=begin, stop=end, num=int(np.ceil((end-begin)/step) + 1))
    
    # Make sure time is increasing
    assert(np.all(np.diff(time) > 0.0))
    assert(np.all(np.diff(new_time) > 0.0))
    
    # Resample invalid entries.
    df = df.apply((lambda col: interpolate_col(col, time, new_time)), axis=0)
    
    # Restore time column.
    df.loc[:, 'time'] = new_time
    
    # Convert time column to dt
    print(new_time - begin)
    df.loc[:, 'dt'] = pd.Series(np.uint((new_time[::-1] - begin)*100), index=df.index)
    
    return df

prepare_history(df, backward=0.1, step=0.1)


5.0 4.9
10.0
[0.  0.1]


Unnamed: 0,angle_f0,angle_f1,heading_change_f0,heading_change_f1,length_f0,length_f1,time,x_f0,x_f1,y_f0,y_f1,dt
0,0.0,0.0,0.083152,0.167469,2.00208,3.386704,4.9,2.0,2.0,2.0,2.0,9
1,0.0,0.0,0.083152,0.167469,2.00208,3.386704,5.0,2.0,2.0,2.0,2.0,0


In [179]:
from itertools import product
with open('../adaptive_bins.model', 'rb') as f:
    edges = pickle.load(f)

#binning.transform_coords_df(df, cutoff_wall_range=0)


In [190]:
world.history.get(fish_id=0, time=0.15)

[{'angle_f0': 0,
  'heading_change_f0': 0.21065706266065629,
  'length_f0': 4.457980772074707,
  'time': 4.9,
  'x_f0': 2,
  'y_f0': 2},
 {'angle_f0': 0,
  'heading_change_f0': 0.21065706266065629,
  'length_f0': 4.457980772074707,
  'time': 4.95,
  'x_f0': 2,
  'y_f0': 2},
 {'angle_f0': 0,
  'heading_change_f0': 0.21065706266065629,
  'length_f0': 4.457980772074707,
  'time': 5.0,
  'x_f0': 2,
  'y_f0': 2}]

In [224]:
animation_dt

0.2

In [225]:
def add_to_buffer(buffer, value):
    buffer_local = np.roll(buffer, shift=-1)
    buffer_local[-1] = value
    np.copyto(dst=buffer, src=buffer_local)
    return buffer

fig = plt.figure(figsize=(10,10))
ax = plt.axes(xlim=(0, world.size[0]), ylim=(0, world.size[1]))
plt.close(fig)

lines = [None] * 2
lines[0], = ax.plot([], [], c='red', linewidth=5, label='fish 1')
lines[1], = ax.plot([], [], c='green', linewidth=5, label='fish 2')

ax.legend(loc='upper right')

# Set up animation buffers.
visible_steps = 10

# Shape of buffer: fish_id, coord, step
animation_buffer = np.zeros((2,2,visible_steps)) + (world.size/2)[0]

animation_frames = 1000
animation_interval = 10
animation_dt = animation_interval/animation_frames

def init():
    for line in lines:
        line.set_data([], [])
    return lines[0], lines[1]

def animate(i):
    world.step(animation_dt)
    cur_positions = world.get_pos()
    #print(cur_positions)
    for fish_id, position in enumerate(cur_positions):
        # Update animation buffers
        add_to_buffer(animation_buffer[fish_id, 0], position[0])
        add_to_buffer(animation_buffer[fish_id, 1], position[1])
        
        # Update graphic
        lines[fish_id].set_data(animation_buffer[fish_id, 0],
                                animation_buffer[fish_id, 1])

    return lines[0], lines[1], 

anim = animation.FuncAnimation(fig,
                               animate,
                               init_func=init,
                               frames=animation_frames,
                               interval=animation_interval,
                               blit=True)
anim

Time: 70.41
Time: 70.42
Time: 70.43
Time: 70.44
Time: 70.45
Time: 70.46
Time: 70.47
Time: 70.48
Time: 70.49
Time: 70.5
Time: 70.51
Time: 70.52
Time: 70.53
Time: 70.54
Time: 70.55
Time: 70.56
Time: 70.57
Time: 70.58
0	[70.77173675]	Kicked
Cleaning buffer for 1
Time: 70.59
Time: 70.6
Time: 70.61
Time: 70.62
Time: 70.63
Time: 70.64
Time: 70.65
Time: 70.66
Time: 70.67
Time: 70.68
Time: 70.69
1	[70.69477891]	Kicked
Cleaning buffer for 0
Time: 70.7
Time: 70.71
Time: 70.72
Time: 70.73
Time: 70.74
Time: 70.75
Time: 70.76
Time: 70.77
Time: 70.78
Time: 70.79
Time: 70.8
Time: 70.81
Time: 70.82
Time: 70.83
Time: 70.84
Time: 70.85
Time: 70.86
Time: 70.87
Time: 70.88
Time: 70.89
Time: 70.9
Time: 70.91
Time: 70.92
Time: 70.93
Time: 70.94
Time: 70.95
Time: 70.96
Time: 70.97
Time: 70.98
Time: 70.99
Time: 71.0
Time: 71.01
Time: 71.02
Time: 71.03
Time: 71.04
Time: 71.05
Time: 71.06
Time: 71.07
Time: 71.08
Time: 71.09
Time: 71.1
Time: 71.11
Time: 71.12
Time: 71.13
Time: 71.14
Time: 71.15
Time: 71.16
Time:

Time: 76.72
Time: 76.73
0	[76.92216067]	Kicked
Cleaning buffer for 1
Time: 76.74
Time: 76.75
Time: 76.76
Time: 76.77
Time: 76.78
Time: 76.79
Time: 76.8
Time: 76.81
Time: 76.82
Time: 76.83
Time: 76.84
Time: 76.85
Time: 76.86
Time: 76.87
Time: 76.88
Time: 76.89
Time: 76.9
Time: 76.91
Time: 76.92
Time: 76.93
Time: 76.94
Time: 76.95
Time: 76.96
Time: 76.97
0	[77.15610159]	Kicked
Time: 76.98
Time: 76.99
Time: 77.0
Time: 77.01
Time: 77.02
Time: 77.03
Time: 77.04
Time: 77.05
Time: 77.06
Time: 77.07
Time: 77.08
Time: 77.09
Time: 77.1
Time: 77.11
Time: 77.12
Time: 77.13
Time: 77.14
Time: 77.15
Time: 77.16
Time: 77.17
Time: 77.18
Time: 77.19
Time: 77.2
Time: 77.21
Time: 77.22
Time: 77.23
Time: 77.24
Time: 77.25
Time: 77.26
Time: 77.27
Time: 77.28
Time: 77.29
Time: 77.3
Time: 77.31
Time: 77.32
Time: 77.33
Time: 77.34
Time: 77.35
Time: 77.36
Time: 77.37
Time: 77.38
1	[77.38170378]	Kicked
Cleaning buffer for 0
Time: 77.39
Time: 77.4
Time: 77.41
Time: 77.42
Time: 77.43
Time: 77.44
Time: 77.45
Time: 

In [226]:
world.history.time

[[array([79.55375]),
  array([79.56375]),
  array([79.57375]),
  array([79.58375]),
  array([79.59375]),
  array([79.60375]),
  array([79.61375]),
  array([79.62375]),
  array([79.63375]),
  array([79.64375]),
  array([79.65375]),
  array([79.66375]),
  array([79.67375]),
  array([79.68375]),
  array([79.69375]),
  array([79.70375]),
  array([79.71375]),
  array([79.72375]),
  array([79.73375]),
  array([79.74375]),
  array([79.75375]),
  array([79.76375]),
  array([79.77375]),
  array([79.78375]),
  array([79.79375]),
  array([79.80375]),
  array([79.81375]),
  array([79.82375]),
  array([79.83375]),
  array([79.84375]),
  array([79.85375]),
  array([79.86375]),
  array([79.87375]),
  array([79.88375]),
  array([79.89375]),
  array([79.90375]),
  array([79.91375]),
  array([79.92375]),
  array([79.93375]),
  array([79.94375]),
  array([79.95375]),
  array([79.96375]),
  array([79.97375]),
  array([79.98375]),
  array([79.99375]),
  array([80.00375]),
  array([80.01375]),
  array([80.0