In [1]:
import os
import torch
import numpy as np
import wandb

import pickle
from tqdm.auto import trange, tqdm
from torch.utils.data import Dataset
from dataclasses import dataclass
from datasets import load_from_disk
from omegaconf import OmegaConf

from citylearn.agents.rbc import HourRBC
from citylearn.agents.q_learning import TabularQLearning
from citylearn.citylearn import CityLearnEnv
from citylearn.data import DataSet
from citylearn.reward_function import RewardFunction
from citylearn.wrappers import NormalizedObservationWrapper
from citylearn.wrappers import StableBaselines3Wrapper
from citylearn.wrappers import TabularQLearningWrapper

from stable_baselines3.a2c import A2C

from torch.utils.data import DataLoader
from trajectory.models.gpt import GPT, GPTTrainer

from trajectory.utils.common import pad_along_axis
from trajectory.utils.discretization import KBinsDiscretizer
from trajectory.utils.env import create_env

%matplotlib inline
import matplotlib.pyplot as plt


  __DEFAULT = ''
  __STORAGE_SUFFIX = '_without_storage'
  __PARTIAL_LOAD_SUFFIX = '_and_partial_load'
  __PV_SUFFIX = '_and_pv'


In [2]:
offline_data_path = "data_interactions/PPO/model_PPO_timesteps_100000_seed_572.pkl"

In [3]:
dataset = load_from_disk(offline_data_path)

In [4]:
def join_trajectory(states, actions, rewards, discount=0.99):
    traj_length = states.shape[0]
    # I can vectorize this for all dataset as once,
    # but better to be safe and do it once and slow and right (and cache it)
    
    if actions.ndim == 3 :
        actions = actions.reshape(actions.shape[0],actions.shape[1])
    
    if rewards.ndim == 1 :
        rewards = rewards.reshape(rewards.shape[0],1)
        
    print("Discount "+str(discount))
    discounts = (discount ** np.arange(traj_length))

    values = np.zeros_like(rewards)
    for t in range(traj_length):
        # discounted return-to-go from state s_t:
        # r_{t+1} + y * r_{t+2} + y^2 * r_{t+3} + ...
        # .T as rewards of shape [len, 1], see https://github.com/Howuhh/faster-trajectory-transformer/issues/9
        values[t] = (rewards[t + 1:].T * discounts[:-t - 1]).sum()
    print(states.shape)
    print(actions.shape)
    print(rewards.shape)
    print(values.shape)

    joined_transition = np.concatenate([states, actions, rewards, values], axis=-1)

    return joined_transition

def segment(states, actions, rewards, terminals):
    assert len(states) == len(terminals)
    
    trajectories = {}

    episode_num = 0
    for t in trange(len(terminals), desc="Segmenting"):
        if episode_num not in trajectories:
            trajectories[episode_num] = {
                "states": [],
                "actions": [],
                "rewards": []
            }
        
        trajectories[episode_num]["states"].append(states[t])
        trajectories[episode_num]["actions"].append(actions[t])
        trajectories[episode_num]["rewards"].append(rewards[t])

        if terminals[t]:
            # next episode
            episode_num = episode_num + 1

    trajectories_lens = [len(v["states"]) for k, v in trajectories.items()]

    for t in trajectories:
        trajectories[t]["states"] = np.stack(trajectories[t]["states"], axis=0)
        trajectories[t]["actions"] = np.stack(trajectories[t]["actions"], axis=0)
        trajectories[t]["rewards"] = np.stack(trajectories[t]["rewards"], axis=0)

    return trajectories, trajectories_lens


In [5]:
trajectories,traj_lengths = segment(dataset["observations"],dataset["actions"],dataset["rewards"],dataset["dones"])

Segmenting:   0%|          | 0/100352 [00:00<?, ?it/s]

In [14]:
len(trajectories[0]["states"][0])

26

In [6]:
trajectories[0]["rewards"].reshape(8759,1)

array([[ -3.44464449],
       [-11.90737871],
       [ -5.21592383],
       ...,
       [-11.21602929],
       [ -8.80495369],
       [ -3.23890096]])

In [7]:
joined= join_trajectory(trajectories[0]["states"], trajectories[0]["actions"], trajectories[0]["rewards"],discount = 0.99)

Discount 0.99
(8759, 26)
(8759, 5)
(8759, 1)
(8759, 1)


