In [None]:
import argparse
import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
import gc

print(torch.cuda.is_available())
if torch.cuda.is_available() and False:
    print ("cuda in use")
    device = torch.device('cuda') 
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    dtype = torch.float32
else:
    print ("cuda not used")
    device = torch.device('cpu')
    torch.set_default_tensor_type('torch.FloatTensor')
    dtype = torch.float32

In [None]:
parser = argparse.ArgumentParser(description='PyTorch actor-critic example')
parser.add_argument('--gamma', type=float, default=0.99, metavar='G', help='discount factor (default: 0.99)')
parser.add_argument('--seed', type=int, default=543, metavar='N', help='random seed (default: 543)')
# parser.add_argument('--render', action='store_true', help='render the environment')
parser.add_argument('--render', type=bool,default=False, help='render the environment')
parser.add_argument('--trace', type=bool,default=False, help='render the environment')
parser.add_argument('--log-interval', type=int, default=100, metavar='N', help='interval between training status logs (default: 10)')
parser.add_argument('-f','--file',help='Path for input file. (Dummy arg to enable execution in notebook.)' )
args = parser.parse_args() 

In [None]:
#from T. Kohonen '84, p. 183, equation 6.33

#Set alpha to one for one-step orthogonalization.
class Novelty_Filter():

    def __init__(self,size,epsilon=1e-3):
        self.O = torch.eye(size, requires_grad=False, dtype=dtype, device=device)
        self.epsilon = epsilon
        return

    def addBasis(self,X,cf=1.):
        l = self.novelty(X)
        n = l.norm()
        if(cf * n > self.epsilon * X.norm()):
            self.O -= cf * l.ger(l) / n.pow(2)
        return

    def novelty(self,X):
        return self.O @ X

    def project(self,X):
        return X - self.novelty(X)

In [76]:

def selu(x):
    alpha = 1.6732632423543772848170429916717
    scale = 1.0507009873554804934193349852946
    return scale * F.elu(x, alpha)

class World():
    
    def __init__(self):
        self.env = gym.make('CartPole-v0')
        self.reset()
        
    def reset(self):
        self.reward = 0.0
        self.done = False
        self.state = torch.tensor(self.env.reset(), requires_grad=False, dtype=dtype, device=device)
        
    def action_count(self):
        return self.env.action_space.n
    
    def dimension_count(self):
        return self.env.observation_space.shape[0]
    
    def step(self,action):
        self.state, self.reward, self.done, _ = self.env.step(action.item())
        self.state = torch.tensor(self.state, requires_grad=False, dtype=dtype, device=device)
        if args.render: self.env.render()

class Critic(nn.Module):
    
    def __init__(self,world: World,hidden_nodes=32):
        super(Critic, self).__init__()
        self.world = world
        self.one = torch.ones([1], requires_grad=False, dtype=dtype, device=device)
        self.zero = torch.zeros([1], requires_grad=False, dtype=dtype, device=device)
        self.l1 = nn.Linear(world.dimension_count(),hidden_nodes)
        self.l1.weight.data.normal_(0.0, np.sqrt(1./(self.world.dimension_count())))
        self.head = nn.Linear(hidden_nodes, 1)
        self.l1.weight.data.normal_(0.0, np.sqrt(1./hidden_nodes))
        self.prev_value = self.zero
        self.value = self.zero
        
    def forward(self, state):
        self.prev_value = self.value
        self.inputs = state
        self.l1_out = selu(self.l1(state))
        self.value = self.head(self.l1_out)
        return self.value
    
    #What the previous value should have been knowing what we know after the last state transition
    def hindsight_value(self):
        #Do not include gradient of the critic value here, just the data.
        return self.world.reward * self.zero if self.world.done else self.world.reward + args.gamma * self.value.data
         
    #Temporal Difference Loss is for the previous state!
    def get_loss(self):
        self.loss = F.mse_loss(self.prev_value,self.hindsight_value())
        if args.trace: print("Critic value and loss:",self.prev_value,self.loss)
        return self.loss
    
    def do_post_gradient(self):#NONE on main class
        pass
    
    def do_pre_update(self):
        pass
    
    def gc(self):
        del self.loss
    
