In [1]:
"""
Copyright (c) Meta Platforms, Inc. and affiliates.

This source code is licensed under the CC BY-NC license found in the
LICENSE.md file in the root directory of this source tree.
"""

#from torch.utils.tensorboard import SummaryWriter
import argparse
import pickle
import random
import time
import gym

import torch
import numpy as np

from datasets import load_from_disk
import datasets

import utils
from replay_buffer import ReplayBuffer
from lamb import Lamb
#from stable_baselines3.common.vec_env import SubprocVecEnv
from pathlib import Path
from data import create_dataloader
from decision_transformer.models.decision_transformer import DecisionTransformer
from evaluation import create_vec_eval_episodes_fn, vec_evaluate_episode_rtg
from trainer import SequenceTrainer
from logger import Logger
from wrappers_custom import *
from utils_.helpers import *

from citylearn.citylearn import CityLearnEnv
from citylearn.wrappers import *
from utils_.variant_dict import variant


In [2]:
class self:
    a = 3
self = self()

In [3]:
env = CityLearnEnv(schema="citylearn_challenge_2022_phase_2")
env.central_agent = True
env = NormalizedObservationWrapper(env)
env = StableBaselines3WrapperCustom(env)

In [4]:
def _get_env_spec(env):
    state_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]
    action_range = [
            float(env.action_space.low.min()) ,
            float(env.action_space.high.max()) ,
        ]
    return state_dim,act_dim, action_range

