In [6]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [1]:
import torch
pytorch_version = f"torch-{torch.__version__}.html"
!pip install --no-index torch-scatter -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install --no-index torch-sparse -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install --no-index torch-cluster -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install --no-index torch-spline-conv -f https://pytorch-geometric.com/whl/$pytorch_version
!pip install torch-geometric pymatgen tqdm torchaudio torchvision pytorch-lightning


Looking in links: https://pytorch-geometric.com/whl/torch-1.13.0+cu116.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_scatter-2.1.0%2Bpt113cu116-cp38-cp38-linux_x86_64.whl (9.4 MB)
[K     |████████████████████████████████| 9.4 MB 31.5 MB/s 
[?25hInstalling collected packages: torch-scatter
Successfully installed torch-scatter-2.1.0+pt113cu116
Looking in links: https://pytorch-geometric.com/whl/torch-1.13.0+cu116.html
Collecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_sparse-0.6.16%2Bpt113cu116-cp38-cp38-linux_x86_64.whl (4.5 MB)
[K     |████████████████████████████████| 4.5 MB 19.5 MB/s 
Installing collected packages: torch-sparse
Successfully installed torch-sparse-0.6.16+pt113cu116
Looking in links: https://pytorch-geometric.com/whl/torch-1.13.0+cu116.html
Collecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-1.13.0%2Bcu116/torch_cluster-1.6.0%2Bpt113cu116-cp38-cp38-linux_x86_64

## 1. Loading the Dataset

In [2]:
import torch
import torch_geometric
import numpy as np 
import os
import pickle
import pandas as pd
import warnings
from tqdm import tqdm
from torch_geometric.data import Dataset, Data
from torch.nn import Embedding, ModuleList
from pymatgen.ext.matproj import MPRester

warnings.filterwarnings('ignore')

In [3]:
class GaussianFilter(object):
    def __init__(self, start=0.1, stop=5.0, n_points=100, sigma=0.5):
        self.start = start
        self.stop = stop
        self.n_points = n_points
        self.sigma = sigma
    def __call__(self, feature: torch.Tensor) -> torch.Tensor:
        feature = feature.view(len(feature), 1)
        r_0 = torch.linspace(self.start, self.stop, self.n_points).repeat(feature.shape[0], 1)
        return torch.exp(-(feature - r_0)/self.sigma)

In [4]:
class CrystalDataset(Dataset):
    def __init__(self, root, api_key, species=None, mp_ids=None, cutoff=5.0, 
                 targets=['band_gap'], global_attrs=['volume', 'spacegroup'], node_features=None, 
                 filter=None, embeddings=None, transform=None, pre_transform=None):
        """
        root = Where the dataset should be stored. This folder is split
        into raw_dir (downloaded dataset) and processed_dir (processed data). 
        """
        assert (species is not None) or (mp_ids is not None)
        self.cutoff = cutoff
        self.species = species
        self.api_key = api_key
        self.targets = targets
        self.filter = filter
        self.node_features = node_features
        self.mp_ids = mp_ids
        self.global_attrs = global_attrs
        self.embeddings = embeddings
        super(CrystalDataset, self).__init__(root, transform, pre_transform)
        
        
    @property
    def raw_file_names(self):
        """ If this file exists in raw_dir, the download is not triggered.  
        """
        return ['structures.dump', 'vasp_data.dump']

    @property
    def processed_file_names(self):
        """ If these files are found in processed_dir, processing is skipped"""
        with open(self.raw_paths[0], 'rb') as f:
            structures = pickle.load(f)
        with open(self.raw_paths[1], 'rb') as f:
            vasp_data = pickle.load(f)
        self.data = [structures, vasp_data]
        return [f'data_{idx}.pt' for idx in range(len(self.data[0]))]
        #return 'not_implemented.pt'

    def download(self):
        rester = MPRester(self.api_key)
        if self.mp_ids is None:
            assert self.species is not None
            entries = rester.get_entries_in_chemsys(self.species)
            entries_ids = [entry.entry_id for entry in entries]
        else:
            if isinstance(self.mp_ids, pd.DataFrame):
                entries_ids_raw = list(self.mp_ids[0].values)
            else:
                entries_ids_raw = list(self.mp_ids)
            entries_ids = [rester.get_materials_id_from_task_id(id) for id in tqdm(entries_ids_raw)]
        structures = []
        vasp_data = []
        for id in tqdm(entries_ids):
            vasp_data.append(rester.get_data(id))
            structures.append(rester.get_structure_by_material_id(id))
        with open(self.raw_paths[0], 'wb') as f:
            pickle.dump(structures, f)
        with open(self.raw_paths[1], 'wb') as f:
            pickle.dump(vasp_data, f)
        self.data = [structures, vasp_data]

    def process(self):
        with open(self.raw_paths[0], 'rb') as f:
            structures = pickle.load(f)
        with open(self.raw_paths[1], 'rb') as f:
            vasp_data = pickle.load(f)
        self.data = [structures, vasp_data]
        for index, structure, vdata in tqdm(zip(range(len(structures)), structures, vasp_data), total=len(structures)):
            # Get node features
            node_feats = self._get_node_features(structure)
            # Get edge features
            edge_feats = self._get_edge_features(structure)
            # Get adjacency info
            edge_index = self._get_adjacency_info(structure)
            # Get labels info
            labels = self._get_labels(vdata)
            # Get global info
            global_feats = self._get_global_features(structure)

            # Create data object
            data = Data(x=node_feats, 
                        edge_index=edge_index,
                        edge_attr=edge_feats,
                        y=labels,
                        global_attr=global_feats
                        ) 
            torch.save(data, 
                    os.path.join(self.processed_dir, 
                                 f'data_{index}.pt'))

    def _get_node_features(self, structure):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of Nodes, Node Feature size]
        """
        all_node_feats = []

        for element in structure.species:
            node_feats = []
            if self.node_features is None:
                # Feature 1: Atomic number
                if self.embeddings is not None:
                    node_feats.extend(self.embeddings[0](torch.tensor(element.Z)))
                else:
                    node_feats.append(element.Z)                
                # Feature 2: Atom radius
                node_feats.append(element.atomic_radius)
                # Feature 3: Average ionic radius
                node_feats.append(element.average_ionic_radius)
                # Feature 4: Ionization energy
                node_feats.append(element.ionization_energy)
                # Feature 5: Electronegativity
                node_feats.append(element.X)
                # Feature 6: Mendeleev number
                node_feats.append(element.mendeleev_no)
                # Feature 7: Electron affinity
                node_feats.append(element.electron_affinity)
                # Feature 8: Min oxidation state
                node_feats.append(element.min_oxidation_state)
                # Feature 9: Max oxitaion state
                node_feats.append(element.max_oxidation_state)
            else:
                node_feats = []
                if self.embeddings is not None and 'Z' in self.node_features:
                    node_feats.extend(self.embeddings[0](torch.tensor(element.Z)))
                    node_feats.extend([getattr(element, attr) for attr in self.node_features if attr is not 'Z'])
                else:
                    node_feats.extend([getattr(element, attr) for attr in self.node_features])

            # Append node features to matrix
            all_node_feats.append(node_feats)
        return torch.tensor(all_node_feats, dtype=torch.float)


    def _get_edge_features(self, structure):
        """ 
        This will return a matrix / 2d array of the shape
        [Number of edges, Edge Feature size]
        """
        
        all_edge_feats = []
        features = torch.tensor(structure.get_neighbor_list(self.cutoff)[-1], dtype=torch.float)
        if self.filter is not None:
            features = self.filter(features)

        return features

    def _get_adjacency_info(self, structure):
        """
        This will return an adjacency matrix in COO format / 2d array of the shape
        [2, Number of edges]
        """
        edge_indices = torch.tensor(structure.get_neighbor_list(self.cutoff)[:2])
        return edge_indices

    def _get_labels(self, label):
        assert isinstance(self.targets, (tuple, list))
        assert np.all([item in label[-1] for item in self.targets])
        return torch.tensor([label[-1][item] for item in self.targets], dtype=torch.float32)

    def _get_global_features(self, structure):
        global_info = []
        for attr in self.global_attrs:
            assert attr in ['volume', 'spacegroup']
            if attr == 'volume':
                global_info.append(structure.volume)
            if attr == 'spacegroup':
                spg = structure.get_space_group_info()[1]
                if self.embeddings is not None:
                    global_info.extend(self.embeddings[1](torch.tensor(spg)))
                else:
                    global_info.append(spg)
        return torch.tensor(global_info, dtype=torch.float).view(1, len(global_info)).contiguous()
            

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        """ - Equivalent to __getitem__ in pytorch
            - Is not needed for PyG's InMemoryDataset
        """
        data = torch.load(os.path.join(self.processed_dir, 
                                 f'data_{idx}.pt'))
        return data

In [10]:
# Uncomment the following lines if you want to manually download the data and convert it to the Dataset

mp_ids = pd.read_csv('https://raw.githubusercontent.com/txie-93/cgcnn/master/data/material-data/mp-ids-3402.csv', header=None)[0].values
root = 'gdrive/MyDrive/Colab Notebooks/GNN/data/MP3402'
API_KEY = '2Uihe3wfq5ac6tMF'
filter = GaussianFilter(start=0.5, stop=5.0, n_points=100)
embeddings = ModuleList([Embedding(95, 16), Embedding(230, 32)])
dataset = CrystalDataset(root=root, api_key=API_KEY, filter=filter, mp_ids=mp_ids, embeddings=embeddings)
dataset.num_global_features = dataset[0].global_attr.shape[1]

100%|██████████| 3402/3402 [07:51<00:00,  7.21it/s]
100%|██████████| 3402/3402 [16:14<00:00,  3.49it/s]
Processing...
100%|██████████| 3402/3402 [00:54<00:00, 62.91it/s]
Done!


In [7]:
# Use this line to continue with already prepared Dataset

# dataset = torch.load('gdrive/MyDrive/Colab Notebooks/crystal_dataset.pt')

In [11]:
dataset

CrystalDataset(3402)

In [12]:
print(f"Total number of entries: {len(dataset)}")
print(f'Number of node features: {dataset.num_node_features}')
print(f'Number of edge features: {dataset.num_edge_features}')
print(f'Number of global features: {dataset.num_global_features}')

Total number of entries: 3402
Number of node features: 24
Number of edge features: 100
Number of global features: 33


## 2. Building the MegNet block

In [13]:
from typing import Union, Tuple
from torch_geometric.typing import PairTensor, Adj, OptTensor, Size, SparseTensor

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Linear, BatchNorm1d, ModuleList
from torch_geometric.nn.conv import MessagePassing


class MegNetBlock(MessagePassing):
    def __init__(self, node_channels: Tuple, edge_channels: Tuple, global_channels: Tuple,
                 aggr: str = 'mean', batch_norm: bool = True, bias: bool = True, **kwargs):
        super(MegNetBlock, self).__init__(aggr=aggr, **kwargs)
        self.node_channels = node_channels
        self.edge_channels = edge_channels
        self.global_channels = global_channels
        self.batch_norm = batch_norm

        self.node_lin = ModuleList([Linear(node_channels[0], node_channels[1], bias=bias), Linear(node_channels[1], node_channels[2], bias=bias)])
        self.edge_lin = ModuleList([Linear(edge_channels[0], edge_channels[1], bias=bias), Linear(edge_channels[1], edge_channels[2], bias=bias)])
        self.global_lin = ModuleList([Linear(global_channels[0], global_channels[1], bias=bias), Linear(global_channels[1], global_channels[2], bias=bias)])

        self.phi_edge = ModuleList([Linear(2*node_channels[2]+edge_channels[2]+global_channels[2], 2*node_channels[2]+edge_channels[2]+global_channels[2], bias=bias),
                                    Linear(2*node_channels[2]+edge_channels[2]+global_channels[2], 2*node_channels[2]+edge_channels[2]+global_channels[2], bias=bias), 
                                    Linear(2*node_channels[2]+edge_channels[2]+global_channels[2], edge_channels[2], bias=bias)])
        self.phi_node = ModuleList([Linear(edge_channels[2]+node_channels[2]+global_channels[2], edge_channels[2]+node_channels[2]+global_channels[2], bias=bias),
                                    Linear(edge_channels[2]+node_channels[2]+global_channels[2], edge_channels[2]+node_channels[2]+global_channels[2], bias=bias), 
                                    Linear(edge_channels[2]+node_channels[2]+global_channels[2], node_channels[2], bias=bias)])
        self.phi_global = ModuleList([Linear(node_channels[2]+edge_channels[2]+global_channels[2], node_channels[2]+edge_channels[2]+global_channels[2], bias=bias),
                                      Linear(node_channels[2]+edge_channels[2]+global_channels[2], node_channels[2]+edge_channels[2]+global_channels[2], bias=bias),
                                      Linear(node_channels[2]+edge_channels[2]+global_channels[2], global_channels[2], bias=bias)])

        self.reset_parameters()

    def reset_parameters(self):
        for module in self.node_lin:
            module.reset_parameters()
        for module in self.edge_lin:
            module.reset_parameters()
        for module in self.global_lin:
            module.reset_parameters()
        for module in self.phi_edge:
            module.reset_parameters()
        for module in self.phi_node:
            module.reset_parameters()
        for module in self.phi_global:
            module.reset_parameters()

    def propagate(self, *args, **kwargs):
        edge_index = kwargs['edge_index']
        size = self.__check_input__(edge_index, None)
        # Run "fused" message and aggregation (if applicable).
        if (isinstance(edge_index, SparseTensor) and self.fuse
                and not self.__explain__):
            raise NotImplementedError
            coll_dict = self.__collect__(self.__fused_user_args__, edge_index,
                                         size, kwargs)

            msg_aggr_kwargs = self.inspector.distribute(
                'message_and_aggregate', coll_dict)
            out = self.message_and_aggregate(edge_index, **msg_aggr_kwargs)

            update_kwargs = self.inspector.distribute('update', coll_dict)
            return self.update(out, **update_kwargs)

        # Otherwise, run both functions in separation.
        elif isinstance(edge_index, Tensor) or not self.fuse:
            coll_dict = self.__collect__(self.__user_args__, edge_index, size,
                                         kwargs)

            msg_kwargs = self.inspector.distribute('message', coll_dict)
            out = self.message(**msg_kwargs)
            new_edge_attr = out.clone()

            # For `GNNExplainer`, we require a separate message and aggregate
            # procedure since this allows us to inject the `edge_mask` into the
            # message passing computation scheme.
            # if self.__explain__:
            #     edge_mask = self.__edge_mask__.sigmoid()
            #     # Some ops add self-loops to `edge_index`. We need to do the
            #     # same for `edge_mask` (but do not train those).
            #     if out.size(self.node_dim) != edge_mask.size(0):
            #         loop = edge_mask.new_ones(size[0])
            #         edge_mask = torch.cat([edge_mask, loop], dim=0)
            #     assert out.size(self.node_dim) == edge_mask.size(0)
            #     out = out * edge_mask.view([-1] + [1] * (out.dim() - 1))

            aggr_kwargs = self.inspector.distribute('aggregate', coll_dict)
            out = self.aggregate(out, **aggr_kwargs)

            update_kwargs = self.inspector.distribute('update', coll_dict)
            edge_embedding = self.update(out, **update_kwargs)
            return edge_embedding, new_edge_attr

    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj, edge_attr: OptTensor, global_attr: Tensor, 
                node_batch: Tensor, edge_batch: Tensor, size: Size = None) -> Tensor:
        """"""

        for lin in self.node_lin:
            x = F.softplus(lin(x))
        
        for lin in self.edge_lin:
            edge_attr = F.softplus(lin(edge_attr))

        for lin in self.global_lin:
            global_attr = F.softplus(lin(global_attr))

        if isinstance(x, Tensor):
            x: PairTensor = (x, x)

        edge_embedding, new_edge_attr = self.propagate(edge_index=edge_index, x=x, edge_attr=edge_attr, global_attr=global_attr, size=size, edge_batch=edge_batch)
        new_node_attr = torch.cat([x[0], edge_embedding, global_attr[node_batch]], dim=-1)

        for module in self.phi_node:
            new_node_attr = F.softplus(module(new_node_attr))

        new_global_attr = torch.cat([global_mean_pool(new_node_attr, node_batch), global_mean_pool(new_edge_attr, edge_batch), global_attr], dim=-1)

        for module in self.phi_global:
            new_global_attr = F.softplus(module(new_global_attr))
        
        new_node_attr = new_node_attr + x[0]
        new_edge_attr = new_edge_attr + edge_attr
        new_global_attr = new_global_attr + global_attr

        return new_node_attr, new_edge_attr, new_global_attr


    def message(self, x_i, x_j, edge_attr: OptTensor, global_attr: Tensor, edge_batch:Tensor) -> Tensor:
        new_edge_attr = torch.cat([x_i, x_j, edge_attr, global_attr[edge_batch]], dim=-1)
        for module in self.phi_edge:
            new_edge_attr = F.softplus(module(new_edge_attr))
        return new_edge_attr

    def __repr__(self):
        return '{}(node_channels={}, edge_channels={}, global_channels={})'.format(
                self.__class__.__name__, self.node_channels, self.edge_channels, self.global_channels,)

In [14]:
node_channels=(dataset.num_node_features, 64, 32)
edge_channels=(dataset.num_edge_features, 64, 32)
global_channels=(dataset.num_global_features, 64, 32)
block = MegNetBlock(node_channels, edge_channels, global_channels)

## 3. Building MegNet Model

In [15]:
from torch.nn import Linear, Embedding
import torch.nn.functional as F
from torch_geometric.nn import CGConv
from torch.nn import ModuleList, Softplus
from torch_geometric.nn import global_mean_pool, Set2Set

class MegNet(torch.nn.Module):
    def __init__(self, node_channels, edge_channels, global_channels, out_channels,
                 pooling_args=(3,), n_blocks=3, batch_norm=False, bias=True):
        super(MegNet, self).__init__()
        # torch.manual_seed(12345)
        self.n_blocks = n_blocks
        self.node_channels = node_channels
        self.edge_channels = edge_channels
        self.global_channels = global_channels
        self.pooling_args = pooling_args
        self.out_channels = out_channels
        self.batch_norm = batch_norm
        if batch_norm:
            self.bn = ModuleList([BatchNorm1d(node_channels[0]), BatchNorm1d(edge_channels[0]), BatchNorm1d(global_channels[0])])
        self.blocks = ModuleList([MegNetBlock(node_channels, edge_channels, global_channels, batch_norm=batch_norm, bias=bias)])
        node_channels = (node_channels[2], node_channels[1], node_channels[2])
        edge_channels = (edge_channels[2], edge_channels[1], edge_channels[2])
        global_channels = (global_channels[2], global_channels[1], global_channels[2])
        self.blocks.extend([MegNetBlock(node_channels, edge_channels, global_channels, batch_norm=batch_norm, bias=bias) for i in range(n_blocks-1)])
        
        self.global_node_pool = Set2Set(node_channels[2], *pooling_args)
        self.global_edge_pool = Set2Set(edge_channels[2], *pooling_args)

        
        
        self.out = ModuleList([Linear(2*node_channels[2]+2*edge_channels[2]+global_channels[2], out_channels[0], bias=bias),
                               Linear(out_channels[0], out_channels[1], bias=bias),
                               Linear(out_channels[1], out_channels[2], bias=bias)])
        

 
    def forward(self, x, edge_index, edge_attr, global_attr, node_batch, edge_batch): 
        if self.batch_norm:
            x = self.bn[0](x)
            edge_attr = self.bn[1](edge_attr)
            global_attr = self.bn[2](global_attr)
        for i, block in enumerate(self.blocks):
            x, edge_attr, global_attr = block(x, edge_index, edge_attr, global_attr, node_batch, edge_batch)
            x = F.softplus(x)
            edge_attr = F.softplus(edge_attr)
            global_attr = F.softplus(global_attr)
            

        x = self.global_node_pool(x, node_batch)
        edge_attr = self.global_edge_pool(edge_attr, edge_batch)

        out = torch.cat([x, edge_attr, global_attr], dim=-1)

        for module in self.out:
            out = F.softplus(module(out))
            out = F.dropout(out, training=self.training)
    
        return out

model = MegNet(node_channels=(16, 16, 9), edge_channels=(10, 16, 10), global_channels=(2, 8, 2), out_channels=(32, 16, 1), n_blocks=3)
print(model)

MegNet(
  (blocks): ModuleList(
    (0): MegNetBlock(node_channels=(16, 16, 9), edge_channels=(10, 16, 10), global_channels=(2, 8, 2))
    (1): MegNetBlock(node_channels=(9, 16, 9), edge_channels=(10, 16, 10), global_channels=(2, 8, 2))
    (2): MegNetBlock(node_channels=(9, 16, 9), edge_channels=(10, 16, 10), global_channels=(2, 8, 2))
  )
  (global_node_pool): Set2Set(9, 18)
  (global_edge_pool): Set2Set(10, 20)
  (out): ModuleList(
    (0): Linear(in_features=40, out_features=32, bias=True)
    (1): Linear(in_features=32, out_features=16, bias=True)
    (2): Linear(in_features=16, out_features=1, bias=True)
  )
)


## 4. Training MegNet Model

In [16]:
dataset = dataset.shuffle()

train_dataset = dataset[:int(0.8 * len(dataset))]
validation_dataset = dataset[int(0.8 * len(dataset)):int(0.9 * len(dataset))]
test_dataset = dataset[int(0.9 * len(dataset)):]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of validation graphs: {len(validation_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

from torch_geometric.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_loader = DataLoader(train_dataset, batch_size=300, num_workers=16, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=100, num_workers=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=100, num_workers=16, shuffle=True)

Number of training graphs: 2721
Number of validation graphs: 340
Number of test graphs: 341


In [17]:
model = MegNet(node_channels=(dataset.num_node_features, 32, 16), 
               edge_channels=(dataset.num_edge_features, 32, 16), 
               global_channels=(dataset.num_global_features, 32, 16), 
               out_channels=(16, 8, 1), n_blocks=3, batch_norm=False)
print(model)
model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
criterion = torch.nn.MSELoss()

MegNet(
  (blocks): ModuleList(
    (0): MegNetBlock(node_channels=(24, 32, 16), edge_channels=(100, 32, 16), global_channels=(33, 32, 16))
    (1): MegNetBlock(node_channels=(16, 32, 16), edge_channels=(16, 32, 16), global_channels=(16, 32, 16))
    (2): MegNetBlock(node_channels=(16, 32, 16), edge_channels=(16, 32, 16), global_channels=(16, 32, 16))
  )
  (global_node_pool): Set2Set(16, 32)
  (global_edge_pool): Set2Set(16, 32)
  (out): ModuleList(
    (0): Linear(in_features=80, out_features=16, bias=True)
    (1): Linear(in_features=16, out_features=8, bias=True)
    (2): Linear(in_features=8, out_features=1, bias=True)
  )
)


In [18]:
def train():
    model.train()

    for batch in train_loader:  # Iterate in batches over the training dataset.
        batch.to(device)
        node_batch = batch.batch
        edge_batch = torch.cat([torch.tensor([i]).repeat(batch[i].num_edges) for i in range(batch.num_graphs)]).to(device)
        out = model(batch.x, batch.edge_index, batch.edge_attr, batch.global_attr, node_batch, edge_batch)  # Perform a single forward pass.
        loss = criterion(out, batch.y.view(*out.shape))  # Compute the loss.
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

def validate(loader):
    model.eval()

    with torch.no_grad():
        loss = 0
        for batch in loader:  # Iterate in batches over the training/test dataset.
            batch.to(device)
            node_batch = batch.batch
            edge_batch = torch.cat([torch.tensor([i]).repeat(batch[i].num_edges) for i in range(batch.num_graphs)]).to(device)
            out = model(batch.x, batch.edge_index, batch.edge_attr, batch.global_attr, node_batch, edge_batch)  
            #loss = criterion(out, batch.y.view(*out.shape))
            loss += torch.abs(out - batch.y.view(*out.shape)).mean()
        return loss / len(loader)

for epoch in range(1, 1000):
    train()
    train_acc = validate(train_loader)
    valid_acc = validate(validation_loader)
    #print(f'Memory allocated, kB: {torch.cuda.memory_allocated()/1000}')
    #print(f'Memory cached, kB: {torch.cuda.memory_cached()/1000}')
    print(f'Epoch: {epoch:03d}, Train MAE: {train_acc:.4f}, Validation MAE: {valid_acc:.4f}')

Epoch: 001, Train MAE: 0.8147, Validation MAE: 0.8598
Epoch: 002, Train MAE: 0.8261, Validation MAE: 0.8822
Epoch: 003, Train MAE: 0.8145, Validation MAE: 0.8062
Epoch: 004, Train MAE: 0.7452, Validation MAE: 0.7674
Epoch: 005, Train MAE: 0.7670, Validation MAE: 0.7166
Epoch: 006, Train MAE: 0.7649, Validation MAE: 0.7825
Epoch: 007, Train MAE: 0.6919, Validation MAE: 0.7306
Epoch: 008, Train MAE: 0.7355, Validation MAE: 0.6704
Epoch: 009, Train MAE: 0.7742, Validation MAE: 0.7639
Epoch: 010, Train MAE: 0.7032, Validation MAE: 0.7006
Epoch: 011, Train MAE: 0.7715, Validation MAE: 0.7518
Epoch: 012, Train MAE: 0.7354, Validation MAE: 0.7125
Epoch: 013, Train MAE: 0.6489, Validation MAE: 0.6552
Epoch: 014, Train MAE: 0.6461, Validation MAE: 0.6717
Epoch: 015, Train MAE: 0.6399, Validation MAE: 0.6127
Epoch: 016, Train MAE: 0.5964, Validation MAE: 0.6246
Epoch: 017, Train MAE: 0.6228, Validation MAE: 0.6098
Epoch: 018, Train MAE: 0.6257, Validation MAE: 0.6503
Epoch: 019, Train MAE: 0.672

KeyboardInterrupt: ignored