class Actor(nn.Module):
    
    def __init__(self, critic: Critic,hidden_nodes=64):
        super(Actor, self).__init__()  
        self.critic = critic
        self.l1_zeros = torch.zeros([critic.world.dimension_count()], requires_grad=False, dtype=dtype, device=device)
        self.l1 = nn.Linear(critic.world.dimension_count(),hidden_nodes)
        self.l1.weight.data.normal_(0.0, np.sqrt(1./(critic.world.dimension_count())))
        self.head = nn.Linear(hidden_nodes, critic.world.action_count())
        self.head.weight.data.normal_(0.0, np.sqrt(1./hidden_nodes))
                
    def forward(self, state):
        self.inputs = state
        self.l1_out = selu(self.l1(state))
        self.value = F.softmax(selu(self.head(self.l1_out)),dim=0)
        return self.value
    
    def randomize(self):
        self.l1_out = self.l1_zeros
        self.value = F.softmax(torch.rand([self.head.out_features], requires_grad=False, dtype=dtype, device=device), dim=0)
        return self.value
    
    def choose_action(self):
        self.categories = Categorical(self.value)
        self.action = self.categories.sample()
        if args.trace: print("action scores:",self.categories.probs,"Action:",self.action.item())
        return self.action

    #The "advantage" is how much better the state is after the action than we expected it would be
    def get_loss(self):
        #Do not include gradient of prev_value here, just the data.
        advantage = self.critic.hindsight_value() - self.critic.prev_value.data
        self.loss = -self.categories.log_prob(self.action)*advantage
        if args.trace: print("actor loss:", self.loss)
        return self.loss
    
    def do_post_gradient(self):#NONE on main class
        pass
    
    def do_pre_update(self):
        pass
        
    def gc(self):
        del self.loss

class Stable_Critic(Critic):
    
    def __init__(self,world: World,hidden_nodes=32, alpha=0.001):
        super(Stable_Critic, self).__init__(world, hidden_nodes)  
        self.alpha = alpha
        self.eps = 1.e-12
        self.l1_filter = Novelty_Filter(self.l1.in_features)
        self.head_filter = Novelty_Filter(self.head.in_features)
        self.smoothed_loss = 1.e-12
        
    def get_loss(self):
        super(Stable_Critic,self).get_loss()
        self.smoothed_loss = (self.smoothed_loss * 0.999 + self.loss.abs() * 0.001).detach()
#         print("Critic losses",self.loss,self.smoothed_loss)
        return self.loss
        
    def get_certainty(self):
        return self.alpha/torch.exp(abs(self.loss.data*4./self.smoothed_loss))
        
    def do_post_gradient(self):
        self.l1.weight.grad *=  self.l1_filter.novelty(self.inputs) / (self.inputs + self.eps)
        self.head.weight.grad *= self.head_filter.novelty(self.l1_out.data) / (self.l1_out.data + self.eps)
    
    def do_pre_update(self):
        certainty_factor = self.get_certainty()
        self.l1_filter.addBasis(self.inputs.data,certainty_factor)
        self.head_filter.addBasis(self.l1_out.data,certainty_factor)
            
    def gc(self):
        super(Stable_Critic, self).gc()  
        try: 
            0 == 0
            del self.l1_novelty
        except AttributeError:
            0 == 0
        try: 
            0 == 0
            del self.head_novelty
        except AttributeError:
            0 == 0
        
class Stable_Actor(Actor):
    
    def __init__(self, critic: Critic,hidden_nodes = 64,alpha=0.001):
        super(Stable_Actor, self).__init__(critic,hidden_nodes)  
        self.alpha = alpha
        self.eps = 1.e-12
        self.l1_filter = Novelty_Filter(self.l1.in_features)
        self.head_filter = Novelty_Filter(self.head.in_features)
        self.smoothed_loss = 1.e-12
        
    def get_loss(self):
        super(Stable_Actor,self).get_loss()
        self.smoothed_loss = (self.smoothed_loss * 0.999 + self.loss.abs() * 0.001).detach()
