In [1]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import dill as pickle
import bisect
import torch
import re
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

from matplotlib import animation, rc
rc('animation', html='html5')

import sys
sys.path.append('..')
from util import add_angles, angle_between, angled_vector, clip_angle, unit_vector, x_axis, get_rotation_matrix
import binning
import data
from plots import *
from social_models import load_social_model



cuda:0
cuda:0


In [2]:
X_train, y_train, X_test, y_test = [d.to(device) for d in data.get_data('../../data/processed/')]

In [3]:
X_train.shape, X_test.shape, y_train.shape


(torch.Size([8, 19427, 192]),
 torch.Size([8, 4436, 192]),
 torch.Size([19427, 2]))

In [4]:
def multiple_replace(string, replacements):
    """
    Given a string and a replacement map, it returns the replaced string.
    :param str string: string to execute replacements on
    :param dict replacements: replacement dictionary {value to find: value to replace}
    :rtype: str
    
    Modified from: https://gist.github.com/bgusach/a967e0587d6e01e889fd1d776c5f3729
    """
    substrs = sorted(replacements, key=len, reverse=True)

    # Create a big OR regex that matches any of the substrings to replace
    regexp = re.compile('|'.join(map(re.escape, substrs)))

    # For each match, look up the new string in the replacements
    return regexp.sub(lambda match: replacements[match.group(0)], string)


def interpolate_time(df, required_timesteps):
    def interpolate_col(data, time, new_time):
        # Piece-wise linear interpolation.
        interpolated = np.interp(x=new_time, fp=data, xp=time)

        return pd.Series(interpolated)

    # Save columns. Otherwise we loose information about dropped frames!
    time = np.copy(df['time'].values)

    new_time = time.max() - required_timesteps
    print(time, '\n', new_time)
    
    # Make sure time is increasing
    assert(np.all(np.diff(time) > 0.0))
    
    # Resample entries.
    df = df.apply((lambda col: interpolate_col(col, time, new_time)), axis=0)
    
    # Restore time column.
    df.loc[:, 'time'] = new_time
       
    # Convert time column to dt
    df.loc[:, 'dt'] = pd.Series(np.uint((required_timesteps)*100), index=df.index)
    
    return df

class History(object):
    def __init__(self, store_minimum=0.5):
        self.store_minimum = store_minimum
        
        self.data = [[],[]]
        self.time = [[], []]
        self.last_kick = [0, 0]
        
    def add(self, fish_id, time, data, is_kick=False):
        #assert(self.time[fish_id][-1] < time)           
        self.data[fish_id].append(data)
        self.time[fish_id].append(np.around(time, decimals=6))
        
        if is_kick:
            self.record_kick(fish_id, time)
            
        assert(len(self.data[fish_id]) == len(self.time[fish_id]))
                
    def record_kick(self, fish_id, time):
        """Sets last kick for fish id and resets buffer for other fish"""
        #time_kick_before = self.last_kick[fish_id]
        self.last_kick[fish_id] = self.time
        
        other_id = 1 - fish_id
        # Clean buffer for other fish if needed.
        if self.time[1-fish_id]:
            end = self.time[other_id][-1]
            start = self.time[other_id][0]
            elapsed = end - start
            
            if elapsed > 2 * self.store_minimum:
                print(f'Cleaning buffer for {other_id}')
                # Find idx where we can cutoff
                idx_time = bisect.bisect(self.time[other_id], end - self.store_minimum)
                print(len(self.data[other_id]))
                self.data[other_id] = self.data[other_id][idx_time:]
                print(len(self.data[other_id]))
                self.time[other_id] = self.time[other_id][idx_time:]

    def get(self, required_timesteps, edges, fish_id):
        time = [{'time': t} for t in self.time[0]]
        history = [{**t, **a, **b} for a,b,t in zip(self.data[:][0], self.data[:][1], time)]
        df = pd.DataFrame(history, index=None)
        
        assert(df.size > 0 )
        # For all later usages, f0 is assumed to be the focal fish.
        # -> If fish_id is different, exchange both.
        #print(df)
        if fish_id != 'f0':
            replacements = {'f0': 'f1', 'f1':'f0'}
            df = df.rename(columns=lambda col: multiple_replace(col, replacements))

        # Get required timesteps
        df = interpolate_time(df=df, 
                              required_timesteps=required_timesteps)
        # Convert to local coordinate system:
        # Not used for predictions
        df.loc[:, 'heading_change'] = pd.Series(float('nan'), index=df.index)
        df.loc[:, 'length'] = pd.Series(float('nan'), index=df.index)

        df = binning.transform_coords_df(df, cutoff_wall_range=None)
        #df = df.drop(columns=['trajectory_f0_x', 'trajectory_f0_y'])

        # Discretize
        df = binning.get_bins_df(df=df, **edges)

        # Normalize
        # Means, stds are from training set
        means, stds = (0.6348294576188627, 0.0012032019097761566), (0.5596922160565242, 0.5326711711690991)

        # Values of y are NaN (they are not needed!)
        X, y = binning.get_Xy(df=df,
                  num_bins=edges['num_bins'],
                  means=means,
                  stds=stds)
        rf_df = binning.Xy_to_df(X, y)

        # Reshape according to our conventions
        X, y = data.process_df(rf_df)

        return X.to(device)
    
