In [21]:
import torch
import numpy as np
from transformers import DecisionTransformerConfig, DecisionTransformerModel, Trainer, TrainingArguments
from torch.utils.data import Dataset
from dataclasses import dataclass
from datasets import Dataset, load_from_disk, load_dataset
import pickle
import random

In [22]:
# dataset = load_dataset("edbeeching/decision_transformer_gym_replay", "halfcheetah-expert-v2")

In [23]:
rewardmap_path = '/home/moonlab/decision_transformer/Active-sampling-multi-robot-learning/Main/trainingData/gaussian_mixture_training_data.pkl'

In [24]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
rewardmap = pickle.load(open(rewardmap_path, 'rb'), encoding='latin1')
rewardmap = torch.tensor(rewardmap, dtype=torch.float32).to(device)

In [25]:
device = "cpu"
num_robot = 3
num_traj = 10

Loading the Dataset

In [26]:
dataset = load_from_disk('/home/moonlab/Active-sampling-multi-robot-learning/Main/dt')

In [27]:
dataset

DatasetDict({
    train: Dataset({
        features: ['states', 'actions', 'rewards', 'returns_to_go', 'timesteps'],
        num_rows: 10
    })
})

In [28]:
@dataclass
class DTdatacollator:
    return_tensors: str = "pt"
    state_dim: int = num_robot*2  # size of state space
    act_dim: int = num_robot*2  # size of action space

    def __init__(self, dataset):
        self.act_dim = len(dataset[0]["actions"][0])
        self.state_dim = len(dataset[0]["states"][0])
        self.dataset = dataset

    # def _discount_cumsum(self, x, gamma):
    #     discount_cumsum = np.zeros_like(x)
    #     discount_cumsum[-1] = x[-1]
    #     for t in reversed(range(x.shape[0] - 1)):
    #         discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
    #     return discount_cumsum
    
    def __call__(self, batch):
        states = torch.from_numpy(np.array([torch.tensor(item['states']) for item in batch]))
        actions = torch.from_numpy(np.array([torch.tensor(item['actions']) for item in batch]))
        rewards = torch.from_numpy(np.array([torch.tensor(item['rewards']) for item in batch]))
        returns_to_go = torch.from_numpy(np.array([torch.tensor(item['returns_to_go']) for item in batch]))
        timesteps = torch.from_numpy(np.array([torch.tensor(item['timesteps']) for item in batch]))
        attention_mask = torch.ones((states.shape[0], states.shape[1]), dtype=torch.float32).to(device)


        return {
            "states": states,
            "actions": actions,
            "rewards": rewards,
            "returns_to_go": returns_to_go,
            "timesteps": timesteps,
            "attention_mask": attention_mask,
        }
    

In [29]:
# class DTdataset(Dataset):
#     def __init__(self, states, actions, rewards, returns, timesteps):
#         self.states = states
#         self.actions = actions
#         self.rewards = rewards
#         self.returns = returns
#         self.timesteps = timesteps

#     def __getitem__(self, index):
#         return self.states[index], self.actions[index], self.rewards[index], self.returns[index], self.timesteps[index]

#     def __len__(self):
#         return len(self.states)

In [30]:
# # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
# class Collation:
#     def __init__(self, states, actions, rewards, returns_to_go, timesteps):
#         self.states = states
#         self.actions = actions
#         self.rewards = rewards
#         self.returns_to_go = returns_to_go
#         self.timesteps = timesteps
#         self.state_dim = states.shape[2]
#         self.act_dim = actions.shape[2]
#         self.device = device

#     def __call__(self, batch):
#         print('batch', len(batch),len(batch[0]))
#         states = self.states
#         actions = self.actions
#         rewards = self.rewards
#         returns_to_go = self.returns_to_go
#         timesteps = self.timesteps
#         attention_mask = torch.ones((states.shape[0], states.shape[1]), dtype=torch.float32).to(device)
#         return {
#             'states': states,
#             'actions': actions,
#             'rewards': rewards,
#             'returns_to_go': returns_to_go,
#             'timesteps': timesteps,
#             'attention_mask': attention_mask
#         }

