## imports

In [1]:
import os
import sys
sys.path.append('../')
from glob import glob
import torch
import numpy as np
import pandas as pd

## get and split data

In [2]:
from apulu.dataset import DatasetGenerator

data_root = '../data'

sotck_path = os.path.join(
        data_root, 'raw', 'stock', 'raw.csv'
)

sec_path = os.path.join(
        data_root, 'raw', 'sec'
)

output_path = os.path.join(
        data_root, 'processed'
)

data_list = sorted(glob(
    os.path.join(
        data_root, 'raw', 'news', '*q*.npy'
    )
))
dg = DatasetGenerator(
    data_list = data_list,
    stock_path=sotck_path,
    sec_path=sec_path,
    freq='quarter'
)

## model definition

In [3]:
from apulu.model import RecurrentGCN

## model training

In [4]:
import argparse
from torch_geometric_temporal.signal import temporal_signal_split
from apulu.utils import Trainer

dataset = dg.process()
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)


INPUT_SHAPE = next(iter(train_dataset)).x.shape[1]
model = RecurrentGCN(input_size = INPUT_SHAPE, hidden_dims=64)
args = argparse.Namespace(
    num_epochs = 500,
    learning_rate = 1e-3,
    device = "cpu",
    val_size = .1,
    verbose = False
)

In [5]:
trainer = Trainer(model, train_dataset, args, test_dataset)

In [6]:
model = trainer.train()

100%|███████████████████████████████████████████████████████████████████████████| 500/500 [01:39<00:00,  5.03it/s]


            best model loss is:
                val loss: 0.4167168438434601 @ epoch: 475
            
ROC AUC score 0.8945454545454546
              precision    recall  f1-score   support

           0       0.88      0.88      0.88        25
           1       0.91      0.91      0.91        33

    accuracy                           0.90        58
   macro avg       0.89      0.89      0.89        58
weighted avg       0.90      0.90      0.90        58

ROC AUC score 0.6244186046511628
              precision    recall  f1-score   support

           0       0.48      0.40      0.44        30
           1       0.80      0.85      0.82        86

    accuracy                           0.73       116
   macro avg       0.64      0.62      0.63       116
weighted avg       0.72      0.73      0.72       116




