In [2]:
import numpy as np
import pandas as pd

import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Flatten
import time 

import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import GCNConv, TopKPooling, SAGEConv, GATConv, SplineConv, Linear, to_hetero
from torch_geometric.loader import DataLoader
from torch_geometric.data import InMemoryDataset, download_url, HeteroData

from tqdm import tqdm
from sklearn.metrics import r2_score
from d2l import torch as d2l
from DoubleBoxDataLoader import HeteroTrainData, HeteroTestData

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset_train = HeteroTrainData(root = 'DoubleBox_Branchline_data/train/')
dataset_test = HeteroTestData(root = 'DoubleBox_Branchline_data/test/')
# dataset_train = dataset_train[0:1200]

batch_size = 100

train_loader = DataLoader(dataset_train, batch_size=batch_size)
test_loader = DataLoader(dataset_test, batch_size=batch_size)

# for data in train_loader:
#     data = data
# #     print(data)
#     break
data = dataset_train[0]

In [3]:
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
#         self.lin1 = Linear(8, 8)
        self.conv1 = SAGEConv(16, 16)
        self.conv2 = SAGEConv(16, 16)

    def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
        x = self.conv1(x, edge_index).relu()
#         x = self.conv2(x, edge_index).relu()
        return x
    
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ML_lin = torch.nn.Linear(2,16)
        self.MT_lin = torch.nn.Linear(3,16)
        self.gcn = GCN()
        self.gcn = to_hetero(self.gcn, metadata=data.metadata())    
        self.flat = Flatten()
        self.lin1 = Linear(-1, 128)
        self.lin2 = Linear(128, 256)
        self.lin3 = Linear(256, 40)
        self.dout1 = nn.Dropout(p=0)

    def forward(self, data: HeteroData) -> Tensor:
        
        x_dict = {
            "ML": self.ML_lin(data["ML"].x),
            "MT": self.MT_lin(data["MT"].x),
        }
        x = self.gcn(x_dict, data.edge_index_dict)
#         x = global_mean_pool(x["ML"], 1000)
        x["ML"] = x["ML"].reshape(batch_size,-1)
        x["MT"] = x["MT"].reshape(batch_size,-1)
#         print(self.flat(x["MT"]).shape)
        x = torch.cat((x["ML"],x["MT"]),1)   #把两个向量直接连起来
        x = self.lin1(x).relu()
        x = self.dout1(x)
        x = self.lin2(x).relu()
#         x = global_mean_pool(x, batch)
        out = self.lin3(x)
    
        return out


model = Model().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
mse_loss = F.mse_loss
# model.load_state_dict(torch.load('GCN_Forward1.params'))

def train(train_loader):
    model.train()
#     ys, preds = [], []
    for data in train_loader:
#         print(data.y.shape)
        data = data.to(device)
#         ys.append(data['ML'].y.reshape(batch_size,-1))
        optimizer.zero_grad()
#         testout = torch.cat((x[:,0], x[:,1]), 0)
        out = model(data)
#         preds.append(out.cpu())
#         print(out.shape)
#         print(torch.reshape(data.y, (100,8)))
        loss = mse_loss(out, data['ML'].y.reshape(batch_size,-1))
        loss.backward()
        optimizer.step()
      
    return loss
#     y, pred = torch.cat(ys, dim=0).cpu().detach().numpy(), torch.cat(preds, dim=0).detach().numpy()
    
#     return r2_score(y, pred)


def test(test_loader):
    model.eval()
    ys, preds = [], []
    for data in test_loader:
        data = data.to(device)
        ys.append(data['ML'].y.reshape(batch_size,-1))
        out = model(data)
        preds.append(out)
        
    y, pred = torch.cat(ys, dim=0), torch.cat(preds, dim=0)
    loss = r2_score(y.to("cpu").detach().numpy(), pred.to("cpu").detach().numpy())
#     print(pred)
#     print(y)
    return loss, pred, y

R2_accuracy = []
time1 = time.time()
for epoch in range(1, 100):
    train_loss = train(train_loader)
    test_f1, pred, y = test(test_loader)
    R2_accuracy.append(test_f1)
    print(f'Epoch: {epoch:02d}, Trainloss: {train_loss:.6f}, TestR2: {test_f1:.4f}')
