# Cartpole

## A2C Agent 

In [None]:
import gym
from a2c import A2CAgent 
import time
import numpy as np

# Create Gym environment
a2c_env = "CartPole-v1"
env = gym.make(a2c_env)

# Check agent class for initialization parameters and initialize agent
if a2c_env == "CartPole-v1":
    gamma = 0.95
    lr = 1e-3

agent = A2CAgent(env, gamma, lr)

# Define training parameters
max_episodes = 500
max_steps = 500

episode_rewards = []
run_time = []
start_time = time.time()
for episode in range(max_episodes):
    trajectory = []
    state = env.reset()
    episode_reward = 0
    for step in range(max_steps):
        action = agent.get_action(state)
        next_state, reward, done, _ = env.step(action)
        trajectory.append((state, action, reward, next_state, done))
        episode_reward += reward  
        if done or step == max_steps:
            episode_rewards.append(episode_reward)
            print("Episode " + str(episode) + ": " + str(episode_reward))
            break
        state = next_state
    agent.update(trajectory, 0)
    elapse = time.time() - start_time
    run_time.append(elapse)
    
a2c_rewards = episode_rewards
a2c_runtime = run_time

In [None]:
name = './log_files/a2c/' + a2c_env + '-' + str(time.time()) + '.csv' 
out = np.column_stack((a2c_runtime, a2c_rewards))
with open(name, 'ab') as f:
    np.savetxt(f, out, delimiter=',')

## DRPO Agent (KL) 

In [9]:
import gym
import gym_electricitymarket
from drpo import DRTRPOAgent 
import time
import numpy as np

# Create Gym environment
kl_env = 'ElectricityMarketDiscreteDQN-v0'
env = gym.make(kl_env)

# Check agent class for initialization parameters and initialize agent

# When the learning rate is large, policy neural network can overflow and lead to NaNs. 
# A possible fix is to reduce lr or increase beta to lower the learning rate.

if kl_env == "ElectricityMarketDiscreteDQN-v0":
    gamma = 0.95
    lr = 1e-2
    beta = 8
    
agent = DRTRPOAgent(env, gamma, lr)

# Define training parameters
max_episodes = 1500
max_steps = 30

episode_rewards = []
run_time = []
start_time = time.time()
for episode in range(max_episodes):
    if episode == 0:
        first_state = env.reset()
    else:
        first_state = state
    state_adv = []
    total_value_loss = 0
    
    episode_reward = 0
    # loop through the first action
    for i in range(env.action_space.n):
        env.reset()
        state = first_state
        action = i
        trajectory = []
        
        for step in range(max_steps):
            if step != 0:
                action = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            trajectory.append((state, action, reward, next_state, done))
            episode_reward += reward  
            if done or step == max_steps-1:
                break
            state = next_state
            
        adv, value_loss = agent.compute_adv_mc(trajectory)
        state_adv.append(adv[0])
        total_value_loss += value_loss
        
    avg_episode_reward = episode_reward/env.action_space.n        
    # add randomness for better exploration
#     if (avg_episode_reward <= 300) and (episode % 10 == 0):
#         state_adv[0] += (np.random.random()-0.5)*0.5
#         state_adv[1] += (np.random.random()-0.5)*0.5
    
#     state_adv[0] += 0.5
    
    # restart the agent if stuck
#     if (episode >= 5) and (avg_episode_reward <= 15):
#         agent = DRTRPOAgent(env, gamma, lr)   
    f
    policy_loss = agent.compute_policy_loss_kl(first_state, state_adv, beta)
    agent.update(value_loss, policy_loss)
    elapse = time.time() - start_time
    run_time.append(elapse)
    
    episode_rewards.append(avg_episode_reward)
    print("Episode " + str(episode) + ": " + str(avg_episode_reward))

Episode 0: -498.2267461048604
Episode 1: -497.5529793506982
Episode 2: -497.64767284097445
Episode 3: -491.37664966527706
Episode 4: -495.203925120832
Episode 5: -478.6152067895799
Episode 6: -480.1957415083302
Episode 7: -466.0819844402748
Episode 8: -472.0835985895797
Episode 9: -466.8353475631937
Episode 10: -449.64080615624886
Episode 11: -450.3046412680523
Episode 12: -444.36733392985815
Episode 13: -425.2618439659688
Episode 14: -418.47252298610937
Episode 15: -399.93933221110626
Episode 16: -394.78214599583043
Episode 17: -380.0066146354148
Episode 18: -365.1101677986087
Episode 19: -361.2507687361076
Episode 20: -342.55640730138964
Episode 21: -341.9417413597205
Episode 22: -318.0394334374973
Episode 23: -303.56667087708337
Episode 24: -303.2415510868051
Episode 25: -290.1661670562507
Episode 26: -268.9253438270836
Episode 27: -265.8768645652758
Episode 28: -254.2565959659708
Episode 29: -250.00158519235978
Episode 30: -253.11160955902815
Episode 31: -251.72122485833026
Episode

