## Different approaches to experience collection( n-step lookahead) from simple to complex.

### 1. Simple Rolling Buffer (most common)

In [None]:
import numpy as np
import gymnasium as gym
from collections import deque

def compute_n_step_returns(rewards, gamma, n_steps):
    """
    Compute n-step discounted returns. 
    rewards: list of rewards [r_t, r_{t+1}, ..., r_{t+n-1}]
    """
    result = 0.0
    for i, r in enumerate(rewards):
        result += (gamma ** i) * r
    return result

class SimpleNStepCollector:
    """
    Minimal n-step experience collector.
    Uses a rolling window approach.
    """
    def __init__(self, gamma=0.99, n_steps=10):
        self.gamma = gamma
        self.n_steps = n_steps
        
        # Rolling buffers for the last n_steps
        self. states = deque(maxlen=n_steps)
        self.actions = deque(maxlen=n_steps)
        self.rewards = deque(maxlen=n_steps)
        
        self.episode_rewards = []
        self.current_episode_reward = 0.0
    
    def add(self, state, action, reward, done):
        """
        Add a transition and return experience if ready.
        
        Returns:
            experience dict or None
        """
        self. states.append(state)
        self.actions.append(action)
        self.rewards.append(reward)
        self.current_episode_reward += reward
        
        experience = None
        
        # Case 1: Episode ended - flush all remaining transitions
        if done:
            # Process all buffered transitions
            experiences = []
            for i in range(len(self.states)):
                n_step_return = compute_n_step_returns(
                    list(self.rewards)[i: ], 
                    self.gamma, 
                    self.n_steps
                )
                experiences.append({
                    'state': self. states[i],
                    'action': self.actions[i],
                    'return': n_step_return,
                    'done': True
                })
            
            # Reset for next episode
            self.episode_rewards. append(self.current_episode_reward)
            self.current_episode_reward = 0.0
            self.states.clear()
            self.actions.clear()
            self.rewards.clear()
            
            return experiences
        
        # Case 2: Buffer is full - emit oldest transition
        elif len(self.states) == self.n_steps:
            n_step_return = compute_n_step_returns(
                list(self.rewards), 
                self.gamma, 
                self.n_steps
            )
            experience = {
                'state': self.states[0],
                'action': self.actions[0],
                'return': n_step_return,
                'done':  False
            }
            return [experience]
        
        # Case 3: Still accumulating
        return None

# Usage
env = gym.make("CartPole-v1")
collector = SimpleNStepCollector(gamma=0.99, n_steps=10)

state, _ = env.reset()
batch = []

for _ in range(1000):
    action = env.action_space.sample()  # Replace with policy
    next_state, reward, terminated, truncated, _ = env. step(action)
    done = terminated or truncated
    
    experiences = collector.add(state, action, reward, done)
    
    if experiences:
        batch. extend(experiences)
    
    if len(batch) >= 8:
        # Train here
        print(f"Training on {len(batch)} experiences")
        batch.clear()
    
    if done:
        state, _ = env.reset()
    else:
        state = next_state

### 2. Using generators

In [None]:
def n_step_experience_generator(env, policy, gamma=0.99, n_steps=10):
    """
    Generator that yields n-step experiences. 
    
    Yields:
        dict with 'state', 'action', 'return', 'done'
    """
    while True:
        # Start new episode
        state, _ = env.reset()
        done = False
        
        # Buffers for current episode
        states = []
        actions = []
        rewards = []
        
        step = 0
        
        while not done:
            # Get action from policy
            action = policy(state)
            
            # Execute
            next_state, reward, terminated, truncated, _ = env. step(action)
            done = terminated or truncated
            
            # Store
            states.append(state)
            actions.append(action)
            rewards.append(reward)
            
            # Yield experience if we have enough steps
            if len(states) >= n_steps:
                idx = 0  # Oldest transition
                n_step_return = sum(
                    (gamma ** i) * rewards[idx + i] 
                    for i in range(n_steps)
                )
                
                yield {
                    'state': states[idx],
                    'action': actions[idx],
                    'return': n_step_return,
                    'done': False,
                    'steps':  n_steps
                }
                
                # Remove oldest
                states.pop(0)
                actions.pop(0)
                rewards.pop(0)
            
            state = next_state
            step += 1
        
        # Episode ended - flush remaining transitions
        while len(states) > 0:
            # Compute return for remaining steps
            n_step_return = sum(
                (gamma ** i) * rewards[i] 
                for i in range(len(rewards))
            )
            
            yield {
                'state': states[0],
                'action':  actions[0],
                'return': n_step_return,
                'done': True,
                'steps': len(states)
            }
            
            # Remove oldest
            states.pop(0)
            actions.pop(0)
            rewards.pop(0)

# Usage
def random_policy(state):
    return env.action_space.sample()

env = gym.make("CartPole-v1")
exp_gen = n_step_experience_generator(env, random_policy, gamma=0.99, n_steps=10)

batch = []
for exp in exp_gen:
    batch.append(exp)
    
    if len(batch) >= 8:
        # Train
        print(f"Batch of {len(batch)} experiences")
        batch.clear()
        
    if len(batch) > 100:  # Stop after some experiences
        break

### using Vectorized/Batch Approach (Most Efficient)

In [None]:
import numpy as np