In [5]:
self.state_dim, self.act_dim, self.action_range = _get_env_spec(env)

  logger.warn(
  logger.warn(
  logger.warn(


## Load Dataset

In [6]:
def _load_dataset(trajectories):
    states, traj_lens, returns = [], [], []
    for path in trajectories:
        states.append(path["observations"])
        traj_lens.append(len(path["observations"]))
        returns.append(np.array(path["rewards"]).sum())
    traj_lens, returns = np.array(traj_lens), np.array(returns)

        # used for input normalization
    states = np.concatenate(states, axis=0)
    state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
    num_timesteps = sum(traj_lens)

    print("=" * 50)
    print(f"Starting new experiment: city_learn")
    print(f"{len(traj_lens)} trajectories, {num_timesteps} timesteps found")
    print(f"Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}")
    print(f"Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}")
    print(f"Average length: {np.mean(traj_lens):.2f}, std: {np.std(traj_lens):.2f}")
    print(f"Max length: {np.max(traj_lens):.2f}, min: {np.min(traj_lens):.2f}")
    print("=" * 50)

    sorted_inds = np.argsort(returns)  # lowest to highest
    num_trajectories = 1
    timesteps = traj_lens[sorted_inds[-1]]
    ind = len(trajectories) - 2
    while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] < num_timesteps:
        timesteps += traj_lens[sorted_inds[ind]]
        num_trajectories += 1
        ind -= 1
    sorted_inds = sorted_inds[-num_trajectories:]
    print(sorted_inds)
    #print(trajectories[1])
    for ii in sorted_inds:
        print(ii)
    #print(trajectories[0].keys())
    trajectories = [trajectories[int(ii)] for ii in sorted_inds]

    for trajectory in trajectories:
        for key in trajectory.keys():
            trajectory[key] = np.array(trajectory[key])


    return trajectories, state_mean, state_std

In [7]:
dataset_path = "./data_interactions/sac_dataset.pkl"

In [8]:
dataset = load_from_disk(dataset_path)

In [9]:
dataset,_ = segment_v2(dataset["observations"],dataset["actions"],dataset["rewards"],dataset["dones"])
   

Segmenting:   0%|          | 0/30000 [00:00<?, ?it/s]

In [10]:
dataset[0]["rewards"].shape

(8759,)

In [11]:
trajectories = datasets.Dataset.from_dict({k: [s[k] for s in dataset] for k in dataset[0].keys()})


In [12]:
self.offline_trajs, self.state_mean, self.state_std = _load_dataset(trajectories)

Starting new experiment: city_learn
4 trajectories, 30000 timesteps found
Average return: -6148.22, std: 1638.95
Max return: -3352.98, min: -7544.23
Average length: 7500.00, std: 2180.65
Max length: 8759.00, min: 3723.00
[1 2 3]
1
2
3


In [13]:
self.replay_buffer = ReplayBuffer(1000, self.offline_trajs)

In [14]:
self.aug_trajs = []

In [15]:
self.device = "cpu"
self.target_entropy = -self.act_dim
MAX_EPISODE_LEN = 8760

In [16]:
self.target_entropy

-5

In [17]:
variant["embed_dim"]

512

In [18]:
self.model = DecisionTransformer(
            state_dim=self.state_dim,
            act_dim=self.act_dim,
            action_range=self.action_range,
            max_length=variant["K"],
            eval_context_length=variant["eval_context_length"],
            max_ep_len=MAX_EPISODE_LEN,
            hidden_size=variant["embed_dim"],
            n_layer=variant["n_layer"],
            n_head=variant["n_head"],
            n_inner=4 * variant["embed_dim"],
            activation_function=variant["activation_function"],
            n_positions=1024,
            resid_pdrop=variant["dropout"],
            attn_pdrop=variant["dropout"],
            n_ctx = 60,
            stochastic_policy=True,
            ordering=variant["ordering"],
            init_temperature=variant["init_temperature"],
            target_entropy=self.target_entropy,
        ).to(device=self.device)

nx 512
nx 512
nx 512
nx 512


In [19]:
self.optimizer = Lamb(
            self.model.parameters(),
            lr=variant["learning_rate"],
            weight_decay=variant["weight_decay"],
            eps=1e-8,
        )
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer, lambda steps: min((steps + 1) / variant["warmup_steps"], 1)
        )
self.log_temperature_optimizer = torch.optim.Adam(
            [self.model.log_temperature],
            lr=1e-4,
            betas=[0.9, 0.999],
        )

In [20]:
self.pretrain_iter = 0
self.online_iter = 0
self.total_transitions_sampled = 0
self.variant = variant
self.reward_scale = 1.0 
self.logger = Logger(variant)

Experiment log path: ./exp/2024.06.03/103455-default


In [21]:
def loss_fn(
            a_hat_dist,
            a,
            attention_mask,
            entropy_reg,
        ):
            # a_hat is a SquashedNormal Distribution
    log_likelihood = a_hat_dist.log_likelihood(a)[attention_mask > 0].mean()

    entropy = a_hat_dist.entropy().mean()
    loss = -(log_likelihood + entropy_reg * entropy)

    return (
                loss,
                -log_likelihood,
                entropy,
            )

In [22]:
def pretrain(self, eval_envs, loss_fn):
    print("\n\n\n*** Pretrain ***")

    eval_fns = [
            create_vec_eval_episodes_fn(
                vec_env=eval_envs,
                eval_rtg=self.variant["eval_rtg"],
                state_dim=self.state_dim,
                act_dim=self.act_dim,
                state_mean=self.state_mean,
                state_std=self.state_std,
                device=self.device,
                use_mean=True,
                reward_scale=self.reward_scale,
            )
        ]

    trainer = SequenceTrainer(
            model=self.model,
            optimizer=self.optimizer,
            log_temperature_optimizer=self.log_temperature_optimizer,
            scheduler=self.scheduler,
            device=self.device,
        )

    writer = (
            SummaryWriter(self.logger.log_path) if self.variant["log_to_tb"] else None
        )
    while self.pretrain_iter < self.variant["max_pretrain_iters"]:
            # in every iteration, prepare the data loader
        dataloader = create_dataloader(
                trajectories=self.offline_trajs,
                num_iters=self.variant["num_updates_per_pretrain_iter"],
                batch_size=self.variant["batch_size"],
                max_len=self.variant["K"],
                state_dim=self.state_dim,
                act_dim=self.act_dim,
                state_mean=self.state_mean,
                state_std=self.state_std,
                reward_scale=self.reward_scale,
                action_range=self.action_range,
            )

        train_outputs = trainer.train_iteration(
                loss_fn=loss_fn,
                dataloader=dataloader,
            )
        eval_outputs, eval_reward = self.evaluate(eval_fns)
        outputs = {"time/total": time.time() - self.start_time}
        outputs.update(train_outputs)
        outputs.update(eval_outputs)
        self.logger.log_metrics(
                outputs,
                iter_num=self.pretrain_iter,
                total_transitions_sampled=self.total_transitions_sampled,
                writer=writer,
            )

        self._save_model(
                path_prefix=self.logger.log_path,
                is_pretrain_model=True,
            )

        self.pretrain_iter += 1

## Pretrain

In [23]:
trainer = SequenceTrainer(
            model=self.model,
            optimizer=self.optimizer,
            log_temperature_optimizer=self.log_temperature_optimizer,
            scheduler=self.scheduler,
            device=self.device,
        )

In [24]:
dataloader = create_dataloader(
                trajectories=self.offline_trajs,
                num_iters=self.variant["num_updates_per_pretrain_iter"],
                batch_size=3,
                max_len=self.variant["K"],
                state_dim=self.state_dim,
                act_dim=self.act_dim,
                state_mean=self.state_mean,
                state_std=self.state_std,
                reward_scale=self.reward_scale,
                action_range=self.action_range,
            )
## Remember n_ctx have to follow the query, key, value 



In [25]:
dataloader

<torch.utils.data.dataloader.DataLoader at 0x16c6d0640>

In [26]:
#train_outputs = trainer.train_iteration(
#                loss_fn=loss_fn,
#                dataloader=dataloader,
#            )

In [27]:
for _, trajs in enumerate(dataloader):
    (
            states,
            actions,
            rewards,
            dones,
            rtg,
            timesteps,
            ordering,
            padding_mask,
        ) = trajs
    break

In [28]:
self.device

'cpu'

In [29]:
dataloader

<torch.utils.data.dataloader.DataLoader at 0x16c6d0640>

In [30]:
states = states.to(self.device)
actions = actions.to(self.device)
rewards = rewards.to(self.device)
dones = dones.to(self.device)
rtg = rtg.to(self.device)
timesteps = timesteps.to(self.device)
ordering = ordering.to(self.device)
padding_mask = padding_mask.to(self.device)

In [31]:
states.shape

torch.Size([3, 20, 44])

In [32]:
actions.shape

torch.Size([3, 20, 5])

In [33]:
rewards.shape

torch.Size([3, 20, 1])

In [34]:
dones.shape

torch.Size([3, 20])

In [35]:
rtg.shape

torch.Size([3, 21, 1])

In [36]:
timesteps.shape

torch.Size([3, 20])

## Model Forward

In [37]:
import torch.nn as nn
batch_size, seq_length = states.shape[0], states.shape[1]


In [38]:
hidden_size = 56

In [39]:
self.embed_return = torch.nn.Linear(1, hidden_size)
self.embed_state = torch.nn.Linear(self.state_dim, hidden_size)
self.embed_action = torch.nn.Linear(self.act_dim, hidden_size)

self.embed_ln = nn.LayerNorm(hidden_size)

In [40]:
state_embeddings = self.embed_state(states)

In [41]:
state_embeddings.shape

torch.Size([3, 20, 56])

In [42]:
returns = rtg[:, :-1]

In [43]:
state_embeddings = self.embed_state(states)
action_embeddings = self.embed_action(actions)
returns_embeddings = self.embed_return(returns)

In [44]:
stacked= torch.stack(
                (returns_embeddings, state_embeddings, action_embeddings), dim=1
            )

In [45]:
stacked_inputs = (
            torch.stack(
                (returns_embeddings, state_embeddings, action_embeddings), dim=1
            )
            .permute(0, 2, 1, 3)
            .reshape(batch_size, 3 * seq_length, hidden_size)
        )
stacked_inputs = self.embed_ln(stacked_inputs)

In [46]:
stacked_inputs.shape

torch.Size([3, 60, 56])

In [47]:
stacked_padding_mask = (
            torch.stack((padding_mask, padding_mask, padding_mask), dim=1)
            .permute(0, 2, 1)
            .reshape(batch_size, 3 * seq_length)
        )

In [48]:
stacked_padding_mask.shape

torch.Size([3, 60])

In [49]:
self.model

DecisionTransformer(
  (transformer): GPT2Model(
    (wte): Embedding(1, 512)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-3): 4 x Block(
        (ln_1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): ReLU()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (embed_timestep): Embedding(8760, 512)
  (embed_return): Linear(in_features=1, out_features=512, bias=True)
  (embed_state): Linear(in_features=44, out_features=512, bias=True)
  (embed_action): Linear(in_features=5, out_features=512, bias=

In [51]:
variant["num_updates_per_pretrain_iter"]

5000