In [1]:
import argparse
import gym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
print(torch.cuda.is_available())
if torch.cuda.is_available() and False:
    print ("cuda in use")
    device = torch.device('cuda') 
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
else:
    print ("cuda not used")
    device = torch.device('cpu')
    torch.set_default_tensor_type('torch.FloatTensor')
    dtype = torch.float32

True
cuda not used


In [2]:
parser = argparse.ArgumentParser(description='PyTorch actor-critic example')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G', help='discount factor (default: 0.99)')
parser.add_argument('--seed', type=int, default=543, metavar='N', help='random seed (default: 543)')
# parser.add_argument('--render', action='store_true', help='render the environment')
parser.add_argument('--render', type=bool,default=False, help='render the environment')
parser.add_argument('--trace', type=bool,default=False, help='render the environment')
parser.add_argument('--log-interval', type=int, default=100, metavar='N', help='interval between training status logs (default: 10)')
parser.add_argument('-f','--file',help='Path for input file. (Dummy arg to enable execution in notebook.)' )
args = parser.parse_args() 

In [25]:
        
#       Cart Pole state values
#         Num	Observation           Min         Max
#         0	Cart Position             -4.8            4.8
#         1	Cart Velocity             -Inf            Inf
#         2	Pole Angle                -24 deg        24 deg
#         3	Pole Velocity At Tip      -Inf            Inf

        
class World():
    
    def __init__(self,hidden_nodes=12):
        self.env = gym.make('CartPole-v0')
#         self.env = gym.make('CartPole-v1')
#         self.env = gym.make('Acrobot-v1')
        self.zeros = torch.zeros([self.dimension_count()], requires_grad=False, dtype=dtype,device=device)
        self.loss_function = torch.nn.L1Loss(reduction='sum')#Mean Absolute Error
        self.reset()
        
    def reset(self):
        self.done = False
        self.state = torch.tensor(self.env.reset(), requires_grad=False, dtype=dtype, device=device)
    
    def dimension_count(self):
        return self.env.observation_space.shape[0]
        
    def action_count(self):
        return self.env.action_space.n
    
    def step(self,action):
        self.prior_state = self.state
        self.state, self.reward, self.done, _ = self.env.step(action)
        self.state = torch.tensor(self.state, requires_grad=False, dtype=dtype, device=device)
        if args.render: self.env.render()
    
    def value(self):
        return 1./torch.log(1.01 + self.loss_function(self.zeros,self.state))

class Critic(nn.Module):
    
    def __init__(self,world: World,hidden_nodes=32):
        super(Critic, self).__init__()
        self.world = world
        self.one = torch.ones([1], requires_grad=False, dtype=dtype, device=device)
        self.zero = torch.zeros([1], requires_grad=False, dtype=dtype, device=device)
        self.l1 = nn.Linear(world.dimension_count(),hidden_nodes)
        self.l1.weight.data.normal_(0.0, np.sqrt(1./(world.dimension_count())))
        self.head = nn.Linear(hidden_nodes, 1)
        self.head.weight.data.normal_(0.0, np.sqrt(1./(hidden_nodes)))
        self.prev_value = self.zero
        self.value = self.zero
        self.loss_function = torch.nn.L1Loss(reduction="mean")
        self.optimizer = optim.Adam(self.parameters(), lr=3e-3,weight_decay=0.0001)#lr=4e-5,weight_decay=0.00001)
        
    def forward(self, state):
        self.prev_value = self.value
        self.l1_out = F.selu(self.l1(state))
        self.value = self.head(self.l1_out)
        return self.value
    
    def step(self):
        self.optimizer.step()
        
    
    #What the previous value should have been knowing what we know after the last state transition
    def hindsight_value(self):
        #Do not include gradient of the critic value here, just the data.
        v = self.world.value() # or world_actor.value if dreaming.... come back to this.
        return v * self.zero if self.world_actor.world.done else v + args.gamma * self.value.data
         
    #Temporal Difference Loss is for the previous state!
    def get_loss(self):
        self.loss = self.loss_function(self.prev_value,self.hindsight_value())
        if args.trace: print("Critic value and loss:",self.prev_value,self.loss)
        return self.loss

