In [1]:
import os
import torch
import numpy as np
import wandb

import pickle
from tqdm.auto import trange, tqdm
from torch.utils.data import Dataset
from dataclasses import dataclass
from datasets import load_from_disk
from omegaconf import OmegaConf
import torch.nn.functional as F

"""
from citylearn.agents.rbc import HourRBC
from citylearn.agents.q_learning import TabularQLearning
from citylearn.citylearn import CityLearnEnv
from citylearn.data import DataSet
from citylearn.reward_function import RewardFunction
from citylearn.wrappers import NormalizedObservationWrapper
from citylearn.wrappers import StableBaselines3Wrapper
from citylearn.wrappers import TabularQLearningWrapper
"""
from stable_baselines3.a2c import A2C

from torch.utils.data import DataLoader
from trajectory.models.gpt import GPT, GPTTrainer

from trajectory.utils.common import pad_along_axis
from trajectory.utils.discretization import KBinsDiscretizer
from trajectory.utils.env import create_env

%matplotlib inline
import matplotlib.pyplot as plt




In [27]:
import torch.nn as nn

In [2]:
offline_data_path = "data_interactions/best_dataset.pkl"

In [3]:
dataset = load_from_disk(offline_data_path)

In [4]:
dataset

Dataset({
    features: ['observations', 'next_observations', 'actions', 'rewards', 'dones', 'info'],
    num_rows: 8759
})

In [5]:
np.array(dataset["observations"]).shape

(8759, 44)

In [6]:
def join_trajectory(states, actions, rewards, discount=0.99):
    traj_length = states.shape[0]
    # I can vectorize this for all dataset as once,
    # but better to be safe and do it once and slow and right (and cache it)
    
    if actions.ndim == 3 :
        actions = actions.reshape(actions.shape[0],actions.shape[1])
    
    if rewards.ndim == 1 :
        rewards = rewards.reshape(rewards.shape[0],1)
        
    print("Discount "+str(discount))
    discounts = (discount ** np.arange(traj_length))

    values = np.zeros_like(rewards)
    for t in range(traj_length):
        # discounted return-to-go from state s_t:
        # r_{t+1} + y * r_{t+2} + y^2 * r_{t+3} + ...
        # .T as rewards of shape [len, 1], see https://github.com/Howuhh/faster-trajectory-transformer/issues/9
        values[t] = (rewards[t + 1:].T * discounts[:-t - 1]).sum()
    print(states.shape)
    print(actions.shape)
    print(rewards.shape)
    print(values.shape)

    joined_transition = np.concatenate([states, actions, rewards, values], axis=-1)

    return joined_transition

def segment(states, actions, rewards, terminals):
    assert len(states) == len(terminals)
    
    trajectories = {}

    episode_num = 0
    for t in trange(len(terminals), desc="Segmenting"):
        if episode_num not in trajectories:
            trajectories[episode_num] = {
                "states": [],
                "actions": [],
                "rewards": []
            }
        
        trajectories[episode_num]["states"].append(states[t])
        trajectories[episode_num]["actions"].append(actions[t])
        trajectories[episode_num]["rewards"].append(rewards[t])

        if terminals[t]:
            # next episode
            episode_num = episode_num + 1

    trajectories_lens = [len(v["states"]) for k, v in trajectories.items()]

    for t in trajectories:
        trajectories[t]["states"] = np.stack(trajectories[t]["states"], axis=0)
        trajectories[t]["actions"] = np.stack(trajectories[t]["actions"], axis=0)
        trajectories[t]["rewards"] = np.stack(trajectories[t]["rewards"], axis=0)

    return trajectories, trajectories_lens


In [7]:
trajectories,traj_lengths = segment(dataset["observations"],dataset["actions"],dataset["rewards"],dataset["dones"])
joined_transitions=[]

Segmenting:   0%|          | 0/8759 [00:00<?, ?it/s]

In [8]:
for t in tqdm(trajectories, desc="Joining transitions"):
    joined_transitions.append(
                    join_trajectory(trajectories[t]["states"], trajectories[t]["actions"], trajectories[t]["rewards"],discount = 0.99)
                )

Joining transitions:   0%|          | 0/1 [00:00<?, ?it/s]

Discount 0.99
(8759, 44)
(8759, 5)
(8759, 1)
(8759, 1)


In [9]:
num_bins = 100
strategy = "uniform"
discretizer = KBinsDiscretizer(
            np.concatenate(joined_transitions, axis=0),
            num_bins=num_bins,
            strategy=strategy
        )

In [10]:
discretizer

<trajectory.utils.discretization.KBinsDiscretizer at 0x7fd0156dc100>

