In [None]:
%matplotlib inline
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import sys
import os

import warnings
import logging
logging.disable(logging.CRITICAL)
from tqdm.autonotebook import tqdm

import torch
from torch.nn import MSELoss, CrossEntropyLoss
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor, ModelCheckpoint
import pytorch_lightning as pl

import optuna
from optuna.integration import PyTorchLightningPruningCallback

from darts import TimeSeries, concatenate
from darts.models import TFTModel
from darts.dataprocessing.transformers import Scaler
from darts.metrics import smape, rmse, r2_score

from sklearn.model_selection import train_test_split

In [None]:
import jupyternotify
ip = get_ipython()
ip.register_magics(jupyternotify.JupyterNotifyMagics)

In [None]:
print(torch.backends.mps.is_available())

In [None]:
%run ../utils/preprocessing.ipynb
%run ../utils/losses.ipynb

In [None]:
# mps_device = torch.device("mps")

# if torch.backends.mps.is_available():
#     mps_device = torch.device("mps")
#     accelerator="mps"
# else:
#     print ("MPS device not found.")
#     accelerator="cpu"

In [None]:
accelerator="cpu"

In [None]:
stock_tickers_all = ['QCOM', 'NVDA', 'AMZN', 'MSFT', 'GOOG', 'TSLA', 'AMD', 'INTC', 'NFLX', 'BAC', 'WFC', 'GS', 'MA', 'SQ', 'PYPL']

selected_tickers = stock_tickers_all[11:]
selected_tickers

In [None]:
#LOAD = False         # True = load previously saved model from disk?  False = (re)train the model
#SAVE = "\_TFT_model_01.pth.tar"   # file name to save the model under

FREQ_INT = 5
DATA_FREQUENCY = minute_frequencies_conventions[FREQ_INT]

EPOCHS = 100
BATCH = 64         

INPUT_LEN = 70 if FREQ_INT == 15 else 270
OUTPUT_LEN = 1

HIDDEN = 32 if FREQ_INT == 15 else 32
HIDDEN_CONTINUOUS_SIZE = 12 if FREQ_INT == 15 else 12 
LSTMLAYERS = 3 if FREQ_INT == 15 else 2
ATTH = 2 if FREQ_INT == 15 else 6

DROPOUT = 0.1 if FREQ_INT == 15 else 0.05

LR = 1e-3  

LOSS=MSELoss()

VALWAIT = 1         # epochs to wait before evaluating the loss on the test/validation set

RAND_SEED = 42           

FIGSIZE = (9, 6)

In [None]:
%%notify
for ticker in selected_tickers:
    
    print('training: ', ticker)

    early_stop_callback = EarlyStopping(monitor="val_loss", min_delta=1e-5, patience=10, verbose=True, mode="min")
    lr_logger = LearningRateMonitor(logging_interval='step')
    loss_logger = LossLogger()
    
    MODEL_NAME = f'{ticker}_{FREQ_INT}_TFT'
    
    checkpoint_callback = ModelCheckpoint(
        dirpath=f'./saved_models/{FREQ_INT}/{MODEL_NAME}/checkpoints/',
        filename='best-{epoch:02d}-{val_loss:.7f}',
        monitor='val_loss',
        mode='min',
        save_top_k=1,
        every_n_epochs=1
    )
        
    callbacks = [early_stop_callback, loss_logger, checkpoint_callback]
    
    model = TFTModel(   model_name=MODEL_NAME,
                        input_chunk_length=INPUT_LEN,
                        output_chunk_length=OUTPUT_LEN,
                        hidden_size=HIDDEN,
                        lstm_layers=LSTMLAYERS,
                        num_attention_heads=ATTH,
                        hidden_continuous_size=HIDDEN_CONTINUOUS_SIZE,
                        batch_size=BATCH,
                        dropout=DROPOUT,
                        n_epochs=EPOCHS,
                        nr_epochs_val_period=VALWAIT,
                        optimizer_kwargs={"lr": LR},
                        loss_fn=LOSS,
                        #log_tensorboard=True,
                        random_state=RAND_SEED,
                        pl_trainer_kwargs={
                            'accelerator': accelerator,
                            'devices': 1,
                            'callbacks': callbacks,
                            'log_every_n_steps': 10,
                            'enable_checkpointing': 'true'
                        },
                        force_reset=True,
                        add_relative_index=True,
                        #save_checkpoints=True
                    )

    stock = load_stock_data(f'../data/resampled_data/{DATA_FREQUENCY}/{ticker}_resampled_{DATA_FREQUENCY}.csv', FREQ_INT)
    X_y_df = separate(stock)
    splits = split_data(**X_y_df)
    ts_splits = transform_splits_to_time_series(**splits)
    ts_full = transform_to_time_series(**X_y_df)
    scaled_splits_data = scale_splits_data(**ts_splits)
    scaled_full_data = scale_full_data(ts_full['ts_X_full'], ts_full['ts_y_full'], scaled_splits_data['scaler_X'], scaled_splits_data['scaler_y'])

    stock_full = {
        "ticker": ticker,
        "stock": stock,
        "splits": splits,
        "ts_splits": ts_splits,
        "ts_full": ts_full,
        "scaled_splits_data": scaled_splits_data,
        "scaled_full_data": scaled_full_data
    }
    
    raw_stock = stock_full['stock']
    
    
    history = model.fit(
        series              =    stock_full['scaled_splits_data']['scaled_y_train'],
        past_covariates     =    stock_full['scaled_splits_data']['scaled_X_train'],
        val_series          =    stock_full['scaled_splits_data']['scaled_y_val'],
        val_past_covariates =    stock_full['scaled_splits_data']['scaled_X_val'],
        num_loader_workers  =    2,
        verbose             =    True
    )

    plot_loss(loss_logger)

    print('saving the model ', MODEL_NAME)
    
    model.save(f'./saved_models/{FREQ_INT}/{MODEL_NAME}/_model.pth.tar')