class World_Actor(nn.Module):
    
    def __init__(self,world):
        super(World_Actor, self).__init__()
        self.world = world
        input_size = world.dimension_count()
        self.zeros = torch.zeros([input_size], requires_grad=False, dtype=dtype,device=device)
        self.actor_training = True
        self.world_training = True
        
        hidden_action_nodes = 24
        self.actor1 = nn.Linear(input_size,hidden_action_nodes)
        self.actor1.weight.data.normal_(0.0, np.sqrt(1./input_size))
        self.actor_optimizer = optim.Adam(self.actor1.parameters(), lr=3e-2,weight_decay=0.0001)#lr=4e-5,weight_decay=0.00001)
        
        self.actor2 = nn.Linear(hidden_action_nodes,hidden_action_nodes)
        self.actor2.weight.data.normal_(0.0, np.sqrt(1./hidden_action_nodes))
        self.actor_optimizer.add_param_group({'params': self.actor2.weight})
        
        self.actor_final = nn.Linear(hidden_action_nodes, world.action_count())
        self.actor_final.weight.data.normal_(0.0, np.sqrt(1./(hidden_action_nodes)))
        self.actor_optimizer.add_param_group({'params': self.actor_final.weight})
        
        hidden_world_nodes = 12
        parms = {}
        input_size = world.dimension_count()+world.action_count()
        
        self.world1 = nn.Linear(input_size,hidden_world_nodes)
        self.world1.weight.data.normal_(0.0, np.sqrt(1./input_size))
        self.world_optimizer = optim.Adam(self.world1.parameters(), lr=3e-4,weight_decay=0.0001)#lr=4e-5,weight_decay=0.00001)
        
        self.world2 = nn.Linear(hidden_world_nodes,hidden_world_nodes)
        self.world2.weight.data.normal_(0.0, np.sqrt(1./hidden_world_nodes))
        self.world_optimizer.add_param_group({'params': self.world2.weight})
        
        self.world_final = nn.Linear(hidden_world_nodes, world.dimension_count())
        self.world_final.weight.data.normal_(0.0, np.sqrt(1./hidden_world_nodes))
        self.world_optimizer.add_param_group({'params': self.world_final.weight})
        
        self.loss_function = nn.MSELoss()
        self.mean_loss = 100.0
        
    def call_actor(self,state):
        self.actor1_out = F.selu(self.actor1(state))
        self.actor2_out = F.selu(self.actor2(self.actor1_out))
        self.actor_value = F.softmax(F.selu(self.actor_final(self.actor2_out)),dim=0)
        self.categories = Categorical(self.actor_value)
        self.action = self.categories.sample()
        self.policy_action = self.actor_value.argmax();
        if args.trace: print("Action Scores:",self.categories.probs,"Selected Action:",self.action.item(),"Policy Action:", self.policy_action)
        
    def random_action(self, state):
        self.action = torch.tensor([random.randint(0,self.world.action_count()-1)], requires_grad=False, dtype=torch.int, device=device)
        self.actor_value = torch.zeros([self.world.action_count()], requires_grad=False, dtype=dtype, device=device)
        self.actor_value[self.action.item()] = 1.
        self.categories = Categorical(self.actor_value)
            
    def call_world(self, state):
        state_action = torch.cat([state, self.actor_value], dim=0)
        self.world1_out = F.selu(self.world1(state_action))
        self.world2_out = F.selu(self.world2(self.world1_out))
        self.world_value = self.world_final(self.world2_out)
        if args.trace: print("world value:",self.world_value)
    
    def forward(self,state):
        self.prior_state = state
        if self.actor_training:
            self.call_actor(state)
        else:
            self.random_action(state)
        self.call_world(state)
        return self.world_value
    
    def value(self):
        return self.loss_function(self.zeros,self.world_value)
    
    def get_actor_loss(self):
        #target of zeros is the perfect pole balance at the center.
        loss = self.critic.get_loss()
#         loss = self.loss_function(self.zeros,self.world_value)
        return loss
    
    def get_world_loss(self,target):
        loss = self.loss_function(target,self.world_value)
        self.mean_loss = 0.95 * self.mean_loss + 0.05 * loss
        return loss
    
    def update_actor(self):
        loss = self.get_actor_loss()
        loss.backward(retain_graph=True)
        self.actor_optimizer.step()

    def update_world(self,target):
        loss = self.get_world_loss(target)
        loss.backward(retain_graph=True)
        self.world_optimizer.step()
        
    def update_critic():
        loss = self.critic.get_loss(target)
        loss.backward(retain_graph=True)
        self.critic.optimizer.step()
        
    def calibrate(self,target):
        if self.actor_training:
            self.update_actor()
            self.actor_optimizer.zero_grad()
        if self.world_training:
            self.update_world(target)
            self.world_optimizer.zero_grad()

In [26]:
args.trace = True
world = World()
state = torch.tensor(world.env.reset(),requires_grad=False, dtype=dtype, device=device)
world_actor = World_Actor(world)
critic = Critic(world)
critic.world_actor = world_actor
world_actor.critic = critic