In [8]:
class DiscretizedDataset(Dataset):
    def __init__(self, dataset,env_name="city_learn", num_bins=100, seq_len=10, discount=0.99, strategy="uniform", cache_path="data"):
        self.seq_len = seq_len
        self.discount = discount
        self.num_bins = num_bins
        self.dataset = dataset
        self.env_name = env_name
        
        trajectories, traj_lengths = segment(self.dataset["observations"],self.dataset["actions"],self.dataset["rewards"],self.dataset["dones"])
        self.trajectories = trajectories
        self.traj_lengths = traj_lengths
        self.cache_path = cache_path
        self.cache_name = f"{env_name}_{num_bins}_{seq_len}_{strategy}_{discount}"
        
        self.joined_transitions = []
        for t in tqdm(trajectories, desc="Joining transitions"):
            self.joined_transitions.append(
                    join_trajectory(trajectories[t]["states"], trajectories[t]["actions"], trajectories[t]["rewards"],discount = self.discount)
                )
        """
        if cache_path is None or not os.path.exists(os.path.join(cache_path, self.cache_name)):
            self.joined_transitions = []
            for t in tqdm(trajectories, desc="Joining transitions"):
                self.joined_transitions.append(
                    join_trajectory(trajectories[t]["states"], trajectories[t]["actions"], trajectories[t]["rewards"],discount = self.discount)
                )

            os.makedirs(os.path.join(cache_path), exist_ok=True)
            # save cached version
            with open(os.path.join(cache_path, self.cache_name), "wb") as f:
                pickle.dump(self.joined_transitions, f)
        else:
            with open(os.path.join(cache_path, self.cache_name), "rb") as f:
                self.joined_transitions = pickle.load(f)
        """

        self.discretizer = KBinsDiscretizer(
            np.concatenate(self.joined_transitions, axis=0),
            num_bins=num_bins,
            strategy=strategy
        )

        # get valid indices for seq_len sampling
        indices = []
        for path_ind, length in enumerate(traj_lengths):
            end = length - 1
            for i in range(end):
                indices.append((path_ind, i, i + self.seq_len))
        self.indices = np.array(indices)

    def get_env_name(self):
        return self.env.name

    def get_discretizer(self):
        return self.discretizer

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        #print(idx)
        traj_idx, start_idx, end_idx = self.indices[idx]
        
        joined = self.joined_transitions[traj_idx][start_idx:end_idx]
        

        loss_pad_mask = np.ones((self.seq_len, joined.shape[-1]))
        if joined.shape[0] < self.seq_len:
            # pad to seq_len if at the end of trajectory, mask for padding
            loss_pad_mask[joined.shape[0]:] = 0
            joined = pad_along_axis(joined, pad_to=self.seq_len, axis=0)

        joined_discrete = self.discretizer.encode(joined).reshape(-1).astype(np.longlong)
        loss_pad_mask = loss_pad_mask.reshape(-1)

        return joined_discrete[:-1], joined_discrete[1:], loss_pad_mask[:-1]


In [9]:
config = OmegaConf.load("configs/medium/city_learn.yaml")
wandb.init(
        **config.wandb,
        config=dict(OmegaConf.to_container(config, resolve=True))
    )
device = "cuda:0"

  return LooseVersion(v) >= LooseVersion(check)


In [10]:
datasets = DiscretizedDataset(dataset,discount = 0.99)

Segmenting:   0%|          | 0/100352 [00:00<?, ?it/s]

Joining transitions:   0%|          | 0/12 [00:00<?, ?it/s]

Discount 0.99
(8759, 26)
(8759, 5)
(8759, 1)
(8759, 1)
Discount 0.99
(8759, 26)
(8759, 5)
(8759, 1)
(8759, 1)
Discount 0.99
(8759, 26)
(8759, 5)
(8759, 1)
(8759, 1)
Discount 0.99
(8759, 26)
(8759, 5)
(8759, 1)
(8759, 1)
Discount 0.99
(8759, 26)
(8759, 5)
(8759, 1)
(8759, 1)
Discount 0.99
(8759, 26)
(8759, 5)
(8759, 1)
(8759, 1)
Discount 0.99
(8759, 26)
(8759, 5)
(8759, 1)
(8759, 1)
Discount 0.99
(8759, 26)
(8759, 5)
(8759, 1)
(8759, 1)
Discount 0.99
(8759, 26)
(8759, 5)
(8759, 1)
(8759, 1)
Discount 0.99
(8759, 26)
(8759, 5)
(8759, 1)
(8759, 1)
Discount 0.99
(8759, 26)
(8759, 5)
(8759, 1)
(8759, 1)
Discount 0.99
(4003, 26)
(4003, 5)
(4003, 1)
(4003, 1)


In [11]:
datasets.joined_transitions[0][0].shape

(33,)

In [12]:
batch_size = 64
dataloader = DataLoader(datasets, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)


In [15]:
for batch in dataloader:
    print(batch[0].size())
    break

torch.Size([64, 329])


### Training Part

In [13]:
path = "configs/medium/city_learn.yaml"
config = OmegaConf.load("configs/medium/city_learn_traj.yaml")
trainer_conf = config.trainer
data_conf = config.dataset

In [14]:
model = GPT(**config.model)
model.to(device)

GPT(
  (tok_emb): Embedding(3300, 128)
  (drop_emb): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-3): 4 x TransformerBlock(
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (drop): Dropout(p=0.1, inplace=False)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=512, out_features=128, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (head): EinLinear(n_models=33, in_features=128, out_features=100, bias=False)
)

In [15]:
num_epochs = int(3e4 / len(datasets) * trainer_conf.num_epochs_ref)

warmup_tokens = len(datasets) * data_conf.seq_len * config.model.transition_dim
final_tokens = warmup_tokens * num_epochs

