In [1]:
# this notebook will use a basic GPT based decision transformer in offline reinforcement learning setting to create bot for trading stock
# get cuda device
# import libraries
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from cust_transf import DecisionTransformer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from datasets.load import load_dataset
from torch.utils.data import Dataset, DataLoader
import numpy as np

# utility function to compute the discounted cumulative sum of a vector
def discount_cumsum(x, gamma):
    disc_cumsum = np.zeros_like(x)
    disc_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        disc_cumsum[t] = x[t] + gamma * disc_cumsum[t+1]
    return disc_cumsum

# define a custom dataset class which loads the data, modifies the reward to be the discounted cumulative sum and apply trajectory masking
class CustomTrajDataset(Dataset):
    def __init__(self, file_name, context_len, gamma, rtg_scale):
        self.gamma = gamma
        self.context_len = context_len

        # load the data
        data = load_dataset("json", data_files = file_name, field = 'data')
        self.data_state = np.array(data['train']['state'], dtype=np.float32)
        self.data_action = np.array(data['train']['action'], dtype=np.float32)
        self.rtg = np.array(data['train']['reward'], dtype=np.float32)


        # calculate min len, the mean and std of the state and rtg for all data
        self.stateshape = self.data_state.shape
        # calculate mean of state and rtg with numpy
        self.state_mean = np.mean(self.data_state, axis=(-2,-1), keepdims=True)
        self.state_std = np.std(np.abs(self.data_state), axis=(-2,-1), keepdims=True)
        #self.state_mean = torch.mean(data['train']['state'], dim=(-2,-1), keepdim=True)
        #self.state_std = torch.std(data['train']['state'], dim=(-2,-1), keepdim=True)
        self.norm_state = (self.data_state - self.state_mean) / self.state_std

        self.rtg = np.apply_along_axis(discount_cumsum, 1, data['train']['reward'], self.gamma) # type: ignore
        self.rtg = self.rtg / rtg_scale

    def get_state_stats(self):
        return self.state_mean, self.state_std        

    def __len__(self):
        return self.stateshape[0]

    def __getitem__(self, idx):
        state = self.norm_state[idx]
        action = self.data_action[idx]
        rtg = self.rtg[idx]

        data_len = state.shape[0]
        
        if data_len > self.context_len:
            # sample random start index
            start_idx = np.random.randint(0, data_len - self.context_len)
            # slice the data and convert to torch
            state = torch.from_numpy(state[start_idx:start_idx+self.context_len])
            action = torch.from_numpy(action[start_idx:start_idx+self.context_len])
            rtg = torch.from_numpy(rtg[start_idx:start_idx+self.context_len])
            timesteps = torch.arange(start=start_idx, end=start_idx + self.context_len, step=1)
            # trajectory mask
            mask = torch.ones(self.context_len, dtype=torch.long)
        else:
            padding_len = self.context_len - data_len

            # pad the data with zeros
            state = torch.from_numpy(state)
            state = torch.cat([state, torch.zeros((padding_len, *state.shape[1:]))], dim=0)

            action = torch.from_numpy(action)
            action = torch.cat([action, torch.zeros((padding_len, *action.shape[1:]))], dim=0)

            rtg = torch.from_numpy(rtg)
            rtg = torch.cat([rtg, torch.zeros((padding_len, *rtg.shape[1:]))], dim=0)

            timesteps = torch.arange(start=0, end=self.context_len, step=1)

            # trajectory mask
            mask = torch.cat([torch.ones(data_len, dtype=torch.long), torch.zeros(padding_len, dtype=torch.long)], dim=0)
        
        return state, action, rtg, timesteps, mask


In [3]:
# load huggingface dataset from json file in replaybuffer folder
foldername = 'replaybuffer'

# get filenames in folder
import os
filenames = os.listdir(foldername)

# get full path of files
full_filenames = [os.path.join(foldername, filename) for filename in filenames]

# create datasets and store in list from the list of filenames 
context_len = 30
Max_balance = 2147483647
gamma = 0.99