In [11]:
class DiscretizedDataset(Dataset):
    def __init__(self, dataset,env_name="city_learn", num_bins=100, seq_len=10, discount=0.99, strategy="uniform", cache_path="data"):
        self.seq_len = seq_len
        self.discount = discount
        self.num_bins = num_bins
        self.dataset = dataset
        self.env_name = env_name
        
        trajectories, traj_lengths = segment(self.dataset["observations"],self.dataset["actions"],self.dataset["rewards"],self.dataset["dones"])
        self.trajectories = trajectories
        self.traj_lengths = traj_lengths
        self.cache_path = cache_path
        self.cache_name = f"{env_name}_{num_bins}_{seq_len}_{strategy}_{discount}"
        
        self.joined_transitions = []
        for t in tqdm(trajectories, desc="Joining transitions"):
            self.joined_transitions.append(
                    join_trajectory(trajectories[t]["states"], trajectories[t]["actions"], trajectories[t]["rewards"],discount = self.discount)
                )
        """
        if cache_path is None or not os.path.exists(os.path.join(cache_path, self.cache_name)):
            self.joined_transitions = []
            for t in tqdm(trajectories, desc="Joining transitions"):
                self.joined_transitions.append(
                    join_trajectory(trajectories[t]["states"], trajectories[t]["actions"], trajectories[t]["rewards"],discount = self.discount)
                )

            os.makedirs(os.path.join(cache_path), exist_ok=True)
            # save cached version
            with open(os.path.join(cache_path, self.cache_name), "wb") as f:
                pickle.dump(self.joined_transitions, f)
        else:
            with open(os.path.join(cache_path, self.cache_name), "rb") as f:
                self.joined_transitions = pickle.load(f)
        """

        self.discretizer = KBinsDiscretizer(
            np.concatenate(self.joined_transitions, axis=0),
            num_bins=num_bins,
            strategy=strategy
        )

        # get valid indices for seq_len sampling
        indices = []
        for path_ind, length in enumerate(traj_lengths):
            end = length - 1
            for i in range(end):
                indices.append((path_ind, i, i + self.seq_len))
        self.indices = np.array(indices)

    def get_env_name(self):
        return self.env.name

    def get_discretizer(self):
        return self.discretizer

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        #print(idx)
        traj_idx, start_idx, end_idx = self.indices[idx]
        
        joined = self.joined_transitions[traj_idx][start_idx:end_idx]
        

        loss_pad_mask = np.ones((self.seq_len, joined.shape[-1]))
        if joined.shape[0] < self.seq_len:
            # pad to seq_len if at the end of trajectory, mask for padding
            loss_pad_mask[joined.shape[0]:] = 0
            joined = pad_along_axis(joined, pad_to=self.seq_len, axis=0)

        joined_discrete = self.discretizer.encode(joined).reshape(-1).astype(np.longlong)
        loss_pad_mask = loss_pad_mask.reshape(-1)

        return joined_discrete[:-1], joined_discrete[1:], loss_pad_mask[:-1]


In [12]:
config = OmegaConf.load("configs/medium/city_learn.yaml")
wandb.init(
        **config.wandb,
        config=dict(OmegaConf.to_container(config, resolve=True))
    )
device = "cuda:0"

  return LooseVersion(v) >= LooseVersion(check)


In [13]:
datasets = DiscretizedDataset(dataset,discount = 0.99)

Segmenting:   0%|          | 0/8759 [00:00<?, ?it/s]

Joining transitions:   0%|          | 0/1 [00:00<?, ?it/s]

Discount 0.99
(8759, 44)
(8759, 5)
(8759, 1)
(8759, 1)


In [14]:
datasets.joined_transitions[0][0].shape

(51,)

In [15]:
batch_size = 1
dataloader = DataLoader(datasets, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)


In [46]:
device = "cuda:0"

In [47]:
for i, batch in enumerate(tqdm(dataloader, desc="Epoch", leave=False)):
    batch = [b.to(device) for b in batch]
    break

Epoch:   0%|          | 0/8758 [00:00<?, ?it/s]

In [48]:
tokens, targets, loss_pad_mask = batch

### Training Part

In [22]:
path = "configs/medium/city_learn.yaml"
config = OmegaConf.load("configs/medium/city_learn_traj.yaml")
trainer_conf = config.trainer
data_conf = config.dataset

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7fcf8c8fa430>
Traceback (most recent call last):
  File "/home/ml-stud15/anaconda3/envs/stable3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1479, in __del__
    self._shutdown_workers()
  File "/home/ml-stud15/anaconda3/envs/stable3/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1443, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/home/ml-stud15/anaconda3/envs/stable3/lib/python3.9/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/home/ml-stud15/anaconda3/envs/stable3/lib/python3.9/multiprocessing/popen_fork.py", line 40, in wait
    if not wait([self.sentinel], timeout):
  File "/home/ml-stud15/anaconda3/envs/stable3/lib/python3.9/multiprocessing/connection.py", line 936, in wait
    ready = selector.select(timeout)
  File "/home/ml-stud15/anaconda3/envs/stable3/lib/python3.9/selectors.py", lin

In [52]:
tok_embd = nn.Embedding(100 * 51, 128).to('cuda:0')


In [59]:
model.pos_emb.size()

torch.Size([1, 330, 128])

In [53]:
tokens

