# Import libraries

In [1]:
import gym
import ptan
from typing import List, Optional, Tuple, Any

# Toy environment

In [2]:
class ToyEnv(gym.Env):
    """
    Environment with observation 0-4 and actions 0-2
    Observations are rotated sequentially mod 5, reward is equal to given action
    Episodes have fixed length of 10
    """
    
    def __init__(self):
        super(ToyEnv, self).__init__()
        
        self.observation_space = gym.spaces.Discrete(n=5)
        self.action_space = gym.spaces.Discrete(n=3)
        self.step_index = 0
        
    def reset(self):
        self.step_index = 0
        return self.step_index
    
    def step(self, action):
        is_done = self.step_index == 10
        
        if is_done:
            return self.step_index % self.observation_space.n, 0.0, is_done, {}
        
        self.step_index += 1
        
        return self.step_index % self.observation_space.n, float(action), self.step_index == 10, {}

# Simple agent

In [3]:
class DullAgent(ptan.agent.BaseAgent):
    """
    Agent that always returns the fixed action
    """
    
    def __init__(self, action: int):
        self.action = action
        
    def __call__(self, observations: List[Any], state: Optional[List]=None) -> Tuple[List[int], Optional[List]]:
        return [self.action for _ in observations], state

# Simple experience replay buffer

In [4]:
env = ToyEnv()
agent = DullAgent(action=1)
exp_source = ptan.experience.ExperienceSourceFirstLast(env=env, agent=agent, gamma=1.0, steps_count=1)
buffer = ptan.experience.ExperienceReplayBuffer(exp_source, buffer_size=100)

In [5]:
len(buffer) # Current length

0

# Training loop

In [6]:
for step in range(6):
    buffer.populate(1)
    
    if len(buffer) < 5:
        continue
        
    batch = buffer.sample(4)
    
    print("Train time, %d batch samples:", len(batch))
    for s in batch:
        print(s)

Train time, %d batch samples: 4
ExperienceFirstLast(state=1, action=1, reward=1.0, last_state=2)
ExperienceFirstLast(state=3, action=1, reward=1.0, last_state=4)
ExperienceFirstLast(state=0, action=1, reward=1.0, last_state=1)
ExperienceFirstLast(state=1, action=1, reward=1.0, last_state=2)
Train time, %d batch samples: 4
ExperienceFirstLast(state=3, action=1, reward=1.0, last_state=4)
ExperienceFirstLast(state=2, action=1, reward=1.0, last_state=3)
ExperienceFirstLast(state=3, action=1, reward=1.0, last_state=4)
ExperienceFirstLast(state=3, action=1, reward=1.0, last_state=4)