with open('../../models/adaptive_bins.model', 'rb') as f:
    edges = pickle.load(f)
    
with open('../../models/kick_duration.model', 'rb') as f:
    kick_duration = pickle.load(f)   

In [5]:
class Fish(object):
    def __init__(self, fish_id, kick_model, boundary_condition):
        self.fish_id = fish_id
        self.kick_model = kick_model
        self.boundary_condition = boundary_condition
        
        self.kick_position_start = np.array([25,25]) + np.random.random(size=2) * 4
        self.kick_position_end = self.kick_position_start + np.array([0.5, 0])
        
        self.kick_time_start = 0
        self.kick_time_end = 0
        self.time = 0
        
    def kick(self, history):
        print(f'{self.fish_id}\t{self.time}\tKicked')
               
        kick_trajectory, kick_duration = self.kick_model(fish_id=self.fish_id,
                       history=history,
                       state=self.get_state())
        
        self.kick_time_start = self.time
        self.kick_time_end = self.time + kick_duration
        # Apply bc to kick start and NOT to kick end.
        # This way we can simply interpolate between start and end!
        self.kick_position_start = self.boundary_condition(self.kick_position_end)
        self.kick_position_end = self.kick_position_start + kick_trajectory

    def step(self, 
             dt,
             history,
             boundary_condition=lambda x:x):
        eps = 1e-6 # to compare floats
        is_kick = False
        
        end_time = self.time + dt
        while dt > eps:
            # Move until length of current kick
            cur_time = min(self.kick_time_end, self.time + dt)
            dt -= cur_time - self.time
            self.time = cur_time

            if dt > eps:
                # We need to kick off here.
                self.kick(history=history)
                
                # Make sure that we didn't kick off before:
                assert(not is_kick)
                is_kick = True                
                
        # Avoid float inprecision!
        # This makes sure that both fish have same time at all history points.
        self.time = end_time
        
        history.add(fish_id=self.fish_id,
                   time=self.time,
                   data=self.get_state(),
                   is_kick=is_kick)        
        
    def get_pos(self):
        # Linear interpolation between end and start
        pos_dt = self.kick_position_end - self.kick_position_start
        kick_time = self.kick_time_end - self.kick_time_start
        elapsed_time = self.time - self.kick_time_start
        
        if kick_time == 0:
            weight = 0
        else:
            weight = elapsed_time/kick_time
        return self.boundary_condition((1-weight) * self.kick_position_start
                            + weight *self.kick_position_end)
    
    def get_state(self):
        x, y = self.get_pos()
        
        kick_trajectory = self.kick_position_end - self.kick_position_start
        # TODO: Double check headings, is wrong!
        heading_change = angle_between(x_axis, kick_trajectory)
        heading = angle_between(x_axis, kick_trajectory)

        kick_length = np.linalg.norm(kick_trajectory)
        return {f'heading_change_f{self.fish_id}': heading_change,
                f'length_f{self.fish_id}': kick_length,
                f'angle_f{self.fish_id}': heading,
                f'x_f{self.fish_id}': x,
                f'y_f{self.fish_id}': y}