datasets = []
for name in full_filenames:
    dataset = CustomTrajDataset(name, context_len, gamma, Max_balance)
    datasets.append(dataset)

# concatenate all datasets
combined_dataset = torch.utils.data.ConcatDataset(datasets)

Found cached dataset json (/home/victoru/.cache/huggingface/datasets/json/default-9b529c6ea512e6e3/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
100%|██████████| 1/1 [00:00<00:00, 95.75it/s]
Found cached dataset json (/home/victoru/.cache/huggingface/datasets/json/default-0bf8fa05208b6986/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
100%|██████████| 1/1 [00:00<00:00, 223.32it/s]


Downloading and preparing dataset json/default to /home/victoru/.cache/huggingface/datasets/json/default-3f4820e3e23aa5e3/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 4253.86it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 51.01it/s]
                                                               

Dataset json downloaded and prepared to /home/victoru/.cache/huggingface/datasets/json/default-3f4820e3e23aa5e3/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 1036.91it/s]


Downloading and preparing dataset json/default to /home/victoru/.cache/huggingface/datasets/json/default-5eb96d451a6d6f10/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 4877.10it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 50.17it/s]
                                                               

Dataset json downloaded and prepared to /home/victoru/.cache/huggingface/datasets/json/default-5eb96d451a6d6f10/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 1074.09it/s]


Downloading and preparing dataset json/default to /home/victoru/.cache/huggingface/datasets/json/default-4d259c4d25298c05/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 4219.62it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 128.04it/s]
                                                               

Dataset json downloaded and prepared to /home/victoru/.cache/huggingface/datasets/json/default-4d259c4d25298c05/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 1070.52it/s]


Downloading and preparing dataset json/default to /home/victoru/.cache/huggingface/datasets/json/default-777faf1abc5243b4/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 5629.94it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 43.53it/s]
                                                               

Dataset json downloaded and prepared to /home/victoru/.cache/huggingface/datasets/json/default-777faf1abc5243b4/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 1079.06it/s]


Downloading and preparing dataset json/default to /home/victoru/.cache/huggingface/datasets/json/default-9046201c9bb75691/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 4782.56it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 52.31it/s]
                                                               

Dataset json downloaded and prepared to /home/victoru/.cache/huggingface/datasets/json/default-9046201c9bb75691/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 1020.02it/s]


Downloading and preparing dataset json/default to /home/victoru/.cache/huggingface/datasets/json/default-01130f1bcc3f6c1d/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 12409.18it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 1080.45it/s]
                                                               

Dataset json downloaded and prepared to /home/victoru/.cache/huggingface/datasets/json/default-01130f1bcc3f6c1d/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 1056.50it/s]


Downloading and preparing dataset json/default to /home/victoru/.cache/huggingface/datasets/json/default-73cf32995c39088a/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 5102.56it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 31.07it/s]
                                                               

Dataset json downloaded and prepared to /home/victoru/.cache/huggingface/datasets/json/default-73cf32995c39088a/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 1062.66it/s]


Downloading and preparing dataset json/default to /home/victoru/.cache/huggingface/datasets/json/default-58e8dd303af5c481/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 5841.65it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 85.20it/s]
                                                               

Dataset json downloaded and prepared to /home/victoru/.cache/huggingface/datasets/json/default-58e8dd303af5c481/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 914.19it/s]


Downloading and preparing dataset json/default to /home/victoru/.cache/huggingface/datasets/json/default-5db6b286d56f7818/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 5793.24it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 1083.80it/s]
                                                               

Dataset json downloaded and prepared to /home/victoru/.cache/huggingface/datasets/json/default-5db6b286d56f7818/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 972.93it/s]


Downloading and preparing dataset json/default to /home/victoru/.cache/huggingface/datasets/json/default-48256180deb3fa0c/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 6061.13it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 92.30it/s]
                                                               

