## imports

In [1]:
import os
import sys
sys.path.append('../')
from glob import glob
import torch
import numpy as np
import pandas as pd

## get and split data

In [2]:
from apulu.dataset import DatasetGenerator

data_root = '../data'

sotck_path = os.path.join(
        data_root, 'raw', 'stock', 'raw.csv'
)

sec_path = os.path.join(
        data_root, 'raw', 'sec'
)

output_path = os.path.join(
        data_root, 'processed'
)

data_list = sorted(list(set(glob(
            os.path.join(
                data_root, 'raw', 'news', '*.npy'
            )
        )) - set(glob(
            os.path.join(
                data_root, 'raw', 'news', '*q*.npy'
            )
        ))))
dg = DatasetGenerator(
    data_list = data_list,
    stock_path=sotck_path,
    sec_path=sec_path,
    freq='month'
)

## model definition

In [3]:
from apulu.model import GCN

## model training

In [4]:
import argparse
from torch_geometric_temporal.signal import temporal_signal_split
from apulu.utils import Trainer

dataset = dg.process()
train_dataset, test_dataset = temporal_signal_split(dataset, train_ratio=0.8)


INPUT_SHAPE = next(iter(train_dataset)).x.shape[1]
model = GCN(input_size = INPUT_SHAPE, hidden_dims=64)
args = argparse.Namespace(
    num_epochs = 500,
    learning_rate = 1e-3,
    device = "cpu",
    val_size = .1,
    verbose = False
)

In [5]:
trainer = Trainer(model, train_dataset, args, test_dataset)

In [6]:
model = trainer.train()

100%|███████████████████████████████████████████████████████████████████████████| 500/500 [00:22<00:00, 21.87it/s]


            best model loss is:
                val loss: 0.5586360454559326 @ epoch: 333
            
ROC AUC score 0.5199343339587242
              precision    recall  f1-score   support

           0       0.40      0.10      0.16        41
           1       0.73      0.94      0.82       104

    accuracy                           0.70       145
   macro avg       0.56      0.52      0.49       145
weighted avg       0.63      0.70      0.63       145

ROC AUC score 0.4724222350230415
              precision    recall  f1-score   support

           0       0.32      0.31      0.31       124
           1       0.62      0.64      0.63       224

    accuracy                           0.52       348
   macro avg       0.47      0.47      0.47       348
weighted avg       0.52      0.52      0.52       348




