In [4]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import fish_models
import robofish.io

In [5]:
fishes = 4

In [71]:
from sklearn.mixture import GaussianMixture
import sklearn.cluster as cl
import sklearn.metrics as mt

class EMAlgorithmFishModel(fish_models.gym_interface.AbstractRaycastBasedModel):
    """
    Representation of a Gaussian mixture model probability distribution. 
    GMMs are probabilistic models that assume all the data points are generated 
    from a mixture of several Gaussian distributions with unknown parameters.

    The EM algorithm is an iterative approach that cycles between two modes:
    E-Step. Estimate the missing variables in the dataset.
    M-Step. Maximize the parameters of the model in the presence of the data.
    """
    def __init__(self):
        self.model = GaussianMixture(n_components=2,
                                     covariance_type='full',
                                     tol=0.01,
                                     max_iter=1000)
        
    
    def choose_action(self, view: np.ndarray):
        
        prediction = self.model.predict_proba([view])
        print(self.model.means_)
        print(self.model.covariances_)

        print(prediction)
        
        speed = prediction[0]
        turn = prediction[1]
        
        # turn correction for walls avoidance
        turn = self.avoid_walls(view, turn)
        
        return speed, turn
    
    def avoid_walls(self, view, turn):
        """
        Forces to turn a fish in a random direction
        if in a view's raycast of the walls
        a wall in the front of a fish is detected to near

        Parameters
        ---------
        view : array_like
            The observations of the virtual fish
        turn : float
            Turn predicted by a model that is to modify

        Returns
        ---------
        turn : float
            Original or modified turn depending on the wall distance
        """
        param = random.randint(-5, 5)
        
        if param == 0:
            param = random.randint(5, 11)
        
        if view[6] > 0.9:
            return param * np.pi
        else:
            return turn
        
 
    def fit(self, dset):
        """
        Learns parameters and runs the the two steps of EM Algorithm 
        until the average log-likelihood converges.
        """
        
        X = dset[:]["views"]
        y = dset[:]["actions"]
        
        self.model.fit(X, y)
        
    
model = EMAlgorithmFishModel()

In [72]:
raycast = fish_models.gym_interface.Raycast(
            n_wall_raycasts=5,
            n_fish_bins=4,
            fov_angle_fish_bins=np.pi,
            fov_angle_wall_raycasts=np.pi,
            world_bounds=([-50, -50], [50, 50]),
        )



In [8]:
data_folder = Path("data/live_female_female/train")

dset = fish_models.datasets.io_dataset.IoDataset(
    data_folder,
    raycast,
    output_strings=["poses", "actions", "views"],
    reduce_dim=2,
    max_files=5,
)

  0%|          | 0/5 [00:00<?, ?it/s]

Loading data from 5 files.


100%|██████████| 5/5 [00:01<00:00,  4.12it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Calculating views from 5 files.


100%|██████████| 5/5 [00:09<00:00,  1.84s/it]

Created IoDataset:
Reduced the first 3 dimensions from (5, 2, 8989) to (89890)
poses	(89890, 3):	consisting of x, y, calc_ori_rad.
actions	(89880, 2):	consisting of speed[cm/s] and turn [rad/s].
views	(89880, 9):	4 fish_bins and 5 wall ray casts.






In [73]:
model.fit(dset)

In [74]:
generator = fish_models.gym_interface.TrackGeneratorGymRaycast(
    model, raycast, [100,100], 25
)

track = generator.create_track(n_guppies=fishes, trackset_len=5000)

  0%|          | 0/4999 [00:00<?, ?it/s]

[[0.      0.31055 0.2857  0.26292 0.78932 0.75222 0.66286 0.55659 0.51382]
 [0.1552  0.      0.      0.      0.74137 0.74917 0.71404 0.62992 0.5633 ]]
[[[ 0.       0.       0.       0.       0.       0.       0.
    0.       0.     ]
  [ 0.       0.18026 -0.08872 -0.08165 -0.01159 -0.01344 -0.00964
   -0.0005   0.01185]
  [ 0.      -0.08872  0.17429 -0.07512  0.01173  0.01427  0.00587
   -0.00743 -0.01568]
  [ 0.      -0.08165 -0.07512  0.16619  0.00218  0.0073   0.01205
    0.00898 -0.00319]
  [ 0.      -0.01159  0.01173  0.00218  0.0426   0.03152  0.00107
   -0.02776 -0.03724]
  [ 0.      -0.01344  0.01427  0.0073   0.03152  0.04306  0.01917
   -0.01709 -0.03714]
  [ 0.      -0.00964  0.00587  0.01205  0.00107  0.01917  0.03676
    0.01728 -0.01461]
  [ 0.      -0.0005  -0.00743  0.00898 -0.02776 -0.01709  0.01728
    0.04854  0.02961]
  [ 0.       0.01185 -0.01568 -0.00319 -0.03724 -0.03714 -0.01461
    0.02961  0.05545]]

 [[ 0.11681  0.       0.       0.      -0.00765 -0.00851 -0.




IndexError: index 1 is out of bounds for axis 0 with size 1

In [None]:
f = generator.as_io_file(track)
f.save_as("output/em_algorithm.hdf5")

In [None]:
plt.figure(figsize=(10,10))
plt.xlim(-50,50)
plt.ylim(-50,50)
for fish_id in range(fishes):
    plt.plot(track[fish_id, :, 0], track[fish_id, :, 1])
plt.show()