In [5]:
import pandas as pd
from datetime import timedelta
import seaborn as sns
import matplotlib.pyplot as plt

from nera.data_management.data_acquisition import DataAcquisition
from nera.data_management.data_transformation import DataTransformation
from nera.data_management import FROM_CSV
from nera.models.gnn import GCONVCheb
from nera.trainer import Trainer



### Data preparation

In [6]:
da = DataAcquisition()
df = da.get_data(FROM_CSV, fname="../resources/other_leagues.csv")
df['DT'] = pd.to_datetime(df['DT'], format="%Y-%m-%d %H:%M:%S")
data_transform = DataTransformation(df, timedelta(365))
df = df[(df['League'] != 'EuroLeague') & (df['League'] != 'EuroCup')] 
df = df.reset_index()

transform = DataTransformation(df, timedelta(365))
dataset = transform.get_dataset(node_f_extract=False, edge_f_one_hot=True)

team_count = transform.num_teams

# We will use the default implemented setting: 
#
# GCONVLSTM(self, team_count: int,
#                 embed_dim: int = 10,
#                 dense_dims: tuple[int] = (8, 8, 8, 8, 8),
#                 conv_out_dim: int = 16,
#                 dropout_rate: float = 0.1,
#                 activation: str = 'relu',
#                 K: int = 5):
model = GCONVCheb(team_count)
model

### MODEL 
--------------------------------
training:

In [7]:
trainer = Trainer(dataset, model)
training_accuracy = trainer.train(epochs=100, verbose=True)


ax = sns.lineplot(x="index", y="training_accuracy", data=pd.DataFrame(training_accuracy, columns=["training_accuracy"]).reset_index())
ax.set(xlabel="Epoch")
plt.show()

Testing:

In [8]:
trainer.test(verbose=True)