class WorldState(object):
    def __init__(self, size, kick_model):
        self.size = size
        # Use periodic boundary conditions
        boundary_condition = lambda pos: np.mod(pos, self.size)
            
        self.time = 0
        self.time_dt = 0.01 # time step from experiment
        self.history_dt = 5 # frames between history snapshots
        self.history_size = 11 # keep n snapshots
        self.fish = [Fish(fish_id=0,
                        boundary_condition=boundary_condition,
                        kick_model=kick_model),
                     Fish(fish_id=1,
                        boundary_condition=boundary_condition,
                        kick_model=kick_model)]
        self.history = History()
        
        # Init history with initial pos.
        # otherwise df is zero for beginning!
        for i,f in enumerate(self.fish):
            self.history.add(fish_id=i,
                        time=self.time,
                        data=f.get_state())
    
    def step(self, dt=None):
        if dt is None:
            dt = self.time_dt
        
        # Pass the history of the other fish to each fish
        for fish in self.fish:
            has_kicked = fish.step(dt, history=self.history)
            
        # Round time, not used for simulation anyway
        self.time = np.around(self.time + dt, decimals=4)
        print(f"Time: {self.time}")
    
    def get_pos(self):
        return np.array([self.fish[0].get_pos(), 
                        self.fish[1].get_pos()])  
   

class KickModel(object):
    def __init__(self, social_model, duration_model, wall_model=None):
        self.social_model = social_model
        self.wall_model = wall_model
        self.duration_model = duration_model
        assert(self.wall_model is None) # not supported
    
    # todo: pass edges
    def __call__(self, fish_id, state, history,):
        receptive_field = history.get(
            required_timesteps=self.social_model.get_required_timesteps()/100,
            edges=edges,
            fish_id=fish_id)
        current_angle = state[f'angle_f{fish_id}']
        
        kick_trajectory = self.social_model.sample(receptive_field).reshape(-1)
        
        # TODO: Predict new trajectory in local coordinate system (???)
        # (where fish is at (0,0) and has angle 0        
        kick_duration = self.duration_model.sample()[0][0,0]
        #kick_length = np.random.random() * 5 #cm
        #kick_trajectory = (angled_vector(np.deg2rad(np.random.random()*20)) *
         #                  kick_length)
        
        # Rotate vector - prediction was in local rf coordinate system
        rotation_matrix = get_rotation_matrix(current_angle)
        print(np.rad2deg(current_angle))
        print(kick_trajectory)
        kick_trajectory = rotation_matrix @ kick_trajectory
        print(kick_trajectory)
        
        return kick_trajectory, kick_duration
    
    
    def get_required_timesteps(self):
        return social_model.get_required_timesteps()
  

social_model = load_social_model('../../models/rnn_mdn.pt', X_train, y_train)
kick_model = KickModel(social_model=social_model,
                      duration_model=kick_duration)
print(social_model.get_required_timesteps())

world = WorldState(size=np.array([50,50]), 
                   kick_model=kick_model)

<class 'int'>
{'no_memory': array([0]), 'memory': array([ 0.        ,  5.71428571, 11.42857143, 17.14285714, 22.85714286,
       28.57142857, 34.28571429, 40.        ])}
[ 0.  5. 10. 15. 20. 25. 30. 35.]


In [6]:
def add_to_buffer(buffer, value):
    buffer_local = np.roll(buffer, shift=-1)
    buffer_local[-1] = value
    np.copyto(dst=buffer, src=buffer_local)
    return buffer

fig = plt.figure(figsize=(10,10))
ax = plt.axes(xlim=(0, world.size[0]), ylim=(0, world.size[1]))
plt.close(fig)

lines = [None] * 2
lines[0], = ax.plot([], [], c='red', linewidth=5, label='fish 1')
lines[1], = ax.plot([], [], c='gray', linewidth=5, label='fish 2')

ax.legend(loc='upper right')

# Set up animation buffers.
visible_steps = 20

# Shape of buffer: fish_id, coord, step
animation_buffer = np.ones((2,2,visible_steps))

