-
Notifications
You must be signed in to change notification settings - Fork 0
/
machine_replacement_plot_return_dists.py
84 lines (70 loc) · 2.21 KB
/
machine_replacement_plot_return_dists.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import bayesian_irl
import mdp_worlds
import utils
import mdp
import numpy as np
import scipy
import random
import generate_efficient_frontier
from machine_replacement import generate_posterior_samples
if __name__=="__main__":
seed = 1234
np.random.seed(seed)
scipy.random.seed(seed)
random.seed(seed)
num_states = 4
num_samples = 200000
gamma = 0.95
alpha = 0.99
lambdas_to_plot = [0.0,0.5,0.95]
posterior = generate_posterior_samples(num_samples)
lambda_usas = []
lambdas = []
#read in the u_sa's from the file
f = open('./results/machine_replacement/policy_usas.txt')
mean = False
for line in f:
print(line)
if "--" in line:
#parse header
if "mean" in line:
mean = True
continue
else:
#get lambda value
lambdas.append(float(line.strip().split(" ")[1]))
mean = False
else:
items = line.strip().split(",")
#parse line
if mean:
mean_usa = np.array([float(i) for i in items])
else:
#parse lambda
lambda_usas.append(np.array([float(i) for i in items]))
print(mean_usa)
for l in lambda_usas:
print(l)
print()
print(lambdas)
#plot the mean returns versus the lambda =0 returns
import matplotlib.pyplot as plt
to_histogram = []
label = []
for l in lambdas_to_plot:
usas = lambda_usas[lambdas.index(l)]
to_histogram.append(np.dot(posterior.transpose(), usas))
label.append("$\lambda={}$".format(l))
to_histogram.append(np.dot(posterior.transpose(), mean_usa))
label.append("mean ($\lambda=1.0$)")
plt.hist(to_histogram, 100, label=label, stacked=False, fill=False, histtype='step', linewidth=2)
plt.legend()
plt.xlim(-2500, 0)
plt.yticks(fontsize=18)
plt.xticks(fontsize=18)
plt.legend(loc='upper left', fontsize=18)
plt.xlabel('Discounted return',fontsize=20)
plt.ylabel('Number of runs',fontsize=20)
plt.tight_layout()
plt.savefig("./figs/machine_replacement/return_distribution_machine_replacement.png")
plt.show()