In [27]:
world_actor.pretraining = True
print(world_actor.actor1.weight.grad)
print(world_actor.actor1.weight[0])
print(world_actor.world1.weight.grad)
print(world_actor.world1.weight[0])
world_actor.forward(state)
critic.forward(state)     
world.step(world_actor.action.item())
critic.forward(world.state)
loss = world_actor.get_world_loss(state)
print("loss",loss)
loss.backward(retain_graph=True)
world_actor.world_optimizer.step()
loss = critic.get_loss()
print("loss",loss)
loss.backward(retain_graph=True)
critic.step()
print(world_actor.actor1.weight.grad)
print(world_actor.actor1.weight[0])
print(world_actor.world1.weight.grad[0])
print(world_actor.world1.weight[0])

world_actor.actor_optimizer.zero_grad()
world_actor.world_optimizer.zero_grad()

None
tensor([-0.5053, -0.4807, -0.9853, -0.5358], grad_fn=<SelectBackward>)
None
tensor([ 0.0974, -0.3094,  0.3273, -0.3574,  0.3676, -0.3686],
       grad_fn=<SelectBackward>)
Action Scores: tensor([0.2582, 0.7418], grad_fn=<DivBackward0>) Selected Action: 1 Policy Action: tensor(1)
world value: tensor([-0.1903, -0.1714, -0.5341,  1.2331], grad_fn=<AddBackward0>)
loss tensor(0.4410, grad_fn=<MeanBackward0>)
Critic value and loss: tensor([0.6205], grad_fn=<AddBackward0>) tensor(2.2481, grad_fn=<L1LossBackward>)
loss tensor(2.2481, grad_fn=<L1LossBackward>)
tensor([[ 4.3868e-04, -5.5687e-04, -1.1274e-04,  5.3265e-04],
        [-3.3182e-04,  4.2122e-04,  8.5280e-05, -4.0290e-04],
        [ 1.7126e-04, -2.1740e-04, -4.4016e-05,  2.0795e-04],
        [-1.4613e-05,  1.8551e-05,  3.7558e-06, -1.7744e-05],
        [-6.1862e-04,  7.8529e-04,  1.5899e-04, -7.5114e-04],
        [ 2.8455e-04, -3.6121e-04, -7.3131e-05,  3.4550e-04],
        [-2.3128e-04,  2.9359e-04,  5.9440e-05, -2.8082e-04],
   

In [57]:
def train(episodes=1000):
    mave_reward = 10.
    mave_value = 100.
    action_preferences = np.array([0.5,0.5])
    
    for i_episode in range(1,episodes+1):
        ep_reward = 0.
        ep_value = 0.
        ep_action_preferences = np.array([0.,0.])
        world.reset() 
        critic(world.state)
        I = 1.
        for moves in range(10000):
            
            world_actor.forward(world.state)
            ep_action_preferences += world_actor.categories.probs.detach().cpu().numpy()
            
            world.step(world_actor.action.item())
            ep_reward += world.reward
            
            critic(world.state)
            world_actor.calibrate(world.state)
            
            ep_value += critic.value.item()
            loss = critic.get_loss()
#             print("loss",loss)
            loss.backward(retain_graph=True)
            critic.step()
            
            if(world.done):
                if args.trace: print("DONE")
                break

        ep_action_preferences /= moves
        action_preferences =  0.05 * ep_action_preferences + (1 - 0.05) * action_preferences
        mave_reward = 0.05 * ep_reward + (1 - 0.05) * mave_reward
        mave_value = 0.05 * ep_value + (1 - 0.05) * mave_value
        if i_episode % args.log_interval == 0:
            print('Episode {}\tLast reward: {:.2f}\tMoving average reward: {:.2f}\tAction Preferences: {:.2f},{:.2f}\tMoving average model loss: {:.5f}\tCritic value: {:.2f}'.format(
                  i_episode, ep_reward, mave_reward, action_preferences[0],action_preferences[1],world_actor.mean_loss,mave_value))
        if mave_reward > world.env.spec.reward_threshold:
            print("Episode {}\tSolved! Running reward is now {} and "
                  "the last episode runs to {} time steps!".format(i_episode,mave_reward, moves))
            break

In [58]:
def reset_trainer():
    args.trace = False
    args.render = False
    global world
    global critic
    global world_actor
    global actor_optimizer
    global world_optimizer
    world = World()
    critic = Critic(world)
    world_actor = World_Actor(world)
    
    
    critic = Critic(world)
    critic.world_actor = world_actor
    world_actor.critic = critic
    world_actor.pretraining = False
    world.env.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.empty_cache()

In [59]:
reset_trainer()

args.trace = False
world_actor.actor_training = False
world_actor.world_training = True
train(2000)
world_actor.actor_training = True
world_actor.world_training = False
train(500)

Episode 100	Last reward: 22.00	Moving average reward: 24.26	Action Preferences: 0.54,0.51	Moving average model loss: 0.00	Critic value: 9088.22
Episode 200	Last reward: 21.00	Moving average reward: 21.74	Action Preferences: 0.50,0.56	Moving average model loss: 0.00	Critic value: -15389.52
Episode 300	Last reward: 12.00	Moving average reward: 20.83	Action Preferences: 0.52,0.54	Moving average model loss: 0.00	Critic value: -33228.85
Episode 400	Last reward: 12.00	Moving average reward: 22.16	Action Preferences: 0.57,0.49	Moving average model loss: 0.00	Critic value: -34270.17
Episode 500	Last reward: 28.00	Moving average reward: 22.60	Action Preferences: 0.54,0.51	Moving average model loss: 0.00	Critic value: 56159.84
Episode 600	Last reward: 12.00	Moving average reward: 21.41	Action Preferences: 0.53,0.53	Moving average model loss: 0.00	Critic value: -104808.99
Episode 700	Last reward: 33.00	Moving average reward: 23.75	Action Preferences: 0.53,0.52	Moving average model loss: 0.00	Crit

In [None]:
args.render = True
train(25)

In [None]:
world.reset()
for i in range(3):
    args.trace = True
    actor.forward(world.prior_state)
    world.step(actor)
    print(actor.categories.log_prob(actor.action))

    world.model_optimizer.zero_grad()
    world.input_action = actor.value.clone().detach().requires_grad_(True)
    world.model(torch.cat([world.prior_state, world.input_action], dim=0))
    print("World input state:",world.prior_state)
    print("World output state:",world.state)
    print("model output state:",world.model.value)
    loss = world.model.get_loss(world.state)
    loss.backward(retain_graph=True)
    print("model loss",loss)
    # world.model_optimizer.step()    

    actor_optimizer.zero_grad()
    loss = world.model.get_actor_loss()
    loss.backward(retain_graph=True)
    print("get actor loss from world",loss)
    print("world input action grad",world.input_action.grad)

    loss = actor.get_loss()
    print("actual actor loss",loss)
    loss.backward()
    print("")
    actor_optimizer.step()
    actor_optimizer.zero_grad()

In [None]:
action scores: tensor([0.1061, 0.8939], grad_fn=<DivBackward0>) Action: 0
tensor(-2.2434, grad_fn=<SqueezeBackward1>)
model output state: tensor([-0.1267, -0.8792,  0.1683,  1.5418], grad_fn=<AddBackward0>)
get actor loss from world tensor(0.7986, grad_fn=<MeanBackward0>)
world input action grad tensor([-0.4064])
actual actor loss tensor([-0.9117], grad_fn=<MulBackward0>)

In [None]:
args.trace = False
train(1000)

In [None]:

world.reset()
actor.forward(world.model.value)
print("forward")
world.step(actor.choose_action())
print(actor.l1.weight[0])
print("step")
actor_optimizer.zero_grad()
print("zero grad")
loss = world.model.get_actor_loss()
print(loss)
print("actor loss")
loss.backward()    
# actor_optimizer.step()
# actor.forward(world.model.value)
# world.step(actor.choose_action())
# actor_optimizer.zero_grad()
# loss = world.model.get_actor_loss()
# loss.backward()    |
# actor_optimizer.step()
print(world.input_action.grad)
loss = actor.get_loss()
print(loss)
loss.backward()
actor_optimizer.step()
# print(world.model.l1.weight.grad)
print(actor.l1.weight[0])

In [None]:

args.trace = False
args.render = False
args.log_interval = 100
for i in range(1):
    print("New training test")
#     reset_trainer()
#     prime_model(3000)
#     prime_actor(1)
    train(1000)

In [None]:
args.log_interval = 10
args.trace = True
args.render = False
train(10) 

In [11]:
v = torch.tensor([0.5,0.6,0.7,0.8])
print(v[2:4])
x = Categorical(v)
x.entropy().sqrt()

tensor([0.7000, 0.8000])


tensor(1.1711)

In [None]:
world.input_action.grad.detach().requires_grad_(False)

In [None]:
actor.categories.log_prob(actor.action)

In [None]:
actor.categories.log_prob(actor.action) * world.input_action.grad.detach().requires_grad_(False)