## imports

In [1]:
import os
import sys
sys.path.append('../')
from glob import glob
import torch
import numpy as np
import pandas as pd

## get and split data

In [2]:
from spug.dataset import DatasetGenerator

data_root = '../data'

sotck_path = os.path.join(
        data_root, 'raw', 'stock', 'raw.csv'
)

sec_path = os.path.join(
        data_root, 'raw', 'sec'
)

output_path = os.path.join(
        data_root, 'processed'
)

data_list = sorted(list(set(glob(
            os.path.join(
                data_root, 'raw', 'stock/price', '*.npy'
            )
        )) - set(glob(
            os.path.join(
                data_root, 'raw', 'stock/price', '*q*.npy'
            )
        ))))
dg = DatasetGenerator(
    data_list = data_list,
    stock_path=sotck_path,
    sec_path=sec_path,
    freq='month'
)

## model definition

In [3]:
from spug.model import GCN

## model training

In [4]:
import argparse
from torch_geometric_temporal.signal import temporal_signal_split
from spug.utils import Trainer

dataset = dg.process()
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)


INPUT_SHAPE = next(iter(train_dataset)).x.shape[1]
model = GCN(input_size = INPUT_SHAPE, hidden_dims=64)
args = argparse.Namespace(
    num_epochs = 500,
    learning_rate = 1e-3,
    device = "cpu",
    val_size = .1,
    verbose = False
)

In [5]:
trainer = Trainer(model, train_dataset, args, test_dataset)

In [6]:
model = trainer.train()

100%|███████████████████████████████████████████████████████████████████████████| 500/500 [00:26<00:00, 19.00it/s]


            best model loss is:
                val loss: 0.4567108690738678 @ epoch: 401
            
ROC AUC score 0.8162179765523231
              precision    recall  f1-score   support

           0       0.76      0.74      0.75        47
           1       0.88      0.89      0.88        98

    accuracy                           0.84       145
   macro avg       0.82      0.82      0.82       145
weighted avg       0.84      0.84      0.84       145

ROC AUC score 0.5686400686400686
              precision    recall  f1-score   support

           0       0.42      0.65      0.51       126
           1       0.71      0.49      0.58       222

    accuracy                           0.55       348
   macro avg       0.56      0.57      0.54       348
weighted avg       0.60      0.55      0.55       348