time2 = time.time()

Epoch: 01, Trainloss: 0.896187, TestR2: -0.3236
Epoch: 02, Trainloss: 0.193142, TestR2: -0.4886
Epoch: 03, Trainloss: 0.183711, TestR2: -0.3022
Epoch: 04, Trainloss: 0.167672, TestR2: -0.2151
Epoch: 05, Trainloss: 0.146087, TestR2: -0.1698
Epoch: 06, Trainloss: 0.113474, TestR2: -0.1677
Epoch: 07, Trainloss: 0.080839, TestR2: -0.1735
Epoch: 08, Trainloss: 0.057150, TestR2: -0.1658
Epoch: 09, Trainloss: 0.044224, TestR2: -0.1118
Epoch: 10, Trainloss: 0.037142, TestR2: -0.0847
Epoch: 11, Trainloss: 0.034565, TestR2: -0.0642
Epoch: 12, Trainloss: 0.034308, TestR2: -0.0073
Epoch: 13, Trainloss: 0.032989, TestR2: 0.0157
Epoch: 14, Trainloss: 0.030893, TestR2: -0.0143
Epoch: 15, Trainloss: 0.031684, TestR2: 0.0452
Epoch: 16, Trainloss: 0.029462, TestR2: 0.0448
Epoch: 17, Trainloss: 0.029462, TestR2: 0.0805
Epoch: 18, Trainloss: 0.028359, TestR2: 0.0908
Epoch: 19, Trainloss: 0.027612, TestR2: 0.1094
Epoch: 20, Trainloss: 0.025387, TestR2: 0.1205
Epoch: 21, Trainloss: 0.024178, TestR2: 0.1316


In [6]:
for epoch in range(1, 200):
    train_loss = train(train_loader)
    test_f1, pred, y = test(test_loader)
    R2_accuracy.append(test_f1)
    print(f'Epoch: {epoch:02d}, Trainloss: {train_loss:.6f}, TestR2: {test_f1:.4f}')

Epoch: 01, Trainloss: 0.002355, TestR2: 0.7741
Epoch: 02, Trainloss: 0.002365, TestR2: 0.7733
Epoch: 03, Trainloss: 0.002376, TestR2: 0.7758
Epoch: 04, Trainloss: 0.002402, TestR2: 0.7804
Epoch: 05, Trainloss: 0.002443, TestR2: 0.7847
Epoch: 06, Trainloss: 0.002356, TestR2: 0.7877
Epoch: 07, Trainloss: 0.002202, TestR2: 0.7894
Epoch: 08, Trainloss: 0.002150, TestR2: 0.7901
Epoch: 09, Trainloss: 0.002134, TestR2: 0.7906
Epoch: 10, Trainloss: 0.002126, TestR2: 0.7905
Epoch: 11, Trainloss: 0.002120, TestR2: 0.7902
Epoch: 12, Trainloss: 0.002118, TestR2: 0.7899
Epoch: 13, Trainloss: 0.002116, TestR2: 0.7895
Epoch: 14, Trainloss: 0.002118, TestR2: 0.7892
Epoch: 15, Trainloss: 0.002125, TestR2: 0.7891
Epoch: 16, Trainloss: 0.002134, TestR2: 0.7892
Epoch: 17, Trainloss: 0.002145, TestR2: 0.7896
Epoch: 18, Trainloss: 0.002153, TestR2: 0.7903
Epoch: 19, Trainloss: 0.002154, TestR2: 0.7910
Epoch: 20, Trainloss: 0.002154, TestR2: 0.7918
Epoch: 21, Trainloss: 0.002153, TestR2: 0.7927
Epoch: 22, Tr

In [None]:
print("The training time:", round(time2-time1, 2), "s")

In [None]:
def getModelSize(model):
    param_size = 0
    param_sum = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
        param_sum += param.nelement()
    buffer_size = 0
    buffer_sum = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
        buffer_sum += buffer.nelement()
    all_size = (param_size + buffer_size) / 1024
    print('The model size is：{:.3f}KB'.format(all_size))
    return (param_size, param_sum, buffer_size, buffer_sum, all_size)
    
a, b, c, d, e = getModelSize(model)