In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import random
from snake import SnakeGame
from agent import A2CAgent
import mlflow
import os
from dotenv import load_dotenv

load_dotenv()

MLFLOW_URI = os.getenv("MLFLOW_URI")
if not MLFLOW_URI:
    raise Exception("MLFLOW_URI is not set")

mlflow.set_tracking_uri(uri=MLFLOW_URI)

In [8]:
import mlflow
import mlflow.pytorch
import torch
from mlflow import log_metric, log_param, log_artifacts
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

def train_agents(episodes):
    mlflow.set_experiment("Snake Game A2C Training")
    
    with mlflow.start_run():
        env = SnakeGame(10, 10, max_steps=20000)
        agent = A2CAgent(device=device)
        batch_size = 64
        rewards = []
        
        # Log parameters
        log_param("episodes", episodes)
        log_param("batch_size", batch_size)
        log_param("gamma", agent.gamma)
        log_param("learning_rate", agent.lr)
        log_param("memory_size", agent.maxlen)
        
        # Initialize metric accumulators
        metric_buffer = {
            "episode_reward": [],
            "episode_length": [],
            "total_loss": [],
            "actor_loss": [],
            "critic_loss": [],
            "entropy": []
        }
        
        for e in range(episodes):
            state = env.reset()
            acc_reward = 0
            done = False
            steps = 0
            
            while not done:
                action = agent.act(state)
                next_state, reward, done = env.step(action)
                agent.remember(state, action, reward, next_state, done)
                acc_reward += reward
                state = next_state
                steps += 1
            
            losses = agent.replay(batch_size)
            rewards.append(acc_reward)
            
            # Accumulate metrics
            metric_buffer["episode_reward"].append(acc_reward)
            metric_buffer["episode_length"].append(steps)
            if losses is not None:
                total_loss, actor_loss, critic_loss, entropy = losses
                metric_buffer["total_loss"].append(total_loss)
                metric_buffer["actor_loss"].append(actor_loss)
                metric_buffer["critic_loss"].append(critic_loss)
                metric_buffer["entropy"].append(entropy)
            
            # Log metrics every 100 episodes
            if e % 100 == 0 and e > 0:
                for metric, values in metric_buffer.items():
                    if values:
                        log_metric(metric, sum(values) / len(values), step=e)
                metric_buffer = {k: [] for k in metric_buffer}
            
            if e % 100_000 == 0 and e > 0:
                print(f"episode={e}")
                # Log model every 100k episodes
                mlflow.pytorch.log_model(agent.model, f"model_episode_{e}")
        
        # Log final model and rewards
        mlflow.pytorch.log_model(agent.model, "final_model")
        agent.save("./agent_10x10.state")
        log_artifacts("agent_10x10.state")
        
        # Save rewards to a file and log as artifact
        torch.save(rewards, "./rewards_10x10.state")
        log_artifacts("rewards_10x10.state")
    
    return agent, rewards

# Run the training
agent, rewards = train_agents(500000)

In [None]:
%load_ext line_profiler


In [None]:

# Profile the train_agents function
%lprun -f train_agents train_agents(1000)


Agent saved to agent_10x10.state




Timer unit: 1e-09 s

Total time: 16.8762 s
File: /tmp/ipykernel_2588604/82477549.py
Function: train_agents at line 10

Line #      Hits         Time  Per Hit   % Time  Line Contents
    10                                           def train_agents(episodes):
    11         1    6232367.0    6e+06      0.0      mlflow.set_experiment("Snake Game A2C Training")
    12                                               
    13         2   14530720.0    7e+06      0.1      with mlflow.start_run():
    14         1     138733.0 138733.0      0.0          env = SnakeGame(10, 10, max_steps=20000)
    15         1    1566002.0    2e+06      0.0          agent = A2CAgent(device=device)
    16         1        150.0    150.0      0.0          batch_size = 64
    17         1        150.0    150.0      0.0          rewards = []
    18                                                   
    19                                                   # Log parameters
    20         1    3759891.0    4e+06      0

In [7]:
import torch.autograd.profiler as profiler 

def profile_train_step(agent, env, pre_simulate_games=300):
    # Pre-simulate games to fill the replay buffer
    for _ in range(pre_simulate_games):
        state = env.reset()
        done = False
        while not done:
            action = agent.act(state)
            next_state, reward, done = env.step(action)
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            if done:
                break

        agent.replay(256)


    # Now start profiling
    state = env.reset()
    with profiler.profile(use_cuda=True) as prof:
        for _ in range(100):  # Profile 100 steps
            action = agent.act(state)
            next_state, reward, done = env.step(action)
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            if done:
                state = env.reset()
        losses = agent.replay(256)
    
    print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))


env = SnakeGame(width=10, height=10, max_steps=20000)
agent = A2CAgent(device=device)
profile_train_step(agent, env)

STAGE:2024-07-11 04:48:42 2625213:2625213 ActivityProfilerController.cpp:312] Completed Stage: Warm Up
STAGE:2024-07-11 04:48:42 2625213:2625213 ActivityProfilerController.cpp:318] Completed Stage: Collection
STAGE:2024-07-11 04:48:42 2625213:2625213 ActivityProfilerController.cpp:322] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                      aten::multinomial        16.09%      11.173ms        48.45%      33.649ms     336.490us       7.441ms         9.40%      33.907ms     339.070us           100  
                                               aten::to         2.00%       1.390ms         9.89%       6.866ms      21.191us       1.655ms         2.09%       7.716ms      23.815us           324  
         