In [31]:
# collator = Collation(dataset['train'].features['states'], dataset['train'].features['actions'], dataset.['train'].features['rewards'], returns_to_go, timesteps)
# collator([0, 1, 2, 3, 4]).keys()

In [32]:
class TrainableDT(DecisionTransformerModel):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, **kwargs):
        output = super().forward(**kwargs)
        action_preds = output[1]
        action_targets = kwargs["actions"]
        attention_mask = kwargs["attention_mask"]
        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_targets = action_targets.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        
        loss = torch.mean((action_preds - action_targets) ** 2)

        return {"loss": loss}

    def original_forward(self, **kwargs):
        return super().forward(**kwargs)

In [33]:

collator = DTdatacollator(dataset["train"])

config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim)
model = TrainableDT(config)

In [34]:
training_args = TrainingArguments(
    output_dir="output/",
    remove_unused_columns=False,
    num_train_epochs=110,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    learning_rate=3e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=0.25,
)
# ['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_torch_npu_fused', 'adamw_apex_fused', 'adafactor', 'adamw_anyprecision', 'sgd', 'adagrad', 'adamw_bnb_8bit', 'adamw_8bit', 'lion_8bit', 'lion_32bit', 'paged_adamw_32bit', 'paged_adamw_8bit', 'paged_lion_32bit', 'paged_lion_8bit', 'rmsprop']

trainer = Trainer(
    model=model,                         
    args=training_args,
    train_dataset=dataset['train'],                          
    data_collator=collator
)

trainer.train()

 95%|█████████▌| 525/550 [00:02<00:00, 172.58it/s]

{'loss': 0.4621, 'grad_norm': 1.376784324645996, 'learning_rate': 3.03030303030303e-05, 'epoch': 100.0}


100%|██████████| 550/550 [00:03<00:00, 178.70it/s]

{'train_runtime': 3.0771, 'train_samples_per_second': 357.483, 'train_steps_per_second': 178.742, 'train_loss': 0.4435250490361994, 'epoch': 110.0}





TrainOutput(global_step=550, training_loss=0.4435250490361994, metrics={'train_runtime': 3.0771, 'train_samples_per_second': 357.483, 'train_steps_per_second': 178.742, 'total_flos': 2609474868000.0, 'train_loss': 0.4435250490361994, 'epoch': 110.0})

In [35]:
max_ep_len = 100

pred_states = torch.zeros((max_ep_len, num_robot*2))
pred_actions = torch.zeros((max_ep_len, num_robot*2))
pred_rewards = torch.zeros((max_ep_len, 1))


In [36]:
# def get_action(model, states, actions, rewards, returns_to_go, timesteps):
#     # This implementation does not condition on past rewards

#     states = states.reshape(1, -1, model.config.state_dim)
#     actions = actions.reshape(1, -1, model.config.act_dim)
#     returns_to_go = returns_to_go.reshape(1, -1, 1)
#     timesteps = timesteps.reshape(1, -1)

#     states = states[:, -model.config.max_length :]
#     actions = actions[:, -model.config.max_length :]
#     returns_to_go = returns_to_go[:, -model.config.max_length :]
#     timesteps = timesteps[:, -model.config.max_length :]
#     padding = model.config.max_length - states.shape[1]
#     # pad all tokens to sequence length
#     attention_mask = torch.cat([torch.zeros(padding), torch.ones(states.shape[1])])
#     attention_mask = attention_mask.to(dtype=torch.long).reshape(1, -1)
#     states = torch.cat([torch.zeros((1, padding, model.config.state_dim)), states], dim=1).float()
#     actions = torch.cat([torch.zeros((1, padding, model.config.act_dim)), actions], dim=1).float()
#     returns_to_go = torch.cat([torch.zeros((1, padding, 1)), returns_to_go], dim=1).float()
#     timesteps = torch.cat([torch.zeros((1, padding), dtype=torch.long), timesteps], dim=1)

#     state_preds, action_preds, return_preds = model.original_forward(
#         states=states,
#         actions=actions,
#         rewards=rewards,
#         returns_to_go=returns_to_go,
#         timesteps=timesteps,
#         attention_mask=attention_mask,
#         return_dict=False,
#     )

#     return action_preds[0, -1]

In [37]:
# torch.cuda.empty_cache()