In [1]:
# this notebook will use a basic GPT based decision transformer in offline reinforcement learning setting to create bot for trading stock
# get cuda device
# import libraries
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from cust_transf import DecisionTransformer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets.load import load_dataset
from torch.utils.data import Dataset, DataLoader
import numpy as np

# utility function to compute the discounted cumulative sum of a vector
def discount_cumsum(x, gamma):
    disc_cumsum = np.zeros_like(x)
    disc_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        disc_cumsum[t] = x[t] + gamma * disc_cumsum[t+1]
    return disc_cumsum

# define a custom dataset class which loads the data, modifies the reward to be the discounted cumulative sum and apply trajectory masking
class CustomTrajDataset(Dataset):
    def __init__(self, file_name, context_len, gamma, rtg_scale):
        self.gamma = gamma
        self.context_len = context_len

        # load the data
        data = load_dataset("json", data_files = file_name, field = 'data')
        self.data_state = np.array(data['train']['state'], dtype=np.float32)
        self.data_action = np.array(data['train']['action'], dtype=np.float32)
        self.rtg = np.array(data['train']['reward'], dtype=np.float32)


        # calculate min len, the mean and std of the state and rtg for all data
        self.stateshape = self.data_state.shape
        # calculate mean of state and rtg with numpy
        self.state_mean = np.mean(self.data_state, axis=(-2,-1), keepdims=True)
        self.state_std = np.std(np.abs(self.data_state), axis=(-2,-1), keepdims=True)
        #self.state_mean = torch.mean(data['train']['state'], dim=(-2,-1), keepdim=True)
        #self.state_std = torch.std(data['train']['state'], dim=(-2,-1), keepdim=True)
        self.norm_state = (self.data_state - self.state_mean) / self.state_std

        self.rtg = np.apply_along_axis(discount_cumsum, 1, data['train']['reward'], self.gamma) # type: ignore
        self.rtg = self.rtg / rtg_scale

    def get_state_stats(self):
        return self.state_mean, self.state_std        

    def __len__(self):
        return self.stateshape[0]

    def __getitem__(self, idx):
        state = self.norm_state[idx]
        action = self.data_action[idx]
        rtg = self.rtg[idx]

        data_len = state.shape[0]
        
        if data_len > self.context_len:
            # sample random start index
            start_idx = np.random.randint(0, data_len - self.context_len)
            # slice the data and convert to torch
            state = torch.from_numpy(state[start_idx:start_idx+self.context_len])
            action = torch.from_numpy(action[start_idx:start_idx+self.context_len])
            rtg = torch.from_numpy(rtg[start_idx:start_idx+self.context_len])
            timesteps = torch.arange(start=start_idx, end=start_idx + self.context_len, step=1)
            # trajectory mask
            mask = torch.ones(self.context_len, dtype=torch.long)
        else:
            padding_len = self.context_len - data_len

            # pad the data with zeros
            state = torch.from_numpy(state)
            state = torch.cat([state, torch.zeros((padding_len, *state.shape[1:]))], dim=0)

            action = torch.from_numpy(action)
            action = torch.cat([action, torch.zeros((padding_len, *action.shape[1:]))], dim=0)

            rtg = torch.from_numpy(rtg)
            rtg = torch.cat([rtg, torch.zeros((padding_len, *rtg.shape[1:]))], dim=0)

            timesteps = torch.arange(start=0, end=self.context_len, step=1)

            # trajectory mask
            mask = torch.cat([torch.ones(data_len, dtype=torch.long), torch.zeros(padding_len, dtype=torch.long)], dim=0)
        
        return state, action, rtg, timesteps, mask


In [3]:
# load huggingface dataset from json file in replaybuffer folder
foldername = 'replaybuffer'

# get filenames in folder
import os
filenames = os.listdir(foldername)

# get full path of files
full_filenames = [os.path.join(foldername, filename) for filename in filenames]

# create datasets and store in list from the list of filenames 
context_len = 20
Max_balance = 2147483647
gamma = 0.99

datasets = []
for name in full_filenames:
    dataset = CustomTrajDataset(name, context_len, gamma, Max_balance)
    datasets.append(dataset)

Using custom data configuration default-0ced0e52e3f61f09
Found cached dataset json (/home/victoru/.cache/huggingface/datasets/json/default-0ced0e52e3f61f09/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)
100%|██████████| 1/1 [00:00<00:00, 114.79it/s]
Using custom data configuration default-c85f0ceac462d6aa
Found cached dataset json (/home/victoru/.cache/huggingface/datasets/json/default-c85f0ceac462d6aa/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)
100%|██████████| 1/1 [00:00<00:00, 812.85it/s]
Using custom data configuration default-332854d71f22288d
Found cached dataset json (/home/victoru/.cache/huggingface/datasets/json/default-332854d71f22288d/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)
100%|██████████| 1/1 [00:00<00:00, 579.24it/s]
Using custom data configuration default-9b529c6ea512e6e3
Found cached dataset json (/home/victoru/.cache/huggingface/datasets/json/default-9b529c6ea512e6e3/0.0.0/e6070c77f18f0

In [4]:
# define training parameters
batch_size = 32
# small learning rate to try to avoid mixed precision caused NaNs
lr = 3e-5
wt_decay = 1e-4
warmup_steps = 10000
n_epochs = 500

In [5]:
# create list of dataloaders from the list of datasets
dataloaders = []
for dataset in datasets:
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    dataloaders.append(dataloader)

In [6]:
# define model parameters
# sample 1 batch from dataloader
norm_state, actions, rtg, timestep, traj_mask = next(iter(dataloader))
# use batch shape to determine state dimension
state_dim = norm_state.shape[-1]
act_dim = actions.shape[-1] # discrete action space
# use batch shape to determine context length


n_blocks = 4 # number of transformer blocks
h_dim = 96 # hidden dimension
n_heads = 6 # number of heads in multi-head attention
drop_p = 0.1 # dropout probability


In [7]:
# create the model
model = DecisionTransformer(state_dim, act_dim, n_blocks, h_dim, context_len, n_heads, drop_p).to(device)

# create optimizer
# use larger eps to try to avoid mixed precision overflow caused NaNs
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wt_decay, eps=1e-6)

