# Training
(WiP)

In [1]:
import graph_class as gc
import numpy as np
import pandas as pd
import dgl
from dgl.nn import GraphConv
import torch
from torch import nn


Using backend: pytorch


In [2]:
dataset = gc.WeatherDataset('test_one')
dataset.create('../data/data_initial_preprocessing.csv')

In [3]:
g = dataset.graph
g = dgl.add_self_loop(g)

In [4]:
g

Graph(num_nodes=667, num_edges=3540,
      ndata_schemes={'x': Scheme(shape=(1393, 7), dtype=torch.float64), 'y': Scheme(shape=(1393,), dtype=torch.float64), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={})

In [5]:
#Classe da Rede Neural
class CGN(nn.Module):
    def __init__(self, in_feats, num_classes):
        super(CGN, self).__init__()
        self.conv1 = GraphConv(in_feats, 32, norm='both')
        self.conv2 = GraphConv(32, 16, norm='both')
        self.conv3 = GraphConv(16, num_classes, norm='both')

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = torch.tanh(h)
        h = self.conv2(g, h)
        h = torch.tanh(h)
        h = self.conv3(g, h)

        return h

In [6]:
net = CGN(g.ndata['x'].shape[2], 1)

In [7]:
net = net.float()
net

CGN(
  (conv1): GraphConv(in=7, out=32, normalization=both, activation=None)
  (conv2): GraphConv(in=32, out=16, normalization=both, activation=None)
  (conv3): GraphConv(in=16, out=1, normalization=both, activation=None)
)

In [8]:
from torch import optim
from torch import nn
from torch.utils.tensorboard import SummaryWriter

In [9]:
p = {
    'epochs': 10000,
    'optim': optim.Adam,
    'loss_function': nn.MSELoss(),
    'lr': 1e-3
}

net = net.to('cuda')
g = g.to('cuda')
name = 'runs/test_3'

In [10]:
writer = SummaryWriter(f"{name}")

features = g.ndata['x'].float()
label = g.ndata['y'].float()

train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']

loss_fn = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=p['lr'])

history = {
    'loss': [],
    'val_loss': []
}

for epoch in range(p['epochs']):
    prediction = net(g, features)
    prediction = prediction.reshape(prediction.shape[0], prediction.shape[1])

    loss = loss_fn(prediction[train_mask], label[train_mask])
    val_loss = loss_fn(prediction[val_mask], label[val_mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    history['loss'].append(loss.cpu().detach().numpy())
    history['val_loss'].append(val_loss.cpu().detach().numpy())

    writer.add_scalar("Loss/train", loss, epoch) #tensorboard
    writer.add_scalar("Loss/Val", val_loss, epoch) #tensorboard

    if epoch % 100 == 0:
        print(f'Epoch: {epoch} Loss: {loss} Val Loss: {val_loss}')
print(f'Epoch: {epoch} Loss: {loss} Val Loss: {val_loss}')

Epoch: 0 Loss: 1.0887255668640137 Val Loss: 1.0006099939346313
Epoch: 100 Loss: 0.945045530796051 Val Loss: 0.8729445934295654
Epoch: 200 Loss: 0.9444595575332642 Val Loss: 0.8720039129257202
Epoch: 300 Loss: 0.9440282583236694 Val Loss: 0.87148517370224
Epoch: 400 Loss: 0.9435788989067078 Val Loss: 0.871043860912323
Epoch: 500 Loss: 0.943105161190033 Val Loss: 0.8708491325378418
Epoch: 600 Loss: 0.9426414370536804 Val Loss: 0.8709275722503662
Epoch: 700 Loss: 0.9421799182891846 Val Loss: 0.8710023164749146
Epoch: 800 Loss: 0.9416746497154236 Val Loss: 0.8708578944206238
Epoch: 900 Loss: 0.9410627484321594 Val Loss: 0.870395839214325
Epoch: 1000 Loss: 0.9402862787246704 Val Loss: 0.8695291876792908
Epoch: 1100 Loss: 0.939276933670044 Val Loss: 0.8682030439376831
Epoch: 1200 Loss: 0.9379695653915405 Val Loss: 0.8665600419044495
Epoch: 1300 Loss: 0.9363989233970642 Val Loss: 0.8647010922431946
Epoch: 1400 Loss: 0.9347147941589355 Val Loss: 0.8626964688301086
Epoch: 1500 Loss: 0.933053314

In [21]:
prediction[train_mask].shape

torch.Size([533, 1400])

In [22]:
label[train_mask].shape

torch.Size([533, 1400])

In [23]:
dataset.scaler_y.inverse_transform(np.repeat(np.array(1.069), 1400).reshape(-1, 1400))

array([[11.39190968,  7.16351918,  5.75264139, ..., 15.19342615,
        15.11701595, 15.01706286]])