# Import Modules

In [None]:
import numpy as np
from fourrooms import Fourrooms
from IPython.display import clear_output
from aoaoc_classes import *
import matplotlib.pyplot as plt

# HyperParameters

In [None]:
# Replace the command line argparse
class Arguments:
    def __init__(self):
        # Numbers
        self.nepisodes=4000
        self.nruns=1
        self.nsteps=2000
        self.noptions=3
        
        # Learning Rates
        self.lr_term=0.1
        self.lr_intra=0.25
        self.lr_critic=0.5
        self.lr_critic_pseudo=0.5
        self.lr_criticA=0.5
        self.lr_criticA_pseudo=0.5
        self.lr_attend=0.02
        
        # Environment Parameters
        self.discount=0.99
        self.deterministic = False
        self.punishEachStep = False
        
        # Attention Parameters
        self.h_learn=False
        self.clipthres = 0.1
        self.stretchthres = 1.
        self.stretchstep = 1.
        
        # Distraction Parameters
        self.xi=1.
        self.n=0.5
        
        # Policy Parameters
        self.epsilon=1e-1
        self.temperature=1.
        
        # Objective Parameters
        self.wo1 = 1.   #q
        self.wo2 = 2.    #cosim
        self.wo3 = 2.    #entropy
        self.wo4 = 5.    #size
        self.wo4p = 2
        
        # Randomness Parameters
        self.seed=2222
        self.seed_startstate=1111
        
        # Display Parameters
        
        
        # Other Parameters
        self.baseline=True
        self.dc = 0.1
        
        
args = Arguments()

# Run

## Set up

In [None]:
rng = np.random.RandomState(args.seed)
env = Fourrooms(args.seed_startstate, args.punishEachStep, args.deterministic)
R = 50.

possible_next_goals = [68, 69, 70, 71, 72, 78, 79, 80, 81, 82, 88, 89, 90, 91, 92, 93, 99, 100, 101, 102, 103]

features = Tabular(env.observation_space.n)
nfeatures, nactions = len(features), env.action_space.n

## Main loop

In [None]:
for run in range(args.nruns):
    # Set up classes
    policy_over_options = POO(rng, nfeatures, args, R)
    CoSimObj.reset()
    options = [Option(rng, nfeatures, nactions, args, R, policy_over_options, i) for i in range(args.noptions)]

    # Loop through games
    for episode in range(args.nepisodes):
        # Initial state
        return_per_episode = 0.0
        observation = env.reset()
        phi = features(observation)    
        option = policy_over_options.sample(phi)
        action = options[option].sample(phi)
        traject = [[phi,option],[phi,option],action]
        
        # Reset record
        cumreward = 0.
        duration = 1
        option_switches = 0
        avgduration = 0.
        
        # Loop through frames in 1 game
        for step in range(args.nsteps):
            # Collect feedback from environment
            observation, reward, done, _ = env.step(action)
            phi = features(observation)
            return_per_episode += pow(args.discount,step)*reward
            
            # Store option index
            last_option = option
            
            # Check termination
            termination = options[option].terminate(phi, value=True)
            if options[option].terminate(phi):
                option = policy_over_options.sample(phi)
                option_switches += 1
                avgduration += (1./option_switches)*(duration - avgduration)
                duration = 1
        
            # Record into trajectory
            traject[0] = traject[1]
            traject[1] = [phi, option]
            traject[2] = action
            
            # Sample next action
            action = options[option].sample(phi)

            # Policy Evaluation + Policy Improvement
            options[last_option].update(traject, reward, done, phi, last_option, termination)
            policy_over_options.update(traject, reward, options[last_option].distract(reward,traject[2]), done, termination)
            
            # End of frame
            cumreward += options[last_option].distract(reward, traject[2])
            duration += 1
            if done:
                break

        print('Run {} episode {} steps {} cumreward {} avg. duration {} switches {}'.format(run, episode, step, cumreward, avgduration, option_switches))

# Visualization