In [1]:
import torch
print(torch.__version__)
print(torch.version.cuda)

1.9.0+cu111
11.1


In [2]:
from typing import Union, Tuple
from torch_geometric.typing import OptPairTensor, Adj, Size # Optional[Tensor], Union[Tensor, SparseTensor], Optional[Tuple[int, int]], all about data type

from torch import Tensor
from torch.nn import Linear
import torch.nn.functional as F
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn.conv import MessagePassing

In [3]:
import warnings
warnings.filterwarnings("ignore")
import os
import time
import random
from tqdm import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from joblib import Parallel, delayed

from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.model_selection import train_test_split

import torch
import torch.nn.functional as F
import torch_geometric.nn as gnn
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Dataset, Data, DataLoader
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.utils import softmax

In [4]:
inv = 0

In [5]:
df = pd.read_csv('input/CMS_trigger.csv')
print(df.columns)
df

scaler_1 = StandardScaler()
df.loc[:,'Phi_0':'MedianTheta'] = scaler_1.fit_transform(df.loc[:,'Phi_0':'MedianTheta']) # normalize

Index(['Unnamed: 0', 'Phi_0', 'Phi_2', 'Phi_3', 'Phi_4', 'Theta_0', 'Theta_2',
       'Theta_3', 'Theta_4', 'BendingAngle_0', 'BendingAngle_2',
       'BendingAngle_3', 'BendingAngle_4', 'TimeInfo_0', 'TimeInfo_2',
       'TimeInfo_3', 'TimeInfo_4', 'RingNumber_0', 'RingNumber_2',
       'RingNumber_3', 'RingNumber_4', 'Front_0', 'Front_2', 'Front_3',
       'Front_4', 'Mask_0', 'Mask_2', 'Mask_3', 'Mask_4',
       'PatternStraightness', 'Zone', 'MedianTheta', 'q/pt', 'PhiAngle',
       'EtaAngle'],
      dtype='object')


In [6]:
features = ['Phi_'+str(i) for i in [0,2,3,4]] + ['Theta_'+str(i) for i in [0,2,3,4]] + \
['Front_'+str(i) for i in [0,2,3,4]] + ['BendingAngle_'+str(i) for i in [0,2,3,4]] + \
['RingNumber_'+str(i) for i in [0,2,3,4]] + \
['TimeInfo_'+str(i) for i in [0,2,3,4]] + ['Mask_'+str(i) for i in [0,2,3,4]] + ['PatternStraightness'] + ['Zone'] + ['MedianTheta']
# edge_index = torch.tensor([(0,1),(1,2),(2,3),(3,2),(2,1),(1,0)], dtype=torch.long).T
edge_index = [(0,1),(1,2),(2,3),(3,2),(2,1),(1,0)]
# edge_index = [(0,1),(1,2),(2,3),(2,4),(1,0),(2,1),(3,2),(4,2)]
# edge_index = [(0,1),(1,2),(2,3),(2,4),(1,0),(0,2),(2,1),(3,2),(4,2),(2,5),(2,6),(5,2),(6,2)]
# edge_index = [(0,1),(1,2),(2,3),(2,4),(1,0),(2,1),(3,2),(4,2),(3,4),(4,3),(1,3),(1,4),(0,3),(0,4),(0,2)]

In [9]:
# x_train, x_test, pT_tr, pT_ts, inv_pT_tr, inv_pT_ts = train_test_split(df[features].to_numpy(), abs(1/df.loc[:,'q/pt']).to_numpy(), 1/abs(1/df.loc[:,'q/pt']).to_numpy(), test_size = 0.2, random_state = 1)
train_mask, test_mask = train_test_split(df['Unnamed: 0'].to_numpy(), test_size = 0.2, random_state = 1)
x_data = df[features].to_numpy()
x_data = np.concatenate([x_data,np.zeros([len(x_data),1])],1)
pT = abs(1/df.loc[:,'q/pt']).to_numpy()
inv_pT = 1/pT
if inv:
    label = inv_pT
else:
    label = pT
num_features = x_data.shape[-1]
print('Data shape: ' + str(x_data.shape))
print(pT.shape)
print('Len train: '+str(len(train_mask))+', Len test: '+str(len(test_mask)))
print('Num. features: '+str(num_features))

