In [1]:
# this notebook will use a basic GPT based decision transformer in offline reinforcement learning setting to create bot for trading stock
# get cuda device
# import libraries
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from cust_transf import DecisionTransformer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets.load import load_dataset
from torch.utils.data import Dataset, DataLoader
import numpy as np

# utility function to compute the discounted cumulative sum of a vector
def discount_cumsum(x, gamma):
    disc_cumsum = np.zeros_like(x)
    disc_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        disc_cumsum[t] = x[t] + gamma * disc_cumsum[t+1]
    return disc_cumsum

# define a custom dataset class which loads the data, modifies the reward to be the discounted cumulative sum and apply trajectory masking
class CustomTrajDataset(Dataset):
    def __init__(self, file_name, context_len, gamma, rtg_scale):
        self.gamma = gamma
        self.context_len = context_len

        # load the data
        data = load_dataset("json", data_files = file_name, field = 'data')
        self.data_state = np.array(data['train']['state'], dtype=np.float32)
        self.data_action = np.array(data['train']['action'], dtype=np.float32)
        self.rtg = np.array(data['train']['reward'], dtype=np.float32)


        # calculate min len, the mean and std of the state and rtg for all data
        self.stateshape = self.data_state.shape
        # calculate mean of state and rtg with numpy
        self.state_mean = np.mean(self.data_state, axis=(-2,-1), keepdims=True)
        self.state_std = np.std(np.abs(self.data_state), axis=(-2,-1), keepdims=True)
        #self.state_mean = torch.mean(data['train']['state'], dim=(-2,-1), keepdim=True)
        #self.state_std = torch.std(data['train']['state'], dim=(-2,-1), keepdim=True)
        self.norm_state = (self.data_state - self.state_mean) / self.state_std

        self.rtg = np.apply_along_axis(discount_cumsum, 1, data['train']['reward'], self.gamma) # type: ignore
        self.rtg = self.rtg / rtg_scale

    def get_state_stats(self):
        return self.state_mean, self.state_std        

    def __len__(self):
        return self.stateshape[0]

    def __getitem__(self, idx):
        state = self.norm_state[idx]
        action = self.data_action[idx]
        rtg = self.rtg[idx]

        data_len = state.shape[0]
        
        if data_len > self.context_len:
            # sample random start index
            start_idx = np.random.randint(0, data_len - self.context_len)
            # slice the data and convert to torch
            state = torch.from_numpy(state[start_idx:start_idx+self.context_len])
            action = torch.from_numpy(action[start_idx:start_idx+self.context_len])
            rtg = torch.from_numpy(rtg[start_idx:start_idx+self.context_len])
            timesteps = torch.arange(start=start_idx, end=start_idx + self.context_len, step=1)
            # trajectory mask
            mask = torch.ones(self.context_len, dtype=torch.long)
        else:
            padding_len = self.context_len - data_len

            # pad the data with zeros
            state = torch.from_numpy(state)
            state = torch.cat([state, torch.zeros((padding_len, *state.shape[1:]))], dim=0)

            action = torch.from_numpy(action)
            action = torch.cat([action, torch.zeros((padding_len, *action.shape[1:]))], dim=0)

            rtg = torch.from_numpy(rtg)
            rtg = torch.cat([rtg, torch.zeros((padding_len, *rtg.shape[1:]))], dim=0)

            timesteps = torch.arange(start=0, end=self.context_len, step=1)

            # trajectory mask
            mask = torch.cat([torch.ones(data_len, dtype=torch.long), torch.zeros(padding_len, dtype=torch.long)], dim=0)
        
        return state, action, rtg, timesteps, mask


In [3]:
# load huggingface dataset from json file in replaybuffer folder
foldername = 'replaybuffer'

# get filenames in folder
import os
filenames = os.listdir(foldername)

# get full path of files
full_filenames = [os.path.join(foldername, filename) for filename in filenames]

# create datasets and store in list from the list of filenames 
context_len = 20
Max_balance = 2147483647
gamma = 0.99

datasets = []
for name in full_filenames:
    dataset = CustomTrajDataset(name, context_len, gamma, Max_balance)
    datasets.append(dataset)

Using custom data configuration default-0ced0e52e3f61f09
Found cached dataset json (/home/victoru/.cache/huggingface/datasets/json/default-0ced0e52e3f61f09/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)
100%|██████████| 1/1 [00:00<00:00, 111.09it/s]
Using custom data configuration default-c85f0ceac462d6aa
Found cached dataset json (/home/victoru/.cache/huggingface/datasets/json/default-c85f0ceac462d6aa/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)
100%|██████████| 1/1 [00:00<00:00, 767.48it/s]
Using custom data configuration default-332854d71f22288d
Found cached dataset json (/home/victoru/.cache/huggingface/datasets/json/default-332854d71f22288d/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)
100%|██████████| 1/1 [00:00<00:00, 544.50it/s]
Using custom data configuration default-9b529c6ea512e6e3
Found cached dataset json (/home/victoru/.cache/huggingface/datasets/json/default-9b529c6ea512e6e3/0.0.0/e6070c77f18f0

In [7]:
# loop through the dataset and find the highest and lowest reward
max_reward = -math.inf
min_reward = math.inf
for dataset in datasets:
    max_reward = max(max_reward, dataset.rtg.max())
    min_reward = min(min_reward, dataset.rtg.min())

print("max reward: ", max_reward)
print("min reward: ", min_reward)

max reward:  0.0011663108217613146
min reward:  -6.0801244349197e-06


In [4]:
# define training parameters
batch_size = 32
# small learning rate to try to avoid mixed precision caused NaNs
lr = 3e-5
wt_decay = 1e-4
warmup_steps = 10000
n_epochs = 500

In [5]:
# create list of dataloaders from the list of datasets
dataloaders = []
for dataset in datasets:
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
    dataloaders.append(dataloader)

In [6]:
# define model parameters
# sample 1 batch from dataloader
norm_state, actions, rtg, timestep, traj_mask = next(iter(dataloader))
# use batch shape to determine state dimension
state_dim = norm_state.shape[-1]
act_dim = actions.shape[-1] # discrete action space
# use batch shape to determine context length


n_blocks = 4 # number of transformer blocks
h_dim = 96 # hidden dimension
n_heads = 8 # number of heads in multi-head attention
drop_p = 0.1 # dropout probability


In [7]:
# create the model
model = DecisionTransformer(state_dim, act_dim, n_blocks, h_dim, context_len, n_heads, drop_p).to(device)

# create optimizer
# use larger eps to try to avoid mixed precision overflow caused NaNs
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wt_decay, eps=1e-6)