Dataset json downloaded and prepared to /home/victoru/.cache/huggingface/datasets/json/default-48256180deb3fa0c/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 1047.01it/s]


Downloading and preparing dataset json/default to /home/victoru/.cache/huggingface/datasets/json/default-cb7b69133e4279f9/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 8355.19it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 136.31it/s]
                                                               

Dataset json downloaded and prepared to /home/victoru/.cache/huggingface/datasets/json/default-cb7b69133e4279f9/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 1084.64it/s]


Downloading and preparing dataset json/default to /home/victoru/.cache/huggingface/datasets/json/default-d7855d2ea3173df3/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4...


Downloading data files: 100%|██████████| 1/1 [00:00<00:00, 2551.28it/s]
Extracting data files: 100%|██████████| 1/1 [00:00<00:00, 68.08it/s]
                                                               

Dataset json downloaded and prepared to /home/victoru/.cache/huggingface/datasets/json/default-d7855d2ea3173df3/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4. Subsequent calls will reuse this data.


100%|██████████| 1/1 [00:00<00:00, 1017.54it/s]


In [6]:
# loop through the dataset and find the highest and lowest reward
max_reward = -math.inf
min_reward = math.inf
for dataset in datasets:
    max_reward = max(max_reward, dataset.rtg.max())
    min_reward = min(min_reward, dataset.rtg.min())

print("max reward: ", max_reward)
print("min reward: ", min_reward)

max reward:  0.002264376001618792
min reward:  -9.086438618410393e-06


In [9]:
# define training parameters
batch_size = 32
# small learning rate to try to avoid mixed precision caused NaNs
lr = 3e-5
wt_decay = 1e-4
warmup_steps = 10000
n_epochs = 500

In [10]:
# create dataloader from the concatenated dataset
dataloader = DataLoader(combined_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [11]:
# define model parameters
# sample 1 batch from dataloader
norm_state, actions, rtg, timestep, traj_mask = next(iter(dataloader))
# use batch shape to determine state dimension
state_dim = norm_state.shape[-1]
act_dim = actions.shape[-1] # discrete action space
# use batch shape to determine context length


n_blocks = 4 # number of transformer blocks
h_dim = 96 # hidden dimension
n_heads = 8 # number of heads in multi-head attention
drop_p = 0.1 # dropout probability


In [12]:
# create the model
model = DecisionTransformer(state_dim, act_dim, n_blocks, h_dim, context_len, n_heads, drop_p).to(device)

# create optimizer
# use larger eps to try to avoid mixed precision overflow caused NaNs
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wt_decay, eps=1e-6)

# create scheduler
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: min(1.0, (step + 1) / warmup_steps))

# create a GradScaler for mixed precision training
scaler = torch.cuda.amp.GradScaler(growth_interval=150)
min_scale = 128

In [13]:
# get the model parameters size
n_params = sum([p.numel() for p in model.parameters() if p.requires_grad])
print(f"Number of parameters: {n_params}")

Number of parameters: 844144


In [8]:
# test run the model
with torch.no_grad():
    norm_state, actions, rtg, timestep , traj_mask= next(iter(dataloader))
    norm_state = norm_state.to(device)
    actions = actions.to(device)
    # convert rtg to float
    rtg = rtg.to(device).float()
    timestep = timestep.to(device)
    traj_mask = traj_mask.to(device)
    action_targets = torch.clone(actions).detach().to(device)
    return_preds, state_preds, act_preds = model.forward(norm_state, rtg, timestep, actions)

    # check shape of norm_state
    print(f"shape norm_state: {norm_state.shape}")
    # check shape of rtg
    print(f"shape rtg: {rtg.shape}")
    # check shape of timestep
    print(f"shape timestep: {timestep.shape}")
    # check shape of actions
    print(f"shape actions: {actions.shape}")
    print(f"shape act_preds: {act_preds.shape}")
    print(f"shape action_targets: {action_targets.shape}")
    
    # consider only the action that are padded
    act_preds = act_preds.view(-1, act_dim)[traj_mask.view(-1) > 0]
    action_targets = action_targets.view(-1, act_dim)[traj_mask.view(-1) > 0]

    # check shape of action targets
    print(action_targets.shape)
    # check shape of action predictions
    print(act_preds.shape)

