In [72]:
from datasets import load_dataset

dataset_template = load_dataset("edbeeching/decision_transformer_gym_replay", "halfcheetah-expert-v2")

In [19]:
# Base libraries
from collections import namedtuple, deque
from dataclasses import dataclass
from typing import List, Tuple, Dict, Callable, Any
import numpy as np
import pandas as pd
import random

# ML libraries
import torch

# Local imports
from board import ConnectFourField
from env import Env
from agents.random_agent import RandomAgent
from agents.deep_q_agent import DeepQAgent
from agents.cql_agent import CQLAgent
import utils

# Fix random seed
utils.seed_everything()

In [None]:
from transformers import DecisionTransformerConfig, DecisionTransformerModel, Trainer, TrainingArguments

In [49]:
Transition = namedtuple('Trajectory', ('length', 'states', 'actions', 'rewards', 'RTGs'))

"""Implement memory class"""
class Memory(object):
    def __init__(self, max_capacity):
        self.memory = deque([], maxlen=max_capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size: int, split_transitions: bool = False):
        minibatch = random.sample(self.memory, batch_size)

        if split_transitions:
            states, actions, rewards, RTGs, time_steps = [], [], [], [], []
        
            for state, action, reward, RTG, time_step in minibatch:
                states.append(state)
                actions.append(action)
                rewards.append(reward)
                RTGs.append(RTG)
                time_steps.append(time_step)
            
            minibatch = [states, actions, rewards, RTGs, time_steps]

            for i in range(len(minibatch)):
                minibatch[i] = torch.tensor(minibatch[i], dtype=torch.float32)
        
        return minibatch
    
    def __len__(self):
        return len(self.memory)
    
    def reset(self):
        self.memory.clear()


In [50]:
#TODO: Move this code into utils, or the corresponding agent or somewhere..

def calculate_RTG(rewards):
    RTGs = []
    for i in range(0,len(rewards)):
        #Sum up all the rewards occuring in the current timestep until the end!
        RTGs.append(sum(rewards[i:]))
    return RTGs
    


In [93]:
# Generate T_offline (offline Trajectories for the offline initialization of the replay buffer)
# NOTE: alternative would be to initialize with either Random Trajs, or with Traj.s from DQN Agent / CQL Agent
# NOTE: Corresponding paper: https://arxiv.org/pdf/2202.05607.pdf

NUM_TRAJ = 50
#memory = Memory(max_capacity=NUM_TRAJ)
offline_trajectories = []

env = Env()
agent = RandomAgent(env)
opponent = RandomAgent(env)

#Initialize Environment and Agent

AGENT = 1
OPPONENT = 2

for i in range(0, NUM_TRAJ):

    # Initialize other variables
    finished = -1

    #Initialize fields for trajectory
    states = []
    actions = []
    rewards = []
    dones = []


    # Make it random who gets to start the game
    # Set to true during the episode
    agent_start = random.choice([True, False])
    # Run one episode of the game
    while finished == -1:
        # Agent makes a turn
        if agent_start:
            state = env.get_state()
            action = agent.act(state)
            valid, reward, finished = agent.env.step(action, AGENT)
            
            # Update current Trajectory
            #Flatten the state to be a 42-entry 1 dimensional array
            states.append(np.ravel(state))
            actions.append(action)
            rewards.append(reward)

            if finished != -1:
                break

        else:
            agent_start = True

        # Opponent makes their turn
        action = opponent.act(env.get_state_inverted())
        valid, reward, finished = opponent.env.step(action, OPPONENT)

        if finished != -1: 
            break

    # The Flag if the Episode is finished is False for n-1 steps and True for the last step obviously..
    dones = ([False] * (len(rewards)-1)) + [True]
    #env.render_pretty()
    env.reset()
    


    assert len(states) == len(actions)
    assert len(actions) == len(rewards)
    assert len(dones) == len(rewards)
    length = len(states)
    RTGs = calculate_RTG(rewards)
    traj = [length, states, actions, rewards, RTGs, dones]


    offline_trajectories.append(traj)
    
# Sort offline buffer such that the order is descending in RTGs
# [4][0] is the first RTG of the trajectory
sorted_offline_trajectories = sorted(offline_trajectories, key=lambda x: x[4][0], reverse = True)





In [94]:
# We have a dataset containing all the offline collected data
# Using this tutorial: https://huggingface.co/blog/train-decision-transformers
# We convert the dataset into data ready for the decision transformer
N = 10
dataset = sorted_offline_trajectories[:N]
lengths = [x[0] for x in dataset]
max_episode_length = max(lengths)
print(f"Maximum Episode length: {max_episode_length}")

LENGTH = 0
STATES = 1
ACTIONS = 2
REWARDS = 3
RTGs = 4
DONES = 5

