In [42]:
import os
import sys

In [43]:
ai2md_path = r"C:\Users\23174\Desktop\GitHub Project\AI2BMD"

In [44]:
sys.path.append(ai2md_path)

In [45]:
import argparse
import logging
import os
import sys

import numpy as np
import pytorch_lightning as pl
import torch
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.loggers import WandbLogger


from visnet import datasets, models, priors
from visnet.data import DataModule
from visnet.models import output_modules
from visnet.models.utils import act_class_mapping, rbf_class_mapping
from visnet.module import LNNP
from visnet.utils import LoadFromCheckpoint, LoadFromFile, number, save_argparse

In [48]:
def get_args():
    parser = argparse.ArgumentParser(description='Training')
    parser.add_argument('--load-model', action=LoadFromCheckpoint, help='Restart training using a model checkpoint')  # keep first
    parser.add_argument('--conf', '-c', type=open, action=LoadFromFile, help='Configuration yaml file')  # keep second
    
    # training settings
    parser.add_argument('--num-epochs', default=300, type=int, help='number of epochs')
    parser.add_argument('--lr-warmup-steps', type=int, default=0, help='How many steps to warm-up over. Defaults to 0 for no warm-up')
    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
    parser.add_argument('--lr-patience', type=int, default=10, help='Patience for lr-schedule. Patience per eval-interval of validation')
    parser.add_argument('--lr-min', type=float, default=1e-6, help='Minimum learning rate before early stop')
    parser.add_argument('--lr-factor', type=float, default=0.8, help='Minimum learning rate before early stop')
    parser.add_argument('--weight-decay', type=float, default=0.0, help='Weight decay strength')
    parser.add_argument('--early-stopping-patience', type=int, default=30, help='Stop training after this many epochs without improvement')
    parser.add_argument('--loss-type', type=str, default='MSE', choices=['MSE', 'MAE'], help='Loss type')
    parser.add_argument('--loss-scale-y', type=float, default=1.0, help="Scale the loss y of the target")
    parser.add_argument('--loss-scale-dy', type=float, default=1.0, help="Scale the loss dy of the target")
    parser.add_argument('--energy-weight', default=1.0, type=float, help='Weighting factor for energies in the loss function')
    
    # dataset specific
    parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset')
    parser.add_argument('--dataset-arg', default=None, type=str, help='Additional dataset argument')
    parser.add_argument('--dataset-root', default=None, type=str, help='Data storage directory')
    parser.add_argument('--derivative', default=False, action=argparse.BooleanOptionalAction, help='If true, take the derivative of the prediction w.r.t coordinates')
    parser.add_argument('--split-mode', default=None, type=str, help='Split mode for Molecule3D dataset')
    
    # dataloader specific
    parser.add_argument('--reload', type=int, default=0, help='Reload dataloaders every n epoch')
    parser.add_argument('--batch-size', default=32, type=int, help='batch size')
    parser.add_argument('--inference-batch-size', default=None, type=int, help='Batchsize for validation and tests.')
    parser.add_argument('--standardize', action=argparse.BooleanOptionalAction, default=False, help='If true, multiply prediction by dataset std and add mean')
    parser.add_argument('--splits', default=None, help='Npz with splits idx_train, idx_val, idx_test')
    parser.add_argument('--train-size', type=number, default=950, help='Percentage/number of samples in training set (None to use all remaining samples)')
    parser.add_argument('--val-size', type=number, default=50, help='Percentage/number of samples in validation set (None to use all remaining samples)')
    parser.add_argument('--test-size', type=number, default=None, help='Percentage/number of samples in test set (None to use all remaining samples)')
    parser.add_argument('--num-workers', type=int, default=4, help='Number of workers for data prefetch')
    
    # model architecture specific
    parser.add_argument('--model', type=str, default='ViSNetBlock', choices=models.__all__, help='Which model to train')
    parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model')
    parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use')
    parser.add_argument('--prior-args', type=dict, default=None, help='Additional arguments for the prior model')
    
    # architectural specific
    parser.add_argument('--embedding-dimension', type=int, default=256, help='Embedding dimension')
    parser.add_argument('--num-layers', type=int, default=6, help='Number of interaction layers in the model')
    parser.add_argument('--num-rbf', type=int, default=64, help='Number of radial basis functions in model')
    parser.add_argument('--activation', type=str, default='silu', choices=list(act_class_mapping.keys()), help='Activation function')
    parser.add_argument('--rbf-type', type=str, default='expnorm', choices=list(rbf_class_mapping.keys()), help='Type of distance expansion')
    parser.add_argument('--trainable-rbf', action=argparse.BooleanOptionalAction, default=False, help='If distance expansion functions should be trainable')
    parser.add_argument('--attn-activation', default='silu', choices=list(act_class_mapping.keys()), help='Attention activation function')
    parser.add_argument('--num-heads', type=int, default=8, help='Number of attention heads')
    parser.add_argument('--cutoff', type=float, default=5.0, help='Cutoff in model')
    parser.add_argument('--max-z', type=int, default=100, help='Maximum atomic number that fits in the embedding matrix')
    parser.add_argument('--max-num-neighbors', type=int, default=32, help='Maximum number of neighbors to consider in the network')
    parser.add_argument('--reduce-op', type=str, default='add', choices=['add', 'mean'], help='Reduce operation to apply to atomic predictions')
    parser.add_argument('--lmax', type=int, default=2, help='Max order of spherical harmonics')
    parser.add_argument('--vecnorm-type', type=str, default='max_min', help='Type of vector normalization')
    parser.add_argument('--trainable-vecnorm', action=argparse.BooleanOptionalAction, default=False, help='If vector normalization should be trainable')
    parser.add_argument('--vertex-type', type=str, default='Edge', choices=['None', 'Edge', 'Node'], help='If add vertex angle and Where to add vertex angles')

    # other specific
    parser.add_argument('--ngpus', type=int, default=-1, help='Number of GPUs, -1 use all available. Use CUDA_VISIBLE_DEVICES=1, to decide gpus')
    parser.add_argument('--num-nodes', type=int, default=1, help='Number of nodes')
    parser.add_argument('--precision', type=int, default=32, choices=[16, 32], help='Floating point precision')
    parser.add_argument('--log-dir', type=str, default=None, help='Log directory')
    parser.add_argument('--task', type=str, default='train', choices=['train', 'inference'], help='Train or inference') 
    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument('--distributed-backend', default='ddp', help='Distributed backend')
    parser.add_argument('--redirect', action=argparse.BooleanOptionalAction, default=False, help='Redirect stdout and stderr to log_dir/log')
    parser.add_argument('--accelerator', default='gpu', help='Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "auto")')
    parser.add_argument('--test-interval', type=int, default=10, help='Test interval, one test per n epochs (default: 10)')
    parser.add_argument('--save-interval', type=int, default=10, help='Save interval, one save per n epochs (default: 10)')
    
    args = parser.parse_args()

    if args.redirect:
        os.makedirs(args.log_dir, exist_ok=True)
        sys.stdout = open(os.path.join(args.log_dir, "log"), "w")
        sys.stderr = sys.stdout
        logging.getLogger("pytorch_lightning").addHandler(logging.StreamHandler(sys.stdout))

    if args.inference_batch_size is None:
        args.inference_batch_size = args.batch_size
    save_argparse(args, os.path.join(args.log_dir, "input.yaml"), exclude=["conf"])
    
    return args