#         print("Actor losses",self.loss,self.smoothed_loss)
        return self.loss
        
    def get_certainty(self):
        return self.alpha/torch.exp(abs(self.loss.data*4./self.smoothed_loss))
        
    def do_post_gradient(self):
        self.l1.weight.grad *=  self.l1_filter.novelty(self.inputs) / (self.inputs + self.eps)
        self.head.weight.grad *= self.head_filter.novelty(self.l1_out.data) / (self.l1_out.data + self.eps)
    
    def do_pre_update(self):
        certainty_factor = self.get_certainty()
        self.l1_filter.addBasis(self.inputs.data,certainty_factor)
        self.head_filter.addBasis(self.l1_out.data,certainty_factor)
        
    def gc(self):
        super(Stable_Actor, self).gc()  
        try: 
            0 == 0
            del self.l1_novelty
        except AttributeError:
            0 == 0
        try: 
            0 == 0
            del self.head_novelty
        except AttributeError:
            0 == 0

In [57]:
def train(episodes=1000):
    mave_reward = 10
    mave_value = 10.
    action_preferences = np.array([0.5,0.5])
    
    for i_episode in range(1,episodes+1):
        ep_reward = 0
        ep_value = 0.
        world.reset()
        critic.forward(world.state)       
        moves = 0
        ep_action_preferences = np.array([0.,0.])
        action_relevance = 1.
        for t in range(10000):
            actor.forward(world.state)
            world.step(actor.choose_action())
            ep_action_preferences += actor.categories.probs.detach().cpu().numpy()
            moves += 1
            critic.forward(world.state)
            
            ep_value += critic.value.item()
            ep_reward += world.reward

            critic_optimizer.zero_grad()
            critic.get_loss().backward()
            critic.do_post_gradient()
            critic.do_pre_update()
            critic_optimizer.step()
#             critic.gc()
 
            actor_optimizer.zero_grad()
            (action_relevance * actor.get_loss()).backward()
            actor.do_post_gradient()
            actor.do_pre_update()
            actor_optimizer.step()
            action_relevance *= args.gamma
            loss = None
#             actor.gc()
            
            if(world.done):
                if args.trace: print("DONE")
                break

        torch.cuda.empty_cache()
        ep_action_preferences /= moves
        action_preferences =  0.05 * ep_action_preferences + (1 - 0.05) * action_preferences
        mave_value /= moves
        mave_reward = 0.05 * ep_reward + (1 - 0.05) * mave_reward
        mave_value = 0.05 * ep_value + (1 - 0.05) * mave_value
        if i_episode % args.log_interval == 0:
            print('Episode {}\tLast reward: {:.2f}\tMoving average reward: {:.2f}\tMoving average critic value: {:.2f}\tAction Preferences: {:.2f},{:.2f}'.format(
                  i_episode, ep_reward, mave_reward, mave_value,action_preferences[0],action_preferences[1]))
        if mave_reward > world.env.spec.reward_threshold:
            print("Episode {}\tSolved! Running reward is now {} and "
                  "the last episode runs to {} time steps!".format(i_episode,mave_reward, t+1))
            break

In [None]:
def prime_actor(episodes=1000):

    for t in range(episodes):
        world.reset()
        critic.forward(world.state)  
        actor.forward(world.state)
        world.step(actor.choose_action())
        critic.forward(world.state)

        actor_optimizer.zero_grad()
        loss = actor.get_loss()
        loss.backward()    
#         actor.do_post_gradient()
        actor_optimizer.step()
        actor.gc()
        
        if t % args.log_interval == 0:
            print(actor.categories.probs,critic.value)
        if(world.done): print("DONE")

In [None]:
def prime_critic(episodes=1000):
    mave_reward = 10.
    mave_value = 10.
    for i_episode in range(1,episodes+1):
        ep_reward = 0.
        ep_value = 0.
        world.reset()
        critic.forward(world.state)
        moves = 0
        for t in range(1000):
            actor.randomize()
            world.step(actor.choose_action())
            moves += 1
            ep_reward += world.reward
            critic.forward(world.state)
            ep_value += critic.value.item()
            
            critic_optimizer.zero_grad()
            loss = critic.get_loss()
            loss.backward()
