In [156]:
import torch
import torch.nn.functional as F
from torch import nn
from torch.nn import BatchNorm1d, ReLU, Linear, Sequential
import numpy as np
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GINConv, global_add_pool, SAGEConv, to_hetero, HeteroConv, GraphConv
import os
from torch.utils.data import DataLoader
from torch_geometric.data import InMemoryDataset, download_url, HeteroData
import networkx as nx
from sklearn.preprocessing import OrdinalEncoder
from sklearn.preprocessing import OneHotEncoder
from sklearn.model_selection import KFold

from sklearn.preprocessing import StandardScaler

import matplotlib.pyplot as plt

In [150]:
data = HeteroData()
data['inst'].x = torch.FloatTensor(np.load("/home/zluo/nn/GNN/to_gnn/train_inst_X.npy"))
data['net'].x = torch.FloatTensor(np.load("/home/zluo/nn/GNN/to_gnn/train_net_X.npy").reshape(-1, 1))
data['net'].y = torch.FloatTensor(np.load("/home/zluo/nn/GNN/to_gnn/net_Y.npy").reshape(-1, 1))
data['inst', 'to', 'net'].edge_index  =  torch.LongTensor(np.load("/home/zluo/nn/GNN/to_gnn/edgs_inst_to_net.npy").T)
data['net', 'to', 'inst'].edge_index  =  torch.LongTensor(np.load("/home/zluo/nn/GNN/to_gnn/edgs_net_to_inst.npy").T)
data['inst', 'to', 'inst'].edge_index  = torch.LongTensor(np.load("/home/zluo/nn/GNN/to_gnn/edge_index_train_inst.npy").T)

In [165]:
data

Data(x=[51442, 3], edge_index=[2, 81176], y=[51442, 1])

In [157]:
data = Data(x=torch.FloatTensor(np.load("/home/zluo/nn/GNN/to_gnn/train_inst_X.npy")), edge_index=torch.LongTensor(np.load("/home/zluo/nn/GNN/to_gnn/edge_index_train_inst.npy").T), y=torch.FloatTensor(np.load("/home/zluo/nn/GNN/to_gnn/net_Y.npy").reshape(-1, 1)))

In [158]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_layers):
        super().__init__()

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HeteroConv({
                ('inst', 'to', 'net'): GATConv((-1, -1), hidden_channels, add_self_loops=False, dropout=0.6),
                ('net', 'to', 'inst'): GATConv((-1, -1), hidden_channels, add_self_loops=False, dropout=0.6),
            }, aggr='sum')
            self.convs.append(conv)

        self.lin1 = Linear(hidden_channels, out_channels)
        self.lin2 = Linear(out_channels, 1)

    def forward(self, x_dict, edge_index_dict):
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)
            x_dict = {key: x.relu() for key, x in x_dict.items()}
            
        x = F.relu(self.lin1(x_dict['net']))
        x = self.lin2(x)
        
        return x

In [163]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GNN, self).__init__()
        self.conv1 = GATConv(-1, hidden_channels)
        self.conv2 = GATConv(hidden_channels, hidden_channels)
        self.fc1 = nn.Linear(hidden_channels, 20)
        self.fc2 = nn.Linear(20, 1)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = self.conv2(x, edge_index)
        
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x

In [154]:
model = HeteroGNN(hidden_channels=16, out_channels=10, num_layers=3)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()

In [166]:
model = GNN(hidden_channels=16)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.MSELoss()

In [162]:
for epoch in range(100):
    model.train()
    out = model(data.x_dict, data.edge_index_dict)
    loss = criterion(out, data['net'].y)
    loss.backward()
    optimizer.step()
    print(float(loss))

AttributeError: 'GlobalStorage' object has no attribute 'x_dict'

In [169]:
for epoch in range(100):
    model.train()
    out = model(data.x, data.edge_index)
    loss = criterion(out, data.y)
    loss.backward()
    optimizer.step()
    print(float(loss))

3541.18701171875
3541.06689453125
3540.937255859375
3540.800048828125
3540.656005859375
3540.506103515625
3540.35400390625
3540.196044921875
3540.037841796875
3539.8779296875
3539.718017578125
3539.558837890625
3539.40234375
3539.249755859375
3539.1005859375
3538.9599609375
3538.8271484375
3538.702880859375
3538.59033203125
3538.489501953125
3538.404052734375
3538.335205078125
3538.281982421875
3538.249755859375
3538.23681640625
3538.2470703125
3538.28173828125
3538.33935546875
3538.4248046875
3538.535888671875
3538.676025390625
3538.843994140625
3539.04052734375
3539.26416015625
3539.51318359375
3539.787353515625
3540.0830078125
3540.3974609375
3540.72509765625
3541.063232421875
3541.400390625
3541.734130859375
3542.056396484375
3542.35888671875
3542.636962890625
3542.884765625
3543.09716796875
3543.26806640625
3543.39794921875
3543.48095703125
3543.517333984375
3543.505859375
3543.44921875
3543.34716796875
3543.203125
3543.02001953125
3542.80224609375
3542.55224609375
3542.2763671875

In [170]:
data.x

tensor([[-0.5945, -0.0308, -0.9162],
        [ 1.4626, -0.0842, -0.9435],
        [-0.5945, -0.0308, -0.9162],
        ...,
        [ 0.4340, -0.0308, -0.1063],
        [-0.5945,  0.9293, -0.1882],
        [-0.5945,  0.0758, -0.9435]])

In [174]:
data.y[:50]

tensor([[3.3588e+01],
        [5.0400e-01],
        [2.9750e+01],
        [1.3680e+00],
        [3.4332e+01],
        [1.7280e+00],
        [4.1209e+01],
        [5.7600e-01],
        [3.2329e+01],
        [5.0400e-01],
        [4.1655e+01],
        [3.0240e+00],
        [1.0046e+01],
        [2.0880e+00],
        [5.1107e+01],
        [1.2960e+00],
        [3.3337e+01],
        [3.6720e+00],
        [3.6062e+01],
        [1.2960e+00],
        [4.2468e+01],
        [5.0400e-01],
        [1.3426e+04],
        [5.3280e+00],
        [3.2041e+01],
        [4.4065e+01],
        [2.6052e+01],
        [3.2412e+01],
        [4.3116e+01],
        [3.3481e+01],
        [3.7033e+01],
        [4.7520e+01],
        [3.0516e+01],
        [1.8720e+00],
        [3.0240e+00],
        [1.2240e+00],
        [3.4560e+00],
        [4.6800e+00],
        [1.0080e+00],
        [9.3600e-01],
        [4.1040e+00],
        [2.7360e+00],
        [2.5920e+00],
        [1.1088e+01],
        [2.7360e+00],
        [2

In [175]:
out[:100]

tensor([[1.7586],
        [1.9927],
        [1.7586],
        [1.9927],
        [1.7586],
        [1.9927],
        [1.7586],
        [1.9927],
        [1.7586],
        [1.9927],
        [1.7586],
        [1.9927],
        [1.7586],
        [1.9927],
        [1.7586],
        [1.9927],
        [1.7586],
        [1.9927],
        [1.7586],
        [1.9927],
        [1.7586],
        [1.9927],
        [1.5632],
        [1.7581],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7484],
        [1.7581],
        [1.7582],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.6987],
        [1.7582],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1.7586],
        [1