In [5]:
import datetime
import os
import torch

from omegaconf import OmegaConf
import cProfile, pstats

from svpg.algos import A2C

In [6]:

dtime = datetime.datetime.now().strftime("/%y-%m-%d/%H-%M-%S/")
params = {
    "save_best": True,
    "logger": {
        "classname": "salina.logger.TFLogger",
        "log_dir": "./tmp/" + dtime,
        "verbose": True,
        "cache_size": 10000,
        "every_n_seconds": 10,
    },
    "algorithm": {
        "n_particles": 16,
        "seed": 5,
        "n_envs": 1,
        "n_steps": 10000,
        "eval_interval": 99,
        "n_evals": 1,
        "clipped": True,
        "max_epochs": 2,
        "discount_factor": 0.95,
        "policy_coef": 0.1,
        "critic_coef": 1.0,
        "entropy_coef": 1e-3,
        "architecture": {"hidden_size": [100, 50, 25]},
    },
    "gym_env": {
        "classname": "svpg.agents.env.make_gym_env",
        "env_name": "MyCartPole-v0",
    },
    "optimizer": {"classname": "torch.optim.Adam", "lr": 5e-3},
}

config = OmegaConf.create(params)

directory = "../runs/" + config.gym_env.env_name + dtime

if not os.path.exists(directory):
    os.makedirs(directory)

torch.manual_seed(config.algorithm.seed)

a2c = A2C(config, solo=True)
profiler = cProfile.Profile()


In [7]:
profiler.enable()
a2c.run(directory)
profiler.disable()

total_policy_loss tensor(14.4531, grad_fn=<RsubBackward1>)
total_critic_loss tensor(89.7181, grad_fn=<AddBackward0>)
total_entropy_loss tensor(-1.6033, grad_fn=<RsubBackward1>)
total_loss tensor(91.1618, grad_fn=<AddBackward0>)
['reward_0' at 10000] = 99.95480346679688
total_policy_loss tensor(14.3515, grad_fn=<RsubBackward1>)
total_critic_loss tensor(89.4323, grad_fn=<AddBackward0>)
total_entropy_loss tensor(-1.6062, grad_fn=<RsubBackward1>)
total_loss tensor(90.8659, grad_fn=<AddBackward0>)
['reward_0' at 20000] = 89.93496704101562


In [8]:
stats = pstats.Stats(profiler).sort_stats("tottime")
stats.print_stats()

         12757468 function calls (12342150 primitive calls) in 31.339 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        2    7.908    3.954    7.908    3.954 {method 'run_backward' of 'torch._C._EngineBase' objects}
   160088    2.055    0.000    2.055    0.000 {built-in method torch._C._nn.linear}
    20022    1.063    0.000    1.882    0.000 c:\Users\Jules Dubreuil\AppData\Local\Programs\Python\Python39\lib\site-packages\torch\distributions\normal.py:71(log_prob)
   240045    0.985    0.000    0.985    0.000 {method 'unsqueeze' of 'torch._C._TensorBase' objects}
   140172    0.928    0.000    0.928    0.000 {built-in method cat}
   120066    0.775    0.000    0.775    0.000 {built-in method relu}
    19999    0.700    0.000    0.700    0.000 {built-in method normal}
   220314    0.611    0.000    1.086    0.000 c:\users\jules dubreuil\documents\universite\projet-androide\pandroide-svpg\libs\salina\salina\workspace.p

<pstats.Stats at 0x1ec51678310>