args = {
    "load_model": None,
    "conf": "examples/ViSNet-MD17.yml",
    "conf": None,
    "num_epochs": 300,
    "lr_warmup_steps": 0,
    "lr": 1e-4,
    "lr_patience": 10,
    "lr_min": 1e-6,
    "lr_factor": 0.8,
    "weight_decay": 0.0,
    "early_stopping_patience": 30,
    "loss_type": "MSE",
    "loss_scale_y": 1.0,
    "loss_scale_dy": 1.0,
    "energy_weight": 1.0,
    "dataset": datasets.__all__,
    "dataset_arg": "aspirin",
    "dataset_root": "data",
    "derivative": False,
    "split_mode": None,
    "reload": 0,
    "batch_size": 17,
    "inference_batch_size": 17,  # Inferred default from batch_size as specified
    "standardize": False,
    "splits": None,
    "train_size": 950,
    "val_size": 50,
    "test_size": None,
    "num_workers": 4,
    "model": "ViSNetBlock",
    "output_model": "Scalar",
    "prior_model": None,
    "prior_args": None,
    "embedding_dimension": 256,
    "num_layers": 6,
    "num_rbf": 64,
    "activation": "silu",
    "rbf_type": "expnorm",
    "trainable_rbf": False,
    "attn_activation": "silu",
    "num_heads": 8,
    "cutoff": 5.0,
    "max_z": 100,
    "max_num_neighbors": 32,
    "reduce_op": "add",
    "lmax": 2,
    "vecnorm_type": "max_min",
    "trainable_vecnorm": False,
    "vertex_type": "Edge",
    "ngpus": 1,
    "num_nodes": 1,
    "precision": 32,
    "log_dir": "logs",
    "task": "train",
    "seed": 1,
    "distributed_backend": "ddp",
    "redirect": False,
    "accelerator": "gpu",
    "test_interval": 10,
    "save_interval": 10
}