# create scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(1.0, (step + 1) / warmup_steps))

# create a GradScaler for mixed precision training
scaler = torch.cuda.amp.GradScaler(growth_interval=150)
min_scale = 128

In [8]:
# test run the model
with torch.no_grad():
    norm_state, actions, rtg, timestep , traj_mask= next(iter(dataloader))
    norm_state = norm_state.to(device)
    actions = actions.to(device)
    # convert rtg to float
    rtg = rtg.to(device).float()
    timestep = timestep.to(device)
    traj_mask = traj_mask.to(device)
    action_targets = torch.clone(actions).detach().to(device)
    return_preds, state_preds, act_preds = model.forward(norm_state, rtg, timestep, actions)

    # check shape of norm_state
    print(f"shape norm_state: {norm_state.shape}")
    # check shape of rtg
    print(f"shape rtg: {rtg.shape}")
    # check shape of timestep
    print(f"shape timestep: {timestep.shape}")
    # check shape of actions
    print(f"shape actions: {actions.shape}")
    print(f"shape act_preds: {act_preds.shape}")
    print(f"shape action_targets: {action_targets.shape}")
    
    # consider only the action that are padded
    act_preds = act_preds.view(-1, act_dim)[traj_mask.view(-1) > 0]
    action_targets = action_targets.view(-1, act_dim)[traj_mask.view(-1) > 0]

    # check shape of action targets
    print(action_targets.shape)
    # check shape of action predictions
    print(act_preds.shape)

# check for nan values and inf values in the input and the output of the model
print(torch.isnan(norm_state).any())
print(torch.isnan(rtg).any())
print(torch.isnan(timestep).any())
print(torch.isnan(actions).any())
print(torch.isnan(act_preds).any())
print(torch.isnan(action_targets).any())



shape norm_state: torch.Size([32, 20, 13])
shape rtg: torch.Size([32, 20, 1])
shape timestep: torch.Size([32, 20])
shape actions: torch.Size([32, 20, 2])
shape act_preds: torch.Size([32, 20, 2])
shape action_targets: torch.Size([32, 20, 2])
torch.Size([640, 2])
torch.Size([640, 2])
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')


In [9]:
from tqdm import tqdm

# get the start time to calculate training time
import datetime

# custom training function which take in the model, dataset, optimizer, scheduler, scaler, n_epochs, min_scale
def train_model(model, dataloader, optimizer, scheduler, scaler, n_epochs, min_scale):

    # record the start time
    start_time = datetime.datetime.now()

    # define training parameters
    log_action_losses = []

    # train model
    for epoch in range(n_epochs):
        model.train()
        

        for norm_state, actions, rtg, timestep, traj_mask in tqdm(dataloader):
            # get batch data to device
            norm_state = norm_state.to(device)
            actions = actions.to(device)
            rtg = rtg.to(device).float()
            timestep = timestep.to(device)
            traj_mask = traj_mask.to(device)

            action_targets = torch.clone(actions).detach().to(device)

            # Zeroes out the gradients
            optimizer.zero_grad()

            # run forward pass with autocasting
            # disable autocasting for now to avoid mixed precision caused NaNs
            with torch.cuda.amp.autocast(enabled=False):
                _, _, act_preds = model.forward(norm_state, rtg, timestep, actions)

                # consider only the action that are padded
                act_preds = act_preds.view(-1, act_dim)[traj_mask.view(-1) > 0]
                action_targets = action_targets.view(-1, act_dim)[traj_mask.view(-1) > 0]

                # calculate losses just for actions
                loss = F.mse_loss(act_preds, action_targets, reduction='mean')

            # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
            scaler.scale(loss).backward()

            # unscale the gradients
            scaler.unscale_(optimizer)
            # Clips the gradients by norm
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)

            # scaler.step() first unscales the gradients of the optimizer's assigned params.
            # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
            # otherwise, optimizer.step() is skipped.
            scaler.step(optimizer)

            # Updates the learning rate according to the scheduler
            scheduler.step()
            # Updates the scale for next iteration.
            scaler.update()
            # enforce min scale to avoid mixed precision caused NaNs
            if scaler.get_scale() < min_scale:
                scaler._scale = torch.tensor(min_scale).to(scaler._scale)
        
            # append action loss to log
            log_action_losses.append(loss.detach().cpu().item())

        # print every 10 loss log
        if epoch % 100 == 0 or epoch == n_epochs - 1:
            print(f'Epoch {epoch}: Loss: {log_action_losses[-1]}')

    # record the end time
    end_time = datetime.datetime.now()
    print(f'Training time: {end_time - start_time}')
    
    return model, log_action_losses