# check for nan values and inf values in the input and the output of the model
print(torch.isnan(norm_state).any())
print(torch.isnan(rtg).any())
print(torch.isnan(timestep).any())
print(torch.isnan(actions).any())
print(torch.isnan(act_preds).any())
print(torch.isnan(action_targets).any())



shape norm_state: torch.Size([32, 20, 13])
shape rtg: torch.Size([32, 20, 1])
shape timestep: torch.Size([32, 20])
shape actions: torch.Size([32, 20, 2])
shape act_preds: torch.Size([32, 20, 2])
shape action_targets: torch.Size([32, 20, 2])
torch.Size([640, 2])
torch.Size([640, 2])
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')
tensor(False, device='cuda:0')


In [14]:
from tqdm import tqdm

# get the start time to calculate training time
import datetime

# custom training function which take in the model, dataset, optimizer, scheduler, scaler, n_epochs, min_scale
def train_model(model, dataloader, optimizer, scheduler, scaler, n_epochs, min_scale):

    # record the start time
    start_time = datetime.datetime.now()

    # define training parameters
    log_action_losses = []

    # train model
    for epoch in range(n_epochs):
        model.train()
        

        for norm_state, actions, rtg, timestep, traj_mask in tqdm(dataloader):
            # get batch data to device
            norm_state = norm_state.to(device)
            actions = actions.to(device)
            rtg = rtg.to(device).float()
            timestep = timestep.to(device)
            traj_mask = traj_mask.to(device)

            action_targets = torch.clone(actions).detach().to(device)

            # Zeroes out the gradients
            optimizer.zero_grad()

            # run forward pass with autocasting
            # disable autocasting for now to avoid mixed precision caused NaNs
            with torch.cuda.amp.autocast(enabled=False):
                _, _, act_preds = model.forward(norm_state, rtg, timestep, actions)

                # consider only the action that are padded
                act_preds = act_preds.view(-1, act_dim)[traj_mask.view(-1) > 0]
                action_targets = action_targets.view(-1, act_dim)[traj_mask.view(-1) > 0]

                # calculate losses just for actions
                loss = F.mse_loss(act_preds, action_targets, reduction='mean')

            # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
            scaler.scale(loss).backward()

            # unscale the gradients
            scaler.unscale_(optimizer)
            # Clips the gradients by norm
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)

            # scaler.step() first unscales the gradients of the optimizer's assigned params.
            # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
            # otherwise, optimizer.step() is skipped.
            scaler.step(optimizer)

            # Updates the learning rate according to the scheduler
            scheduler.step()
            # Updates the scale for next iteration.
            scaler.update()
            # enforce min scale to avoid mixed precision caused NaNs
            if scaler.get_scale() < min_scale:
                scaler._scale = torch.tensor(min_scale).to(scaler._scale)
        
            # append action loss to log
            log_action_losses.append(loss.detach().cpu().item())

        # print every 10 loss log
        if epoch % 100 == 0 or epoch == n_epochs - 1:
            print(f'Epoch {epoch}: Loss: {log_action_losses[-1]}')

    # record the end time
    end_time = datetime.datetime.now()
    print(f'Training time: {end_time - start_time}')
    
    return model, log_action_losses

In [15]:
# train model on each dataloader and store log_action_losses in a list
_, log_action_losses = train_model(model, dataloader, optimizer, scheduler, scaler, n_epochs, min_scale)


100%|██████████| 122/122 [00:02<00:00, 43.07it/s]


Epoch 0: Loss: 0.9919530749320984


