In [1]:
import torch
import numpy as np

from single_state_mdp import SingleStateMDP
from structures import Trainer, Critic, ReplayBuffer

seeds = np.random.randint(0, 2**32-1, 10)

In [3]:
min1_means = np.zeros((10))
min1_variances = np.zeros((10))
min1_distances = np.zeros((10))

min2_means = np.zeros((10))
min2_variances = np.zeros((10))
min2_distances = np.zeros((10))

for i, seed in enumerate(seeds):
    # Setup env
    env = SingleStateMDP(seed=seed)

    # Setup Q-networks
    critic1 = Critic(n_nets=10)
    critic2 = Critic(n_nets=10)
    min_trainer1 = Trainer(critic1, bias_correction_method="MIN")
    min_trainer2 = Trainer(critic2, bias_correction_method="MIN2")

    # Training the critic
    train_action_grid = torch.linspace(-1, 1, 50)
    train_replay_buffer = ReplayBuffer(env, train_action_grid, max_size=50)

    min_trainer1.train(train_replay_buffer, 3000)
    min_trainer2.train(train_replay_buffer, 3000)

    # Evaluate
    eval_action_grid = torch.linspace(-1, 1, 2000)
    eval_replay_buffer = ReplayBuffer(env, eval_action_grid, max_size=2000)

    mean1, var1, dist1 = min_trainer1.evaluate(env, eval_replay_buffer, verbose=1)
    mean2, var2, dist2 = min_trainer2.evaluate(env, eval_replay_buffer, verbose=1)

    min1_means[i] = mean1
    min1_variances[i] = var1
    min1_distances[i] = dist1

    min2_means[i] = mean2
    min2_variances[i] = var2
    min2_distances[i] = dist2

print(f"averages : {min1_means.mean()} - {min2_means.mean()}")
print(f"variances : {min1_variances.mean()} - {min2_variances.mean()}")
print(f"distances : {min1_distances.mean()} - {min2_distances.mean()}")

argmax: 1.0 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
argmax: 1.0 - a*: 0.33166587352752686
argmax: 0.3236619234085083 - a*: 0.33166587352752686
argmax: 1.0 - a*: 0.33166587352752686
argmax: 0.36168086528778076 - a*: 0.33166587352752686
argmax: 1.0 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
argmax: 1.0 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
argmax: 1.0 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
argmax: 1.0 - a*: 0.33166587352752686
argmax: 0.29764890670776367 - a*: 0.33166587352752686
argmax: 1.0 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
argmax: 1.0 - a*: 0.33166587352752686
argmax: 0.34267139434814453 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
averages : 428.56835174560547 - 50.91926517486572
variances : 44.031157922744754 - 0.10818810909986495
distances : 0.7346673011779785 - 0.8073036670684814


In [2]:
mean1_means = np.zeros((10))
mean1_variances = np.zeros((10))
mean1_distances = np.zeros((10))

mean2_means = np.zeros((10))
mean2_variances = np.zeros((10))
mean2_distances = np.zeros((10))

for i, seed in enumerate(seeds):
    # Setup env
    env = SingleStateMDP(seed=seed)

    # Setup Q-networks
    critic1 = Critic(n_nets=10)
    critic2 = Critic(n_nets=10)
    mean_trainer1 = Trainer(critic1, bias_correction_method="AVG")
    mean_trainer2 = Trainer(critic2, bias_correction_method="AVG2")

    # Training the critic
    train_action_grid = torch.linspace(-1, 1, 50)
    train_replay_buffer = ReplayBuffer(env, train_action_grid, max_size=50)

    mean_trainer1.train(train_replay_buffer, 3000)
    mean_trainer2.train(train_replay_buffer, 3000)

    # Evaluate
    eval_action_grid = torch.linspace(-1, 1, 2000)
    eval_replay_buffer = ReplayBuffer(env, eval_action_grid, max_size=2000)

    mean1, var1, dist1 = mean_trainer1.evaluate(env, eval_replay_buffer, verbose=1)
    mean2, var2, dist2 = mean_trainer2.evaluate(env, eval_replay_buffer, verbose=1)

    mean1_means[i] = mean1
    mean1_variances[i] = var1
    mean1_distances[i] = dist1

    mean2_means[i] = mean2
    mean2_variances[i] = var2
    mean2_distances[i] = dist2

print(f"averages : {mean1_means.mean()} - {mean2_means.mean()}")
print(f"variances: {mean1_variances.mean()} - {mean2_variances.mean()}")
print(f"distances: {mean1_distances.mean()} - {mean2_distances.mean()}")

argmax: 0.1135568618774414 - a*: 0.33166587352752686
argmax: -0.16058027744293213 - a*: 0.33166587352752686
argmax: 0.15857934951782227 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
argmax: 1.0 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
argmax: -0.0985492467880249 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
argmax: 0.08454227447509766 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
argmax: -0.09654825925827026 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
argmax: 0.09354686737060547 - a*: 0.33166587352752686
argmax: -0.17358678579330444 - a*: 0.33166587352752686
argmax: 0.15557777881622314 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
argmax: 0.22961485385894775 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
argmax: -1.0 - a*: 0.33166587352752686
averages : 1620.9749633789063 - 976.980093383789
variances: 71.923744