# create scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(1.0, (step + 1) / warmup_steps))

# create a GradScaler for mixed precision training
scaler = torch.cuda.amp.GradScaler(growth_interval=150)
min_scale = 128

In [8]:
# test run the model
with torch.no_grad():
    norm_state, actions, rtg, timestep , traj_mask= next(iter(dataloader))
    norm_state = norm_state.to(device)
    actions = actions.to(device)
    # convert rtg to float
    rtg = rtg.to(device).float()
    timestep = timestep.to(device)
    traj_mask = traj_mask.to(device)
    action_targets = torch.clone(actions).detach().to(device)
    return_preds, state_preds, act_preds = model.forward(norm_state, rtg, timestep, actions)

    # check shape of norm_state
    print(f"shape norm_state: {norm_state.shape}")
    # check shape of rtg
    print(f"shape rtg: {rtg.shape}")
    # check shape of timestep
    print(f"shape timestep: {timestep.shape}")
    # check shape of actions
    print(f"shape actions: {actions.shape}")
    print(f"shape act_preds: {act_preds.shape}")
    print(f"shape action_targets: {action_targets.shape}")
    
    # consider only the action that are padded
    act_preds = act_preds.view(-1, act_dim)[traj_mask.view(-1) > 0]
    action_targets = action_targets.view(-1, act_dim)[traj_mask.view(-1) > 0]

    # check shape of action targets
    print(action_targets.shape)
    # check shape of action predictions
    print(act_preds.shape)

# check for nan values and inf values in the input and the output of the model
print(torch.isnan(norm_state).any())
print(torch.isnan(rtg).any())
print(torch.isnan(timestep).any())
print(torch.isnan(actions).any())
print(torch.isnan(act_preds).any())
print(torch.isnan(action_targets).any())



shape norm_state: torch.Size([32, 20, 13])
shape rtg: torch.Size([32, 20, 1])
shape timestep: torch.Size([32, 20])
shape actions: torch.Size([32, 20, 2])
shape act_preds: torch.Size([32, 20, 2])
shape action_targets: torch.Size([32, 20, 2])
torch.Size([640, 2])
torch.Size([640, 2])
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')


In [9]:
from tqdm import tqdm

# get the start time to calculate training time
import datetime

# custom training function which take in the model, dataset, optimizer, scheduler, scaler, n_epochs, min_scale
def train_model(model, dataloader, optimizer, scheduler, scaler, n_epochs, min_scale):

    # record the start time
    start_time = datetime.datetime.now()

    # define training parameters
    log_action_losses = []

    # train model
    for epoch in range(n_epochs):
        model.train()
        

        for norm_state, actions, rtg, timestep, traj_mask in tqdm(dataloader):
            # get batch data to device
            norm_state = norm_state.to(device)
            actions = actions.to(device)
            rtg = rtg.to(device).float()
            timestep = timestep.to(device)
            traj_mask = traj_mask.to(device)

            action_targets = torch.clone(actions).detach().to(device)

            # Zeroes out the gradients
            optimizer.zero_grad()

            # run forward pass with autocasting
            # disable autocasting for now to avoid mixed precision caused NaNs
            with torch.cuda.amp.autocast(enabled=False):
                _, _, act_preds = model.forward(norm_state, rtg, timestep, actions)

                # consider only the action that are padded
                act_preds = act_preds.view(-1, act_dim)[traj_mask.view(-1) > 0]
                action_targets = action_targets.view(-1, act_dim)[traj_mask.view(-1) > 0]

                # calculate losses just for actions
                loss = F.mse_loss(act_preds, action_targets, reduction='mean')

            # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
            scaler.scale(loss).backward()

            # unscale the gradients
            scaler.unscale_(optimizer)
            # Clips the gradients by norm
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)

            # scaler.step() first unscales the gradients of the optimizer's assigned params.
            # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
            # otherwise, optimizer.step() is skipped.
            scaler.step(optimizer)

            # Updates the learning rate according to the scheduler
            scheduler.step()
            # Updates the scale for next iteration.
            scaler.update()
            # enforce min scale to avoid mixed precision caused NaNs
            if scaler.get_scale() < min_scale:
                scaler._scale = torch.tensor(min_scale).to(scaler._scale)
        
            # append action loss to log
            log_action_losses.append(loss.detach().cpu().item())

        # print every 10 loss log
        if epoch % 100 == 0 or epoch == n_epochs - 1:
            print(f'Epoch {epoch}: Loss: {log_action_losses[-1]}')

    # record the end time
    end_time = datetime.datetime.now()
    print(f'Training time: {end_time - start_time}')
    
    return model, log_action_losses

In [10]:
# train model on each dataloader and store log_action_losses in a list
log_action_losses_list = []
for dataloader in dataloaders:
    _, log_action_losses = train_model(model, dataloader, optimizer, scheduler, scaler, n_epochs, min_scale)
    log_action_losses_list.append(log_action_losses)

100%|██████████| 1/1 [00:01<00:00,  1.15s/it]


Epoch 0: Loss: 0.6171316504478455


