# Training Custom Agents with MLFlow-Assist 🤖

This notebook shows how to train custom AI agents using MLFlow-Assist. We'll create a simple reinforcement learning agent that learns to solve the CartPole environment.

In [None]:
import gym
import numpy as np
from torch.utils.data import Dataset, DataLoader

from mlflow_assist.agents import RLAgent, RLConfig, AgentTrainer

## 1. Create the Environment and Dataset

In [None]:
class RLDataset(Dataset):
    def __init__(self, transitions):
        self.states = [t[0] for t in transitions]
        self.actions = [t[1] for t in transitions]
        self.rewards = [t[2] for t in transitions]
        self.next_states = [t[3] for t in transitions]
        self.dones = [t[4] for t in transitions]
        
    def __len__(self):
        return len(self.states)
        
    def __getitem__(self, idx):
        return {
            'states': self.states[idx],
            'actions': self.actions[idx],
            'rewards': self.rewards[idx],
            'next_states': self.next_states[idx],
            'dones': self.dones[idx]
        }

# Create environment
env = gym.make('CartPole-v1')

## 2. Configure and Create the Agent

In [None]:
# Configure the agent
config = RLConfig(
    state_size=4,  # CartPole has 4 state dimensions
    action_size=2,  # CartPole has 2 actions
    hidden_size=128,
    learning_rate=3e-4
)

# Create the agent
agent = RLAgent(config)

## 3. Collect Training Data

In [None]:
def collect_episode(env, agent, max_steps=500):
    state = env.reset()
    transitions = []
    
    for _ in range(max_steps):
        action = agent.predict(state)
        next_state, reward, done, _ = env.step(action)
        
        transitions.append((state, action, reward, next_state, done))
        state = next_state
        
        if done:
            break
            
    return transitions

# Collect initial training data
episodes = [collect_episode(env, agent) for _ in range(10)]
train_dataset = RLDataset([t for e in episodes for t in e])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

## 4. Train the Agent

In [None]:
# Create trainer
trainer = AgentTrainer(
    agent=agent,
    train_loader=train_loader,
    experiment_name="cartpole-training"
)

# Train the agent
trainer.train(
    num_epochs=50,
    eval_every=5,
    save_every=10
)

## 5. Test the Trained Agent

In [None]:
def evaluate_agent(env, agent, episodes=10):
    total_rewards = []
    
    for _ in range(episodes):
        state = env.reset()
        episode_reward = 0
        done = False
        
        while not done:
            action = agent.predict(state)
            state, reward, done, _ = env.step(action)
            episode_reward += reward
            
        total_rewards.append(episode_reward)
        
    return np.mean(total_rewards)

# Test the agent
mean_reward = evaluate_agent(env, agent)
print(f"Average reward over 10 episodes: {mean_reward:.2f}")

## Next Steps

Now you can:
1. Try different environments
2. Customize the agent architecture
3. Experiment with hyperparameters
4. Add more advanced features

For more examples, check out the other notebooks in the `examples` directory!