In [10]:
# train model on each dataloader and store log_action_losses in a list
log_action_losses_list = []
for dataloader in dataloaders:
    _, log_action_losses = train_model(model, dataloader, optimizer, scheduler, scaler, n_epochs, min_scale)
    log_action_losses_list.append(log_action_losses)

100%|██████████| 1/1 [00:00<00:00,  7.64it/s]


Epoch 0: Loss: 0.34466972947120667


100%|██████████| 1/1 [00:00<00:00,  7.99it/s]
100%|██████████| 1/1 [00:00<00:00,  8.99it/s]
100%|██████████| 1/1 [00:00<00:00,  9.69it/s]
100%|██████████| 1/1 [00:00<00:00,  9.04it/s]
100%|██████████| 1/1 [00:00<00:00,  8.54it/s]
100%|██████████| 1/1 [00:00<00:00,  7.69it/s]
100%|██████████| 1/1 [00:00<00:00,  9.44it/s]
100%|██████████| 1/1 [00:00<00:00,  8.08it/s]
100%|██████████| 1/1 [00:00<00:00,  8.59it/s]
100%|██████████| 1/1 [00:00<00:00,  8.38it/s]
100%|██████████| 1/1 [00:00<00:00,  8.17it/s]
100%|██████████| 1/1 [00:00<00:00,  8.32it/s]
100%|██████████| 1/1 [00:00<00:00,  8.49it/s]
100%|██████████| 1/1 [00:00<00:00,  7.57it/s]
100%|██████████| 1/1 [00:00<00:00,  8.03it/s]
100%|██████████| 1/1 [00:00<00:00,  7.75it/s]
100%|██████████| 1/1 [00:00<00:00,  8.04it/s]
100%|██████████| 1/1 [00:00<00:00,  8.84it/s]
100%|██████████| 1/1 [00:00<00:00,  8.12it/s]
100%|██████████| 1/1 [00:00<00:00,  8.14it/s]
100%|██████████| 1/1 [00:00<00:00,  7.89it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 100: Loss: 0.7544183135032654


100%|██████████| 1/1 [00:00<00:00,  9.42it/s]
100%|██████████| 1/1 [00:00<00:00,  8.16it/s]
100%|██████████| 1/1 [00:00<00:00,  8.04it/s]
100%|██████████| 1/1 [00:00<00:00,  8.50it/s]
100%|██████████| 1/1 [00:00<00:00,  8.27it/s]
100%|██████████| 1/1 [00:00<00:00,  8.15it/s]
100%|██████████| 1/1 [00:00<00:00,  9.30it/s]
100%|██████████| 1/1 [00:00<00:00,  8.11it/s]
100%|██████████| 1/1 [00:00<00:00,  9.04it/s]
100%|██████████| 1/1 [00:00<00:00,  8.23it/s]
100%|██████████| 1/1 [00:00<00:00,  8.83it/s]
100%|██████████| 1/1 [00:00<00:00,  8.68it/s]
100%|██████████| 1/1 [00:00<00:00,  8.66it/s]
100%|██████████| 1/1 [00:00<00:00,  8.38it/s]
100%|██████████| 1/1 [00:00<00:00,  8.55it/s]
100%|██████████| 1/1 [00:00<00:00,  8.51it/s]
100%|██████████| 1/1 [00:00<00:00,  8.61it/s]
100%|██████████| 1/1 [00:00<00:00,  7.94it/s]
100%|██████████| 1/1 [00:00<00:00,  8.36it/s]
100%|██████████| 1/1 [00:00<00:00,  7.87it/s]
100%|██████████| 1/1 [00:00<00:00,  9.06it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 200: Loss: 0.4612196385860443


100%|██████████| 1/1 [00:00<00:00,  7.37it/s]
100%|██████████| 1/1 [00:00<00:00,  8.95it/s]
100%|██████████| 1/1 [00:00<00:00,  9.86it/s]
100%|██████████| 1/1 [00:00<00:00,  8.34it/s]
100%|██████████| 1/1 [00:00<00:00,  7.85it/s]
100%|██████████| 1/1 [00:00<00:00,  8.87it/s]
100%|██████████| 1/1 [00:00<00:00,  8.85it/s]
100%|██████████| 1/1 [00:00<00:00,  8.96it/s]
100%|██████████| 1/1 [00:00<00:00,  8.11it/s]
100%|██████████| 1/1 [00:00<00:00,  8.04it/s]
100%|██████████| 1/1 [00:00<00:00,  9.11it/s]
100%|██████████| 1/1 [00:00<00:00,  7.99it/s]
100%|██████████| 1/1 [00:00<00:00,  8.60it/s]
100%|██████████| 1/1 [00:00<00:00,  7.92it/s]
100%|██████████| 1/1 [00:00<00:00,  9.09it/s]
100%|██████████| 1/1 [00:00<00:00,  8.81it/s]
100%|██████████| 1/1 [00:00<00:00,  8.39it/s]
100%|██████████| 1/1 [00:00<00:00,  8.44it/s]
100%|██████████| 1/1 [00:00<00:00,  8.53it/s]
100%|██████████| 1/1 [00:00<00:00,  7.80it/s]
100%|██████████| 1/1 [00:00<00:00,  8.83it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 300: Loss: 0.5899916887283325


100%|██████████| 1/1 [00:00<00:00,  7.80it/s]
100%|██████████| 1/1 [00:00<00:00,  8.41it/s]
100%|██████████| 1/1 [00:00<00:00,  8.70it/s]
100%|██████████| 1/1 [00:00<00:00,  8.61it/s]
100%|██████████| 1/1 [00:00<00:00,  8.90it/s]
100%|██████████| 1/1 [00:00<00:00,  8.46it/s]
100%|██████████| 1/1 [00:00<00:00,  8.07it/s]
100%|██████████| 1/1 [00:00<00:00,  8.89it/s]
100%|██████████| 1/1 [00:00<00:00,  8.36it/s]
100%|██████████| 1/1 [00:00<00:00,  7.75it/s]
100%|██████████| 1/1 [00:00<00:00,  8.05it/s]
100%|██████████| 1/1 [00:00<00:00,  9.24it/s]
100%|██████████| 1/1 [00:00<00:00,  9.00it/s]
100%|██████████| 1/1 [00:00<00:00,  8.13it/s]
100%|██████████| 1/1 [00:00<00:00,  8.85it/s]
100%|██████████| 1/1 [00:00<00:00,  8.64it/s]
100%|██████████| 1/1 [00:00<00:00,  8.58it/s]
100%|██████████| 1/1 [00:00<00:00,  8.12it/s]
100%|██████████| 1/1 [00:00<00:00,  8.06it/s]
100%|██████████| 1/1 [00:00<00:00,  8.15it/s]
100%|██████████| 1/1 [00:00<00:00,  8.01it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 400: Loss: 0.6540932059288025


100%|██████████| 1/1 [00:00<00:00,  8.91it/s]
100%|██████████| 1/1 [00:00<00:00,  8.38it/s]
100%|██████████| 1/1 [00:00<00:00,  8.57it/s]
100%|██████████| 1/1 [00:00<00:00,  8.32it/s]
100%|██████████| 1/1 [00:00<00:00,  8.62it/s]
100%|██████████| 1/1 [00:00<00:00,  8.62it/s]
100%|██████████| 1/1 [00:00<00:00,  8.10it/s]
100%|██████████| 1/1 [00:00<00:00,  8.20it/s]
100%|██████████| 1/1 [00:00<00:00,  8.50it/s]
100%|██████████| 1/1 [00:00<00:00,  7.73it/s]
100%|██████████| 1/1 [00:00<00:00,  9.52it/s]
100%|██████████| 1/1 [00:00<00:00,  8.78it/s]
100%|██████████| 1/1 [00:00<00:00,  7.86it/s]
100%|██████████| 1/1 [00:00<00:00,  8.10it/s]
100%|██████████| 1/1 [00:00<00:00,  8.22it/s]
100%|██████████| 1/1 [00:00<00:00,  8.23it/s]
100%|██████████| 1/1 [00:00<00:00,  8.68it/s]
100%|██████████| 1/1 [00:00<00:00,  8.01it/s]
100%|██████████| 1/1 [00:00<00:00,  8.38it/s]
100%|██████████| 1/1 [00:00<00:00,  8.98it/s]
100%|██████████| 1/1 [00:00<00:00,  8.89it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 499: Loss: 0.34244999289512634
Training time: 0:01:00.142811


100%|██████████| 1/1 [00:00<00:00,  8.86it/s]


Epoch 0: Loss: 0.9478718638420105


100%|██████████| 1/1 [00:00<00:00,  8.12it/s]
100%|██████████| 1/1 [00:00<00:00,  9.24it/s]
100%|██████████| 1/1 [00:00<00:00,  7.81it/s]
100%|██████████| 1/1 [00:00<00:00,  7.37it/s]
100%|██████████| 1/1 [00:00<00:00,  9.44it/s]
100%|██████████| 1/1 [00:00<00:00,  7.98it/s]
100%|██████████| 1/1 [00:00<00:00,  8.16it/s]
100%|██████████| 1/1 [00:00<00:00,  8.84it/s]
100%|██████████| 1/1 [00:00<00:00,  8.03it/s]
100%|██████████| 1/1 [00:00<00:00,  8.83it/s]
100%|██████████| 1/1 [00:00<00:00,  9.22it/s]
100%|██████████| 1/1 [00:00<00:00,  9.09it/s]
100%|██████████| 1/1 [00:00<00:00,  8.11it/s]
100%|██████████| 1/1 [00:00<00:00,  8.50it/s]
100%|██████████| 1/1 [00:00<00:00,  7.86it/s]
100%|██████████| 1/1 [00:00<00:00,  7.79it/s]
100%|██████████| 1/1 [00:00<00:00,  8.70it/s]
100%|██████████| 1/1 [00:00<00:00,  8.63it/s]
100%|██████████| 1/1 [00:00<00:00,  8.63it/s]
100%|██████████| 1/1 [00:00<00:00,  8.70it/s]
100%|██████████| 1/1 [00:00<00:00,  8.25it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 100: Loss: 0.5567891597747803


100%|██████████| 1/1 [00:00<00:00,  8.06it/s]
100%|██████████| 1/1 [00:00<00:00,  8.60it/s]
100%|██████████| 1/1 [00:00<00:00,  9.63it/s]
100%|██████████| 1/1 [00:00<00:00,  8.00it/s]
100%|██████████| 1/1 [00:00<00:00,  8.75it/s]
100%|██████████| 1/1 [00:00<00:00,  8.60it/s]
100%|██████████| 1/1 [00:00<00:00,  8.22it/s]
100%|██████████| 1/1 [00:00<00:00,  8.12it/s]
100%|██████████| 1/1 [00:00<00:00,  7.92it/s]
100%|██████████| 1/1 [00:00<00:00,  8.39it/s]
100%|██████████| 1/1 [00:00<00:00,  7.81it/s]
100%|██████████| 1/1 [00:00<00:00,  8.32it/s]
100%|██████████| 1/1 [00:00<00:00,  9.26it/s]
100%|██████████| 1/1 [00:00<00:00,  7.99it/s]
100%|██████████| 1/1 [00:00<00:00,  7.77it/s]
100%|██████████| 1/1 [00:00<00:00,  8.89it/s]
100%|██████████| 1/1 [00:00<00:00,  9.08it/s]
100%|██████████| 1/1 [00:00<00:00,  9.36it/s]
100%|██████████| 1/1 [00:00<00:00,  8.54it/s]
100%|██████████| 1/1 [00:00<00:00,  8.12it/s]
100%|██████████| 1/1 [00:00<00:00,  9.51it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 200: Loss: 0.27848315238952637


100%|██████████| 1/1 [00:00<00:00,  8.29it/s]
100%|██████████| 1/1 [00:00<00:00,  8.48it/s]
100%|██████████| 1/1 [00:00<00:00,  7.89it/s]
100%|██████████| 1/1 [00:00<00:00,  8.32it/s]
100%|██████████| 1/1 [00:00<00:00,  8.86it/s]
100%|██████████| 1/1 [00:00<00:00,  8.13it/s]
100%|██████████| 1/1 [00:00<00:00,  8.18it/s]
100%|██████████| 1/1 [00:00<00:00,  8.04it/s]
100%|██████████| 1/1 [00:00<00:00,  8.67it/s]
100%|██████████| 1/1 [00:00<00:00,  8.64it/s]
100%|██████████| 1/1 [00:00<00:00,  8.50it/s]
100%|██████████| 1/1 [00:00<00:00,  9.07it/s]
100%|██████████| 1/1 [00:00<00:00,  9.71it/s]
100%|██████████| 1/1 [00:00<00:00,  8.30it/s]
100%|██████████| 1/1 [00:00<00:00,  7.79it/s]
100%|██████████| 1/1 [00:00<00:00,  8.14it/s]
100%|██████████| 1/1 [00:00<00:00,  8.57it/s]
100%|██████████| 1/1 [00:00<00:00,  8.06it/s]
100%|██████████| 1/1 [00:00<00:00,  7.83it/s]
100%|██████████| 1/1 [00:00<00:00,  8.98it/s]
100%|██████████| 1/1 [00:00<00:00,  8.73it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 300: Loss: 0.09334244579076767


100%|██████████| 1/1 [00:00<00:00,  7.95it/s]
100%|██████████| 1/1 [00:00<00:00,  8.48it/s]
100%|██████████| 1/1 [00:00<00:00,  8.00it/s]
100%|██████████| 1/1 [00:00<00:00,  8.10it/s]
100%|██████████| 1/1 [00:00<00:00,  7.67it/s]
100%|██████████| 1/1 [00:00<00:00,  9.13it/s]
100%|██████████| 1/1 [00:00<00:00,  7.99it/s]
100%|██████████| 1/1 [00:00<00:00,  8.73it/s]
100%|██████████| 1/1 [00:00<00:00,  8.79it/s]
100%|██████████| 1/1 [00:00<00:00,  8.11it/s]
100%|██████████| 1/1 [00:00<00:00,  7.85it/s]
100%|██████████| 1/1 [00:00<00:00,  9.06it/s]
100%|██████████| 1/1 [00:00<00:00,  7.72it/s]
100%|██████████| 1/1 [00:00<00:00,  9.06it/s]
100%|██████████| 1/1 [00:00<00:00,  8.19it/s]
100%|██████████| 1/1 [00:00<00:00,  8.01it/s]
100%|██████████| 1/1 [00:00<00:00, 10.25it/s]
100%|██████████| 1/1 [00:00<00:00,  8.74it/s]
100%|██████████| 1/1 [00:00<00:00,  8.12it/s]
100%|██████████| 1/1 [00:00<00:00,  8.53it/s]
100%|██████████| 1/1 [00:00<00:00,  9.31it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 400: Loss: 0.0866183266043663


100%|██████████| 1/1 [00:00<00:00,  7.86it/s]
100%|██████████| 1/1 [00:00<00:00,  8.11it/s]
100%|██████████| 1/1 [00:00<00:00,  8.54it/s]
100%|██████████| 1/1 [00:00<00:00,  8.04it/s]
100%|██████████| 1/1 [00:00<00:00,  8.53it/s]
100%|██████████| 1/1 [00:00<00:00,  8.25it/s]
100%|██████████| 1/1 [00:00<00:00,  9.29it/s]
100%|██████████| 1/1 [00:00<00:00,  8.14it/s]
100%|██████████| 1/1 [00:00<00:00,  8.77it/s]
100%|██████████| 1/1 [00:00<00:00,  8.18it/s]
100%|██████████| 1/1 [00:00<00:00,  8.55it/s]
100%|██████████| 1/1 [00:00<00:00,  7.84it/s]
100%|██████████| 1/1 [00:00<00:00,  7.97it/s]
100%|██████████| 1/1 [00:00<00:00,  7.93it/s]
100%|██████████| 1/1 [00:00<00:00,  8.79it/s]
100%|██████████| 1/1 [00:00<00:00,  8.09it/s]
100%|██████████| 1/1 [00:00<00:00,  8.11it/s]
100%|██████████| 1/1 [00:00<00:00,  8.64it/s]
100%|██████████| 1/1 [00:00<00:00,  8.67it/s]
100%|██████████| 1/1 [00:00<00:00,  7.95it/s]
100%|██████████| 1/1 [00:00<00:00,  9.55it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 499: Loss: 0.06362972408533096
Training time: 0:01:00.464127


100%|██████████| 1/1 [00:00<00:00,  8.11it/s]


Epoch 0: Loss: 0.6904229521751404


100%|██████████| 1/1 [00:00<00:00,  9.17it/s]
100%|██████████| 1/1 [00:00<00:00,  7.72it/s]
100%|██████████| 1/1 [00:00<00:00,  8.15it/s]
100%|██████████| 1/1 [00:00<00:00,  8.65it/s]
100%|██████████| 1/1 [00:00<00:00,  8.86it/s]
100%|██████████| 1/1 [00:00<00:00,  8.55it/s]
100%|██████████| 1/1 [00:00<00:00,  8.77it/s]
100%|██████████| 1/1 [00:00<00:00,  8.74it/s]
100%|██████████| 1/1 [00:00<00:00,  8.61it/s]
100%|██████████| 1/1 [00:00<00:00,  8.01it/s]
100%|██████████| 1/1 [00:00<00:00,  8.30it/s]
100%|██████████| 1/1 [00:00<00:00,  8.45it/s]
100%|██████████| 1/1 [00:00<00:00,  9.03it/s]
100%|██████████| 1/1 [00:00<00:00,  9.41it/s]
100%|██████████| 1/1 [00:00<00:00,  9.24it/s]
100%|██████████| 1/1 [00:00<00:00,  9.45it/s]
100%|██████████| 1/1 [00:00<00:00,  8.59it/s]
100%|██████████| 1/1 [00:00<00:00,  7.76it/s]
100%|██████████| 1/1 [00:00<00:00,  8.40it/s]
100%|██████████| 1/1 [00:00<00:00,  8.00it/s]
100%|██████████| 1/1 [00:00<00:00,  8.95it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 100: Loss: 0.6811796426773071


100%|██████████| 1/1 [00:00<00:00,  8.39it/s]
100%|██████████| 1/1 [00:00<00:00,  8.62it/s]
100%|██████████| 1/1 [00:00<00:00,  9.03it/s]
100%|██████████| 1/1 [00:00<00:00,  8.06it/s]
100%|██████████| 1/1 [00:00<00:00,  7.90it/s]
100%|██████████| 1/1 [00:00<00:00,  8.26it/s]
100%|██████████| 1/1 [00:00<00:00,  8.49it/s]
100%|██████████| 1/1 [00:00<00:00,  8.92it/s]
100%|██████████| 1/1 [00:00<00:00,  9.35it/s]
100%|██████████| 1/1 [00:00<00:00,  7.93it/s]
100%|██████████| 1/1 [00:00<00:00,  7.82it/s]
100%|██████████| 1/1 [00:00<00:00,  8.02it/s]
100%|██████████| 1/1 [00:00<00:00,  9.36it/s]
100%|██████████| 1/1 [00:00<00:00,  8.16it/s]
100%|██████████| 1/1 [00:00<00:00,  7.47it/s]
100%|██████████| 1/1 [00:00<00:00,  7.90it/s]
100%|██████████| 1/1 [00:00<00:00,  9.25it/s]
100%|██████████| 1/1 [00:00<00:00,  7.58it/s]
100%|██████████| 1/1 [00:00<00:00,  9.03it/s]
100%|██████████| 1/1 [00:00<00:00,  8.31it/s]
100%|██████████| 1/1 [00:00<00:00,  8.77it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 200: Loss: 0.5986211895942688


100%|██████████| 1/1 [00:00<00:00,  7.50it/s]
100%|██████████| 1/1 [00:00<00:00,  8.24it/s]
100%|██████████| 1/1 [00:00<00:00,  8.31it/s]
100%|██████████| 1/1 [00:00<00:00,  9.42it/s]
100%|██████████| 1/1 [00:00<00:00,  8.02it/s]
100%|██████████| 1/1 [00:00<00:00,  7.87it/s]
100%|██████████| 1/1 [00:00<00:00,  8.08it/s]
100%|██████████| 1/1 [00:00<00:00,  9.19it/s]
100%|██████████| 1/1 [00:00<00:00,  7.99it/s]
100%|██████████| 1/1 [00:00<00:00,  7.85it/s]
100%|██████████| 1/1 [00:00<00:00,  7.86it/s]
100%|██████████| 1/1 [00:00<00:00,  8.99it/s]
100%|██████████| 1/1 [00:00<00:00,  9.54it/s]
100%|██████████| 1/1 [00:00<00:00,  8.50it/s]
100%|██████████| 1/1 [00:00<00:00,  8.02it/s]
100%|██████████| 1/1 [00:00<00:00,  8.15it/s]
100%|██████████| 1/1 [00:00<00:00,  7.89it/s]
100%|██████████| 1/1 [00:00<00:00,  9.11it/s]
100%|██████████| 1/1 [00:00<00:00,  7.94it/s]
100%|██████████| 1/1 [00:00<00:00,  8.31it/s]
100%|██████████| 1/1 [00:00<00:00,  8.74it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 300: Loss: 0.530731737613678


100%|██████████| 1/1 [00:00<00:00,  8.57it/s]
100%|██████████| 1/1 [00:00<00:00,  8.54it/s]
100%|██████████| 1/1 [00:00<00:00,  7.84it/s]
100%|██████████| 1/1 [00:00<00:00,  8.15it/s]
100%|██████████| 1/1 [00:00<00:00,  7.86it/s]
100%|██████████| 1/1 [00:00<00:00,  8.85it/s]
100%|██████████| 1/1 [00:00<00:00,  8.17it/s]
100%|██████████| 1/1 [00:00<00:00,  8.04it/s]
100%|██████████| 1/1 [00:00<00:00,  8.35it/s]
100%|██████████| 1/1 [00:00<00:00,  8.45it/s]
100%|██████████| 1/1 [00:00<00:00,  7.66it/s]
100%|██████████| 1/1 [00:00<00:00,  7.94it/s]
100%|██████████| 1/1 [00:00<00:00,  8.33it/s]
100%|██████████| 1/1 [00:00<00:00,  9.23it/s]
100%|██████████| 1/1 [00:00<00:00,  8.27it/s]
100%|██████████| 1/1 [00:00<00:00,  8.78it/s]
100%|██████████| 1/1 [00:00<00:00,  7.95it/s]
100%|██████████| 1/1 [00:00<00:00,  8.06it/s]
100%|██████████| 1/1 [00:00<00:00,  9.14it/s]
100%|██████████| 1/1 [00:00<00:00,  8.60it/s]
100%|██████████| 1/1 [00:00<00:00,  8.05it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 400: Loss: 0.3925077021121979


100%|██████████| 1/1 [00:00<00:00,  8.24it/s]
100%|██████████| 1/1 [00:00<00:00,  8.02it/s]
100%|██████████| 1/1 [00:00<00:00,  8.70it/s]
100%|██████████| 1/1 [00:00<00:00,  7.94it/s]
100%|██████████| 1/1 [00:00<00:00,  7.92it/s]
100%|██████████| 1/1 [00:00<00:00,  9.12it/s]
100%|██████████| 1/1 [00:00<00:00,  8.00it/s]
100%|██████████| 1/1 [00:00<00:00,  8.12it/s]
100%|██████████| 1/1 [00:00<00:00,  8.87it/s]
100%|██████████| 1/1 [00:00<00:00,  8.54it/s]
100%|██████████| 1/1 [00:00<00:00,  8.19it/s]
100%|██████████| 1/1 [00:00<00:00,  7.78it/s]
100%|██████████| 1/1 [00:00<00:00,  8.09it/s]
100%|██████████| 1/1 [00:00<00:00,  9.26it/s]
100%|██████████| 1/1 [00:00<00:00,  7.66it/s]
100%|██████████| 1/1 [00:00<00:00,  8.22it/s]
100%|██████████| 1/1 [00:00<00:00,  7.80it/s]
100%|██████████| 1/1 [00:00<00:00,  8.22it/s]
100%|██████████| 1/1 [00:00<00:00,  7.84it/s]
100%|██████████| 1/1 [00:00<00:00,  7.88it/s]
100%|██████████| 1/1 [00:00<00:00,  8.04it/s]
100%|██████████| 1/1 [00:00<00:00,

Epoch 499: Loss: 0.4854079782962799
Training time: 0:01:00.510033


100%|██████████| 32/32 [00:00<00:00, 61.23it/s]


Epoch 0: Loss: 0.2982633113861084


100%|██████████| 32/32 [00:00<00:00, 60.81it/s]
100%|██████████| 32/32 [00:00<00:00, 62.85it/s]
100%|██████████| 32/32 [00:00<00:00, 63.67it/s]
100%|██████████| 32/32 [00:00<00:00, 60.97it/s]
100%|██████████| 32/32 [00:00<00:00, 64.35it/s]
100%|██████████| 32/32 [00:00<00:00, 62.30it/s]
100%|██████████| 32/32 [00:00<00:00, 63.37it/s]
100%|██████████| 32/32 [00:00<00:00, 60.94it/s]
100%|██████████| 32/32 [00:00<00:00, 61.41it/s]
100%|██████████| 32/32 [00:00<00:00, 62.24it/s]
100%|██████████| 32/32 [00:00<00:00, 62.20it/s]
100%|██████████| 32/32 [00:00<00:00, 63.70it/s]
100%|██████████| 32/32 [00:00<00:00, 59.97it/s]
100%|██████████| 32/32 [00:00<00:00, 57.78it/s]
100%|██████████| 32/32 [00:00<00:00, 61.08it/s]
100%|██████████| 32/32 [00:00<00:00, 62.58it/s]
100%|██████████| 32/32 [00:00<00:00, 58.63it/s]
100%|██████████| 32/32 [00:00<00:00, 60.59it/s]
100%|██████████| 32/32 [00:00<00:00, 65.45it/s]
100%|██████████| 32/32 [00:00<00:00, 60.81it/s]
100%|██████████| 32/32 [00:00<00:00, 65.

Epoch 100: Loss: 0.24200497567653656


100%|██████████| 32/32 [00:00<00:00, 59.13it/s]
100%|██████████| 32/32 [00:00<00:00, 65.80it/s]
100%|██████████| 32/32 [00:00<00:00, 58.37it/s]
100%|██████████| 32/32 [00:00<00:00, 61.69it/s]
100%|██████████| 32/32 [00:00<00:00, 58.63it/s]
100%|██████████| 32/32 [00:00<00:00, 60.39it/s]
100%|██████████| 32/32 [00:00<00:00, 59.88it/s]
100%|██████████| 32/32 [00:00<00:00, 62.33it/s]
100%|██████████| 32/32 [00:00<00:00, 63.80it/s]
100%|██████████| 32/32 [00:00<00:00, 62.51it/s]
100%|██████████| 32/32 [00:00<00:00, 64.92it/s]
100%|██████████| 32/32 [00:00<00:00, 62.21it/s]
100%|██████████| 32/32 [00:00<00:00, 62.58it/s]
100%|██████████| 32/32 [00:00<00:00, 61.05it/s]
100%|██████████| 32/32 [00:00<00:00, 63.27it/s]
100%|██████████| 32/32 [00:00<00:00, 62.73it/s]
100%|██████████| 32/32 [00:00<00:00, 62.45it/s]
100%|██████████| 32/32 [00:00<00:00, 64.24it/s]
100%|██████████| 32/32 [00:00<00:00, 62.82it/s]
100%|██████████| 32/32 [00:00<00:00, 63.23it/s]
100%|██████████| 32/32 [00:00<00:00, 60.

Epoch 200: Loss: 0.17962899804115295


100%|██████████| 32/32 [00:00<00:00, 64.24it/s]
100%|██████████| 32/32 [00:00<00:00, 55.37it/s]
100%|██████████| 32/32 [00:00<00:00, 58.39it/s]
100%|██████████| 32/32 [00:00<00:00, 60.92it/s]
100%|██████████| 32/32 [00:00<00:00, 57.31it/s]
100%|██████████| 32/32 [00:00<00:00, 61.75it/s]
100%|██████████| 32/32 [00:00<00:00, 62.26it/s]
100%|██████████| 32/32 [00:00<00:00, 66.26it/s]
100%|██████████| 32/32 [00:00<00:00, 61.03it/s]
100%|██████████| 32/32 [00:00<00:00, 58.24it/s]
100%|██████████| 32/32 [00:00<00:00, 58.99it/s]
100%|██████████| 32/32 [00:00<00:00, 64.17it/s]
100%|██████████| 32/32 [00:00<00:00, 60.40it/s]
100%|██████████| 32/32 [00:00<00:00, 57.33it/s]
100%|██████████| 32/32 [00:00<00:00, 59.58it/s]
100%|██████████| 32/32 [00:00<00:00, 61.45it/s]
100%|██████████| 32/32 [00:00<00:00, 59.61it/s]
100%|██████████| 32/32 [00:00<00:00, 63.35it/s]
100%|██████████| 32/32 [00:00<00:00, 61.36it/s]
100%|██████████| 32/32 [00:00<00:00, 58.13it/s]
100%|██████████| 32/32 [00:00<00:00, 52.

Epoch 300: Loss: 0.1937071979045868


100%|██████████| 32/32 [00:00<00:00, 65.03it/s]
100%|██████████| 32/32 [00:00<00:00, 63.50it/s]
100%|██████████| 32/32 [00:00<00:00, 59.66it/s]
100%|██████████| 32/32 [00:00<00:00, 58.49it/s]
100%|██████████| 32/32 [00:00<00:00, 63.79it/s]
100%|██████████| 32/32 [00:00<00:00, 62.17it/s]
100%|██████████| 32/32 [00:00<00:00, 63.45it/s]
100%|██████████| 32/32 [00:00<00:00, 62.24it/s]
100%|██████████| 32/32 [00:00<00:00, 65.18it/s]
100%|██████████| 32/32 [00:00<00:00, 65.19it/s]
100%|██████████| 32/32 [00:00<00:00, 57.46it/s]
100%|██████████| 32/32 [00:00<00:00, 60.04it/s]
100%|██████████| 32/32 [00:00<00:00, 60.29it/s]
100%|██████████| 32/32 [00:00<00:00, 62.51it/s]
100%|██████████| 32/32 [00:00<00:00, 60.52it/s]
100%|██████████| 32/32 [00:00<00:00, 62.04it/s]
100%|██████████| 32/32 [00:00<00:00, 59.01it/s]
100%|██████████| 32/32 [00:00<00:00, 59.11it/s]
100%|██████████| 32/32 [00:00<00:00, 64.03it/s]
100%|██████████| 32/32 [00:00<00:00, 58.89it/s]
100%|██████████| 32/32 [00:00<00:00, 56.

Epoch 400: Loss: 0.19047389924526215


100%|██████████| 32/32 [00:00<00:00, 63.72it/s]
100%|██████████| 32/32 [00:00<00:00, 61.40it/s]
100%|██████████| 32/32 [00:00<00:00, 62.07it/s]
100%|██████████| 32/32 [00:00<00:00, 61.30it/s]
100%|██████████| 32/32 [00:00<00:00, 62.11it/s]
100%|██████████| 32/32 [00:00<00:00, 59.35it/s]
100%|██████████| 32/32 [00:00<00:00, 63.39it/s]
100%|██████████| 32/32 [00:00<00:00, 58.48it/s]
100%|██████████| 32/32 [00:00<00:00, 61.39it/s]
100%|██████████| 32/32 [00:00<00:00, 60.50it/s]
100%|██████████| 32/32 [00:00<00:00, 62.87it/s]
100%|██████████| 32/32 [00:00<00:00, 58.03it/s]
100%|██████████| 32/32 [00:00<00:00, 56.16it/s]
100%|██████████| 32/32 [00:00<00:00, 57.75it/s]
100%|██████████| 32/32 [00:00<00:00, 51.46it/s]
100%|██████████| 32/32 [00:00<00:00, 61.14it/s]
100%|██████████| 32/32 [00:00<00:00, 63.71it/s]
100%|██████████| 32/32 [00:00<00:00, 57.18it/s]
100%|██████████| 32/32 [00:00<00:00, 61.87it/s]
100%|██████████| 32/32 [00:00<00:00, 60.18it/s]
100%|██████████| 32/32 [00:00<00:00, 64.

Epoch 499: Loss: 0.2118770182132721
Training time: 0:04:18.014858





In [14]:
# save model using torch.save() and save it to a directory
directory = 'model'
model_name = 'AAPL_model.pt'
if not os.path.exists(directory):
     os.makedirs(directory)
torch.save(model.state_dict(), os.path.join(directory, model_name))