pl.seed_everything(args['seed'], workers=True)

from torch.utils.data import Subset



Global seed set to 1


In [49]:
def prepare_dataset(self):
    
    # assert hasattr(self, f"_prepare_{self.hparams['dataset']}_dataset"), f"Dataset {self.hparams['dataset']} not defined"
    dataset_factory = lambda t: getattr(self, f"_prepare_MD17_dataset")()
    self.idx_train, self.idx_val, self.idx_test = dataset_factory(self.hparams["dataset"])
    print(self.dataset)
    self.train_dataset = Subset(self.dataset, self.idx_train)
    self.val_dataset = Subset(self.dataset, self.idx_val)
    self.test_dataset = Subset(self.dataset, self.idx_test)

    if self.hparams["standardize"]:
        self._standardize()
        

DataModule.prepare_dataset = prepare_dataset


In [50]:
# initialize data module
data = DataModule(args)
data._prepare_MD17_dataset()
data.prepare_dataset()

MD17(211762)


In [51]:
priors.__all__

['Atomref']

In [52]:

from visnet.priors import Atomref
prior = None

model = LNNP(args, prior_model=prior, mean=data.mean, std=data.std)


In [53]:
# 获取训练数据加载器
train_loader = data.train_dataloader()

# 迭代训练数据加载器
for idx, batch in enumerate(train_loader):
    # 在这里处理每一个批次的数据
    # `batch`通常包含两部分：输入数据和目标（标签）
    if idx == 0:
        batch_one = batch
    elif idx == 1:
        batch_two = batch
    else:
        break


In [56]:
batch_two

DataBatch(y=[17, 1], pos=[357, 3], z=[357], dy=[357, 3], batch=[357], ptr=[18])

In [58]:
type(batch_two)

torch_geometric.data.batch.DataBatch

In [None]:
batch. batch

tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  3,  3,  3,  3,  3,  3,  3,  3,  3,
         3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,
         4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  4,  5,  5,  5,
         5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,
         6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,  6,
         6,  6,  6,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,  7,
         7,  7,  7,  7,  7,  7,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
         8,  8,  8,  8,  8,  8,  8,  8,  8,  9,  9,  9,  9,  9,  9,  9,  9,  9,
         9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 10,
        10, 10, 10, 10, 10, 10, 10, 10, 

In [None]:
_1, _2 = model(batch)

In [None]:
357//17

21

# # gnn lf load

In [24]:
import os
import json
import pathlib
import sys
model_config_json_path = "model_config.json"
model_config = json.loads(pathlib.Path("model_config.json").read_text())

In [25]:
dl_gnn_path = r"C:\Users\23174\Desktop\GitHub Project\GitHubProjectBigData\GNN-Molecular-Project\GNN-LF-AND-ColfNet"
dl_gnn_path_test = r"C:\Users\23174\Desktop\GitHub Project\GitHubProjectBigData\GNN-Molecular-Project\GNN-LF-AND-ColfNet\tests"

