# Training
(WiP)

In [11]:
import graph_class as gc
import numpy as np
import pandas as pd
import dgl
from dgl.nn import GraphConv
import torch
from torch import nn


In [12]:
dataset = gc.WeatherDataset('test_one')
dataset.create('../data/data_initial_preprocessing.csv')

In [13]:
g = dataset.graph
g = dgl.add_self_loop(g)

In [14]:
g

Graph(num_nodes=667, num_edges=4002,
      ndata_schemes={'x': Scheme(shape=(1400, 8), dtype=torch.float64), 'y': Scheme(shape=(1400,), dtype=torch.float64), 'train_mask': Scheme(shape=(), dtype=torch.bool), 'val_mask': Scheme(shape=(), dtype=torch.bool), 'test_mask': Scheme(shape=(), dtype=torch.bool)}
      edata_schemes={})

In [15]:
#Classe da Rede Neural
class CGN(nn.Module):
    def __init__(self, in_feats, num_classes):
        super(CGN, self).__init__()
        self.conv1 = GraphConv(in_feats, 32, norm='both')
        self.conv2 = GraphConv(32, 16, norm='both')
        self.conv3 = GraphConv(16, num_classes, norm='both')

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = torch.tanh(h)
        h = self.conv2(g, h)
        h = torch.tanh(h)
        h = self.conv3(g, h)

        return h

In [16]:
net = CGN(g.ndata['x'].shape[2], 1)

In [17]:
net = net.float()
net

CGN(
  (conv1): GraphConv(in=8, out=32, normalization=both, activation=None)
  (conv2): GraphConv(in=32, out=16, normalization=both, activation=None)
  (conv3): GraphConv(in=16, out=1, normalization=both, activation=None)
)

In [18]:
from torch import optim
from torch import nn
from torch.utils.tensorboard import SummaryWriter

In [19]:
p = {
    'epochs': 10000,
    'optim': optim.Adam,
    'loss_function': nn.MSELoss(),
    'lr': 1e-3
}

net = net.to('cuda')
g = g.to('cuda')
name = 'runs/test_2'

In [20]:
writer = SummaryWriter(f"{name}")

features = g.ndata['x'].float()
label = g.ndata['y'].float()

train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']

loss_fn = nn.MSELoss()
optimizer = optim.Adam(net.parameters(), lr=p['lr'])

history = {
    'loss': [],
    'val_loss': []
}

for epoch in range(p['epochs']):
    prediction = net(g, features)
    prediction = prediction.reshape(prediction.shape[0], prediction.shape[1])

    loss = loss_fn(prediction[train_mask], label[train_mask])
    val_loss = loss_fn(prediction[val_mask], label[val_mask])

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    history['loss'].append(loss.cpu().detach().numpy())
    history['val_loss'].append(val_loss.cpu().detach().numpy())

    writer.add_scalar("Loss/train", loss, epoch) #tensorboard
    writer.add_scalar("Loss/Val", val_loss, epoch) #tensorboard

    if epoch % 100 == 0:
        print(f'Epoch: {epoch} Loss: {loss} Val Loss: {val_loss}')
print(f'Epoch: {epoch} Loss: {loss} Val Loss: {val_loss}')

Epoch: 0 Loss: 1.6060307025909424 Val Loss: 1.4931100606918335
Epoch: 100 Loss: 0.950023889541626 Val Loss: 0.8784621357917786
Epoch: 200 Loss: 0.9491601586341858 Val Loss: 0.8780856728553772
Epoch: 300 Loss: 0.9487327337265015 Val Loss: 0.8777257204055786
Epoch: 400 Loss: 0.9484791159629822 Val Loss: 0.8774576783180237
Epoch: 500 Loss: 0.9482951164245605 Val Loss: 0.8772980570793152
Epoch: 600 Loss: 0.9481412172317505 Val Loss: 0.877231240272522
Epoch: 700 Loss: 0.9480010271072388 Val Loss: 0.877230167388916
Epoch: 800 Loss: 0.9478656649589539 Val Loss: 0.8772702217102051
Epoch: 900 Loss: 0.9477286338806152 Val Loss: 0.8773313760757446
Epoch: 1000 Loss: 0.9475834369659424 Val Loss: 0.8773981928825378
Epoch: 1100 Loss: 0.9474226236343384 Val Loss: 0.877459704875946
Epoch: 1200 Loss: 0.9472373127937317 Val Loss: 0.8775104284286499
Epoch: 1300 Loss: 0.9470175504684448 Val Loss: 0.8775526285171509
Epoch: 1400 Loss: 0.9467513561248779 Val Loss: 0.8776010870933533
Epoch: 1500 Loss: 0.946426

In [21]:
prediction[train_mask].shape

torch.Size([533, 1400])

In [22]:
label[train_mask].shape

torch.Size([533, 1400])

In [23]:
dataset.scaler_y.inverse_transform(np.repeat(np.array(1.069), 1400).reshape(-1, 1400))

array([[11.39190968,  7.16351918,  5.75264139, ..., 15.19342615,
        15.11701595, 15.01706286]])