In [6]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

import fish_models
import robofish.io


In [8]:
class SimpleForwardModel(fish_models.gym_interface.AbstractRaycastBasedModel):
    def choose_action(self, view: np.ndarray):
        # Return speed and turn from view
        speed = np.random.random() * 5.
        turn = (np.random.random() - 0.5) * 5.
        return speed, turn

In [9]:
model = SimpleForwardModel()
raycast = fish_models.gym_interface.Raycast(
            n_wall_raycasts=5,
            n_fish_bins=4,
            fov_angle_fish_bins=np.pi,
            fov_angle_wall_raycasts=np.pi,
            world_bounds=([-50, -50], [50, 50]),
        )

data_folder = Path("data/live_female_female/train")

dset = fish_models.datasets.io_dataset.IoDataset(
    data_folder,
    raycast,
    output_strings=["poses", "actions", "views"],
    reduce_dim=2,
    max_files=10,
)

  0%|          | 0/10 [00:00<?, ?it/s]

Loading data from 10 files.


100%|██████████| 10/10 [00:03<00:00,  2.62it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

Calculating views from 10 files.


100%|██████████| 10/10 [00:18<00:00,  1.85s/it]

Created IoDataset:
Reduced the first 3 dimensions from (10, 2, 8989) to (179780)
poses	(179780, 3):	consisting of x, y, calc_ori_rad.
actions	(179760, 2):	consisting of speed[cm/s] and turn [rad/s].
views	(179760, 9):	4 fish_bins and 5 wall ray casts.






In [12]:
generator = fish_models.gym_interface.TrackGeneratorGymRaycast(
    model, raycast, [100,100], 25
)

track = generator.create_track(n_guppies=2, trackset_len=1000)
print(track.shape)

100%|██████████| 999/999 [00:01<00:00, 975.03it/s] 

(2, 1000, 3)





In [14]:
f = generator.as_io_file(track)
f.save_as("output/simple_forward.hdf5")

<HDF5 file "eb824036-216b-4e54-938c-0b1dd885d5df" (mode r+)>