In [26]:
sys.path.append(dl_gnn_path)
sys.path.append(dl_gnn_path_test)

In [27]:
import argparse
from colorama import Fore, Back, Style
from dl_gnn.models.impl.GNNLF import GNNLF
from dl_gnn.models.impl.ThreeDimFrame import GNNLF as ThreeDGNNLF
from dl_gnn.configs.path_configs import OUTPUT_PATH
from dl_gnn.models.impl import Utils
import torch
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR
from torch.nn.functional import l1_loss, mse_loss
import numpy as np
import os
import time
from tqdm import tqdm
import wandb
from warnings import filterwarnings
from loguru import logger

# from dl_gnn.tests.main_md17 import load_dataset

In [28]:
train_dataset = torch.load('train_dataset.pth')
validation_dataset = torch.load('validation_dataset.pth')
test_dataset = torch.load('test_dataset.pth')

In [29]:
train_dataloader = DataLoader(train_dataset, batch_size=17, shuffle=False)


In [30]:
first_train_batch = next(iter(train_dataloader))

In [31]:
z, pos, y, dy = first_train_batch

In [32]:

y.shape

torch.Size([17, 1])

In [33]:
pos.view(-1, pos.size(-1)).shape

torch.Size([357, 3])

In [34]:
z.shape

torch.Size([17, 21])

In [35]:
z.view(-1, ).shape

torch.Size([357])

In [36]:
dy.shape

torch.Size([17, 21, 3])

In [37]:
dy.view(-1, dy.size(-1)).shape

torch.Size([357, 3])

In [38]:
sep = pos.shape[1]
_batch_size = pos.shape[0]
_ptr = torch.arange(0, _batch_size * sep + 1, sep)
_ptr

tensor([  0,  21,  42,  63,  84, 105, 126, 147, 168, 189, 210, 231, 252, 273,
        294, 315, 336, 357])

In [39]:
pos.shape

torch.Size([17, 21, 3])

In [47]:
from dl_gnn.models.visnet.models import output_modules


In [40]:
gnnlf_model = GNNLF(**model_config)


In [62]:
from typing import Final
import torch.nn as nn
import torch
from torch.nn import ModuleList
from dl_gnn.models.impl.Mol2Graph import Mol2Graph
from dl_gnn.models.impl.Utils import innerprod
from torch import Tensor

