In [24]:
import os
import torch
import numpy as np
import torch.nn.functional as F
from trajectory.utils.common import round_to_multiple

import pickle
from tqdm.auto import trange, tqdm
from torch.utils.data import Dataset
from dataclasses import dataclass
from datasets import load_from_disk
from omegaconf import OmegaConf

from trajectory.planning.sample import sample

from citylearn.agents.rbc import HourRBC
from citylearn.agents.q_learning import TabularQLearning
from citylearn.citylearn import CityLearnEnv
from citylearn.data import DataSet
from citylearn.reward_function import RewardFunction
from citylearn.wrappers import NormalizedObservationWrapper
from citylearn.wrappers import StableBaselines3Wrapper
from citylearn.wrappers import TabularQLearningWrapper

import wandb
from torch.utils.data import DataLoader
from trajectory.models.gpt import GPT, GPTTrainer
from trajectory.utils.common import pad_along_axis
from trajectory.utils.discretization import KBinsDiscretizer
from trajectory.utils.env import create_env
from trajectory.utils.CityStableEnv import EnvCityGym


from omegaconf import OmegaConf

In [2]:
device = "cpu"
checkpoints_path = 'checkpoints/city_learn/uniform/baseline'
schema =  "citylearn_challenge_2022_phase_2"
config = "configs/eval_base.yaml"
config = OmegaConf.load(config)
run_config= "configs/medium/city_learn_traj.yaml"
run_config = OmegaConf.load(run_config)

In [3]:
beam_context = config.beam_context
beam_width = config.beam_width
beam_steps = config.beam_steps
plan_every = config.plan_every
sample_expand = config.sample_expand
k_act = config.k_act
k_obs = config.k_obs
k_reward = config.k_reward
temperature = config.temperature
discount = config.discount
max_steps = 719

In [4]:
discretizer = torch.load(os.path.join(checkpoints_path, "discretizer.pt"), map_location=device)


In [5]:
model = GPT(**run_config.model)
model.eval()
model.to(device)
model.load_state_dict(torch.load(os.path.join(checkpoints_path, "model_last.pt"), map_location=device))

<All keys matched successfully>

In [178]:
env = CityLearnEnv(schema=schema)
env = EnvCityGym(env)

In [198]:
obs = env.reset()

In [199]:
obs = np.array(obs)

In [200]:
obs

array([5.83333333e-01, 1.00000000e+00, 8.53622034e-02, 1.82999992e-01,
       8.10000000e-01, 2.19999999e-01, 2.98903322e-01, 0.00000000e+00,
       0.00000000e+00, 2.98903322e-01, 1.54143333e-01, 0.00000000e+00,
       0.00000000e+00, 1.54143333e-01, 1.95058192e-08, 0.00000000e+00,
       0.00000000e+00, 1.95058192e-08, 1.26090002e-01, 0.00000000e+00,
       0.00000000e+00, 1.26090002e-01, 1.09140003e-01, 0.00000000e+00,
       0.00000000e+00, 1.09140003e-01])

In [201]:
model

GPT(
  (tok_emb): Embedding(3300, 128)
  (drop_emb): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-3): 4 x TransformerBlock(
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (drop): Dropout(p=0.1, inplace=False)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=512, out_features=128, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (head): EinLinear(n_models=33, in_features=128, out_features=100, bias=False)
)

In [202]:
 transition_dim, obs_dim, act_dim = model.transition_dim, model.observation_dim, model.action_dim

In [203]:
context = torch.zeros(1, model.transition_dim * (max_steps + 1), dtype=torch.long).to(device)

In [204]:
obs_tokens = discretizer.encode(obs, subslice=(0, obs_dim)).squeeze()

In [205]:
context[:, :model.observation_dim] = torch.as_tensor(obs_tokens, device=device)  # initial tokens for planning

In [206]:
context

tensor([[54, 99, 47,  ...,  0,  0,  0]])

In [207]:
done, total_reward, render_frames = False, 0.0, []

In [208]:
context_offset = model.transition_dim * (0 + 1) - model.action_dim - 2 

In [209]:
context_bm = context[:, :context_offset]

## Beam Plan

In [210]:
beam_width = 3
steps = beam_steps

In [211]:
rewards = torch.zeros(beam_width, steps + 1, device=context.device)
discounts = discount ** torch.arange(steps + 1, device=context.device)

In [212]:
context_size = 5
context_size = context_size * model.transition_dim
n_crop = round_to_multiple(max(0, context_bm.shape[1] - context_size), model.transition_dim)

In [213]:
context_bm = context_bm[:, n_crop:]


In [214]:
context_bm.shape

torch.Size([1, 26])

In [215]:
plan = context_bm.repeat(beam_width, 1)

