# Q-Learning Agent Demo

This notebook demonstrates training a Q-learning agent for sentiment classification.

In [None]:
import sys
sys.path.append('..')

from agent.q_learning import QLearningAgent
from environments.sentiment_env import SentimentEnv
import numpy as np
import matplotlib.pyplot as plt

## 1. Initialize Environment and Agent

In [None]:
# Create environment
env = SentimentEnv(
    dataset_name='imdb',
    split='train',
    use_subset=True,
    subset_size=500
)

# Create agent
agent = QLearningAgent(
    actions=env.get_all_sentiments(),
    alpha=0.1,
    gamma=0.9,
    epsilon=0.3,
    epsilon_decay=0.995
)

print(f"Environment: {len(env.dataset)} samples")
print(f"Actions: {agent.actions}")

## 2. Train the Agent

In [None]:
# Training parameters
num_episodes = 200

# Metrics
episode_rewards = []
episode_accuracies = []

# Training loop
for episode in range(num_episodes):
    # Reset environment
    text, true_label = env.reset()
    
    # Agent predicts
    prediction = agent.predict(text, explore=True)
    
    # Environment step
    _, reward, done, info = env.step(prediction)
    
    # Agent learns
    agent.learn(text, prediction, reward, done=True)
    
    # Decay epsilon
    agent.decay_epsilon()
    
    # Track metrics
    episode_rewards.append(reward)
    episode_accuracies.append(int(info['correct']))
    
    if (episode + 1) % 50 == 0:
        recent_acc = np.mean(episode_accuracies[-50:])
        print(f"Episode {episode + 1}: Accuracy = {recent_acc:.3f}, Epsilon = {agent.epsilon:.3f}")

print("\nTraining complete!")

## 3. Visualize Learning

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Rewards
window = 20
ma_rewards = np.convolve(episode_rewards, np.ones(window)/window, mode='valid')
axes[0].plot(episode_rewards, alpha=0.3, label='Episode Reward')
axes[0].plot(range(window-1, len(episode_rewards)), ma_rewards, label=f'{window}-Episode MA', linewidth=2)
axes[0].set_xlabel('Episode')
axes[0].set_ylabel('Reward')
axes[0].set_title('Episode Rewards')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
ma_acc = np.convolve(episode_accuracies, np.ones(window)/window, mode='valid')
axes[1].plot(episode_accuracies, alpha=0.3, label='Episode Accuracy')
axes[1].plot(range(window-1, len(episode_accuracies)), ma_acc, label=f'{window}-Episode MA', linewidth=2)
axes[1].axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='Random Baseline')
axes[1].set_xlabel('Episode')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Learning Curve')
axes[1].set_ylim([-0.05, 1.05])
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 4. Test the Agent

In [None]:
# Test samples
test_texts = [
    "This movie was absolutely amazing! I loved every minute of it.",
    "Terrible waste of time. Would not recommend.",
    "The film was okay, nothing special."
]

print("Testing trained agent:\n")
for text in test_texts:
    prediction = agent.predict(text, explore=False)
    q_values = agent.get_state_action_values(text)
    
    print(f"Text: {text}")
    print(f"Prediction: {prediction}")
    print(f"Q-values: {q_values}")
    print()

## 5. Agent Statistics

In [None]:
stats = agent.get_stats()
print("Agent Statistics:")
for key, value in stats.items():
    print(f"  {key}: {value}")