def forward(self, atomic_numbers: Tensor, atomic_positions: Tensor):
    """
    神经网络前向传播
    :param atomic_numbers: the atomic number of each atom in the molecule, shape: batch_size x atomic_number
    :param atomic_positions:  the coordinate of each atom in the molecule, shape: batch_size x atomic_number x 3
    :return:
        - atomic_number_embedding: Embeddings of atomic numbers with layer normalization applied.
                                        Shape: (batch_size, atomic_number, hid_dim)
        - atomic_adjacency_matrix: Smoothed adjacency matrix representing atomic connections.
                                        Shape: (batch_size, atomic_number, atomic_number)
        - normalized_atom_position_distances: Normalized vectors representing interatomic distances.
                                                    Shape: (batch_size, atomic_number, atomic_number, 3)
        - edge_features: Edge features computed using RBF, representing interatomic relationships.
                                Shape: (batch_size, atomic_number, atomic_number, ef_dim)
    """
    (
        atomic_number_embedding,
        atomic_adjacency_matrix,
        normalized_atom_position_distances,
        edge_features
    ) = (
        self.mol2graph(atomic_numbers, atomic_positions)
    )
    mask = self.ef_proj(edge_features) * atomic_adjacency_matrix.unsqueeze(-1)
    #  通过一层linear， 再让其特征增加， 并且与领接矩阵相乘， 使得其与领接矩阵的维度一致。
    #  unsqueeze(-1) 表示在最后一个维度上增加一个维度， 这里是增加一个维度， 使得其与mask的维度一致。
    #  mask 的维度为 batch_size x atomic_number x atomic_number x hid_dim
    s = self.neighbor_emb(atomic_numbers, atomic_number_embedding, mask)
    #  s的维度为 batch_size x atomic_number x hid_dim
    v = self.s2v(s, normalized_atom_position_distances, mask)
    #  v的维度为 batch_size x atomic_number x 3 x hid_dim
    if self.global_frame:
        v = torch.sum(v, dim=1, keepdim=True).expand(-1, s.shape[1], -1, -1)
    atomic_direction_feature_list = []

    # TODO: 将colfnet的关于局部坐标框架的特征添投影到hidden_dim的维度上
    # local_frame_featutes = self.localframe_features(position_batch_center=atomic_positions,
    #                                                 atomic_adjacency_matrix=atomic_adjacency_matrix)
    # projected_local_frame_featutes = self.colfnet_features_projection(local_frame_featutes)
    # # #
    # if self.colfnet_features:
    #     atomic_direction_feature_list.append(projected_local_frame_featutes)
    #
    # if self.colfnet_features:
    #     atomic_direction_feature_list.append(local_frame_featutes)
    if self.use_dir1:
        atomic_direction_feature_1 = innerprod(v.unsqueeze(1), normalized_atom_position_distances.unsqueeze(-1))
        atomic_direction_feature_list.append(atomic_direction_feature_1)
        # 就是每个元素想乘然后对第三维求
    if self.use_dir2:
        atomic_direction_feature_2 = innerprod(v.unsqueeze(2), normalized_atom_position_distances.unsqueeze(-1))
        atomic_direction_feature_list.append(atomic_direction_feature_2)
        # 两者添加的维度位置不同。
    if self.use_dir3:
        atomic_direction_feature_3 = innerprod(
            self.q_proj(v).unsqueeze(1),
            self.k_proj(v).unsqueeze(2))
        atomic_direction_feature_list.append(atomic_direction_feature_3)
    # dirs 里面的每个元素的维度都是 batch_size x atomic_number x atomic_number x hid_dim

    # batch_size x atomic_number x atomic_number x 8(3+3+2) # 坐标，坐标，角度
    combined_direction_features = torch.cat(atomic_direction_feature_list, dim=-1)  # batch_size x atomic_number x atomic_number x hid_dim x 2
    # 这个就是把矩阵的最后一个维度进行拼接， 拼接的维度是 hid_dim * (use_dir1 + use_dir2 + use_dir3)
    if self.ev_decay:
        combined_direction_features = combined_direction_features * atomic_adjacency_matrix.unsqueeze(-1)
    if self.add_ef2dir or self.no_filter_decomposition:
        combined_direction_features = torch.cat((combined_direction_features, edge_features), dim=-1) # batch_size x atomic_number x atomic_number x (hid_dim * 2 + ef_dim)
    if self.no_filter_decomposition:
        mask = self.dir_proj(combined_direction_features)
    else:
        dir_project2_max = self.dir_proj(combined_direction_features)
        mask = mask * dir_project2_max
    # 然后就是把dir投影到mask的维度， 然后想乘， 带有广播机制的。
    # 所以这几步的目的就是构建特征矩阵， 把前面的特征全部
    # mask: batch_size x atomic_number x atomic_number x hid_dim => 还是类似于领接矩阵表示的边
    # s: batch_size x atomic_number x hid_dim => 类似于节点。
    for layer_idx, interaction in enumerate(self.interactions):
        if self.no_share_filter:
            mask = (self.ef_projections[layer_idx](edge_features) *
                    self.dir_projections[layer_idx](combined_direction_features))
            # mask的形状未改变。
        s = interaction(s, mask) + s
        # s的形状未改变。
    s[atomic_numbers == 0] = 0
    # s = self.output_module_1(s)
    # s = s.squeeze(-1)
    # s = self.output_module_activation(s)
    # s = self.output_module_2(s)
    # s = torch.sum(s, dim=1)
    # s = self.output_module(s)
    # s = s * self.y_std + self.y_mean
    # print(f"y_std: {self.y_std}, y_mean: {self.y_mean}")
    #  without this: Training (Epoch 42/6000):
    #  100%|██████████| 8/8 [00:01<00:00,  5.91it/s, energy_loss=733, force_loss=25.8]
    #  with this: Training (Epoch 42/6000):
    #  100%|██████████| 8/8 [00:01<00:00,  6.00it/s, energy_loss=280, force_loss=26.8]
    #  seems that work for force loss, 但是这里面的y_std 和 y_mean 并没有改变。 其实也没必要改变， 因为这里的y_std 和 y_mean。
    return s, mask