tensor([[93, 58, 98, 28, 18, 47, 22, 65, 71, 48, 74,  0,  1, 53,  0,  0,  4, 79,
          0,  8,  0,  0,  0,  0,  7,  0, 10, 46, 12,  0,  0, 43,  0,  0,  6, 36,
          5,  0, 42, 44,  8,  0, 54, 45, 57, 55, 49, 50, 49, 83, 60, 93, 58, 93,
         27, 28, 47, 22, 63, 60, 48, 68,  0, 11, 44,  0,  0, 45, 71,  0,  8,  0,
          0,  0,  0,  6,  0, 10, 45,  8,  0,  0, 44,  0,  0,  6, 36,  5,  0, 43,
         44, 10,  0, 54, 46, 57, 55, 49, 50, 49, 84, 60, 93, 58, 85, 22, 37, 43,
         20, 65, 53, 61, 71,  0, 28, 33,  0,  0, 71, 73,  0,  8,  0,  0,  0,  0,
          7,  0,  9, 46,  6,  0,  0, 41,  0,  0,  5, 36,  5,  0, 43, 44, 10,  0,
         54, 46, 56, 56, 49, 50, 34, 84, 59, 93, 58, 75, 22, 41, 41, 20, 65, 61,
         70, 68,  0, 43, 17,  0,  0, 82, 56,  0, 10,  0,  0, 87,  0,  8,  0,  8,
         46,  5,  0,  0, 40,  0,  0,  5, 36,  6,  0, 44, 44, 32,  0, 35, 47, 55,
         56, 49, 50, 39, 84, 59, 93, 58, 62, 20, 47, 37, 16, 65, 51, 82, 77,  0,
         54,  3,  0,  0, 87,

In [77]:
token_embeddings = tok_embd(tokens)

In [23]:
model = GPT(**config.model)
model.to(device)

GPT(
  (tok_emb): Embedding(3300, 128)
  (drop_emb): Dropout(p=0.1, inplace=False)
  (blocks): ModuleList(
    (0-3): 4 x TransformerBlock(
      (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
      (drop): Dropout(p=0.1, inplace=False)
      (attention): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
      )
      (mlp): Sequential(
        (0): Linear(in_features=128, out_features=512, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=512, out_features=128, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  (head): EinLinear(n_models=33, in_features=128, out_features=100, bias=False)
)

In [24]:
tokens, targets, loss_pad_mask = batch

In [82]:
class TransformerBlock(nn.Module):
    """ Pre-norm transformer block """
    def __init__(self, transition_dim, seq_len, embedding_dim, num_heads, attention_dropout, residual_dropout):
        super().__init__()
        self.seq_len = seq_len
        self.norm1 = nn.LayerNorm(embedding_dim).to('cuda:0')
        self.norm2 = nn.LayerNorm(embedding_dim).to('cuda:0')
        self.drop = nn.Dropout(residual_dropout).to('cuda:0')

        self.attention = nn.MultiheadAttention(embedding_dim, num_heads, attention_dropout, batch_first=True).to('cuda:0')
        self.mlp = nn.Sequential(
            nn.Linear(embedding_dim, 4 * embedding_dim).to('cuda:0'),
            nn.GELU(),
            nn.Linear(4 * embedding_dim, embedding_dim).to('cuda:0'),
            nn.Dropout(residual_dropout).to('cuda:0'),
        )
        # True value indicates that the corresponding position is not allowed to attend
        self.register_buffer("attn_mask", ~torch.tril(torch.ones(seq_len, seq_len, device='cuda:0')).to(bool))
        # mask out previous value estimates (as they have information about future)
        self.attn_mask[:, transition_dim - 1::transition_dim] = True

    def forward(self, x, state=None, attn_pad_mask=None):
        # state is a previous input to this layer
        
        x_norm = self.norm1(x)
        #print(x.shape)

        if state is None:
            # if context_len < seq_len
            
            attn_mask = self.attn_mask[:x.shape[1], :x.shape[1]]
            q, k, v = x_norm, x_norm, x_norm
        else:
            state = state.to('cuda:0')
            assert x.size(1) == 1, f'when using memory input should be 1-time-step tensor, got {x.size(1)} timesteps.'
            assert state.shape[1] + 1 <= self.seq_len, f"{state.shape[1] + 1}"

            attn_mask = None
            q, k, v = x_norm, torch.cat([state, x_norm], dim=1), torch.cat([state, x_norm], dim=1)

        new_state = k
        x = x + self.drop(self.attention(q, k, v, attn_mask=attn_mask, key_padding_mask=attn_pad_mask, need_weights=False)[0])
        x = x + self.mlp(self.norm2(x))

        return x, new_state


In [87]:
attn_mask = ~torch.tril(torch.ones(509, 509)).to(bool)

In [88]:
attn_mask

tensor([[False,  True,  True,  ...,  True,  True,  True],
        [False, False,  True,  ...,  True,  True,  True],
        [False, False, False,  ...,  True,  True,  True],
        ...,
        [False, False, False,  ..., False,  True,  True],
        [False, False, False,  ..., False, False,  True],
        [False, False, False,  ..., False, False, False]])

In [91]:
block = TransformerBlock(51,509,128,1,0.1,0.1)

In [93]:
after_block = block(token_embeddings)

Here


In [95]:
after_block[0].size()

torch.Size([1, 509, 128])