class DecisionTransformerGymDataCollator:
    return_tensors: str = "pt"
    max_len: int = 20 #subsets of the episode we use for training
    state_dim: int = 42  # size of state space
    act_dim: int = 1  # size of action space
    max_ep_len: int = 42 # max episode length in the dataset TODO: is this the correct value?
    #scale: float = 1000.0  # normalization of rewards/returns
    n_traj: int = 0 # to store the number of trajectories in the dataset

    def __init__(self, dataset) -> None:
        self.act_dim = 1
        self.state_dim = 42
        self.dataset = dataset
        self.n_traj = len(dataset)

    def _discount_cumsum(self, x, gamma):
        discount_cumsum = np.zeros_like(x)
        discount_cumsum[-1] = x[-1]
        for t in reversed(range(x.shape[0] - 1)):
            discount_cumsum[t] = x[t] + gamma * discount_cumsum[t + 1]
        return discount_cumsum

    def __call__(self, features):
        batch_size = len(features)
        # this is a bit of a hack to be able to sample of a non-uniform distribution
        batch_inds = np.random.choice(np.arange(self.n_traj), size=batch_size, replace=True)
        
        # a batch of dataset features
        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []
        # trajectory structure in the dataset: traj = [length, states, actions, rewards, RTGs, dones]
        
        for ind in batch_inds:
            # Select trajectory at given index
            feature = self.dataset[int(ind)]
            #Set si to be a random value within the range of steps in the trajectory 
            #i.e. if the traj. had 20 steps, select randomly a value between 0 and 19
            # NOTE: rewards/states/actions etc. are all equally long, here "3" is just rewards, to get the length
            si = random.randint(0, len(feature[REWARDS]) - 1)

            # get sequences from dataset
            s.append(np.array(feature[STATES][si : si + self.max_len]).reshape(1, -1, self.state_dim))
            a.append(np.array(feature[ACTIONS][si : si + self.max_len]).reshape(1, -1, self.act_dim))
            r.append(np.array(feature[REWARDS][si : si + self.max_len]).reshape(1, -1, 1))

            d.append(np.array(feature[DONES][si : si + self.max_len]).reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= self.max_ep_len] = self.max_ep_len - 1  # padding cutoff
            rtg.append(
                self._discount_cumsum(np.array(feature[REWARDS][si:]), gamma=1.0)[
                    : s[-1].shape[1]   # TODO check the +1 removed here
                ].reshape(1, -1, 1)
            )
            if rtg[-1].shape[1] < s[-1].shape[1]:
                print("if true")
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, self.state_dim)), s[-1]], axis=1)
            a[-1] = np.concatenate(
                [np.ones((1, self.max_len - tlen, self.act_dim)) * -10.0, a[-1]],
                axis=1,
            )
            r[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, self.max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, self.max_len - tlen, 1)), rtg[-1]], axis=1) / self.scale
            timesteps[-1] = np.concatenate([np.zeros((1, self.max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, self.max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).float()
        a = torch.from_numpy(np.concatenate(a, axis=0)).float()
        r = torch.from_numpy(np.concatenate(r, axis=0)).float()
        d = torch.from_numpy(np.concatenate(d, axis=0))
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).float()
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).long()
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).float()

        return {
            "states": s,
            "actions": a,
            "rewards": r,
            "returns_to_go": rtg,
            "timesteps": timesteps,
            "attention_mask": mask,
        }

Maximum Episode length: 20


In [83]:
d = dataset_template['train']
print(d[4]['dones'])

[False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False

In [95]:
class TrainableDT(DecisionTransformerModel):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, **kwargs):
        output = super().forward(**kwargs)
        # add the DT loss
        action_preds = output[1]
        action_targets = kwargs["actions"]
        attention_mask = kwargs["attention_mask"]
        act_dim = action_preds.shape[2]
        action_preds = action_preds.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        action_targets = action_targets.reshape(-1, act_dim)[attention_mask.reshape(-1) > 0]
        
        loss = torch.mean((action_preds - action_targets) ** 2)

        return {"loss": loss}

    def original_forward(self, **kwargs):
        return super().forward(**kwargs)

In [104]:
training_args = TrainingArguments(
    output_dir="output/",
    remove_unused_columns=False,
    num_train_epochs=120,
    per_device_train_batch_size=64,
    learning_rate=1e-4,
    weight_decay=1e-4,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=0.25,
)

collator = DecisionTransformerGymDataCollator(dataset)

config = DecisionTransformerConfig(state_dim=collator.state_dim, act_dim=collator.act_dim)
model = TrainableDT(config)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    data_collator=collator,
)

trainer.train()

ImportError: Using the `Trainer` with `PyTorch` requires `accelerate>=0.20.1`: Please run `pip install transformers[torch]` or `pip install accelerate -U`