# Init buffer with starting position of fish
animation_buffer[0,0,:] *= world.get_pos()[0][0]
animation_buffer[0,1,:] *= world.get_pos()[0][1]
animation_buffer[1,0,:] += world.get_pos()[1][0]
animation_buffer[1,1,:] *= world.get_pos()[1][1]

animation_frames = 2000//2
animation_interval = 40//2
animation_dt = animation_interval/animation_frames
print(animation_dt)

def init():
    for line in lines:
        line.set_data([], [])
    return lines[0], lines[1]

def animate(i):
    world.step(animation_dt)
    cur_positions = world.get_pos()
    #print(cur_positions)
    
    for i, f in enumerate(world.fish):
        break
        state = f.get_state()
        x, y = state[f'x_f{i}'], state[f'x_f{i}']
        angle = state[f'angle_f{i}']
        traj_x, traj_y = angled_vector(angle)*2
        ax.arrow(x, y, traj_x, traj_y, width=0.2, color='gray')
        ax.scatter(x,y, s=300, marker='x', linewidth=2)
    
    for fish_id, position in enumerate(cur_positions):
        # Update animation buffers
        add_to_buffer(animation_buffer[fish_id, 0], position[0])
        add_to_buffer(animation_buffer[fish_id, 1], position[1])
        
        data_0 = animation_buffer[fish_id, 0]
        data_1 = animation_buffer[fish_id, 1]
        
        # Don't draw data before 'jump' due to boundary cond!
        cond = (np.abs(data_0 - data_0[-1]) < 2) & (np.abs(data_1 - data_1[-1]) < 2)
        data_0 = data_0[cond]
        data_1 = data_1[cond]
        
        # Update graphic
        lines[fish_id].set_data(data_0,
                                data_1)
        
        #lines[fish_id].set_data(animation_buffer[fish_id, 0],
        #                        animation_buffer[fish_id, 1])

    return lines[0], lines[1], 

anim = animation.FuncAnimation(fig,
                               animate,
                               init_func=init,
                               frames=animation_frames,
                               interval=animation_interval,
                               blit=True,)

#anim

0.02


In [7]:
anim.save('../../figures/social_anim.mp4', dpi=92*2, fps=animation_frames/20)

0	0	Kicked
[0] 
 [ 0.   -0.05 -0.1  -0.15 -0.2  -0.25 -0.3  -0.35]
0.0
[1.79623966 0.80131481]
[1.79623966 0.80131481]
1	0	Kicked
[0] 
 [ 0.   -0.05 -0.1  -0.15 -0.2  -0.25 -0.3  -0.35]
0.0
[1.07311928 0.93760591]
[1.07311928 0.93760591]
Time: 0.02
Time: 0.04
Time: 0.06
Time: 0.08
Time: 0.1
Time: 0.12
Time: 0.14
Time: 0.16
Time: 0.18
Time: 0.2
Time: 0.22
Time: 0.24
Time: 0.26
Time: 0.28
0	0.29411522106870347	Kicked
[0.   0.02 0.04 0.06 0.08 0.1  0.12 0.14 0.16 0.18 0.2  0.22 0.24 0.26
 0.28] 
 [ 0.28  0.23  0.18  0.13  0.08  0.03 -0.02 -0.07]
24.041977330445444
[ 1.57622592 -0.35656973]
[1.58475253 0.31652732]
Time: 0.3
Time: 0.32
1	0.32008255331912755	Kicked
[0.   0.02 0.04 0.06 0.08 0.1  0.12 0.14 0.16 0.18 0.2  0.22 0.24 0.26
 0.28 0.3  0.32] 
 [ 0.32  0.27  0.22  0.17  0.12  0.07  0.02 -0.03]
41.144365998490336
[ 1.74274007 -0.17614278]
[1.42827227 1.01400569]
Time: 0.34
Time: 0.36
Time: 0.38
Time: 0.4
Time: 0.42
Time: 0.44
Time: 0.46
Time: 0.48
Time: 0.5
0	0.5170310529347135	Kicke