Episode 261: 382.1121654757635
Episode 262: 382.4014157070135
Episode 263: 382.30039335354127
Episode 264: 382.6289302500691
Episode 265: 382.214922870208
Episode 266: 382.5921838702081
Episode 267: 382.0237829702078
Episode 268: 382.956878795208
Episode 269: 382.3846603847913
Episode 270: 382.43289753062487
Episode 271: 382.63382385631905
Episode 272: 382.5336140924302
Episode 273: 381.8973344722911
Episode 274: 382.6454607139579
Episode 275: 382.5054893806246
Episode 276: 382.4491118889579
Episode 277: 382.42860344729144
Episode 278: 382.7748372986804
Episode 279: 382.60444020354134
Episode 280: 382.17081956604125
Episode 281: 383.07712602784693
Episode 282: 382.10220189798576
Episode 283: 382.570962245208
Episode 284: 382.1829149063189
Episode 285: 382.19757506534677
Episode 286: 382.6441262972911
Episode 287: 382.3931636702081
Episode 288: 382.45002726951355
Episode 289: 382.30396456395795
Episode 290: 382.76863468409687
Episode 291: 382.4384821445136
Episode 292: 382.4774114861802

Episode 523: 382.9580138778471
Episode 524: 382.86729919729135
Episode 525: 382.8244157945136
Episode 526: 382.52675292506905
Episode 527: 382.93947383618024
Episode 528: 382.687648252847
Episode 529: 382.87195590840247
Episode 530: 383.05272130840257
Episode 531: 382.9063748750692
Episode 532: 382.79756458965255
Episode 533: 382.7325150590969
Episode 534: 382.8505593646525
Episode 535: 383.18600011118036
Episode 536: 382.8705334611803
Episode 537: 382.88342855562485
Episode 538: 382.84366655006926
Episode 539: 382.74051175631917
Episode 540: 382.917843409097
Episode 541: 382.95351208340253
Episode 542: 382.83512373618026
Episode 543: 383.04369788895804
Episode 544: 383.05036483340257
Episode 545: 382.9709465868748
Episode 546: 382.9210661979859
Episode 547: 382.8804690278472
Episode 548: 383.18794173618045
Episode 549: 382.96436198618034
Episode 550: 382.8530085139581
Episode 551: 382.92755711118025
Episode 552: 383.18039400006916
Episode 553: 382.9078012167359
Episode 554: 382.588108

Episode 784: 383.0083192222914
Episode 785: 383.12139823618037
Episode 786: 382.8527360313191
Episode 787: 382.90224197506916
Episode 788: 382.73037796465263
Episode 789: 383.0676532118748
Episode 790: 382.57058769729144
Episode 791: 382.91931643965256
Episode 792: 382.78551365631904
Episode 793: 383.2164152500693
Episode 794: 382.94032853131927
Episode 795: 383.169364166736
Episode 796: 382.9838829611804
Episode 797: 382.67862424520797
Episode 798: 382.88293819451354
Episode 799: 383.16246041673594
Episode 800: 382.5170034924302
Episode 801: 382.92699882576363
Episode 802: 382.72000600562467
Episode 803: 382.9656013611803
Episode 804: 383.00010529868024
Episode 805: 382.8074001695137
Episode 806: 382.5791612389579
Episode 807: 382.78433260840245
Episode 808: 383.0927924028469
Episode 809: 383.1657418056248
Episode 810: 382.7980710486803
Episode 811: 382.8246575722913
Episode 812: 382.8021903292358
Episode 813: 382.7722561840969
Episode 814: 382.8273181132636
Episode 815: 382.376082730

Episode 1044: 383.2324923084026
Episode 1045: 383.10150680562475
Episode 1046: 382.9024015090969
Episode 1047: 382.76735615354147
Episode 1048: 383.22516517784703
Episode 1049: 383.25786583340266
Episode 1050: 382.9247022146524
Episode 1051: 382.99876158687476
Episode 1052: 383.1814993229859
Episode 1053: 382.98499558965256
Episode 1054: 382.6192672396523
Episode 1055: 382.88294165631913
Episode 1056: 382.7880872361803
Episode 1057: 383.0271998750692
Episode 1058: 383.21056013895816
Episode 1059: 383.027219902847
Episode 1060: 382.77550922576347
Episode 1061: 383.26667911395816
Episode 1062: 382.55003867020804
Episode 1063: 382.6476287917358
Episode 1064: 383.26785618062473
Episode 1065: 382.70814815284695
Episode 1066: 383.15528040284704
Episode 1067: 383.22816518340267
Episode 1068: 383.0431354334024
Episode 1069: 382.67075331187465
Episode 1070: 382.9790805757636
Episode 1071: 382.9638874445136
Episode 1072: 382.77404975631924
Episode 1073: 382.85575341951375
Episode 1074: 382.99640