Data shape: (1179356, 32)
(1179356,)
Len train: 943484, Len test: 235872
Num. features: 32


In [10]:
def process_data(i):
    
    data = Data(x=torch.tensor(x_data[i].reshape(-1,4).T, dtype=torch.float), y=torch.tensor(label[i], dtype=torch.float),
                edge_index=torch.tensor(edge_index, dtype = torch.long).T)
    return data

In [11]:
class MPL(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(MPL, self).__init__(aggr='add')
        self.mlp1 = torch.nn.Linear(in_channels*2, out_channels)
        self.mlp2 = torch.nn.Linear(in_channels, out_channels)
        self.mlp3 = torch.nn.Linear(2*out_channels, 1)
        self.mlp4 = torch.nn.Linear(2*out_channels, 1)
        self.mlp5 = torch.nn.Linear(in_channels,16)
        self.mlp6 = torch.nn.Linear(out_channels,16)
        self.mlp7 = torch.nn.Linear(16,1)

    def forward(self, x, edge_index):
#         edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
        msg = self.propagate(edge_index, x=x)
        x = F.relu(self.mlp2(x))
        w1 = F.sigmoid(self.mlp3(torch.cat([x,msg], dim=1)))
        w2 = F.sigmoid(self.mlp4(torch.cat([x,msg], dim=1)))
        out = w1*msg + w2*x
        
        return out

    def message(self, x_i, x_j, edge_index):
        msg = F.relu(self.mlp1(torch.cat([x_i, x_j-x_i], dim=1)))
        w1 = F.tanh(self.mlp5(x_i))
        w2 = F.tanh(self.mlp6(msg))
        w = self.mlp7(w1*w2)
        w = softmax(w, edge_index[0])
        return msg*w

In [21]:
class MPNN(torch.nn.Module):
    def __init__(self):
        super(MPNN, self).__init__()
        self.conv1 = MPL(int(len(features)/4)+1,128)
        self.conv2 = MPL(128,64)
        self.conv3 = MPL(64,64)
        self.conv4 = MPL(64,64)
        self.lin1 = torch.nn.Linear(128, 128)
        self.lin2 = torch.nn.Linear(128, 16)
        self.lin3 = torch.nn.Linear(16, 16)
        self.lin4 = torch.nn.Linear(16, 1)
        self.lin5 = torch.nn.Linear(128, 128)
        self.lin6 = torch.nn.Linear(128, 16)
        self.lin7 = torch.nn.Linear(16, 16)
        self.lin8 = torch.nn.Linear(16, 1)
        self.global_att_pool1 = gnn.GlobalAttention(torch.nn.Sequential(torch.nn.Linear(64, 1)))
        self.global_att_pool2 = gnn.GlobalAttention(torch.nn.Sequential(torch.nn.Linear(64, 1)))
    
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x1 = self.global_att_pool1(x, batch)
        x = F.relu(self.conv3(x, edge_index))
        x = F.relu(self.conv4(x, edge_index))
        x2 = self.global_att_pool2(x, batch)
        x_out = torch.cat([x1, x2], dim=1)
        
#         x = F.relu(self.lin1(x_out))
#         x = F.relu(self.lin2(x))
#         x = self.lin3(x)
#         x = self.lin4(x).squeeze(1)
        
        x = F.relu(self.lin1(x_out))
        x = F.relu(self.lin2(x))
        x = self.lin3(x)
        xf1 = self.lin4(x).squeeze(1)
        x = F.relu(self.lin5(x_out))
        x = F.relu(self.lin6(x))
        x = self.lin7(x)
        xf2 = F.sigmoid(self.lin8(x).squeeze(1))
        
        return xf1, xf2

In [22]:
class MyDataset(Dataset):
    def __init__(self, indices=list(range(len(df))), transform=None):
        self.transform = transform
        self.indices = indices
    
    def __getitem__(self, idx):
        return process_data(self.indices[idx])
    
    def __len__(self):
        return len(self.indices)

In [23]:
# def mse_custom(outputs, labels):
#     weights = torch.tensor(labels<80, dtype=torch.float).to(device)*labels + torch.tensor(labels>=80, dtype=torch.float).to(device)*torch.tensor(labels<160, dtype=torch.float).to(device)*labels*2.4 + torch.tensor(labels>=160, dtype=torch.float).to(device)*10
#     error = weights*(((outputs-labels)/labels)**2)
#     return torch.mean(error)

In [24]:
def multitask_mse(outputs, labels):
    weights = torch.tensor(labels<80, dtype=torch.float).to(device)*labels + torch.tensor(labels>=80, dtype=torch.float).to(device)*torch.tensor(labels<160, dtype=torch.float).to(device)*labels*2.4 + torch.tensor(labels>=160, dtype=torch.float).to(device)*10
    error = weights*(((outputs-labels)/labels)**2)
    return torch.mean(error)

In [25]:
scale = 4e2
mse2 = torch.nn.MSELoss()

In [26]:
def train(prog_bar = True):
    
    train_losses, test_losses = list(), list()
    min_test_loss = float('inf')
    train_loader = DataLoader(MyDataset(indices=train_mask), batch_size=batch_size)
    test_loader = DataLoader(MyDataset(indices=test_mask), batch_size=batch_size)
    
    for epoch in range(epochs):
        train_loss = 0
        test_loss = 0
        if prog_bar:
            pbar = tqdm(train_loader,position=0)
        else:
            pbar = train_loader
            
        # train
        for data in pbar:
            data = data.to(device)
            optimizer.zero_grad()
            outputs1, outputs2 = model(data)
            labels = data.y
            loss = multitask_mse(outputs1, data.y) + scale*mse2(outputs2, 1/data.y)
            loss.backward()
            optimizer.step()
            if prog_bar:
                pbar.set_description('pTLoss: '+str(loss.cpu().detach().numpy()))
                train_loss += loss.cpu().detach().numpy()/len(train_loader)
                
        # test
        for data in test_loader:
            data = data.to(device)
            optimizer.zero_grad()
            outputs1, outputs2 = model(data)
            labels = data.y
            loss = multitask_mse(outputs1, data.y) + scale*mse2(outputs2, 1/data.y)
            test_loss += loss.cpu().detach().numpy()/len(test_loader)
        if test_loss<min_test_loss:
            print('Min loss changed from '+str(min_test_loss)+' to '+str(test_loss))
            min_test_loss = test_loss
            torch.save(model.state_dict(), model_name)
        train_losses.append(train_loss)
        test_losses.append(test_loss)
        if epoch > 10 and min(test_losses[-7:])>min_test_loss+1e-9:
            break
        lr_scheduler.step(test_loss)
        print('Epoch: ', str(epoch+1)+'/'+str(epochs),'| Training pTLoss: ', train_loss, '| Testing pTLoss: ', test_loss)
        
        if not prog_bar:
            plt.plot(train_losses, label="Train Loss")
            plt.plot(test_losses, label="Validation Loss")
            plt.xlabel("# Epoch")
            plt.ylabel("Loss")
            plt.legend(loc='upper right')
            plt.show()
    return train_losses, test_losses

In [27]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')
batch_size = 2**14
epochs = 50
model = MPNN().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# optimizer = Ranger(model.parameters(), lr=0.00005, weight_decay=5e-4) # loss: ~21
# optimizer = RangerLars(model.parameters(), lr=0.01, weight_decay=5e-4)
# optimizer = torch.optim.Adam(model.parameters(), lr=0.002, weight_decay=5e-4)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, verbose=True, patience=1, factor=0.5)

model_name = 'GNN_v19_road_vars_inv_' + str(inv) + '.pth'
train_losses, test_losses = train(prog_bar=True)

pTLoss: 8.074457: 100%|████████████████████████████████████████████████████████████████| 58/58 [01:07<00:00,  1.17s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from inf to 7.990189425150554
Epoch:  1/50 | Training pTLoss:  13.672740475884797 | Testing pTLoss:  7.990189425150554


pTLoss: 6.5709486: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:09<00:00,  1.20s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 7.990189425150554 to 6.649222342173258
Epoch:  2/50 | Training pTLoss:  7.1317362456486135 | Testing pTLoss:  6.649222342173258


pTLoss: 6.1417203: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:09<00:00,  1.20s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 6.649222342173258 to 6.235939566294352
Epoch:  3/50 | Training pTLoss:  6.468880382077448 | Testing pTLoss:  6.235939566294352


pTLoss: 5.5841727: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:08<00:00,  1.18s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 6.235939566294352 to 5.395938841501871
Epoch:  4/50 | Training pTLoss:  5.976469664738095 | Testing pTLoss:  5.395938841501871


pTLoss: 3.0100605: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:09<00:00,  1.19s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 5.395938841501871 to 3.4226952393849692
Epoch:  5/50 | Training pTLoss:  3.8853866478492476 | Testing pTLoss:  3.4226952393849692


pTLoss: 2.82546: 100%|█████████████████████████████████████████████████████████████████| 58/58 [01:09<00:00,  1.19s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 3.4226952393849692 to 2.877654774983724
Epoch:  6/50 | Training pTLoss:  2.895328147657986 | Testing pTLoss:  2.877654774983724


pTLoss: 2.675417: 100%|████████████████████████████████████████████████████████████████| 58/58 [01:08<00:00,  1.18s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.877654774983724 to 2.7461835225423172
Epoch:  7/50 | Training pTLoss:  2.6869321486045568 | Testing pTLoss:  2.7461835225423172


pTLoss: 2.5276566: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:04<00:00,  1.12s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.7461835225423172 to 2.7188809553782147
Epoch:  8/50 | Training pTLoss:  2.5969521752719213 | Testing pTLoss:  2.7188809553782147


pTLoss: 2.4335072: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:04<00:00,  1.12s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.7188809553782147 to 2.6787744839986165
Epoch:  9/50 | Training pTLoss:  2.5687692658654577 | Testing pTLoss:  2.6787744839986165


pTLoss: 2.4505854: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:05<00:00,  1.13s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.6787744839986165 to 2.5832404772440594
Epoch:  10/50 | Training pTLoss:  2.521184403320838 | Testing pTLoss:  2.5832404772440594


pTLoss: 2.5255303: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:07<00:00,  1.17s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch:  11/50 | Training pTLoss:  2.533536795912117 | Testing pTLoss:  2.632106415430705


pTLoss: 2.4332008: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:06<00:00,  1.14s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.5832404772440594 to 2.4845922470092776
Epoch:  12/50 | Training pTLoss:  2.489528142172714 | Testing pTLoss:  2.4845922470092776


pTLoss: 2.346668: 100%|████████████████████████████████████████████████████████████████| 58/58 [01:04<00:00,  1.12s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.4845922470092776 to 2.438681077957153
Epoch:  13/50 | Training pTLoss:  2.4470208554432307 | Testing pTLoss:  2.438681077957153


pTLoss: 2.3039093: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:05<00:00,  1.12s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch:  14/50 | Training pTLoss:  2.4018620293715904 | Testing pTLoss:  2.457957553863526


pTLoss: 2.3301713: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:04<00:00,  1.10s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch    15: reducing learning rate of group 0 to 5.0000e-03.
Epoch:  15/50 | Training pTLoss:  2.4082376258126614 | Testing pTLoss:  2.5129272143046064


pTLoss: 2.2134194: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:03<00:00,  1.10s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.438681077957153 to 2.367256991068522
Epoch:  16/50 | Training pTLoss:  2.321836352348327 | Testing pTLoss:  2.367256991068522


pTLoss: 2.1888657: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:07<00:00,  1.16s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.367256991068522 to 2.359126091003418
Epoch:  17/50 | Training pTLoss:  2.2951661800516066 | Testing pTLoss:  2.359126091003418


pTLoss: 2.176219: 100%|████████████████████████████████████████████████████████████████| 58/58 [01:11<00:00,  1.23s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.359126091003418 to 2.3393230915069583
Epoch:  18/50 | Training pTLoss:  2.285623118795198 | Testing pTLoss:  2.3393230915069583


pTLoss: 2.1642303: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:13<00:00,  1.27s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.3393230915069583 to 2.3233396371205646
Epoch:  19/50 | Training pTLoss:  2.275324171987073 | Testing pTLoss:  2.3233396371205646


pTLoss: 2.1582892: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:14<00:00,  1.28s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.3233396371205646 to 2.314676268895467
Epoch:  20/50 | Training pTLoss:  2.265161769143467 | Testing pTLoss:  2.314676268895467


pTLoss: 2.1745903: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:10<00:00,  1.22s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch:  21/50 | Training pTLoss:  2.265319199397646 | Testing pTLoss:  2.3363502184549967


pTLoss: 2.1739936: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:12<00:00,  1.24s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch    22: reducing learning rate of group 0 to 2.5000e-03.
Epoch:  22/50 | Training pTLoss:  2.258940577507019 | Testing pTLoss:  2.3326062520345054


pTLoss: 2.1300538: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:11<00:00,  1.23s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.314676268895467 to 2.2735421180725095
Epoch:  23/50 | Training pTLoss:  2.215145242625269 | Testing pTLoss:  2.2735421180725095


pTLoss: 2.1216772: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:10<00:00,  1.22s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.2735421180725095 to 2.269461933771769
Epoch:  24/50 | Training pTLoss:  2.1923650009878752 | Testing pTLoss:  2.269461933771769


pTLoss: 2.1250334: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:09<00:00,  1.20s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch:  25/50 | Training pTLoss:  2.187444444360404 | Testing pTLoss:  2.271911811828613


pTLoss: 2.1266017: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:09<00:00,  1.19s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch    26: reducing learning rate of group 0 to 1.2500e-03.
Epoch:  26/50 | Training pTLoss:  2.1837667966711103 | Testing pTLoss:  2.2727776368459063


pTLoss: 2.083982: 100%|████████████████████████████████████████████████████████████████| 58/58 [01:07<00:00,  1.16s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.269461933771769 to 2.249190886815389
Epoch:  27/50 | Training pTLoss:  2.1550417398584303 | Testing pTLoss:  2.249190886815389


pTLoss: 2.077077: 100%|████████████████████████████████████████████████████████████████| 58/58 [01:10<00:00,  1.21s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.249190886815389 to 2.246477794647217
Epoch:  28/50 | Training pTLoss:  2.150848302347907 | Testing pTLoss:  2.246477794647217


pTLoss: 2.0686374: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:10<00:00,  1.22s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.246477794647217 to 2.2455040772755943
Epoch:  29/50 | Training pTLoss:  2.146623819038786 | Testing pTLoss:  2.2455040772755943


pTLoss: 2.0647416: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:12<00:00,  1.24s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.2455040772755943 to 2.2449838002522786
Epoch:  30/50 | Training pTLoss:  2.142912359073244 | Testing pTLoss:  2.2449838002522786


pTLoss: 2.0613923: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:10<00:00,  1.22s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.2449838002522786 to 2.2433238983154298
Epoch:  31/50 | Training pTLoss:  2.1397510376469837 | Testing pTLoss:  2.2433238983154298


pTLoss: 2.058387: 100%|████████████████████████████████████████████████████████████████| 58/58 [01:07<00:00,  1.17s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch:  32/50 | Training pTLoss:  2.1368142983009073 | Testing pTLoss:  2.2433718522389725


pTLoss: 2.057573: 100%|████████████████████████████████████████████████████████████████| 58/58 [01:07<00:00,  1.17s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.2433238983154298 to 2.241991106669108
Epoch:  33/50 | Training pTLoss:  2.1345103341957614 | Testing pTLoss:  2.241991106669108


pTLoss: 2.055706: 100%|████████████████████████████████████████████████████████████████| 58/58 [01:07<00:00,  1.16s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.241991106669108 to 2.241777181625366
Epoch:  34/50 | Training pTLoss:  2.131927823198253 | Testing pTLoss:  2.241777181625366


pTLoss: 2.0523095: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:26<00:00,  1.50s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.241777181625366 to 2.2407493273417156
Epoch:  35/50 | Training pTLoss:  2.1296938616653978 | Testing pTLoss:  2.2407493273417156


pTLoss: 2.0503726: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:48<00:00,  1.87s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.2407493273417156 to 2.238966369628906
Epoch:  36/50 | Training pTLoss:  2.1267667227777944 | Testing pTLoss:  2.238966369628906


pTLoss: 2.0479612: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:47<00:00,  1.86s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.238966369628906 to 2.2365836461385094
Epoch:  37/50 | Training pTLoss:  2.124503302163091 | Testing pTLoss:  2.2365836461385094


pTLoss: 2.041567: 100%|████████████████████████████████████████████████████████████████| 58/58 [01:47<00:00,  1.86s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.2365836461385094 to 2.235635995864868
Epoch:  38/50 | Training pTLoss:  2.1217072092253586 | Testing pTLoss:  2.235635995864868


pTLoss: 2.0385413: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:48<00:00,  1.87s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.235635995864868 to 2.234046173095703
Epoch:  39/50 | Training pTLoss:  2.1185422642477616 | Testing pTLoss:  2.234046173095703


pTLoss: 2.0353394: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:42<00:00,  1.77s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch:  40/50 | Training pTLoss:  2.1155632857618665 | Testing pTLoss:  2.2345340569814045


pTLoss: 2.03399: 100%|█████████████████████████████████████████████████████████████████| 58/58 [01:44<00:00,  1.79s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch    41: reducing learning rate of group 0 to 6.2500e-04.
Epoch:  41/50 | Training pTLoss:  2.113089676561027 | Testing pTLoss:  2.234995222091675


pTLoss: 2.0060465: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:42<00:00,  1.77s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch:  42/50 | Training pTLoss:  2.0989851437765976 | Testing pTLoss:  2.2359059333801268


pTLoss: 2.0049033: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:42<00:00,  1.77s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch    43: reducing learning rate of group 0 to 3.1250e-04.
Epoch:  43/50 | Training pTLoss:  2.0940335680698525 | Testing pTLoss:  2.2367039839426677


pTLoss: 1.9883916: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:43<00:00,  1.78s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.234046173095703 to 2.218719228108724
Epoch:  44/50 | Training pTLoss:  2.079568735484419 | Testing pTLoss:  2.218719228108724


pTLoss: 1.9859326: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:45<00:00,  1.83s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch:  45/50 | Training pTLoss:  2.0746730730451386 | Testing pTLoss:  2.2189329306284584


pTLoss: 1.9842829: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:49<00:00,  1.89s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch    46: reducing learning rate of group 0 to 1.5625e-04.
Epoch:  46/50 | Training pTLoss:  2.073077906822337 | Testing pTLoss:  2.219063138961792


pTLoss: 1.976331: 100%|████████████████████████████████████████████████████████████████| 58/58 [01:49<00:00,  1.88s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Min loss changed from 2.218719228108724 to 2.213063987096151
Epoch:  47/50 | Training pTLoss:  2.064596307688747 | Testing pTLoss:  2.213063987096151


pTLoss: 1.9739523: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:53<00:00,  1.96s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch:  48/50 | Training pTLoss:  2.063038678004824 | Testing pTLoss:  2.2134221394856772


pTLoss: 1.9729404: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:51<00:00,  1.93s/it]
  0%|                                                                                           | 0/58 [00:00<?, ?it/s]

Epoch    49: reducing learning rate of group 0 to 7.8125e-05.
Epoch:  49/50 | Training pTLoss:  2.0620676854561126 | Testing pTLoss:  2.213593117396037


pTLoss: 1.9680057: 100%|███████████████████████████████████████████████████████████████| 58/58 [01:53<00:00,  1.95s/it]


Epoch:  50/50 | Training pTLoss:  2.058266002556373 | Testing pTLoss:  2.2147322018941247


In [28]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loaded_model = MPNN().to(device)
rand_model = MPNN().to(device)
loaded_model.load_state_dict(torch.load(model_name))

<All keys matched successfully>

In [29]:
test_loader = DataLoader(MyDataset(indices=test_mask), batch_size=batch_size)
# test
test_los = 0
preds = []
preds2 = []
for data in tqdm(test_loader,position=0):
    data = data.to(device)
    outputs1, outputs2 = model(data)
    labels = data.y
    loss = multitask_mse(outputs1, data.y) + scale*mse2(outputs2, 1/data.y)
    preds.append(outputs1.cpu().detach())
    preds2.append(outputs2.cpu().detach())
    test_los += loss.cpu().detach().numpy()/len(test_loader)
print('Test_loss: '+str(test_los))

100%|██████████████████████████████████████████████████████████████████████████████████| 15/15 [00:21<00:00,  1.41s/it]

Test_loss: 2.2147322018941247





In [30]:
# save to csv
pred_ls = [float(i) for p in preds for i in p]
pred_ls2 = [float(i) for p in preds2 for i in p]
df_pred = pd.DataFrame(pred_ls)
df_pred2 = pd.DataFrame(pred_ls2)
df_pred.to_csv('GNN_v19_road_vars_inv_' + str(0) + '.csv')
df_pred2.to_csv('GNN_v19_road_vars_inv_' + str(1) + '.csv')
print('Files saved!')

Files saved!