100%|██████████| 122/122 [00:01<00:00, 68.32it/s]
100%|██████████| 122/122 [00:01<00:00, 69.30it/s]
100%|██████████| 122/122 [00:01<00:00, 70.27it/s]
100%|██████████| 122/122 [00:01<00:00, 66.88it/s]
100%|██████████| 122/122 [00:01<00:00, 69.68it/s]
100%|██████████| 122/122 [00:01<00:00, 69.71it/s]
100%|██████████| 122/122 [00:01<00:00, 70.85it/s]
100%|██████████| 122/122 [00:01<00:00, 70.79it/s]
100%|██████████| 122/122 [00:01<00:00, 69.13it/s]
100%|██████████| 122/122 [00:01<00:00, 69.14it/s]
100%|██████████| 122/122 [00:01<00:00, 70.69it/s]
100%|██████████| 122/122 [00:01<00:00, 70.22it/s]
100%|██████████| 122/122 [00:01<00:00, 71.75it/s]
100%|██████████| 122/122 [00:01<00:00, 68.83it/s]
100%|██████████| 122/122 [00:01<00:00, 69.46it/s]
100%|██████████| 122/122 [00:01<00:00, 70.32it/s]
100%|██████████| 122/122 [00:01<00:00, 70.56it/s]
100%|██████████| 122/122 [00:01<00:00, 70.33it/s]
100%|██████████| 122/122 [00:01<00:00, 70.81it/s]
100%|██████████| 122/122 [00:01<00:00, 69.43it/s]


Epoch 100: Loss: 0.16668927669525146


100%|██████████| 122/122 [00:01<00:00, 66.28it/s]
100%|██████████| 122/122 [00:01<00:00, 68.53it/s]
100%|██████████| 122/122 [00:01<00:00, 67.66it/s]
100%|██████████| 122/122 [00:01<00:00, 68.84it/s]
100%|██████████| 122/122 [00:01<00:00, 68.73it/s]
100%|██████████| 122/122 [00:01<00:00, 68.65it/s]
100%|██████████| 122/122 [00:01<00:00, 71.10it/s]
100%|██████████| 122/122 [00:01<00:00, 69.07it/s]
100%|██████████| 122/122 [00:01<00:00, 68.68it/s]
100%|██████████| 122/122 [00:01<00:00, 68.58it/s]
100%|██████████| 122/122 [00:01<00:00, 68.76it/s]
100%|██████████| 122/122 [00:01<00:00, 69.26it/s]
100%|██████████| 122/122 [00:01<00:00, 68.40it/s]
100%|██████████| 122/122 [00:01<00:00, 68.45it/s]
100%|██████████| 122/122 [00:01<00:00, 69.45it/s]
100%|██████████| 122/122 [00:01<00:00, 69.32it/s]
100%|██████████| 122/122 [00:01<00:00, 69.77it/s]
100%|██████████| 122/122 [00:01<00:00, 68.51it/s]
100%|██████████| 122/122 [00:01<00:00, 67.51it/s]
100%|██████████| 122/122 [00:01<00:00, 67.93it/s]


Epoch 200: Loss: 0.23529082536697388


100%|██████████| 122/122 [00:01<00:00, 83.20it/s]
100%|██████████| 122/122 [00:01<00:00, 85.30it/s]
100%|██████████| 122/122 [00:01<00:00, 84.52it/s]
100%|██████████| 122/122 [00:01<00:00, 85.06it/s]
100%|██████████| 122/122 [00:01<00:00, 84.81it/s]
100%|██████████| 122/122 [00:01<00:00, 84.15it/s]
100%|██████████| 122/122 [00:01<00:00, 82.87it/s]
100%|██████████| 122/122 [00:01<00:00, 85.43it/s]
100%|██████████| 122/122 [00:01<00:00, 85.96it/s]
100%|██████████| 122/122 [00:01<00:00, 87.99it/s]
100%|██████████| 122/122 [00:01<00:00, 85.80it/s]
100%|██████████| 122/122 [00:01<00:00, 87.66it/s]
100%|██████████| 122/122 [00:01<00:00, 87.29it/s]
100%|██████████| 122/122 [00:01<00:00, 84.93it/s]
100%|██████████| 122/122 [00:01<00:00, 84.55it/s]
100%|██████████| 122/122 [00:01<00:00, 86.39it/s]
100%|██████████| 122/122 [00:01<00:00, 87.07it/s]
100%|██████████| 122/122 [00:01<00:00, 86.03it/s]
100%|██████████| 122/122 [00:01<00:00, 85.89it/s]
100%|██████████| 122/122 [00:01<00:00, 85.88it/s]