In [216]:
for t in trange(steps, leave=False):
        # [beam_width * sample_expand, ...]
    plan = plan.repeat(sample_expand, 1)
    rewards = rewards.repeat(sample_expand, 1)

    #if model_state is not None:
            # [beam_width * sample_expand, cache_len, emb_dim]
    #   model_state = [s.repeat(sample_expand, 1, 1) for s in model_state]

        # sample action tokens
    plan, model_state, _ = sample(
            model, plan, model_state=None, steps=model.action_dim, top_k=k_act, temperature=temperature
        )
        # sample reward and value estimates
    plan, model_state, logits = sample(
            model, plan, model_state=None, steps=2, top_k=k_reward, temperature=temperature
        )
    probs = F.softmax(logits, dim=-1)
    reward_and_value = discretizer.expectation(probs, subslice=[model.transition_dim - 2, model.transition_dim])
    print(reward_and_value)
    rewards[:, t:t + 2] = reward_and_value
    print("---------------")
    print(rewards)
    values = (rewards * discounts).sum(-1)
    values, idxs = torch.topk(values, k=beam_width)

    plan, rewards = plan[idxs], rewards[idxs]
    
    print("------------ Final Plan------------------")
    print(plan.size())
    
    print("------ Final Rewards ---------")
    print(rewards)
    print("-------------------------------")
    model_state = [s[idxs] for s in model_state]

    if t < steps - 1:
            # sample obs unless last step
        plan, model_state, _ = sample(
                model, plan, model_state=model_state, steps=model.observation_dim, top_k=k_obs, temperature=temperature
            )
    

  0%|          | 0/5 [00:00<?, ?it/s]

tensor([[ -10.5940, -607.4584],
        [  -7.0893, -623.1907],
        [ -10.9646, -614.9783],
        [ -12.4505, -607.2290],
        [  -8.4440, -615.8477],
        [ -10.4521, -605.1689]], dtype=torch.float64, grad_fn=<SumBackward1>)
---------------
tensor([[ -10.5940, -607.4584,    0.0000,    0.0000,    0.0000,    0.0000],
        [  -7.0893, -623.1907,    0.0000,    0.0000,    0.0000,    0.0000],
        [ -10.9646, -614.9783,    0.0000,    0.0000,    0.0000,    0.0000],
        [ -12.4505, -607.2290,    0.0000,    0.0000,    0.0000,    0.0000],
        [  -8.4440, -615.8477,    0.0000,    0.0000,    0.0000,    0.0000],
        [ -10.4521, -605.1689,    0.0000,    0.0000,    0.0000,    0.0000]],
       grad_fn=<CopySlices>)