In [16]:
num_epochs

14

In [17]:
trainer = GPTTrainer(
        final_tokens=final_tokens,
        warmup_tokens=warmup_tokens,
        action_weight=trainer_conf.action_weight,
        value_weight=trainer_conf.value_weight,
        reward_weight=trainer_conf.reward_weight,
        learning_rate=trainer_conf.lr,
        betas=trainer_conf.betas,
        weight_decay=trainer_conf.weight_decay,
        clip_grad=trainer_conf.clip_grad,
        eval_seed=trainer_conf.eval_seed,
        eval_every=trainer_conf.eval_every,
        eval_episodes=trainer_conf.eval_episodes,
        eval_temperature=trainer_conf.eval_temperature,
        eval_discount=trainer_conf.eval_discount,
        eval_plan_every=trainer_conf.eval_plan_every,
        eval_beam_width=trainer_conf.eval_beam_width,
        eval_beam_steps=trainer_conf.eval_beam_steps,
        eval_beam_context=trainer_conf.eval_beam_context,
        eval_sample_expand=trainer_conf.eval_sample_expand,
        eval_k_obs=trainer_conf.eval_k_obs,  # as in original implementation
        eval_k_reward=trainer_conf.eval_k_reward,
        eval_k_act=trainer_conf.eval_k_act,
        checkpoints_path=trainer_conf.checkpoints_path,
        save_every=1,
        device=device
    )

In [18]:
trainer.train(
        model=model,
        dataloader=dataloader,
        num_epochs=num_epochs
    )

Training:   0%|          | 0/14 [00:00<?, ?it/s]

Epoch:   0%|          | 0/1568 [00:00<?, ?it/s]

   EPOCH 1: 5.208583519399168


Epoch:   0%|          | 0/1568 [00:00<?, ?it/s]

   EPOCH 2: 4.204621164260852


Epoch:   0%|          | 0/1568 [01:20<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb7f1ac45e0>
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fb7f1ac45e0>Exception ignored in: Traceback (most recent call last):

  File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
<function _MultiProcessingDataLoaderIter.__del__ at 0x7fb7f1ac45e0>Traceback (most recent call last):
    
  File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
self._shutdown_workers()Traceback (most recent call last):

      File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
  File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1462, in _shutdown_workers
self._shutdown_workers()        
if w.is_alive():s

   EPOCH 3: 3.937248721385705


Epoch:   0%|          | 0/1568 [00:00<?, ?it/s]

   EPOCH 4: 3.7893836208994514


Epoch:   0%|          | 0/1568 [00:00<?, ?it/s]

   EPOCH 5: 3.70218578502506


Epoch:   0%|          | 0/1568 [00:00<?, ?it/s]

   EPOCH 6: 3.642668913940583


Epoch:   0%|          | 0/1568 [00:00<?, ?it/s]

   EPOCH 7: 3.59653997517662


Epoch:   0%|          | 0/1568 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb7f1ac45e0>
Traceback (most recent call last):
  File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1462, in _shutdown_workers
    if w.is_alive():
  File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb7f1ac45e0>
Traceback (most recent call last):
  File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/ml-stud15/a

   EPOCH 8: 3.5569109584974266


Epoch:   0%|          | 0/1568 [00:00<?, ?it/s]

   EPOCH 9: 3.5215758410977123


Epoch:   0%|          | 0/1568 [00:00<?, ?it/s]

   EPOCH 10: 3.4915563310790403


Epoch:   0%|          | 0/1568 [00:00<?, ?it/s]

   EPOCH 11: 3.467680165991402


Epoch:   0%|          | 0/1568 [00:00<?, ?it/s]

   EPOCH 12: 3.4512651699067414


Epoch:   0%|          | 0/1568 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb7f1ac45e0>
Traceback (most recent call last):
  File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1462, in _shutdown_workers
    if w.is_alive():
  File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb7f1ac45e0>
Traceback (most recent call last):
  File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/ml-stud15/a

   EPOCH 13: 3.444154729218725


Epoch:   0%|          | 0/1568 [00:00<?, ?it/s]

Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb7f1ac45e0><function _MultiProcessingDataLoaderIter.__del__ at 0x7fb7f1ac45e0>

Traceback (most recent call last):
  File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
Traceback (most recent call last):
      File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
Exception ignored in: self._shutdown_workers()    <function _MultiProcessingDataLoaderIter.__del__ at 0x7fb7f1ac45e0>
self._shutdown_workers()
  File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1462, in _shutdown_workers

Traceback (most recent call last):
      File "/home/ml-stud15/anaconda3/envs/stable_env/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1462, in _shutdown_workers
if w.is_ali

   EPOCH 14: 3.4392617026875443


GPT(
  (tok_emb): Embedding(3300, 128)
  (drop_emb): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-3): 4 x TransformerBlock(
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (drop): Dropout(p=0.1, inplace=False)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=512, out_features=128, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (head): EinLinear(n_models=33, in_features=128, out_features=100, bias=False)
)

In [None]:
device 

In [None]:
torch.cuda.is_available()