#             critic.do_post_gradient()
            critic_optimizer.step()
            critic.gc()
            if(world.done):
                if args.trace: print("DONE")
                break
        ep_value /= moves
        mave_reward = 0.05 * ep_reward + (1 - 0.05) * mave_reward
        mave_value = 0.05 * ep_value + (1 - 0.05) * mave_value
        if i_episode % args.log_interval == 0:
            print('Episode {}\tLast reward: {:.2f}\tMoving average reward: {:.2f}\tMoving average critic value: {:.2f}'.format(
                  i_episode, ep_reward, mave_reward, mave_value))
        if mave_reward > world.env.spec.reward_threshold:
            print("Episode {}\tSolved! Running reward is now {} and "
                  "the last episode runs to {} time steps!".format(i_episode,mave_reward, t+1))
            break

In [102]:
def reset_trainer():
    args.trace = False
    args.render = False
    global world
    global actor
    global critic
    global actor_optimizer
    global critic_optimizer
    world = World()
    critic = Stable_Critic(world,32,0.002)#world,32,0.001)
#     critic = Critic(world)
    actor = Stable_Actor(critic,64,0.002)#critic,64,0.002)
#     actor = Actor(critic)
    world.env.seed(args.seed)
    torch.manual_seed(args.seed)
    actor_optimizer = optim.Adam(actor.parameters(), lr=4e-5,weight_decay=0.0001)#lr=4e-5,weight_decay=0.00001)
    critic_optimizer = optim.Adam(critic.parameters(), lr=5e-3,weight_decay=0.0001)#lr=5e-3,weight_decay=0.00001)

In [111]:

args.trace = False
args.log_interval = 100
for i in range(3):
    reset_trainer()
    prime_critic(1000)
    prime_actor(50)
    train(1000)

Episode 100	Last reward: 30.00	Moving average reward: 23.96	Moving average critic value: 16.70
Episode 200	Last reward: 13.00	Moving average reward: 20.81	Moving average critic value: 12.19
Episode 300	Last reward: 15.00	Moving average reward: 20.68	Moving average critic value: 12.03
Episode 400	Last reward: 45.00	Moving average reward: 27.92	Moving average critic value: 13.20
Episode 500	Last reward: 39.00	Moving average reward: 23.52	Moving average critic value: 12.12
Episode 600	Last reward: 26.00	Moving average reward: 21.54	Moving average critic value: 11.94
Episode 700	Last reward: 21.00	Moving average reward: 23.35	Moving average critic value: 12.11
Episode 800	Last reward: 20.00	Moving average reward: 18.64	Moving average critic value: 9.24
Episode 900	Last reward: 32.00	Moving average reward: 21.69	Moving average critic value: 9.87
Episode 1000	Last reward: 27.00	Moving average reward: 23.24	Moving average critic value: 11.46
tensor([0.4067, 0.5933], grad_fn=<DivBackward0>) te

In [None]:
args.log_interval = 10
args.trace = False
args.render = False

reset_trainer()
train(10) 

In [None]:
print(critic.filter1.O,critic.filter2.O)
print(actor.filter1.O,actor.filter2.O)

In [None]:
eps = 1e-12
#    certainty_factor = 0.1/torch.exp(abs(self.loss))
for i in range(-6,6):
    X = torch.tensor([i], requires_grad=False, dtype=dtype, device=device)
    print(i,0.0001/torch.exp(abs(X)))

In [104]:
critic.l1_filter.O

tensor([[ 0.4112, -0.0249,  0.0298, -0.0159],
        [-0.0249,  0.3935, -0.0209,  0.0370],
        [ 0.0298, -0.0209,  0.5550, -0.0433],
        [-0.0159,  0.0370, -0.0433,  0.3683]])

In [105]:

actor.l1_filter.O

tensor([[ 0.4873, -0.0536,  0.0392, -0.0297],
        [-0.0536,  0.4505, -0.0662,  0.0567],
        [ 0.0392, -0.0662,  0.8194, -0.0619],
        [-0.0297,  0.0567, -0.0619,  0.4468]])

In [106]:
print(actor.loss,actor.smoothed_loss)
print(critic.loss,critic.smoothed_loss)

tensor([-25.8279], grad_fn=<MulBackward0>) tensor([0.4362])
tensor(2973.9817, grad_fn=<MseLossBackward>) tensor(19.4608)


In [109]:
actor.inputs

tensor([-1.5032, -0.1820,  0.0473, -0.2592])

In [110]:
critic.inputs

tensor([-1.5068, -0.3778,  0.0421,  0.0480])

In [58]:
args.trace = False
# train(100)
args.trace = True
train(1)