Episode 1298: 382.9048678722915
Episode 1299: 382.8743996834025
Episode 1300: 383.18650997229145
Episode 1301: 382.9541085313193
Episode 1302: 382.6618390313192
Episode 1303: 382.87465735423575
Episode 1304: 383.1135372396525
Episode 1305: 382.9654039306248
Episode 1306: 383.1247079445137
Episode 1307: 382.90916045840254
Episode 1308: 382.8801976597914
Episode 1309: 383.22364729173586
Episode 1310: 382.9150513084024
Episode 1311: 382.8666871250691
Episode 1312: 382.9532995028471
Episode 1313: 382.7565367729858
Episode 1314: 383.10297626673594
Episode 1315: 382.9819526972915
Episode 1316: 383.15479930562486
Episode 1317: 382.7009551695136
Episode 1318: 383.2055838889581
Episode 1319: 382.8881732639581
Episode 1320: 383.0417965500692
Episode 1321: 382.82551864173575
Episode 1322: 382.7581042063192
Episode 1323: 383.0749012063192
Episode 1324: 383.0187301459026
Episode 1325: 382.9566217306247
Episode 1326: 383.13597114520815
Episode 1327: 382.90208366673596
Episode 1328: 383.0140147674304

In [10]:
dr_trpo_kl_rewards = episode_rewards
dr_trpo_kl_runtime = run_time

In [11]:
name = './log_files/dr_trpo_kl/' + kl_env + '-' + str(time.time()) + '.csv' 
out = np.column_stack((dr_trpo_kl_runtime, dr_trpo_kl_rewards))
with open(name, 'ab') as f:
    np.savetxt(f, out, delimiter=',')

## DRPO Agent (Wasserstein)

In [119]:
import gym
from drpo import DRTRPOAgent 
import time
import numpy as np

wass_env = "CartPole-v1"
# Create Gym environment
env = gym.make(wass_env)

# Check agent class for initialization parameters and initialize agent
if wass_env == "CartPole-v1":
    gamma = 0.95
    lr = 1e-2
    
agent = DRTRPOAgent(env, gamma, lr)

# Define training parameters
max_episodes = 150
max_steps = 500
total_adv_diff = 0

episode_rewards = []
run_time = []
start_time = time.time()
for episode in range(max_episodes):
    if episode == 0:
        first_state = env.reset()
    else:
        first_state = state
    state_adv = []
    total_value_loss = 0
    
    episode_reward = 0
    # loop through the first action
    for i in range(env.action_space.n):
        env.reset()
        state = first_state
        action = i
        trajectory = []
        
        for step in range(max_steps):
            if step != 0:
                action = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            trajectory.append((state, action, reward, next_state, done))
            episode_reward += reward  
            if done or step == max_steps-1:
                break
            state = next_state
            
        adv, value_loss = agent.compute_adv_mc(trajectory)
        state_adv.append(adv[0])
        total_value_loss += value_loss
    
    total_adv_diff += abs(state_adv[1] - state_adv[0])
    # larger beta, better stability; smaller beta, better exploration
    beta = total_adv_diff/episode 
    beta += np.random.random()*0.3-0.1
    
    avg_episode_reward = episode_reward/env.action_space.n
    # add randomness for better exploration
    if (episode % 10 == 0) and (avg_episode_reward <= 350): 
        state_adv[0] += (np.random.random()-0.5)*0.5
        state_adv[1] += (np.random.random()-0.5)*0.5
        
    state_adv[0] += 0.5
        
    # restart the agent if stuck
    if (episode >= 5) and (avg_episode_reward <= 15):
        agent = DRTRPOAgent(env, gamma, lr)
    
    policy_loss = agent.compute_policy_loss_wass(first_state, state_adv, beta)
    agent.update(value_loss, policy_loss)
    elapse = time.time() - start_time
    run_time.append(elapse)
    
    episode_rewards.append(avg_episode_reward)
    print("Episode " + str(episode) + ": " + str(avg_episode_reward))

dr_trpo_wass_rewards = episode_rewards
dr_trpo_wass_runtime = run_time

