## imports

In [1]:
import os
import sys
sys.path.append('../')
from glob import glob
import torch
import numpy as np
import pandas as pd

## get and split data

In [2]:
from spug.dataset import DatasetGenerator

data_root = '../data'

sotck_path = os.path.join(
        data_root, 'raw', 'stock', 'raw.csv'
)

sec_path = os.path.join(
        data_root, 'raw', 'sec'
)

output_path = os.path.join(
        data_root, 'processed'
)

data_list = sorted(list(set(glob(
            os.path.join(
                data_root, 'raw', 'stock/price', '*.npy'
            )
        )) - set(glob(
            os.path.join(
                data_root, 'raw', 'stock/price', '*q*.npy'
            )
        ))))
dg = DatasetGenerator(
    data_list = data_list,
    stock_path=sotck_path,
    sec_path=sec_path,
    freq='month'
)

## model definition

In [3]:
from spug.model import RecurrentGCN

## model training

In [4]:
import argparse
from torch_geometric_temporal.signal import temporal_signal_split
from spug.utils import Trainer

dataset = dg.process()
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)


INPUT_SHAPE = next(iter(train_dataset)).x.shape[1]
model = RecurrentGCN(input_size = INPUT_SHAPE, hidden_dims=64)
args = argparse.Namespace(
    num_epochs = 500,
    learning_rate = 1e-3,
    device = "cpu",
    val_size = .1,
    verbose = False
)

In [5]:
trainer = Trainer(model, train_dataset, args, test_dataset)

In [6]:
model = trainer.train()

100%|███████████████████████████████████████████████████████████████████████████| 500/500 [11:18<00:00,  1.36s/it]



            best model loss is:
                val loss: 0.3272077739238739 @ epoch: 431
            
ROC AUC score 0.9787234042553192
              precision    recall  f1-score   support

           0       1.00      0.96      0.98        47
           1       0.98      1.00      0.99        98

    accuracy                           0.99       145
   macro avg       0.99      0.98      0.98       145
weighted avg       0.99      0.99      0.99       145

ROC AUC score 0.6058558558558558
              precision    recall  f1-score   support

           0       0.50      0.50      0.50       126
           1       0.71      0.71      0.71       222

    accuracy                           0.64       348
   macro avg       0.61      0.61      0.61       348
weighted avg       0.64      0.64      0.64       348