100%|██████████| 1/1 [00:00<00:00,  8.23it/s]
100%|██████████| 1/1 [00:00<00:00,  9.59it/s]
100%|██████████| 1/1 [00:00<00:00,  9.13it/s]
100%|██████████| 1/1 [00:00<00:00, 10.43it/s]
100%|██████████| 1/1 [00:00<00:00,  9.05it/s]
100%|██████████| 1/1 [00:00<00:00,  9.31it/s]
100%|██████████| 1/1 [00:00<00:00,  8.88it/s]
100%|██████████| 1/1 [00:00<00:00, 10.15it/s]
100%|██████████| 1/1 [00:00<00:00,  8.84it/s]
100%|██████████| 1/1 [00:00<00:00,  9.91it/s]
100%|██████████| 1/1 [00:00<00:00,  8.26it/s]
100%|██████████| 1/1 [00:00<00:00,  9.45it/s]
100%|██████████| 1/1 [00:00<00:00, 10.10it/s]
100%|██████████| 1/1 [00:00<00:00,  8.70it/s]
100%|██████████| 1/1 [00:00<00:00,  9.91it/s]
100%|██████████| 1/1 [00:00<00:00, 10.26it/s]
100%|██████████| 1/1 [00:00<00:00,  9.62it/s]
100%|██████████| 1/1 [00:00<00:00,  8.84it/s]
100%|██████████| 1/1 [00:00<00:00,  9.85it/s]
100%|██████████| 1/1 [00:00<00:00, 10.41it/s]
100%|██████████| 1/1 [00:00<00:00,  8.41it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 100: Loss: 0.5814082622528076


100%|██████████| 1/1 [00:00<00:00,  8.41it/s]
100%|██████████| 1/1 [00:00<00:00, 10.03it/s]
100%|██████████| 1/1 [00:00<00:00, 10.76it/s]
100%|██████████| 1/1 [00:00<00:00, 11.08it/s]
100%|██████████| 1/1 [00:00<00:00,  9.99it/s]
100%|██████████| 1/1 [00:00<00:00,  8.94it/s]
100%|██████████| 1/1 [00:00<00:00,  8.96it/s]
100%|██████████| 1/1 [00:00<00:00,  9.32it/s]
100%|██████████| 1/1 [00:00<00:00, 10.15it/s]
100%|██████████| 1/1 [00:00<00:00,  9.67it/s]
100%|██████████| 1/1 [00:00<00:00,  9.70it/s]
100%|██████████| 1/1 [00:00<00:00,  8.53it/s]
100%|██████████| 1/1 [00:00<00:00,  8.29it/s]
100%|██████████| 1/1 [00:00<00:00, 10.10it/s]
100%|██████████| 1/1 [00:00<00:00,  9.62it/s]
100%|██████████| 1/1 [00:00<00:00, 11.10it/s]
100%|██████████| 1/1 [00:00<00:00,  9.84it/s]
100%|██████████| 1/1 [00:00<00:00,  9.94it/s]
100%|██████████| 1/1 [00:00<00:00,  8.58it/s]
100%|██████████| 1/1 [00:00<00:00,  9.08it/s]
100%|██████████| 1/1 [00:00<00:00,  9.76it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 200: Loss: 0.46586957573890686


100%|██████████| 1/1 [00:00<00:00,  9.07it/s]
100%|██████████| 1/1 [00:00<00:00,  8.93it/s]
100%|██████████| 1/1 [00:00<00:00,  9.79it/s]
100%|██████████| 1/1 [00:00<00:00,  9.66it/s]
100%|██████████| 1/1 [00:00<00:00,  8.98it/s]
100%|██████████| 1/1 [00:00<00:00,  8.12it/s]
100%|██████████| 1/1 [00:00<00:00,  8.39it/s]
100%|██████████| 1/1 [00:00<00:00,  9.52it/s]
100%|██████████| 1/1 [00:00<00:00,  8.56it/s]
100%|██████████| 1/1 [00:00<00:00,  8.68it/s]
100%|██████████| 1/1 [00:00<00:00,  9.53it/s]
100%|██████████| 1/1 [00:00<00:00,  9.45it/s]
100%|██████████| 1/1 [00:00<00:00, 10.11it/s]
100%|██████████| 1/1 [00:00<00:00,  9.38it/s]
100%|██████████| 1/1 [00:00<00:00,  9.50it/s]
100%|██████████| 1/1 [00:00<00:00,  9.70it/s]
100%|██████████| 1/1 [00:00<00:00,  8.27it/s]
100%|██████████| 1/1 [00:00<00:00, 10.14it/s]
100%|██████████| 1/1 [00:00<00:00,  8.33it/s]
100%|██████████| 1/1 [00:00<00:00,  9.44it/s]
100%|██████████| 1/1 [00:00<00:00,  8.98it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 300: Loss: 0.507270336151123


100%|██████████| 1/1 [00:00<00:00,  7.56it/s]
100%|██████████| 1/1 [00:00<00:00,  9.09it/s]
100%|██████████| 1/1 [00:00<00:00,  8.80it/s]
100%|██████████| 1/1 [00:00<00:00,  9.46it/s]
100%|██████████| 1/1 [00:00<00:00,  8.75it/s]
100%|██████████| 1/1 [00:00<00:00,  8.09it/s]
100%|██████████| 1/1 [00:00<00:00,  9.66it/s]
100%|██████████| 1/1 [00:00<00:00,  7.64it/s]
100%|██████████| 1/1 [00:00<00:00,  9.37it/s]
100%|██████████| 1/1 [00:00<00:00,  9.53it/s]
100%|██████████| 1/1 [00:00<00:00,  9.20it/s]
100%|██████████| 1/1 [00:00<00:00,  8.14it/s]
100%|██████████| 1/1 [00:00<00:00,  8.90it/s]
100%|██████████| 1/1 [00:00<00:00,  8.68it/s]
100%|██████████| 1/1 [00:00<00:00,  8.18it/s]
100%|██████████| 1/1 [00:00<00:00,  9.32it/s]
100%|██████████| 1/1 [00:00<00:00,  9.47it/s]
100%|██████████| 1/1 [00:00<00:00,  9.50it/s]
100%|██████████| 1/1 [00:00<00:00, 10.13it/s]
100%|██████████| 1/1 [00:00<00:00,  8.77it/s]
100%|██████████| 1/1 [00:00<00:00,  8.89it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 400: Loss: 0.48935461044311523


100%|██████████| 1/1 [00:00<00:00,  8.67it/s]
100%|██████████| 1/1 [00:00<00:00,  9.81it/s]
100%|██████████| 1/1 [00:00<00:00,  8.69it/s]
100%|██████████| 1/1 [00:00<00:00,  9.90it/s]
100%|██████████| 1/1 [00:00<00:00, 10.20it/s]
100%|██████████| 1/1 [00:00<00:00,  9.74it/s]
100%|██████████| 1/1 [00:00<00:00,  9.27it/s]
100%|██████████| 1/1 [00:00<00:00,  8.95it/s]
100%|██████████| 1/1 [00:00<00:00,  8.77it/s]
100%|██████████| 1/1 [00:00<00:00,  9.87it/s]
100%|██████████| 1/1 [00:00<00:00,  8.61it/s]
100%|██████████| 1/1 [00:00<00:00,  8.90it/s]
100%|██████████| 1/1 [00:00<00:00, 10.18it/s]
100%|██████████| 1/1 [00:00<00:00,  9.34it/s]
100%|██████████| 1/1 [00:00<00:00,  9.93it/s]
100%|██████████| 1/1 [00:00<00:00,  8.85it/s]
100%|██████████| 1/1 [00:00<00:00,  8.82it/s]
100%|██████████| 1/1 [00:00<00:00,  8.71it/s]
100%|██████████| 1/1 [00:00<00:00,  9.27it/s]
100%|██████████| 1/1 [00:00<00:00,  8.98it/s]
100%|██████████| 1/1 [00:00<00:00,  8.81it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 499: Loss: 0.6426291465759277
Training time: 0:00:56.197096


100%|██████████| 1/1 [00:00<00:00,  9.02it/s]


Epoch 0: Loss: 0.5070634484291077


100%|██████████| 1/1 [00:00<00:00,  9.31it/s]
100%|██████████| 1/1 [00:00<00:00,  9.26it/s]
100%|██████████| 1/1 [00:00<00:00,  8.80it/s]
100%|██████████| 1/1 [00:00<00:00,  8.53it/s]
100%|██████████| 1/1 [00:00<00:00, 10.24it/s]
100%|██████████| 1/1 [00:00<00:00,  9.14it/s]
100%|██████████| 1/1 [00:00<00:00,  8.82it/s]
100%|██████████| 1/1 [00:00<00:00,  8.26it/s]
100%|██████████| 1/1 [00:00<00:00,  9.20it/s]
100%|██████████| 1/1 [00:00<00:00,  8.97it/s]
100%|██████████| 1/1 [00:00<00:00,  9.39it/s]
100%|██████████| 1/1 [00:00<00:00,  8.49it/s]
100%|██████████| 1/1 [00:00<00:00,  9.44it/s]
100%|██████████| 1/1 [00:00<00:00,  9.32it/s]
100%|██████████| 1/1 [00:00<00:00,  8.44it/s]
100%|██████████| 1/1 [00:00<00:00,  8.72it/s]
100%|██████████| 1/1 [00:00<00:00,  9.42it/s]
100%|██████████| 1/1 [00:00<00:00,  9.76it/s]
100%|██████████| 1/1 [00:00<00:00, 10.18it/s]
100%|██████████| 1/1 [00:00<00:00,  8.48it/s]
100%|██████████| 1/1 [00:00<00:00,  9.69it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 100: Loss: 0.25374898314476013


100%|██████████| 1/1 [00:00<00:00,  8.79it/s]
100%|██████████| 1/1 [00:00<00:00,  8.70it/s]
100%|██████████| 1/1 [00:00<00:00,  8.36it/s]
100%|██████████| 1/1 [00:00<00:00,  8.05it/s]
100%|██████████| 1/1 [00:00<00:00,  8.60it/s]
100%|██████████| 1/1 [00:00<00:00,  9.29it/s]
100%|██████████| 1/1 [00:00<00:00,  8.76it/s]
100%|██████████| 1/1 [00:00<00:00,  9.27it/s]
100%|██████████| 1/1 [00:00<00:00,  8.80it/s]
100%|██████████| 1/1 [00:00<00:00,  8.50it/s]
100%|██████████| 1/1 [00:00<00:00,  9.66it/s]
100%|██████████| 1/1 [00:00<00:00,  8.76it/s]
100%|██████████| 1/1 [00:00<00:00,  8.52it/s]
100%|██████████| 1/1 [00:00<00:00,  9.32it/s]
100%|██████████| 1/1 [00:00<00:00,  9.60it/s]
100%|██████████| 1/1 [00:00<00:00,  9.87it/s]
100%|██████████| 1/1 [00:00<00:00,  8.53it/s]
100%|██████████| 1/1 [00:00<00:00,  8.49it/s]
100%|██████████| 1/1 [00:00<00:00,  9.93it/s]
100%|██████████| 1/1 [00:00<00:00,  8.80it/s]
100%|██████████| 1/1 [00:00<00:00,  8.40it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 200: Loss: 0.14888881146907806


100%|██████████| 1/1 [00:00<00:00,  8.54it/s]
100%|██████████| 1/1 [00:00<00:00,  9.91it/s]
100%|██████████| 1/1 [00:00<00:00,  9.56it/s]
100%|██████████| 1/1 [00:00<00:00,  9.78it/s]
100%|██████████| 1/1 [00:00<00:00, 10.03it/s]
100%|██████████| 1/1 [00:00<00:00,  8.71it/s]
100%|██████████| 1/1 [00:00<00:00,  9.66it/s]
100%|██████████| 1/1 [00:00<00:00,  9.22it/s]
100%|██████████| 1/1 [00:00<00:00,  8.99it/s]
100%|██████████| 1/1 [00:00<00:00,  9.90it/s]
100%|██████████| 1/1 [00:00<00:00,  8.42it/s]
100%|██████████| 1/1 [00:00<00:00,  8.97it/s]
100%|██████████| 1/1 [00:00<00:00,  8.84it/s]
100%|██████████| 1/1 [00:00<00:00,  9.16it/s]
100%|██████████| 1/1 [00:00<00:00,  9.56it/s]
100%|██████████| 1/1 [00:00<00:00,  9.13it/s]
100%|██████████| 1/1 [00:00<00:00,  7.71it/s]
100%|██████████| 1/1 [00:00<00:00,  8.82it/s]
100%|██████████| 1/1 [00:00<00:00, 10.11it/s]
100%|██████████| 1/1 [00:00<00:00,  9.31it/s]
100%|██████████| 1/1 [00:00<00:00,  8.60it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 300: Loss: 0.08727464079856873


100%|██████████| 1/1 [00:00<00:00,  9.45it/s]
100%|██████████| 1/1 [00:00<00:00,  8.45it/s]
100%|██████████| 1/1 [00:00<00:00,  9.33it/s]
100%|██████████| 1/1 [00:00<00:00,  9.76it/s]
100%|██████████| 1/1 [00:00<00:00,  9.81it/s]
100%|██████████| 1/1 [00:00<00:00,  7.84it/s]
100%|██████████| 1/1 [00:00<00:00,  9.24it/s]
100%|██████████| 1/1 [00:00<00:00,  8.82it/s]
100%|██████████| 1/1 [00:00<00:00,  9.08it/s]
100%|██████████| 1/1 [00:00<00:00,  9.44it/s]
100%|██████████| 1/1 [00:00<00:00,  9.88it/s]
100%|██████████| 1/1 [00:00<00:00,  9.02it/s]
100%|██████████| 1/1 [00:00<00:00,  8.78it/s]
100%|██████████| 1/1 [00:00<00:00,  9.53it/s]
100%|██████████| 1/1 [00:00<00:00,  9.08it/s]
100%|██████████| 1/1 [00:00<00:00,  9.92it/s]
100%|██████████| 1/1 [00:00<00:00, 10.30it/s]
100%|██████████| 1/1 [00:00<00:00,  8.98it/s]
100%|██████████| 1/1 [00:00<00:00,  8.15it/s]
100%|██████████| 1/1 [00:00<00:00,  8.83it/s]
100%|██████████| 1/1 [00:00<00:00, 10.15it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 400: Loss: 0.07473816722631454


100%|██████████| 1/1 [00:00<00:00,  8.44it/s]
100%|██████████| 1/1 [00:00<00:00,  8.93it/s]
100%|██████████| 1/1 [00:00<00:00,  9.60it/s]
100%|██████████| 1/1 [00:00<00:00,  9.33it/s]
100%|██████████| 1/1 [00:00<00:00,  9.86it/s]
100%|██████████| 1/1 [00:00<00:00,  9.95it/s]
100%|██████████| 1/1 [00:00<00:00,  7.05it/s]
100%|██████████| 1/1 [00:00<00:00, 10.76it/s]
100%|██████████| 1/1 [00:00<00:00,  9.02it/s]
100%|██████████| 1/1 [00:00<00:00,  9.32it/s]
100%|██████████| 1/1 [00:00<00:00,  8.55it/s]
100%|██████████| 1/1 [00:00<00:00,  9.42it/s]
100%|██████████| 1/1 [00:00<00:00,  8.26it/s]
100%|██████████| 1/1 [00:00<00:00, 10.34it/s]
100%|██████████| 1/1 [00:00<00:00,  8.78it/s]
100%|██████████| 1/1 [00:00<00:00,  9.17it/s]
100%|██████████| 1/1 [00:00<00:00,  9.80it/s]
100%|██████████| 1/1 [00:00<00:00,  8.43it/s]
100%|██████████| 1/1 [00:00<00:00,  8.27it/s]
100%|██████████| 1/1 [00:00<00:00,  9.91it/s]
100%|██████████| 1/1 [00:00<00:00,  9.44it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 499: Loss: 0.07509516924619675
Training time: 0:00:55.544851


100%|██████████| 1/1 [00:00<00:00,  8.72it/s]


Epoch 0: Loss: 0.5485255122184753


100%|██████████| 1/1 [00:00<00:00,  8.37it/s]
100%|██████████| 1/1 [00:00<00:00, 10.00it/s]
100%|██████████| 1/1 [00:00<00:00,  9.03it/s]
100%|██████████| 1/1 [00:00<00:00, 10.33it/s]
100%|██████████| 1/1 [00:00<00:00,  8.96it/s]
100%|██████████| 1/1 [00:00<00:00,  9.48it/s]
100%|██████████| 1/1 [00:00<00:00,  9.42it/s]
100%|██████████| 1/1 [00:00<00:00,  9.20it/s]
100%|██████████| 1/1 [00:00<00:00,  9.16it/s]
100%|██████████| 1/1 [00:00<00:00,  9.08it/s]
100%|██████████| 1/1 [00:00<00:00,  8.26it/s]
100%|██████████| 1/1 [00:00<00:00,  9.00it/s]
100%|██████████| 1/1 [00:00<00:00,  9.79it/s]
100%|██████████| 1/1 [00:00<00:00,  9.90it/s]
100%|██████████| 1/1 [00:00<00:00,  9.91it/s]
100%|██████████| 1/1 [00:00<00:00,  8.52it/s]
100%|██████████| 1/1 [00:00<00:00,  8.86it/s]
100%|██████████| 1/1 [00:00<00:00,  8.59it/s]
100%|██████████| 1/1 [00:00<00:00,  8.64it/s]
100%|██████████| 1/1 [00:00<00:00,  9.91it/s]
100%|██████████| 1/1 [00:00<00:00, 10.41it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 100: Loss: 0.7470284104347229


100%|██████████| 1/1 [00:00<00:00,  8.59it/s]
100%|██████████| 1/1 [00:00<00:00,  8.69it/s]
100%|██████████| 1/1 [00:00<00:00,  8.14it/s]
100%|██████████| 1/1 [00:00<00:00,  9.08it/s]
100%|██████████| 1/1 [00:00<00:00, 10.15it/s]
100%|██████████| 1/1 [00:00<00:00,  8.92it/s]
100%|██████████| 1/1 [00:00<00:00,  8.34it/s]
100%|██████████| 1/1 [00:00<00:00,  8.30it/s]
100%|██████████| 1/1 [00:00<00:00,  8.35it/s]
100%|██████████| 1/1 [00:00<00:00,  8.74it/s]
100%|██████████| 1/1 [00:00<00:00,  8.57it/s]
100%|██████████| 1/1 [00:00<00:00, 10.04it/s]
100%|██████████| 1/1 [00:00<00:00,  8.77it/s]
100%|██████████| 1/1 [00:00<00:00,  8.80it/s]
100%|██████████| 1/1 [00:00<00:00,  9.81it/s]
100%|██████████| 1/1 [00:00<00:00,  8.57it/s]
100%|██████████| 1/1 [00:00<00:00,  9.74it/s]
100%|██████████| 1/1 [00:00<00:00,  9.47it/s]
100%|██████████| 1/1 [00:00<00:00,  8.50it/s]
100%|██████████| 1/1 [00:00<00:00,  8.69it/s]
100%|██████████| 1/1 [00:00<00:00,  8.10it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 200: Loss: 0.44558483362197876


100%|██████████| 1/1 [00:00<00:00,  8.90it/s]
100%|██████████| 1/1 [00:00<00:00,  9.67it/s]
100%|██████████| 1/1 [00:00<00:00,  8.46it/s]
100%|██████████| 1/1 [00:00<00:00,  9.11it/s]
100%|██████████| 1/1 [00:00<00:00, 10.24it/s]
100%|██████████| 1/1 [00:00<00:00, 10.16it/s]
100%|██████████| 1/1 [00:00<00:00,  9.96it/s]
100%|██████████| 1/1 [00:00<00:00,  9.61it/s]
100%|██████████| 1/1 [00:00<00:00,  9.91it/s]
100%|██████████| 1/1 [00:00<00:00,  8.18it/s]
100%|██████████| 1/1 [00:00<00:00,  9.56it/s]
100%|██████████| 1/1 [00:00<00:00,  9.51it/s]
100%|██████████| 1/1 [00:00<00:00,  9.89it/s]
100%|██████████| 1/1 [00:00<00:00,  8.48it/s]
100%|██████████| 1/1 [00:00<00:00,  8.36it/s]
100%|██████████| 1/1 [00:00<00:00,  9.12it/s]
100%|██████████| 1/1 [00:00<00:00,  8.61it/s]
100%|██████████| 1/1 [00:00<00:00,  8.73it/s]
100%|██████████| 1/1 [00:00<00:00,  7.72it/s]
100%|██████████| 1/1 [00:00<00:00,  8.62it/s]
100%|██████████| 1/1 [00:00<00:00,  8.69it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 300: Loss: 0.42618051171302795


100%|██████████| 1/1 [00:00<00:00,  8.61it/s]
100%|██████████| 1/1 [00:00<00:00,  8.54it/s]
100%|██████████| 1/1 [00:00<00:00,  9.39it/s]
100%|██████████| 1/1 [00:00<00:00,  9.35it/s]
100%|██████████| 1/1 [00:00<00:00,  8.80it/s]
100%|██████████| 1/1 [00:00<00:00,  8.70it/s]
100%|██████████| 1/1 [00:00<00:00,  8.14it/s]
100%|██████████| 1/1 [00:00<00:00,  9.38it/s]
100%|██████████| 1/1 [00:00<00:00,  8.01it/s]
100%|██████████| 1/1 [00:00<00:00,  8.74it/s]
100%|██████████| 1/1 [00:00<00:00,  7.98it/s]
100%|██████████| 1/1 [00:00<00:00,  8.60it/s]
100%|██████████| 1/1 [00:00<00:00,  8.35it/s]
100%|██████████| 1/1 [00:00<00:00,  8.81it/s]
100%|██████████| 1/1 [00:00<00:00,  8.24it/s]
100%|██████████| 1/1 [00:00<00:00,  8.69it/s]
100%|██████████| 1/1 [00:00<00:00,  9.10it/s]
100%|██████████| 1/1 [00:00<00:00,  7.91it/s]
100%|██████████| 1/1 [00:00<00:00,  8.92it/s]
100%|██████████| 1/1 [00:00<00:00,  8.70it/s]
100%|██████████| 1/1 [00:00<00:00,  9.33it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 400: Loss: 0.25230249762535095


100%|██████████| 1/1 [00:00<00:00,  7.80it/s]
100%|██████████| 1/1 [00:00<00:00,  8.81it/s]
100%|██████████| 1/1 [00:00<00:00,  9.01it/s]
100%|██████████| 1/1 [00:00<00:00,  8.57it/s]
100%|██████████| 1/1 [00:00<00:00,  8.65it/s]
100%|██████████| 1/1 [00:00<00:00,  8.92it/s]
100%|██████████| 1/1 [00:00<00:00,  9.13it/s]
100%|██████████| 1/1 [00:00<00:00,  8.56it/s]
100%|██████████| 1/1 [00:00<00:00,  9.36it/s]
100%|██████████| 1/1 [00:00<00:00,  8.36it/s]
100%|██████████| 1/1 [00:00<00:00,  9.02it/s]
100%|██████████| 1/1 [00:00<00:00,  9.40it/s]
100%|██████████| 1/1 [00:00<00:00,  8.73it/s]
100%|██████████| 1/1 [00:00<00:00,  8.33it/s]
100%|██████████| 1/1 [00:00<00:00,  8.65it/s]
100%|██████████| 1/1 [00:00<00:00,  8.28it/s]
100%|██████████| 1/1 [00:00<00:00,  8.99it/s]
100%|██████████| 1/1 [00:00<00:00,  8.71it/s]
100%|██████████| 1/1 [00:00<00:00,  8.97it/s]
100%|██████████| 1/1 [00:00<00:00,  9.16it/s]
100%|██████████| 1/1 [00:00<00:00,  9.66it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 499: Loss: 0.47627273201942444
Training time: 0:00:56.662612


100%|██████████| 32/32 [00:00<00:00, 56.74it/s]


Epoch 0: Loss: 0.2952148914337158


100%|██████████| 32/32 [00:00<00:00, 66.46it/s]
100%|██████████| 32/32 [00:00<00:00, 60.75it/s]
100%|██████████| 32/32 [00:00<00:00, 64.99it/s]
100%|██████████| 32/32 [00:00<00:00, 66.63it/s]
100%|██████████| 32/32 [00:00<00:00, 62.17it/s]
100%|██████████| 32/32 [00:00<00:00, 64.98it/s]
100%|██████████| 32/32 [00:00<00:00, 62.80it/s]
100%|██████████| 32/32 [00:00<00:00, 60.80it/s]
100%|██████████| 32/32 [00:00<00:00, 64.89it/s]
100%|██████████| 32/32 [00:00<00:00, 63.53it/s]
100%|██████████| 32/32 [00:00<00:00, 66.32it/s]
100%|██████████| 32/32 [00:00<00:00, 67.14it/s]
100%|██████████| 32/32 [00:00<00:00, 60.93it/s]
100%|██████████| 32/32 [00:00<00:00, 64.02it/s]
100%|██████████| 32/32 [00:00<00:00, 62.95it/s]
100%|██████████| 32/32 [00:00<00:00, 63.76it/s]
100%|██████████| 32/32 [00:00<00:00, 68.92it/s]
100%|██████████| 32/32 [00:00<00:00, 67.64it/s]
100%|██████████| 32/32 [00:00<00:00, 62.47it/s]
100%|██████████| 32/32 [00:00<00:00, 65.22it/s]
100%|██████████| 32/32 [00:00<00:00, 64.

Epoch 100: Loss: 0.19228926301002502


100%|██████████| 32/32 [00:00<00:00, 64.68it/s]
100%|██████████| 32/32 [00:00<00:00, 68.51it/s]
100%|██████████| 32/32 [00:00<00:00, 67.00it/s]
100%|██████████| 32/32 [00:00<00:00, 63.98it/s]
100%|██████████| 32/32 [00:00<00:00, 63.38it/s]
100%|██████████| 32/32 [00:00<00:00, 63.07it/s]
100%|██████████| 32/32 [00:00<00:00, 62.10it/s]
100%|██████████| 32/32 [00:00<00:00, 64.03it/s]
100%|██████████| 32/32 [00:00<00:00, 63.23it/s]
100%|██████████| 32/32 [00:00<00:00, 64.21it/s]
100%|██████████| 32/32 [00:00<00:00, 65.53it/s]
100%|██████████| 32/32 [00:00<00:00, 66.12it/s]
100%|██████████| 32/32 [00:00<00:00, 66.20it/s]
100%|██████████| 32/32 [00:00<00:00, 64.82it/s]
100%|██████████| 32/32 [00:00<00:00, 64.27it/s]
100%|██████████| 32/32 [00:00<00:00, 65.23it/s]
100%|██████████| 32/32 [00:00<00:00, 65.77it/s]
100%|██████████| 32/32 [00:00<00:00, 62.02it/s]
100%|██████████| 32/32 [00:00<00:00, 63.60it/s]
100%|██████████| 32/32 [00:00<00:00, 61.72it/s]
100%|██████████| 32/32 [00:00<00:00, 65.

Epoch 200: Loss: 0.19915616512298584


100%|██████████| 32/32 [00:00<00:00, 65.00it/s]
100%|██████████| 32/32 [00:00<00:00, 66.21it/s]
100%|██████████| 32/32 [00:00<00:00, 62.24it/s]
100%|██████████| 32/32 [00:00<00:00, 66.84it/s]
100%|██████████| 32/32 [00:00<00:00, 66.28it/s]
100%|██████████| 32/32 [00:00<00:00, 66.20it/s]
100%|██████████| 32/32 [00:00<00:00, 62.24it/s]
100%|██████████| 32/32 [00:00<00:00, 65.49it/s]
100%|██████████| 32/32 [00:00<00:00, 64.43it/s]
100%|██████████| 32/32 [00:00<00:00, 65.00it/s]
100%|██████████| 32/32 [00:00<00:00, 61.83it/s]
100%|██████████| 32/32 [00:00<00:00, 64.02it/s]
100%|██████████| 32/32 [00:00<00:00, 64.56it/s]
100%|██████████| 32/32 [00:00<00:00, 64.26it/s]
100%|██████████| 32/32 [00:00<00:00, 63.60it/s]
100%|██████████| 32/32 [00:00<00:00, 65.12it/s]
100%|██████████| 32/32 [00:00<00:00, 60.93it/s]
100%|██████████| 32/32 [00:00<00:00, 63.37it/s]
100%|██████████| 32/32 [00:00<00:00, 67.55it/s]
100%|██████████| 32/32 [00:00<00:00, 65.72it/s]
100%|██████████| 32/32 [00:00<00:00, 64.

Epoch 300: Loss: 0.20820163190364838


100%|██████████| 32/32 [00:00<00:00, 64.15it/s]
100%|██████████| 32/32 [00:00<00:00, 63.43it/s]
100%|██████████| 32/32 [00:00<00:00, 66.49it/s]
100%|██████████| 32/32 [00:00<00:00, 67.15it/s]
100%|██████████| 32/32 [00:00<00:00, 60.10it/s]
100%|██████████| 32/32 [00:00<00:00, 62.10it/s]
100%|██████████| 32/32 [00:00<00:00, 65.33it/s]
100%|██████████| 32/32 [00:00<00:00, 64.26it/s]
100%|██████████| 32/32 [00:00<00:00, 61.46it/s]
100%|██████████| 32/32 [00:00<00:00, 65.48it/s]
100%|██████████| 32/32 [00:00<00:00, 67.03it/s]
100%|██████████| 32/32 [00:00<00:00, 64.55it/s]
100%|██████████| 32/32 [00:00<00:00, 67.79it/s]
100%|██████████| 32/32 [00:00<00:00, 68.92it/s]
100%|██████████| 32/32 [00:00<00:00, 66.41it/s]
100%|██████████| 32/32 [00:00<00:00, 64.74it/s]
100%|██████████| 32/32 [00:00<00:00, 68.08it/s]
100%|██████████| 32/32 [00:00<00:00, 65.03it/s]
100%|██████████| 32/32 [00:00<00:00, 68.34it/s]
100%|██████████| 32/32 [00:00<00:00, 62.51it/s]
100%|██████████| 32/32 [00:00<00:00, 61.

Epoch 400: Loss: 0.20309904217720032


100%|██████████| 32/32 [00:00<00:00, 64.24it/s]
100%|██████████| 32/32 [00:00<00:00, 65.65it/s]
100%|██████████| 32/32 [00:00<00:00, 62.62it/s]
100%|██████████| 32/32 [00:00<00:00, 66.79it/s]
100%|██████████| 32/32 [00:00<00:00, 62.10it/s]
100%|██████████| 32/32 [00:00<00:00, 68.73it/s]
100%|██████████| 32/32 [00:00<00:00, 65.04it/s]
100%|██████████| 32/32 [00:00<00:00, 62.27it/s]
100%|██████████| 32/32 [00:00<00:00, 62.38it/s]
100%|██████████| 32/32 [00:00<00:00, 66.22it/s]
100%|██████████| 32/32 [00:00<00:00, 64.45it/s]
100%|██████████| 32/32 [00:00<00:00, 67.56it/s]
100%|██████████| 32/32 [00:00<00:00, 66.43it/s]
100%|██████████| 32/32 [00:00<00:00, 61.86it/s]
100%|██████████| 32/32 [00:00<00:00, 63.63it/s]
100%|██████████| 32/32 [00:00<00:00, 62.10it/s]
100%|██████████| 32/32 [00:00<00:00, 60.87it/s]
100%|██████████| 32/32 [00:00<00:00, 63.66it/s]
100%|██████████| 32/32 [00:00<00:00, 61.67it/s]
100%|██████████| 32/32 [00:00<00:00, 65.44it/s]
100%|██████████| 32/32 [00:00<00:00, 67.

Epoch 499: Loss: 0.1902758628129959
Training time: 0:04:07.257109





In [None]:
# evaluate the model by running it on the open ai gym environment
# Example of the environment usage:
# import gymanisum as gym
# import pandas as pd
# from TradingEnvClass import StockTradingEnv

# load stock price data
# df = pd.read_csv('stock_prices.csv')

# create trading environment
# env = StockTradingEnv(df, init_balance=10000, max_step=1000, random=True)

# reset environment to initial state
# obs = env.reset()

# loop over steps
# for i in range(1000):
#     # choose random action
#     action = env.action_space.sample()
#     # step forward in time
#     obs, reward, done, info = env.step(action)
#     # render environment
#     env.render()
#     # check if episode is done
#     if done:
#         break

# the model has four inputs: norm_state, rtg, timestep, actions and three outputs: return_preds, state_preds, act_preds
# norm_state is the normalized state of the environment which is a tensor of shape (batch_size, seq_len, state_dim)
# rtg is the return to go which is a tensor of shape (batch_size, seq_len)
# timestep is the timestep of the environment which is a tensor of shape (batch_size, seq_len)
# actions is the actions taken by the agent which is a tensor of shape (batch_size, seq_len, act_dim)
# return_preds is the predicted return of the environment which is a tensor of shape (batch_size, seq_len)
# state_preds is the predicted state of the environment which is a tensor of shape (batch_size, seq_len, state_dim)

# the custom environment has one input: actions which is a numpy.ndarray with shape (2,) and four outputs: obs, reward, done, info where obs and reward are numpy.ndarray and done and info are bool and dict respectively

def evaluate_on_env(model, device, context_len, env, rtg_target, rtg_scale, num_eval_ep=10, max_test_ep_len=1000, state_mean=None, state_std=None, render=False):
    
    eval_batch_size = 1 # required for forward pass

    results = {}
    total_reward = 0
    total_steps = 0

    state_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    if state_mean is None:
        state_mean = torch.zeros(state_dim).to(device)
    else:
        state_mean = torch.tensor(state_mean).to(device)
    
    if state_std is None:
        state_std = torch.ones(state_dim).to(device)
    else:
        state_std = torch.tensor(state_std).to(device)

    # same as timesteps used for training the transformer
    timestep = torch.arange(start = 0, end = max_test_ep_len, step = 1)
    timestep = timestep.repeat(eval_batch_size, 1).to(device)

    # evaluate
    model.eval()
    with torch.no_grad():
        for _ in range(num_eval_ep):

            # zeros place holders
            actions = torch.zeros((eval_batch_size, max_test_ep_len, act_dim), dtype=torch.float32, device=device)
            states = torch.zeros((eval_batch_size, max_test_ep_len, state_dim), dtype=torch.float32, device=device)
            rtg = torch.zeros((eval_batch_size, max_test_ep_len,1), dtype=torch.float32, device=device)

            # initialize environment
            running_state = env.reset()
            running_reward = 0
            running_rtg = rtg_target/rtg_scale

            for t in range(max_test_ep_len):
                total_timesteps += 1
                
                # add state in placeholder and normalize
                states[0,t] = torch.tensor(running_state).to(device)
                states[0,t] = (states[0,t] - state_mean)/state_std

                # calculate running rtg and add to placeholder
                running_rtg = running_rtg - (running_reward/rtg_scale)
                rtg[0,t] = running_rtg

                if t < context_len:
                    # run forward pass to get action
                    _, _, act_preds = model.forward(states[:,:t+1], rtg[:,:t+1], timestep[:,:t+1], actions[:,:t+1])
                    act = act_preds[0,t].detach()
                else:
                    # run forward pass to get action
                    _, _, act_preds = model.forward(states[:,t-context_len+1:t+1], rtg[:,t-context_len+1:t+1], timestep[:,t-context_len+1:t+1], actions[:,t-context_len+1:t+1])
                    act = act_preds[0,-1].detach()
                
                # step in environment using action
                running_state, running_reward, done, _ = env.step(act.cpu().numpy())

                # add action in placeholder
                actions[0,t] = act
                total_reward += running_reward

                if render:
                    env.render()
                if done:
                    break
    
    results['eval/avg_reward'] = total_reward/num_eval_ep
    results['eval/avg_steps'] = total_steps/num_eval_ep

    return results
                