Time: 3.74
Time: 3.76
Time: 3.78
Time: 3.8
Time: 3.82
Time: 3.84
Time: 3.86
Time: 3.88
Time: 3.9
Time: 3.92
1	3.920096101279506	Kicked
[2.92 2.94 2.96 2.98 3.   3.02 3.04 3.06 3.08 3.1  3.12 3.14 3.16 3.18
 3.2  3.22 3.24 3.26 3.28 3.3  3.32 3.34 3.36 3.38 3.4  3.42 3.44 3.46
 3.48 3.5  3.52 3.54 3.56 3.58 3.6  3.62 3.64] 
 [3.64 3.59 3.54 3.49 3.44 3.39 3.34 3.29]
-90.91925514872884
[ 0.51000806 -2.3220817 ]
[-2.32996508 -0.47268847]
Cleaning buffer for 0
52
25
Time: 3.94
Time: 3.96
Time: 3.98
Time: 4.0
Time: 4.02
Time: 4.04
Time: 4.06
Time: 4.08
0	4.080171331644991	Kicked
[3.46 3.48 3.5  3.52 3.54 3.56 3.58 3.6  3.62 3.64 3.66 3.68 3.7  3.72
 3.74 3.76 3.78 3.8  3.82 3.84 3.86 3.88 3.9  3.92 3.94 3.96 3.98 4.
 4.02 4.04 4.06 4.08] 
 [4.08 4.03 3.98 3.93 3.88 3.83 3.78 3.73]
163.467808946815
[-0.0920471  -1.53410631]
[0.52477797 1.444494  ]
Time: 4.1
Time: 4.12
Time: 4.14
Time: 4.16
Time: 4.18
Time: 4.2
Time: 4.22
Time: 4.24
Time: 4.26
Time: 4.28
Time: 4.3
Time: 4.32
Time: 4.34
Time: 

Time: 7.76
Time: 7.78
Time: 7.8
1	7.803315069726653	Kicked
[7.16 7.18 7.2  7.22 7.24 7.26 7.28 7.3  7.32 7.34 7.36 7.38 7.4  7.42
 7.44 7.46 7.48 7.5  7.52 7.54 7.56 7.58 7.6  7.62 7.64 7.66 7.68 7.7
 7.72 7.74 7.76 7.78 7.8  7.82] 
 [7.82 7.77 7.72 7.67 7.62 7.57 7.52 7.47]
-11.45991964196846
[ 0.1363506  -1.30187662]
[-0.12502766 -1.30301278]
Time: 7.82
Time: 7.84
Time: 7.86
Time: 7.88
Time: 7.9
Time: 7.92
Time: 7.94
Time: 7.96
Time: 7.98
Time: 8.0
0	8.013271633138203	Kicked
[7.16 7.18 7.2  7.22 7.24 7.26 7.28 7.3  7.32 7.34 7.36 7.38 7.4  7.42
 7.44 7.46 7.48 7.5  7.52 7.54 7.56 7.58 7.6  7.62 7.64 7.66 7.68 7.7
 7.72 7.74 7.76 7.78 7.8  7.82 7.84 7.86 7.88 7.9  7.92 7.94 7.96 7.98
 8.  ] 
 [8.   7.95 7.9  7.85 7.8  7.75 7.7  7.65]
-101.41807599500285
[ 0.74294295 -0.32347148]
[-0.46414747 -0.66420264]
Cleaning buffer for 1
64
25
Time: 8.02
Time: 8.04
Time: 8.06
Time: 8.08
Time: 8.1
Time: 8.12
Time: 8.14
Time: 8.16
Time: 8.18
Time: 8.2
Time: 8.22
Time: 8.24
Time: 8.26
0	8.2643654352

Time: 11.42
Time: 11.44
Time: 11.46
Time: 11.48
Time: 11.5
Time: 11.52
0	11.533248953190913	Kicked
[10.9  10.92 10.94 10.96 10.98 11.   11.02 11.04 11.06 11.08 11.1  11.12
 11.14 11.16 11.18 11.2  11.22 11.24 11.26 11.28 11.3  11.32 11.34 11.36
 11.38 11.4  11.42 11.44 11.46 11.48 11.5  11.52] 
 [11.52 11.47 11.42 11.37 11.32 11.27 11.22 11.17]
