In [2]:
import gymnasium as gym
import minigrid
from minigrid.wrappers import ImgObsWrapper, RGBImgObsWrapper
from src.modules.environment.minigrid_wrappers import FullyObsWrapper
import numpy as np

# minigrid.register_minigrid_envs()

In [8]:
env = gym.make("MiniGrid-Empty-5x5-v0")
observation, info = env.reset(seed=42)

env = FullyObsWrapper(env)
env = ImgObsWrapper(env)

observation, info = env.reset(seed=42)
print(observation)

observation, reward, terminated, truncated, info = env.step(0)
print(observation)


[[[ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [10  0  0]
  [ 1  0  0]
  [ 1  0  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [ 1  0  0]
  [ 1  0  0]
  [ 1  0  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [ 1  0  0]
  [ 1  0  0]
  [ 8  1  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]]]
[[[ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [10  0  3]
  [ 1  0  0]
  [ 1  0  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [ 1  0  0]
  [ 1  0  0]
  [ 1  0  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [ 1  0  0]
  [ 1  0  0]
  [ 8  1  0]
  [ 2  5  0]]

 [[ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]
  [ 2  5  0]]]


In [3]:
import torch
from typing import Any

class CustomIndex:
    def __init__(self, data):
        self.data = data if isinstance(data, torch.Tensor) else torch.tensor(data)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return CustomIndexOperation(self, idx)
        
    def __repr__(self):
        return f"CustomIndex(data={self.data})"

class CustomIndexOperation:
    def __init__(self, parent, idx):
        self.parent = parent
        self.idx = idx
    
    def __index__(self):
        # For single-value indexing
        if isinstance(self.idx, (int, slice)):
            return self.parent.data[self.idx].item()
        return self.parent.data.item()
    
    def __torch_function__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        
        # Get the tensor being indexed (first argument)
        tensor = args[0]
        
        # Handle different indexing cases
        if isinstance(self.idx, (int, slice)):
            # For nested indexing (e.g., x[custom_idx[0]])
            return tensor.__getitem__(self.parent.data[self.idx])
        else:
            # For direct indexing (e.g., x[custom_idx])
            return tensor.__getitem__(self.parent.data)
    
    def __array__(self):
        # Support numpy conversion if needed
        if isinstance(self.idx, (int, slice)):
            return self.parent.data[self.idx].numpy()
        return self.parent.data.numpy()
    
    def __repr__(self):
        return f"CustomIndexOperation(parent={self.parent}, idx={self.idx})"

# Example usage
def create_example():
    # Create sample data
    x = torch.arange(15)
    indices = torch.tensor([1, 3, 5, 10])
    
    # Create custom indexer
    custom_idx = CustomIndex(indices)
    
    print("Original tensor shape:", x.shape)
    
    print("\nDirect indexing - x[custom_idx]:")
    result0 = x[custom_idx]
    print(result0)
    print("Shape:", result0.shape)
    
    print("\nDirect indexing - x[custom_idx[0]]:")
    result1 = x[custom_idx[0]]
    print(result1)
    print("Shape:", result1.shape)
    
    print("\nNested indexing with slice - x[custom_idx[1:3]]:")
    result2 = x[custom_idx[1:3]]
    print(result2)
    print("Shape:", result2.shape)
    
    return x, custom_idx

if __name__ == "__main__":
    create_example()

Original tensor shape: torch.Size([15])

Direct indexing - x[custom_idx]:


RuntimeError: Could not infer dtype of CustomIndexOperation