In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import random
from snake import SnakeGame
from agent import A2CAgent
import mlflow
import os
from dotenv import load_dotenv

load_dotenv()

MLFLOW_URI = os.getenv("MLFLOW_URI")
if not MLFLOW_URI:
    raise Exception("MLFLOW_URI is not set")

mlflow.set_tracking_uri(uri=MLFLOW_URI)

In [3]:
import mlflow
import mlflow.pytorch
import torch
from mlflow import log_metric, log_param, log_artifacts
import os

def train_agents(episodes):
    # Set up MLflow
    mlflow.set_experiment("Snake Game A2C Training")
    
    with mlflow.start_run():
        env = SnakeGame(10, 10, max_steps=20000)
        agent = A2CAgent()
        batch_size = 64
        rewards = []
        
        # Log parameters
        log_param("episodes", episodes)
        log_param("batch_size", batch_size)
        log_param("gamma", agent.gamma)
        log_param("learning_rate", agent.lr)
        log_param("memory_size", agent.maxlen)
        
        for e in range(episodes):
            state = env.reset()
            acc_reward = 0
            done = False
            steps = 0
            
            while not done:
                action = agent.act(state)
                next_state, reward, done = env.step(action)
                agent.remember(state, action, reward, next_state, done)
                acc_reward += reward
                state = next_state
                steps += 1
                
                if done:
                    break
            
            losses = agent.replay(batch_size)
            rewards.append(acc_reward)
            
            # Log metrics
            log_metric("episode_reward", acc_reward, step=e)
            log_metric("episode_length", steps, step=e)
            if losses is not None:
                total_loss, actor_loss, critic_loss = losses
                log_metric("total_loss", total_loss, step=e)
                log_metric("actor_loss", actor_loss, step=e)
                log_metric("critic_loss", critic_loss, step=e)
            
            if e % 5000 == 0:
                print(f"episode={e}")
                # Log model every 5000 episodes
                mlflow.pytorch.log_model(agent.model, f"model_episode_{e}")
        
        # Log final model and rewards
        mlflow.pytorch.log_model(agent.model, "final_model")
        agent.save("./agent_10x10.state")
        log_artifacts("agent_10x10.state")
        
        # Save rewards to a file and log as artifact
        torch.save(rewards, "./rewards_10x10.state")
        log_artifacts("rewards_10x10.state")
    
    return agent, rewards

# Run the training
agent, rewards = train_agents(500000)

  from .autonotebook import tqdm as notebook_tqdm


TypeError: cannot unpack non-iterable NoneType object