In [63]:
gnnlf_model.forward = forward

In [67]:
equvilent_output_module = output_modules.EquivariantScalar(hidden_channels=256)
#  x = self.output_model.pre_reduce(x, v, data.z, data.pos, data.batch)


In [64]:
s, mask = gnnlf_model.forward(self=gnnlf_model, atomic_numbers=z, atomic_positions=pos)

In [71]:
pre_reduct_out = equvilent_output_module.pre_reduce(s, mask, None, None, None)

In [72]:
pre_reduct_out.shape

torch.Size([17, 21, 1])

In [73]:
torch.sum(pre_reduct_out, dim=1).shape

torch.Size([17, 1])

In [65]:
s.shape

torch.Size([17, 21, 256])

In [66]:
mask.shape

torch.Size([17, 21, 21, 256])

In [None]:
equvilent_output_module.

In [42]:
pred_y.shape

torch.Size([17, 1])

In [43]:
type(batch_one)

NameError: name 'batch_one' is not defined

In [44]:
import torch_geometric

def gnn_lf_batch2visnet_adapter(batch):

    z, pos, y, dy = batch
    batch_size = pos.shape[0]
    sep = pos.shape[1]
    visnet_y = y
    visnet_pos = pos.view(-1, pos.size(-1))
    visnet_z = z.view(-1, )
    visnet_batch = torch.arange(0, batch_size)
    # repeat each of the value in the visnet_batch for sep times 
    visnet_batch = visnet_batch.repeat_interleave(sep)
    visnet_dy = dy.view(-1, dy.size(-1))
    ptr = torch.arange(0, batch_size * sep + 1, sep)
    # conver it to DataBatch with key of "y, pos, z, dy, batch, ptr"
    # init a torch_geometric.data.batch.DataBatch
    _batch = torch_geometric.data.Batch(y=visnet_y, pos=visnet_pos, z=visnet_z, dy=visnet_dy, batch=visnet_batch, ptr=ptr)
    return _batch    
    # from torch 
    # torch_geometric.data.batch.DataBatch


In [70]:
gnn_lf_batch2visnet_adapter(first_train_batch)

DataBatch(y=[17, 1], pos=[357, 3], z=[357], dy=[357, 3], batch=[357], ptr=[18])

In [71]:
type(gnn_lf_batch2visnet_adapter(first_train_batch))

torch_geometric.data.batch.DataBatch

In [66]:
model(gnn_lf_batch2visnet_adapter(first_train_batch))

(tensor([[-1.9876],
         [-1.2070],
         [-1.3699],
         [-1.7094],
         [-2.4913],
         [-1.9085],
         [-2.2696],
         [-1.8223],
         [-2.0106],
         [-2.4089],
         [-2.0464],
         [-2.0982],
         [-1.7267],
         [-1.9202],
         [-2.1074],
         [-1.9732],
         [-1.9028]], grad_fn=<AddBackward0>),
 None)

In [None]:
model(gnn_lf_batch2visnet_adapter(first_train_batch))[0].shape

torch.Size([17, 1])

In [89]:
pred_y = gnnlf_model(z, pos)

In [91]:
pred_y

tensor([[-0.2601],
        [-0.2455],
        [-0.1928],
        [-0.2315],
        [-0.1238],
        [-0.1599],
        [-0.2234],
        [-0.2103],
        [-0.1576],
        [-0.2259],
        [-0.2741],
        [-0.1637],
        [-0.1835],
        [-0.2287],
        [-0.2016],
        [-0.2201],
        [-0.1845]], grad_fn=<AddmmBackward0>)

In [90]:
pred_y.shape

torch.Size([17, 1])