------------ Final Plan------------------
torch.Size([3, 33])
------ Final Rewards ---------
tensor([[ -10.4521, -605.1689,    0.0000,    0.0000,    0.0000,    0.0000],
        [ -10.5940, -607.4584,    0.0000,    0.0000,    0.0000,    0.0000],
        [ -12.4

In [219]:
context_bm.shape

torch.Size([1, 26])

In [177]:
rewards

tensor([[ -12.9115,  -16.2543,   -4.0252,   -2.7249,   -7.1054, -554.9442],
        [ -12.9115,  -16.2543,   -4.0252,   -2.8037,   -5.7905, -560.5143],
        [ -12.9115,  -16.2543,   -4.0252,   -2.7249,  -10.4618, -555.9832]],
       grad_fn=<IndexBackward0>)

In [70]:
values

tensor([-539.0653, -539.1727, -550.2013], grad_fn=<TopkBackward0>)

### Steps

In [194]:
sample_expand = 1

In [195]:
plan = plan.repeat(1, 1)
rewards = rewards.repeat(sample_expand, 1)

In [196]:
rewards

tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])

In [197]:
# Just one action for further one 
test,ms = model(context[:,-1:])

In [198]:
test.shape

torch.Size([1, 1, 100])

#### Sample

In [199]:
batch_size = context.shape[0]
batch_size 

1

In [200]:
raw_logits = torch.zeros(batch_size, steps, model.vocab_size, device=context.device)

In [201]:
raw_logits.shape

torch.Size([1, 5, 100])

In [202]:
logits,model_state = model(context, state=None)

In [203]:
logits.size()

torch.Size([1, 26, 100])

In [204]:
def _sample_inner(logits, top_k, temperature, greedy=False):
    logits = logits / temperature

    if top_k is not None:
        logits = top_k_logits(logits, k=top_k)

    probs = F.softmax(logits, dim=-1)

    if greedy:
        idx = torch.topk(probs, k=1, dim=-1)[-1]
    else:
        idx = torch.multinomial(probs, num_samples=1)

    return idx

In [205]:
logits[:,-1,:].shape

torch.Size([1, 100])

In [206]:
sampled_tokens = _sample_inner(logits[:, -1, :], None, 1, False)

In [207]:
 context = torch.hstack([context, sampled_tokens])

In [208]:
context.size()

torch.Size([1, 27])

In [209]:
logits

tensor([[[ -2.6365, -15.5346, -15.5312,  ..., -15.4928, -15.5322,  -2.4405],
         [ -8.1907,  -6.3107,  -4.2083,  ...,  -6.0341,  -8.8072,  -5.8740],
         [ -6.3233,  -3.6427, -13.5979,  ..., -13.4999, -13.5672,  -8.4786],
         ...,
         [  1.1276,  -0.2794,  -0.5848,  ...,   0.8913,   1.1704,   1.9366],
         [ -7.6923,  -6.4762,  -4.2407,  ..., -11.7560, -11.7819,  -8.5468],
         [  1.8463,  -0.9704,  -1.5757,  ...,  -1.0963,  -1.5979,   1.9532]]],
       grad_fn=<SliceBackward0>)

In [210]:
raw_logits[:, 0] = logits[:, -1, :]

In [211]:
raw_logits.shape

torch.Size([1, 5, 100])

### RESUMING STEPS

In [212]:
plan, model_state, _ = sample(
            model, plan, model_state=None, steps=model.action_dim, top_k=k_act, temperature=temperature
        )

In [213]:
plan.size()

torch.Size([3, 31])

For Rewards and Value

In [214]:
 plan, model_state, logits = sample(
            model, plan, model_state=model_state, steps=2, top_k=k_reward, temperature=temperature
        )

In [261]:
probs = F.softmax(logits, dim=-1)

In [262]:
probs.size()

torch.Size([3, 2, 100])

In [268]:
discretizer.expectation(probs, subslice=[model.transition_dim -2, model.transition_dim])

tensor([[  -9.5088, -614.6954],
        [ -12.1123, -602.8091],
        [ -14.4683, -612.9071]], dtype=torch.float64, grad_fn=<SumBackward1>)

In [269]:
reward_and_value = discretizer.expectation(probs, subslice=[model.transition_dim - 2, model.transition_dim])

In [270]:
rewards[:, 0:0 + 2] = reward_and_value

In [271]:
rewards

tensor([[  -9.5088, -614.6954,    0.0000,    0.0000,    0.0000,    0.0000],
        [ -12.1123, -602.8091,    0.0000,    0.0000,    0.0000,    0.0000],
        [ -14.4683, -612.9071,    0.0000,    0.0000,    0.0000,    0.0000]],
       grad_fn=<CopySlices>)

In [272]:
values = (rewards * discounts).sum(-1)

In [273]:
values

tensor([-618.0573, -608.8934, -621.2463], grad_fn=<SumBackward1>)

In [275]:
 values, idxs = torch.topk(values, k=beam_width)

In [276]:
model_state = [s[idxs] for s in model_state]

In [282]:
idxs

tensor([1, 0, 2])

In [283]:
context.shape[1]

27

In [284]:
best_idx = torch.argmax(values)

In [290]:
context.shape[1]

27

In [291]:
prediction_tokens=plan[best_idx,context.shape[1]:]

In [292]:
action_tokens = prediction_tokens[:act_dim]

In [293]:
action = discretizer.decode(action_tokens.cpu().numpy(), subslice=(obs_dim, obs_dim + act_dim)).squeeze()


In [294]:
action

array([-0.99,  0.99, -0.93,  0.09,  0.37])

In [295]:
obs, reward, done, _ = env.step(action)

In [298]:
obs = np.array(obs)

In [299]:
obs_tokens = discretizer.encode(obs, subslice=(0, obs_dim)).squeeze()

In [301]:
value_placeholder = 1e6

In [302]:
reward_tokens = discretizer.encode(
            np.array([reward, value_placeholder]),
            subslice=(transition_dim - 2, transition_dim)
        )

### NEXT

In [304]:
step = 1

In [305]:
context_offset = model.transition_dim * step

In [314]:
context[:, context_offset + obs_dim:context_offset + obs_dim + act_dim] = torch.as_tensor(action_tokens, device=device)

In [317]:
context[:, context_offset + transition_dim - 2:context_offset + transition_dim] = torch.as_tensor(reward_tokens, device=device)

In [318]:
context[:, context_offset + model.transition_dim:context_offset + model.transition_dim + model.observation_dim] = torch.as_tensor(obs_tokens, device=device)


In [324]:
test,_ = model(context[:,:59])

In [326]:
test.size()

torch.Size([1, 59, 100])