In [1]:
import numpy as np
from multiprocessing import Pool
from tqdm import tqdm

import gym
import time
import flexible_bus
import json
from concurrent.futures import ProcessPoolExecutor
from functools import partial
import os

import pandas as pd
import torch
import torch.nn as nn

In [None]:
def traj_simulate(seed,gamma,base_policy):
    
    """
    Func to Simulate the FRD process
        return: trajectory, return list
    
    """
    # np.random.seed(seed)
    trajectory = {
        "observations": [],
        "direct_actions":[],
        "actions": [],
        "return": 0.0
    }

    env = gym.make('FlexibleBus-v0')
    obs = env.reset()
    r = 0
    done = False
    
    while not done:
        trajectory["observations"].append(obs.tolist() if isinstance(obs, np.ndarray) else obs)
        obs = torch.tensor(obs,dtype=torch.float32)
        [p1, p2] = base_policy(obs)
        p1 = p1.item()
        p2 = p2.item()
        direct_action = [p1, p2]
        trajectory["direct_actions"].append(direct_action.tolist() if isinstance(direct_action, np.ndarray) else direct_action)
        deviate_1 = int(np.random.choice([0, 1], p=[1-p1, p1]))
        deviate_2 = int(np.random.choice([0, 1], p=[1-p2, p2]))
        action = [deviate_1,deviate_2]  
        trajectory["actions"].append(action.tolist() if isinstance(action, np.ndarray) else action)
        obs, rewards, done, info = env.step(action)
        r = r * gamma + rewards
    trajectory["return"] = r
    return trajectory, r

In [3]:
class Behavioural(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.model(x)

In [34]:
traj_num = 1000000
gamma = 0.99
behavior_pi = Behavioural(input_dim=5, hidden_dim=64, output_dim=2)
# 2. Load the weights into it
behavior_pi.load_state_dict(torch.load("./Behavioural_model.pth"))
trajectory, r = traj_simulate(gamma=gamma, base_policy=behavior_pi)
# Ensure the key names match the output of traj_simulate
trajectory

TypeError: traj_simulate() missing 1 required positional argument: 'seed'

In [40]:
df = pd.read_parquet("./Traject/trajectories.pq")
df

Unnamed: 0,trajectory_id,return,obs_0,action_0,direct_action_0,obs_1,action_1,direct_action_1,obs_2,action_2,direct_action_2,obs_3,action_3,direct_action_3,obs_4,action_4,direct_action_4
0,0,17.340915,"[2, 0, 0, 1, 3]","[0.5974307656288147, 0.20102691650390625]","[1, 0]","[1, 1, 1, 1, 1]","[0.5973367094993591, 0.6010875701904297]","[1, 1]","[3, 0, 1, 0, 1]","[0.19842304289340973, 0.20325665175914764]","[0, 0]","[1, 0, 1, 0, 0]","[0.1996067464351654, 0.6009144186973572]","[0, 0]","[1, 0, 0, 0, 0]","[0.5996896028518677, 0.6008356213569641]","[0, 0]"
1,1,13.447464,"[1, 0, 1, 0, 1]","[0.1988186538219452, 0.2009650021791458]","[0, 0]","[1, 0, 2, 1, 1]","[0.19658391177654266, 0.6007030606269836]","[0, 0]","[1, 0, 2, 0, 2]","[0.1975231021642685, 0.20006458461284637]","[0, 1]","[1, 0, 0, 0, 1]","[0.5996164679527283, 0.20083443820476532]","[0, 0]","[0, 0, 0, 0, 0]","[0.5992087125778198, 0.5997883081436157]","[0, 1]"
2,2,13.663189,"[1, 0, 1, 0, 0]","[0.1996067464351654, 0.6009144186973572]","[0, 1]","[3, 0, 0, 0, 0]","[0.6000474691390991, 0.6033831834793091]","[0, 1]","[1, 0, 0, 0, 0]","[0.5996896028518677, 0.6008356213569641]","[1, 1]","[1, 0, 1, 2, 1]","[0.19768930971622467, 0.6009380221366882]","[0, 1]","[2, 1, 0, 0, 1]","[0.6001291275024414, 0.20202820003032684]","[0, 0]"
3,3,16.582689,"[1, 0, 1, 1, 0]","[0.19903424382209778, 0.6006752848625183]","[0, 0]","[1, 0, 1, 0, 1]","[0.1988186538219452, 0.2009650021791458]","[0, 0]","[5, 0, 1, 0, 0]","[0.1989564448595047, 0.6059902310371399]","[0, 1]","[1, 0, 2, 0, 0]","[0.19875408709049225, 0.6007257699966431]","[0, 0]","[1, 0, 0, 1, 1]","[0.5987434387207031, 0.6006065011024475]","[0, 1]"
4,4,14.416757,"[2, 0, 0, 0, 0]","[0.600102961063385, 0.601844072341919]","[1, 1]","[3, 0, 1, 1, 0]","[0.19920769333839417, 0.6022247076034546]","[1, 1]","[0, 0, 1, 0, 0]","[0.1987830400466919, 0.6000381112098694]","[0, 0]","[2, 0, 1, 0, 1]","[0.19871538877487183, 0.20210181176662445]","[1, 0]","[1, 0, 0, 1, 1]","[0.5987434387207031, 0.6006065011024475]","[1, 1]"
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
999995,999995,19.420296,"[1, 0, 1, 0, 0]","[0.1996067464351654, 0.6009144186973572]","[1, 1]","[1, 0, 4, 1, 1]","[0.19522128999233246, 0.6018306016921997]","[0, 1]","[0, 0, 1, 1, 0]","[0.19849029183387756, 0.5999235510826111]","[0, 0]","[3, 0, 4, 1, 0]","[0.196189746260643, 0.6025325655937195]","[0, 1]","[1, 0, 0, 0, 1]","[0.5996164679527283, 0.20083443820476532]","[0, 0]"
999996,999996,20.633284,"[1, 0, 1, 1, 0]","[0.19903424382209778, 0.6006752848625183]","[0, 1]","[1, 0, 1, 0, 2]","[0.19780664145946503, 0.20025143027305603]","[0, 0]","[1, 0, 2, 1, 0]","[0.19746242463588715, 0.6007243990898132]","[0, 0]","[4, 0, 2, 2, 1]","[0.19518472254276276, 0.6031304597854614]","[0, 0]","[2, 1, 0, 0, 1]","[0.6001291275024414, 0.20202820003032684]","[1, 0]"
999997,999997,15.624573,"[2, 0, 2, 2, 0]","[0.19647951424121857, 0.6008825302124023]","[0, 1]","[2, 0, 1, 2, 0]","[0.1987171322107315, 0.6006354689598083]","[0, 0]","[1, 1, 0, 0, 0]","[0.6007809042930603, 0.6020244359970093]","[0, 0]","[2, 1, 0, 0, 1]","[0.6001291275024414, 0.20202820003032684]","[0, 0]","[2, 0, 1, 0, 0]","[0.19964157044887543, 0.601934015750885]","[0, 1]"
999998,999998,20.574278,"[2, 0, 0, 3, 0]","[0.5989845395088196, 0.5991877913475037]","[1, 1]","[0, 0, 1, 1, 0]","[0.19849029183387756, 0.5999235510826111]","[0, 1]","[3, 1, 2, 0, 2]","[0.19617819786071777, 0.20096886157989502]","[0, 0]","[2, 0, 1, 0, 0]","[0.19964157044887543, 0.601934015750885]","[0, 0]","[4, 0, 0, 0, 0]","[0.5998150110244751, 0.6044010519981384]","[1, 0]"
