In [1]:
import sys
import os
import networkx as nx
import torch
import torch_geometric
from torch_geometric import nn
from torch_geometric.utils.convert import from_networkx
from torch_geometric.data import Dataset, Data, DataLoader
import pandas as pd
import numpy as np
from tqdm import *
import pickle

#load local packages
sys.path.append(os.path.join(os.getcwd(), '..'))
from ConuForecast.src.graph_utils import GraphManager, DBconnector
from ConuForecast.src.pre_proc import GraphDataLoader, ConuGraphDataset

In [2]:
conn = DBconnector('172.17.0.1', 5555, 'base-ina', 'postgres', 'postgres')

MODEL = 'model_007'
EVENT = '007'
PRECIP = 'precipitation_007'
ATTRS = {
    'nodes': ['elevation', 'area', 'imperv', 'slope'],
    'edges': ['flow_rate', 'flow_velocity']
    }
ET= '2017-01-01 14:15:00'


conu_basin_dataloader = GraphDataLoader(model=MODEL, event=EVENT, precip=PRECIP, conn=conn, attrs_dict=ATTRS)

In [3]:
conu_dataset = ConuGraphDataset('../data/', ET, conu_basin_dataloader, clean=False)

In [62]:
torch.manual_seed(12345)
dataset = conu_dataset.shuffle()

train_dataset = conu_dataset[:7000]
test_dataset = conu_dataset[7000:]

In [63]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=8)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False,  num_workers=8)

In [66]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        edge_index = edge_index.long()
        x = x.float()
        
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

model = GCN(hidden_channels=64)
print(model)

GCN(
  (conv1): GCNConv(4, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (lin): Linear(in_features=64, out_features=3, bias=True)
)


In [61]:
model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()
    
    n = 0
    
    for data in train_loader:  # Iterate in batches over the training dataset.
        out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
        loss = criterion(out, data.y)  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.
        
        n += 1
        print(f'Batch: {n} of {len(train_loader)}')
        
        

def test(loader):
    model.eval()

    correct = 0
    for data in loader:  # Iterate in batches over the training/test dataset.
        out = model(data.x, data.edge_index, data.batch)  
        pred = out.argmax(dim=1)  # Use the class with highest probability.
        correct += int((pred == data.y).sum())  # Check against ground-truth labels.
    return correct / len(loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(1, 10):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

1 of 110
2 of 110
3 of 110
4 of 110
5 of 110
6 of 110
7 of 110
8 of 110
9 of 110
10 of 110
11 of 110
12 of 110
13 of 110
14 of 110
15 of 110
16 of 110
17 of 110
18 of 110
19 of 110
20 of 110
21 of 110
22 of 110
23 of 110
24 of 110
25 of 110
26 of 110
27 of 110
28 of 110
29 of 110
30 of 110
31 of 110
32 of 110
33 of 110
34 of 110
35 of 110
36 of 110
37 of 110
38 of 110
39 of 110
40 of 110
41 of 110
42 of 110
43 of 110
44 of 110
45 of 110
46 of 110
47 of 110
48 of 110
49 of 110
50 of 110
51 of 110
52 of 110
53 of 110
54 of 110
55 of 110
56 of 110
57 of 110
58 of 110
59 of 110
60 of 110
61 of 110
62 of 110
63 of 110
64 of 110
65 of 110
66 of 110
67 of 110
68 of 110
69 of 110
70 of 110
71 of 110
72 of 110
73 of 110
74 of 110
75 of 110
76 of 110
77 of 110
78 of 110
79 of 110
80 of 110
81 of 110
82 of 110
83 of 110
84 of 110
85 of 110
86 of 110
87 of 110
88 of 110
89 of 110
90 of 110
91 of 110
92 of 110
93 of 110
94 of 110
95 of 110
96 of 110
97 of 110
98 of 110
99 of 110
100 of 110
101 of 1