In [267]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import BatchNorm1d, ReLU, Linear, Sequential
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GINConv, global_add_pool, SAGEConv, to_hetero, HeteroConv, GraphConv
import os
from torch.utils.data import DataLoader
from torch_geometric.data import InMemoryDataset, download_url, HeteroData
import networkx as nx
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import KFold

from sklearn.preprocessing import StandardScaler

import matplotlib.pyplot as plt

In [268]:
data = HeteroData()
data['inst'].x = torch.FloatTensor(np.load("/home/zluo/nn/GNN/to_gnn/train_inst_X.npy"))
data['net'].x = torch.FloatTensor(np.load("/home/zluo/nn/GNN/to_gnn/train_net_X.npy").reshape(-1, 1))
data['net'].y = torch.FloatTensor(np.load("/home/zluo/nn/GNN/to_gnn/net_Y.npy").reshape(-1, 1))
data['inst', 'to', 'net'].edge_index  =  torch.LongTensor(np.load("/home/zluo/nn/GNN/to_gnn/edgs_inst_to_net.npy").T)
data['net', 'to', 'inst'].edge_index  =  torch.LongTensor(np.load("/home/zluo/nn/GNN/to_gnn/edgs_net_to_inst.npy").T)
data['inst', 'to', 'inst'].edge_index  = torch.LongTensor(np.load("/home/zluo/nn/GNN/to_gnn/edge_index_train_inst.npy").T)

In [269]:
data

HeteroData(
  [1minst[0m={ x=[51442, 3] },
  [1mnet[0m={
    x=[51442, 1],
    y=[51442, 1]
  },
  [1m(inst, to, net)[0m={ edge_index=[2, 51442] },
  [1m(net, to, inst)[0m={ edge_index=[2, 81176] },
  [1m(inst, to, inst)[0m={ edge_index=[2, 81176] }
)

In [168]:
data = Data(x=torch.FloatTensor(np.load("/home/zluo/nn/GNN/to_gnn/train_inst_X.npy")), edge_index=torch.LongTensor(np.load("/home/zluo/nn/GNN/to_gnn/edge_index_train_inst.npy").T), y=torch.FloatTensor(np.load("/home/zluo/nn/GNN/to_gnn/train_net_X.npy")[:, 0].reshape(-1, 1)))

In [279]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                ('inst', 'to', 'net'): GATConv((-1, -1), hidden_channels, add_self_loops=False, dropout=0.6),
                ('net', 'to', 'inst'): GATConv((-1, -1), hidden_channels, add_self_loops=False, dropout=0.6),
            }, aggr='sum')
            self.convs.append(conv)

        self.lin1 = Linear(hidden_channels, out_channels)
        self.lin2 = Linear(out_channels, 1)

    def forward(self, x_dict, edge_index_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
            
        x = F.relu(self.lin1(x_dict['net']))
        x = self.lin2(x)
        
        return x

In [239]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GNN, self).__init__()
        self.conv1 = GATConv(-1, hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels)
        self.fc1 = nn.Linear(hidden_channels, 20)
        self.fc2 = nn.Linear(20, 1)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = self.conv2(x, edge_index)
        
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x

In [275]:
model = HeteroGNN(hidden_channels=32, out_channels=10, num_layers=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
criterion = torch.nn.MSELoss()

In [170]:
model = GNN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = torch.nn.MSELoss()

In [278]:
for epoch in range(100):
    model.train()
    out = model(data.x_dict, data.edge_index_dict)
    loss = criterion(out, data['net'].y)
    loss.backward()
    optimizer.step()
    print(float(loss))

3542.2470703125
3542.9521484375
3543.7021484375
3544.495361328125
3545.331298828125
3546.207763671875
3547.123291015625
3548.075927734375
3549.064453125
3550.0859375
3551.13916015625
3552.222412109375
3553.333984375
3554.469970703125
3555.6318359375
3556.8134765625
3558.015625
3559.233642578125
3560.46728515625
3561.713134765625
3562.96875
3564.23291015625
3565.502197265625
3566.7724609375
3568.043212890625
3569.3125
3570.57666015625
3571.832275390625
3573.07763671875
3574.310791015625
3575.527587890625
3576.726318359375
3577.903564453125
3579.0576171875
3580.184326171875
3581.282958984375
3582.35009765625
3583.382080078125
3584.376708984375
3585.33251953125
3586.2470703125
3587.115234375
3587.937744140625
3588.709716796875
3589.4306640625
3590.09765625
3590.708251953125
3591.261962890625
3591.75390625
3592.18603515625
3592.552001953125
3592.855712890625
3593.091064453125
3593.25927734375
3593.35888671875
3593.38916015625
3593.348388671875
3593.237060546875
3593.055419921875
3592.80078

In [177]:
for epoch in range(100):
    model.train()
    out = model(data.x, data.edge_index)
    loss = criterion(out, data.y)
    loss.backward()
    optimizer.step()
    print(float(loss))

4267.00927734375
4277.6240234375
4286.45849609375
4293.47119140625
4298.59033203125
4301.751953125
4302.7529296875
4301.62353515625
4298.43994140625
4293.15478515625
4285.7880859375
4276.3759765625
4264.86474609375
4251.2841796875
4235.6630859375
4218.04541015625
4198.5458984375
4177.2529296875
4154.271484375
4129.72314453125
4103.73974609375
4076.47509765625
4048.07763671875
4018.72021484375
3988.588623046875
3957.879150390625
3926.79638671875
3895.554931640625
3864.38134765625
3833.477294921875
3803.085693359375
3773.458984375
3744.8427734375
3717.47998046875
3691.615966796875
3667.497314453125
3645.36376953125
3625.454345703125
3608.000732421875
3593.226806640625
3581.350341796875
3572.573974609375
3567.095458984375
3565.09228515625
3566.732666015625
3572.160888671875
3581.5107421875
3594.893310546875
3612.397705078125
3634.092041015625
3660.02197265625
3690.204345703125
3724.637451171875
3763.28271484375
3806.08349609375
3852.947265625
3903.75390625
3958.352783203125
4016.567626953