## imports

In [1]:
import os
import sys
sys.path.append('../')
from glob import glob
import torch
import numpy as np
import pandas as pd

## get and split data

In [2]:
from spug.dataset import DatasetGenerator

data_root = '../data'

sotck_path = os.path.join(
        data_root, 'raw', 'stock', 'raw.csv'
)

sec_path = os.path.join(
        data_root, 'raw', 'sec'
)

output_path = os.path.join(
        data_root, 'processed'
)

data_list = sorted(glob(
    os.path.join(
        data_root, 'raw', 'twitter', '*q*.npy'
    )
))
dg = DatasetGenerator(
    data_list = data_list,
    stock_path=sotck_path,
    sec_path=sec_path,
    freq='quarter'
)

## model definition

In [3]:
from spug.model import RecurrentGCN

## model training

In [4]:
import argparse
from torch_geometric_temporal.signal import temporal_signal_split
from spug.utils import Trainer

dataset = dg.process()
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)


INPUT_SHAPE = next(iter(train_dataset)).x.shape[1]
model = RecurrentGCN(input_size = INPUT_SHAPE, hidden_dims=64)
args = argparse.Namespace(
    num_epochs = 500,
    learning_rate = 1e-3,
    device = "cpu",
    val_size = .1,
    verbose = False
)

In [5]:
trainer = Trainer(model, train_dataset, args, test_dataset)

In [6]:
model = trainer.train()

100%|███████████████████████████████████████████████████████████████████████████| 500/500 [01:50<00:00,  4.53it/s]


            best model loss is:
                val loss: 0.3305211067199707 @ epoch: 481
            
ROC AUC score 0.9583333333333333
              precision    recall  f1-score   support

           0       1.00      0.92      0.96        12
           1       0.98      1.00      0.99        46

    accuracy                           0.98        58
   macro avg       0.99      0.96      0.97        58
weighted avg       0.98      0.98      0.98        58

ROC AUC score 0.6229086229086229
              precision    recall  f1-score   support

           0       0.62      0.38      0.47        42
           1       0.71      0.86      0.78        74

    accuracy                           0.69       116
   macro avg       0.66      0.62      0.63       116
weighted avg       0.68      0.69      0.67       116




