In [1]:
import numpy as np
import os
import scipy.io

In [6]:
class EelEnv:
    def __init__(self, data_folder):
        self.data_folder = data_folder
        self.trial_files = sorted([f for f in os.listdir(data_folder) if f.endswith('.mat')])
        self.current_trial_idx = -1
        self.trial_data = None
        self.eels_info = []  # Holds info per eel
        self.load_next_trial()

    def load_next_trial(self):
        self.current_trial_idx += 1
        if self.current_trial_idx >= len(self.trial_files):
            self.trial_data = None
            self.eels_info = []
            return

        filepath = os.path.join(self.data_folder, self.trial_files[self.current_trial_idx])
        mat = scipy.io.loadmat(filepath, squeeze_me=True)
        curr_trial_data = mat['curr_trial_data']

        # Get the eel tuples (should be a list of 2 tuples)
        self.trial_data = curr_trial_data
        eel_field = curr_trial_data['eels']
        if isinstance(eel_field, np.ndarray) and eel_field.dtype == object:
            eel_tuples = eel_field.tolist()  # Converts 0-d array of tuples into a list
        else:
            eel_tuples = [eel_field]  # Just wrap it as a list if single tuple

        # Make list of dicts per eel
        self.eels_info = []
        for eel in eel_tuples:
            eel_dict = {
                'idx': eel[0],
                'side': eel[1],  # 1 = left, 2 = right
                'color': eel[2],  # RGB triplet
                'pos': eel[3],
                'fish_pos': eel[4],
                'potent': eel[5],
                'comp_change': eel[6],
                'dist_params': eel[7],
                'competency': eel[8],
                'final_competency': eel[9],
                'reliability': eel[10],
            }
            self.eels_info.append(eel_dict)
    
    def color_name(self, rgb):
        if list(rgb) == [0, 0, 255]:
            return "Blue"
        elif list(rgb) == [157, 0, 255]:
            return "Purple"
        else:
            return f"RGB{list(rgb)}"
    

    def step(self, action):
        """
        action: 0 for left eel, 1 for right eel
        returns: reward (0 or 1), done (bool), and info (e.g., which eel was chosen)
        """
        if self.trial_data is None or not self.eels_info:
            return None, True, {}
        
        print("Eel 0 info:")
        print(f"  Side: {'left' if self.eels_info[0]['side'] == 1 else 'right'}")
        print(f"  Color: {self.color_name(self.eels_info[0]['color'])}")
        print(f"  Competency: {self.eels_info[0]['competency']:.3f}")
        print(f"  Reliability: {self.eels_info[0]['reliability']:.3f}")

        print("Eel 1 info:")
        print(f"  Side: {'left' if self.eels_info[1]['side'] == 1 else 'right'}")
        print(f"  Color: {self.color_name(self.eels_info[1]['color'])}")
        print(f"  Competency: {self.eels_info[1]['competency']:.3f}")
        print(f"  Reliability: {self.eels_info[1]['reliability']:.3f}")

        print("---------------------------------------------------------")
        chosen_eel = self.eels_info[action]

        competency = chosen_eel['competency']  # Value between 0 and 1
        reliability = chosen_eel['reliability']  # Value between 0 and 1

        # -------------------------------
        # 1. Sample number of fish caught based on competency
        # We'll assume max 3 fish can be caught
        # Use a categorical distribution where probabilities depend on competency
        # Example: [p(0 fish), p(1), p(2), p(3)]
        comp = competency
        prob_fish_caught = np.array([
            (1 - comp) ** 3,                  # 0 fish
            3 * comp * (1 - comp) ** 2,       # 1 fish
            3 * comp ** 2 * (1 - comp),       # 2 fish
            comp ** 3                         # 3 fish
        ])
        prob_fish_caught /= prob_fish_caught.sum()  # Normalize just in case
        num_fish_caught = np.random.choice([0, 1, 2, 3], p=prob_fish_caught)

        # -------------------------------
        # 2. Multiply num_fish_caught by reliability to get reward probability
        reward_prob = reliability * num_fish_caught
        reward = int(np.random.rand() < reward_prob)
        
        # Print chosen eel's outcome
        print(f"\nChosen eel side: {'left' if chosen_eel['side'] == 1 else 'right'}")
        print(f"Chosen eel color: {self.color_name(chosen_eel['color'])}")
        print(f"Number of fish caught: {num_fish_caught}")
        print(f"Reward probability (reliability × fish caught): {reward_prob:.3f}")
        print(f"Reward delivered: {reward}\n{'-'*40}")

        info = {
            'chosen_side': 'left' if chosen_eel['side'] == 1 else 'right',
            'eel_color': chosen_eel['color'],
            'competency': competency,
            'reliability': reliability,
            'num_fish_caught': num_fish_caught,
            'reward_prob': reward_prob
        }

        # Load next trial for next step
        done = not self.load_next_trial()
        return reward, done, info

In [7]:
env = EelEnv(data_folder='premade_eels')

for _ in range(10):
    action = np.random.choice([0, 1])  # Replace with agent's choice later
    reward, done, info = env.step(action)

Eel 0 info:
  Side: left
  Color: Purple
  Competency: 0.700
  Reliability: 0.100
Eel 1 info:
  Side: right
  Color: Blue
  Competency: 0.400
  Reliability: 0.250
---------------------------------------------------------

Chosen eel side: right
Chosen eel color: Blue
Number of fish caught: 1
Reward probability (reliability × fish caught): 0.250
Reward delivered: 0
----------------------------------------
Eel 0 info:
  Side: right
  Color: Blue
  Competency: 0.900
  Reliability: 0.250
Eel 1 info:
  Side: left
  Color: Purple
  Competency: 0.400
  Reliability: 0.100
---------------------------------------------------------

Chosen eel side: right
Chosen eel color: Blue
Number of fish caught: 2
Reward probability (reliability × fish caught): 0.500
Reward delivered: 1
----------------------------------------
Eel 0 info:
  Side: right
  Color: Blue
  Competency: 0.900
  Reliability: 0.250
Eel 1 info:
  Side: left
  Color: Purple
  Competency: 0.400
  Reliability: 0.100
---------------------