Epoch 300: Loss: 0.19295398890972137


100%|██████████| 122/122 [00:01<00:00, 83.57it/s]
100%|██████████| 122/122 [00:01<00:00, 84.12it/s]
100%|██████████| 122/122 [00:01<00:00, 83.94it/s]
100%|██████████| 122/122 [00:01<00:00, 84.06it/s]
100%|██████████| 122/122 [00:01<00:00, 84.59it/s]
100%|██████████| 122/122 [00:01<00:00, 84.25it/s]
100%|██████████| 122/122 [00:01<00:00, 84.12it/s]
100%|██████████| 122/122 [00:01<00:00, 85.81it/s]
100%|██████████| 122/122 [00:01<00:00, 87.16it/s]
100%|██████████| 122/122 [00:01<00:00, 86.96it/s]
100%|██████████| 122/122 [00:01<00:00, 87.81it/s]
100%|██████████| 122/122 [00:01<00:00, 87.97it/s]
100%|██████████| 122/122 [00:01<00:00, 84.52it/s]
100%|██████████| 122/122 [00:01<00:00, 87.26it/s]
100%|██████████| 122/122 [00:01<00:00, 87.14it/s]
100%|██████████| 122/122 [00:01<00:00, 86.64it/s]
100%|██████████| 122/122 [00:01<00:00, 87.39it/s]
100%|██████████| 122/122 [00:01<00:00, 86.19it/s]
100%|██████████| 122/122 [00:01<00:00, 88.16it/s]
100%|██████████| 122/122 [00:01<00:00, 86.87it/s]


Epoch 400: Loss: 0.22467394173145294


100%|██████████| 122/122 [00:01<00:00, 83.02it/s]
100%|██████████| 122/122 [00:01<00:00, 84.52it/s]
100%|██████████| 122/122 [00:01<00:00, 85.01it/s]
100%|██████████| 122/122 [00:01<00:00, 85.15it/s]
100%|██████████| 122/122 [00:01<00:00, 82.58it/s]
100%|██████████| 122/122 [00:01<00:00, 84.84it/s]
100%|██████████| 122/122 [00:01<00:00, 82.57it/s]
100%|██████████| 122/122 [00:01<00:00, 86.89it/s]
100%|██████████| 122/122 [00:01<00:00, 87.29it/s]
100%|██████████| 122/122 [00:01<00:00, 83.60it/s]
100%|██████████| 122/122 [00:01<00:00, 88.70it/s]
100%|██████████| 122/122 [00:01<00:00, 87.11it/s]
100%|██████████| 122/122 [00:01<00:00, 88.02it/s]
100%|██████████| 122/122 [00:01<00:00, 84.40it/s]
100%|██████████| 122/122 [00:01<00:00, 85.94it/s]
100%|██████████| 122/122 [00:01<00:00, 88.63it/s]
100%|██████████| 122/122 [00:01<00:00, 85.86it/s]
100%|██████████| 122/122 [00:01<00:00, 88.25it/s]
100%|██████████| 122/122 [00:01<00:00, 85.76it/s]
100%|██████████| 122/122 [00:01<00:00, 87.56it/s]


Epoch 499: Loss: 0.23923668265342712
Training time: 0:12:41.472033





In [16]:
# save model using torch.save() and save it to a directory
directory = 'model'
model_name = 'AAPL_model.pt'
if not os.path.exists(directory):
     os.makedirs(directory)
torch.save(model.state_dict(), os.path.join(directory, model_name))