-68.53657228756605
[ 1.03945047 -0.05772951]
[ 0.32661642 -0.98848959]
Time: 11.54
Time: 11.56
1	11.577759164221495	Kicked
[10.9  10.92 10.94 10.96 10.98 11.   11.02 11.04 11.06 11.08 11.1  11.12
 11.14 11.16 11.18 11.2  11.22 11.24 11.26 11.28 11.3  11.32 11.34 11.36
 11.38 11.4  11.42 11.44 11.46 11.48 11.5  11.52 11.54 11.56 11.58] 
 [11.58 11.53 11.48 11.43 11.38 11.33 11.28 11.23]
-84.3239716480681
[1.19585275 0.75076572]
[ 0.86535868 -1.11573623]
Time: 11.58
Time: 11.6
Time: 11.62
Time: 11.64
Time: 11.66
Time: 11.68
Time: 11.7
Time: 11.72
Time: 11.74
Time: 11.76
Time: 11.78
Time: 11.8
Time: 11.82
Time: 11.84
Time: 11.86
Time: 11.88
Time: 1

Time: 14.6
Time: 14.62
Time: 14.64
Time: 14.66
Time: 14.68
1	14.694247711954267	Kicked
[13.88 13.9  13.92 13.94 13.96 13.98 14.   14.02 14.04 14.06 14.08 14.1
 14.12 14.14 14.16 14.18 14.2  14.22 14.24 14.26 14.28 14.3  14.32 14.34
 14.36 14.38 14.4  14.42 14.44 14.46 14.48 14.5  14.52] 
 [14.52 14.47 14.42 14.37 14.32 14.27 14.22 14.17]
25.619181459252165
[1.03781733 0.87710118]
[0.55653954 1.23961084]
Time: 14.7
Time: 14.72
Time: 14.74
Time: 14.76
Time: 14.78
Time: 14.8
Time: 14.82
Time: 14.84
Time: 14.86
Time: 14.88
Time: 14.9
Time: 14.92
Time: 14.94
0	14.956139836579219	Kicked
[13.88 13.9  13.92 13.94 13.96 13.98 14.   14.02 14.04 14.06 14.08 14.1
 14.12 14.14 14.16 14.18 14.2  14.22 14.24 14.26 14.28 14.3  14.32 14.34
 14.36 14.38 14.4  14.42 14.44 14.46 14.48 14.5  14.52 14.54 14.56 14.58
 14.6  14.62 14.64 14.66 14.68 14.7  14.72 14.74 14.76 14.78] 
 [14.78 14.73 14.68 14.63 14.58 14.53 14.48 14.43]
-157.21447197715048
[0.54995508 0.63044693]
[-0.26287593 -0.79423559]
Time: 14.9

Time: 17.52
Time: 17.54
Time: 17.56
Time: 17.58
Time: 17.6
Time: 17.62
Time: 17.64
0	17.652488121193237	Kicked
[16.62 16.64 16.66 16.68 16.7  16.72 16.74 16.76 16.78 16.8  16.82 16.84
 16.86 16.88 16.9  16.92 16.94 16.96 16.98 17.   17.02 17.04 17.06 17.08
 17.1  17.12 17.14 17.16 17.18 17.2  17.22 17.24 17.26 17.28 17.3  17.32
 17.34 17.36 17.38 17.4  17.42 17.44 17.46 17.48] 
 [17.48 17.43 17.38 17.33 17.28 17.23 17.18 17.13]
-94.97229819335116
[0.5771018  1.22574592]
[ 1.17111333 -0.68117041]
Time: 17.66
Time: 17.68
1	17.68869447783379	Kicked
[16.62 16.64 16.66 16.68 16.7  16.72 16.74 16.76 16.78 16.8  16.82 16.84
 16.86 16.88 16.9  16.92 16.94 16.96 16.98 17.   17.02 17.04 17.06 17.08
 17.1  17.12 17.14 17.16 17.18 17.2  17.22 17.24 17.26 17.28 17.3  17.32
 17.34 17.36 17.38 17.4  17.42 17.44 17.46 17.48 17.5  17.52] 
 [17.52 17.47 17.42 17.37 17.32 17.27 17.22 17.17]
163.67844530103426
[ 0.93735604 -0.20691311]
[-0.84143193  0.46199751]
Cleaning buffer for 0
55
25
Time: 17.7
Time:

In [8]:
import matplotlib