class VectorizedNStepBuffer:
    """
    Efficiently handles n-step returns using numpy arrays.
    Best for multiple parallel environments.
    """
    def __init__(self, buffer_size, state_shape, gamma=0.99, n_steps=10):
        self.buffer_size = buffer_size
        self.gamma = gamma
        self.n_steps = n_steps
        self.ptr = 0
        self.size = 0
        
        # Pre-allocate arrays
        self.states = np. zeros((buffer_size, *state_shape), dtype=np.float32)
        self.actions = np.zeros(buffer_size, dtype=np. int64)
        self.rewards = np.zeros(buffer_size, dtype=np.float32)
        self. dones = np.zeros(buffer_size, dtype=np.bool_)
        
        # Precompute gamma powers for efficiency
        self.gamma_powers = np.array([gamma ** i for i in range(n_steps)])
    
    def add(self, state, action, reward, done):
        """Add single transition"""
        self.states[self. ptr] = state
        self. actions[self.ptr] = action
        self.rewards[self. ptr] = reward
        self. dones[self.ptr] = done
        
        self.ptr = (self. ptr + 1) % self.buffer_size
        self.size = min(self.size + 1, self.buffer_size)
    
    def compute_n_step_returns(self):
        """
        Compute n-step returns for all transitions in buffer.
        Returns arrays of (states, actions, returns, dones)
        """
        if self. size < self.n_steps:
            return None
        
        n_samples = self.size - self.n_steps + 1
        returns = np.zeros(n_samples, dtype=np. float32)
        
        for i in range(n_samples):
            # Check if episode ends within n_steps
            done_idx = None
            for j in range(self.n_steps):
                if self.dones[i + j]:
                    done_idx = j + 1
                    break
            
            # Compute n-step return
            if done_idx is not None: 
                # Episode ended, use only rewards up to termination
                returns[i] = np.sum(
                    self.gamma_powers[: done_idx] * self.rewards[i: i+done_idx]
                )
            else:
                # Full n-step return
                returns[i] = np.sum(
                    self.gamma_powers * self.rewards[i:i+self.n_steps]
                )
        
        return (
            self.states[: n_samples]. copy(),
            self.actions[: n_samples].copy(),
            returns,
            self.dones[:n_samples].copy()
        )
    
    def sample_batch(self, batch_size):
        """Sample random batch of n-step experiences"""
        data = self.compute_n_step_returns()
        if data is None: 
            return None
        
        states, actions, returns, dones = data
        
        # Random sampling
        indices = np.random.choice(len(states), size=batch_size, replace=False)
        
        return {
            'states': states[indices],
            'actions': actions[indices],
            'returns': returns[indices],
            'dones': dones[indices]
        }
    
    def clear(self):
        """Clear buffer"""
        self.ptr = 0
        self.size = 0

# Usage
env = gym.make("CartPole-v1")
state_shape = env.observation_space. shape

buffer = VectorizedNStepBuffer(
    buffer_size=1000,
    state_shape=state_shape,
    gamma=0.99,
    n_steps=10
)

# Collect experience
state, _ = env.reset()
for _ in range(500):
    action = env.action_space.sample()
    next_state, reward, terminated, truncated, _ = env.step(action)
    done = terminated or truncated
    
    buffer. add(state, action, reward, done)
    
    if done:
        state, _ = env.reset()
    else:
        state = next_state

# Sample batch for training
batch = buffer.sample_batch(batch_size=32)
if batch: 
    print(f"Sampled batch with shapes:")
    print(f"  States: {batch['states'].shape}")
    print(f"  Actions: {batch['actions'].shape}")
    print(f"  Returns: {batch['returns'].shape}")

### simple episodic logic(clearest )

In [None]:
def collect_n_step_episode(env, policy_fn, gamma=0.99, n_steps=10):
    """
    Collect full episode and compute n-step returns.
    Returns list of (state, action, n_step_return) tuples.
    """
    trajectory = []
    state, _ = env.reset()
    done = False
    
    while not done:
        action = policy_fn(state)
        next_state, reward, terminated, truncated, _ = env. step(action)
        done = terminated or truncated
        
        trajectory.append({
            'state': state,
            'action': action,
            'reward': reward,
            'done':  done
        })
        
        state = next_state
    
    # Compute n-step returns for each transition
    experiences = []
    for i in range(len(trajectory)):
        # Look ahead up to n_steps or end of episode
        horizon = min(n_steps, len(trajectory) - i)
        
        # Compute discounted return
        n_step_return = sum(
            (gamma ** j) * trajectory[i + j]['reward']
            for j in range(horizon)
        )
        
        experiences.append({
            'state': trajectory[i]['state'],
            'action': trajectory[i]['action'],
            'return': n_step_return,
            'done': trajectory[i]['done'],
            'n_steps_used': horizon
        })
    
    return experiences

# Usage - clearest pattern
env = gym.make("CartPole-v1")

def policy(state):
    # Your policy here
    return env.action_space.sample()

batch = []
for episode in range(10):
    experiences = collect_n_step_episode(env, policy, gamma=0.99, n_steps=10)
    batch.extend(experiences)
    
    if len(batch) >= 32:
        # Train on batch
        print(f"Training on {len(batch)} experiences")
        # ...  training code ...
        batch.clear()