Episode 0: 20.0
Episode 1: 15.0
Episode 2: 13.5
Episode 3: 12.5
Episode 4: 11.0
Episode 5: 11.0
Episode 6: 15.5
Episode 7: 20.0
Episode 8: 20.5
Episode 9: 20.5
Episode 10: 40.0
Episode 11: 50.5
Episode 12: 86.5
Episode 13: 75.0
Episode 14: 62.5
Episode 15: 56.5
Episode 16: 46.5
Episode 17: 75.0
Episode 18: 46.0
Episode 19: 52.5
Episode 20: 50.0
Episode 21: 35.0
Episode 22: 26.0
Episode 23: 27.0
Episode 24: 17.5
Episode 25: 16.5
Episode 26: 14.5
Episode 27: 20.5
Episode 28: 18.0
Episode 29: 18.0
Episode 30: 19.0
Episode 31: 24.0
Episode 32: 19.0
Episode 33: 27.0
Episode 34: 32.0
Episode 35: 59.5
Episode 36: 71.5
Episode 37: 63.0
Episode 38: 105.0
Episode 39: 41.0
Episode 40: 33.5
Episode 41: 30.0
Episode 42: 26.5
Episode 43: 26.5
Episode 44: 19.0
Episode 45: 28.0
Episode 46: 27.0
Episode 47: 31.5
Episode 48: 38.0
Episode 49: 44.0
Episode 50: 47.0
Episode 51: 83.5
Episode 52: 142.5
Episode 53: 95.0
Episode 54: 76.5
Episode 55: 93.5
Episode 56: 53.0
Episode 57: 91.5
Episode 58: 74.5
Episo

In [120]:
name = './log_files/dr_trpo_wass/' + wass_env + '-' + str(time.time()) + '.csv' 
out = np.column_stack((dr_trpo_wass_runtime, dr_trpo_wass_rewards))
with open(name, 'ab') as f:
    np.savetxt(f, out, delimiter=',')

## DRPO Agent (Sinkhorn)

In [None]:
import gym
from drpo import DRTRPOAgent 
import time
import numpy as np

sink_env = "CartPole-v1"
# Create Gym environment
env = gym.make(sink_env)

# Check agent class for initialization parameters and initialize agent
if wass_env == "CartPole-v1":
    gamma = 0.95
    lr = 1e-2
    
agent = DRTRPOAgent(env, gamma, lr)

# Define training parameters
max_episodes = 200
max_steps = 500
total_adv_diff = 0

episode_rewards = []
run_time = []
start_time = time.time()
for episode in range(max_episodes):
    if episode == 0:
        first_state = env.reset()
    else:
        first_state = state
    state_adv = []
    total_value_loss = 0
    
    episode_reward = 0
    # loop through the first action
    for i in range(env.action_space.n):
        env.reset()
        state = first_state
        action = i
        trajectory = []
        
        for step in range(max_steps):
            if step != 0:
                action = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            trajectory.append((state, action, reward, next_state, done))
            episode_reward += reward  
            if done or step == max_steps-1:
                break
            state = next_state
            
        adv, value_loss = agent.compute_adv_mc(trajectory)
        state_adv.append(adv[0])
        total_value_loss += value_loss

    total_adv_diff += abs(state_adv[1] - state_adv[0])
    # larger beta, better stability; smaller beta, better exploration
    beta = total_adv_diff/episode 
    beta += np.random.random()*0.3-0.1
    
    avg_episode_reward = episode_reward/env.action_space.n
    # add randomness for better exploration
    if (episode % 10 == 0) and (avg_episode_reward <= 350): 
        state_adv[0] += (np.random.random()-0.5)*0.5
        state_adv[1] += (np.random.random()-0.5)*0.5
        
    # restart the agent if stuck
    if (episode >= 5) and (avg_episode_reward <= 10):
        agent = DRTRPOAgent(env, gamma, lr)
    
    beta = 50
    policy_loss = agent.compute_policy_loss_sinkhorn(first_state, state_adv, beta)
    agent.update(value_loss, policy_loss)
    elapse = time.time() - start_time
    run_time.append(elapse)
    
    episode_rewards.append(avg_episode_reward)
    print("Episode " + str(episode) + ": " + str(avg_episode_reward))

dr_trpo_sink_rewards = episode_rewards
dr_trpo_sink_runtime = run_time

In [None]:
name = './log_files/dr_trpo_sink/' + sink_env + '-' + str(time.time()) + '.csv' 
out = np.column_stack((dr_trpo_sink_runtime, dr_trpo_sink_rewards))
with open(name, 'ab') as f:
    np.savetxt(f, out, delimiter=',')