In [1]:
import sys
import os
import yaml
import random
import argparse
import logging
import torch
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger, CSVLogger, TensorBoardLogger
from lightning.pytorch.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
)
from torchmdnet.module import LNNP
from torchmdnet import datasets, priors, models
from torchmdnet.data import DataModule
from torchmdnet.loss import loss_class_mapping
from torchmdnet.models import output_modules
from torchmdnet.models.model import create_prior_models
from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping, dtype_mapping
from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch.utils.data import ConcatDataset, Subset, SubsetRandomSampler, random_split

In [2]:
args = {
    "activation": "silu",
    "aggr": "add",
    "atom_filter": -1,
    "attn_activation": "silu",
    "batch_size": 1000,
    "coord_files": None,
    "cutoff_lower": 0.0,
    "cutoff_upper": 5.0,
    "dataset": "MD17",
    "dataset_arg": {
        "molecules": "ethanol"
    },
    "dataset_root": "~/data",
    "derivative": False,
    "distance_influence": "both",
    "early_stopping_patience": 300,
    "ema_alpha_neg_dy": 1.0,
    "ema_alpha_y": 0.05,
    "embed_files": None,
    "embedding_dimension": 128,
    "energy_files": None,
    "y_weight": 0.2,
    "force_files": None,
    "neg_dy_weight": 0.8,
    "load_model": None,
    "log_dir": "logs/",
    "lr": 0.001,
    "lr_factor": 0.8,
    "lr_min": 1.0e-07,
    "lr_patience": 30,
    "lr_warmup_steps": 1000,
    "max_num_neighbors": 32,
    "max_z": 100,
    "model": "equivariant-transformer",
    "neighbor_embedding": True,
    "ngpus": -1,
    "num_epochs": 3000,
    "num_heads": 8,
    "num_layers": 6,
    "num_nodes": 1,
    "num_rbf": 32,
    "num_workers": 6,
    "output_model": "Scalar",
    "precision": 32,
    "prior_model": None,
    "rbf_type": "expnorm",
    "redirect": False,
    "reduce_op": "add",
    "save_interval": 10,
    "splits": None,
    "standardize": True,
    "test_interval": 5,
    "test_size": 555090,
    "train_size": 1,
    "trainable_rbf": False,
    "val_size": 1,
    "weight_decay": 0.0,
    "box_vecs": None,
    "charge": False,
    "spin": False,
    "vector_cutoff": True,
    "wandb_use": True,
    "wandb_project": "MD17-Mix_No_Ethanol",
    "tensorboard_use": True,
    "wandb_name": "ET-Transformer-Mix_No_Ethanol",
    "pairwise_thread": True,
    "triples_thread": True,
    "return_vecs": True,
    "loop": True,
    "base_cutoff": 5.0,
    "outer_cutoff": 5.0,
    "gradient_clipping": 0.0,
    "remove_ref_energy": False,
    "train_loss": "mse_loss",
    "train_loss_arg": None,
    "seed": 1,
    "dataset_preload_limit": 1024,
    "lr_metric": "val",
    "box": None,
    "long_edge_index": True,
    "check_errors": True,
    "strategy": "brute",
    "include_transpose": True,
    "resize_to_fit": True,
    "output_mlp_num_layers": 0,
    "equivariance_invariance_group": "O(3)",
    "static_shapes": False,
    "wandb_resume_from_id": None,
    "inference_batch_size": 1000,
}

args = argparse.Namespace(**args)

# Now you can access the arguments as attributes of the args object
print(args.activation)

silu


In [None]:
!pwd

In [3]:
data = DataModule(args)
data.prepare_data()
data.setup("fit")

  self.data, self.slices = torch.load(self.processed_paths[idx])


computing mean and std: 100%|██████████| 1/1 [00:00<00:00,  3.09it/s]
  self._std = ys.std(dim=0)


In [4]:
prior_models = create_prior_models(vars(args), data.dataset)
args.prior_args = [p.get_init_args() for p in prior_models]
# initialize lightning module
model = LNNP(args, prior_model=prior_models, mean=data.mean, std=data.std)

In [5]:
trainer = pl.Trainer(
    strategy="auto",
    max_epochs=args.num_epochs,
    accelerator="auto",
    devices=args.ngpus,
    num_nodes=args.num_nodes,
    default_root_dir=args.log_dir,
    # callbacks=[early_stopping, checkpoint_callback],
    # logger=_logger,
    precision=args.precision,
    gradient_clip_val=args.gradient_clipping,
    inference_mode=False,
    # Test-during-training requires reloading the dataloaders every epoch
    reload_dataloaders_every_n_epochs=1 if args.test_interval > 0 else 0,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [6]:
model = LNNP.load_from_checkpoint("logs/epoch=2999-val_loss=0.5655-test_loss=0.4790.ckpt")
model.eval()

LNNP(
  (model): TorchMD_Net(
    (representation_model): TorchMD_ET(hidden_channels=128, num_layers=6, num_rbf=32, rbf_type=expnorm, trainable_rbf=False, activation=silu, attn_activation=silu, neighbor_embedding=NeighborEmbedding(
      (embedding): Embedding(100, 128)
      (distance_proj): Linear(in_features=32, out_features=128, bias=True)
      (combine): Linear(in_features=256, out_features=128, bias=True)
      (cutoff): CosineCutoff()
    ), num_heads=8, distance_influence=both, cutoff_lower=0.0, cutoff_upper=5.0), dtype=torch.float32
    (output_model): EquivariantScalar(
      (output_network): ModuleList(
        (0): GatedEquivariantBlock(
          (vec1_proj): Linear(in_features=128, out_features=128, bias=False)
          (vec2_proj): Linear(in_features=128, out_features=64, bias=False)
          (update_net): MLP(
            (act): SiLU()
            (layers): Sequential(
              (0): Linear(in_features=256, out_features=128, bias=True)
              (1): SiLU()


In [None]:
len(data.test_dataloader())

In [None]:
    trainer = pl.Trainer(
        # logger=_logger,
        inference_mode=False,
        accelerator="auto",
        devices=args.ngpus,
        num_nodes=args.num_nodes,
    )
    trainer.test(model, data)

In [None]:
model.model

In [26]:
from rich.progress import Progress, TextColumn, BarColumn, TimeRemainingColumn, MofNCompleteColumn

model.eval()
outputs_list = []
label_list = []

# Get total number of batches
total_batches = len(data.test_dataloader())

# Create progress bar
with Progress(
    TextColumn("[progress.description]{task.description}"),
    BarColumn(),
    MofNCompleteColumn(\),
    TimeRemainingColumn(),
) as progress:
    
    # Add task
    task = progress.add_task("[cyan]Processing batches...", total=total_batches)
    
    # Inference loop
    with torch.no_grad():
        for batch in data.test_dataloader():
            # Get inputs and move to device
            z = batch.z.to(model.device)
            pos = batch.pos.to(model.device)
            batch_idx = batch.batch.to(model.device)
            y = batch.y

            # Forward pass
            outputs, _ = model.model(z, pos, batch=batch_idx)
            
            # Convert outputs to Python floats
            outputs_list.extend([x.item() for x in outputs])
            label_list.extend([x.item() for x in y])
            
            # Update progress
            progress.advance(task)

Output()

In [30]:
import plotly.express as px

# Create a DataFrame for easier plotting
import pandas as pd

sub_outputs_list = outputs_list[:1000]
sub_label_list = label_list[:1000]
df = pd.DataFrame({
    'Index': range(len(sub_label_list)),
    'Outputs': sub_outputs_list,
    'Labels': sub_label_list,
    # "diff": [outputs_list[i] - label_list[i] for i in range(len(outputs_list))],
})

# Create a linear plot
fig = px.line(df, x='Index', y=['Outputs', 'Labels'], labels={'value': 'Values', 'variable': 'Legend'}, title='Outputs vs Labels')

# Show the plotء٬س۰
fig.show()

In [None]:
import plotly.express as px

# Create a DataFrame for easier plotting
import pandas as pd

sub_outputs_list = outputs_list[:1000]
sub_outputs_list_mean = outputs_list[:1000].mean()
sub_label_list = label_list[:1000]
sub_label_list_mean = label_list[:1000].mean()

df = pd.DataFrame({
    'Index': range(len(sub_label_list)),
    'Outputs': sub_outputs_list - sub_outputs_list_mean,
    'Labels': sub_label_list - sub_label_list_mean,
    # "diff": [outputs_list[i] - label_list[i] for i in range(len(outputs_list))],
})

# Create a linear plot
fig = px.line(df, x='Index', y=['Outputs', 'Labels'], labels={'value': 'Values', 'variable': 'Legend'}, title='Outputs vs Labels')

# Show the plotء٬س۰
fig.show()

In [33]:
from torchmetrics import MeanAbsoluteError

# Create the metric
mae = MeanAbsoluteError()

# Compute the metric
mae(torch.tensor(outputs_list), torch.tensor(label_list))

# Print the result
mae.compute()



tensor(71373.0234)

In [None]:
import matplotlib.pyplot as plt


# Plot the data
plt.figure(figsize=(10, 6))
plt.plot(range(len(outputs_list)), outputs_list, label='Outputs', marker='o')
plt.plot(range(len(label_list)), label_list, label='Labels', marker='x')

# Add labels and title
plt.xlabel('Index')
plt.ylabel('Values')
plt.title('Outputs vs Labels')
plt.legend()

# Show the plot
plt.show()

In [31]:
import pickle
# Save the lists to a file
with open('lists.pkl', 'wb') as f:
    pickle.dump({'outputs_list': outputs_list, 'label_list': label_list}, f)

In [None]:
import pickle

# Load the lists from the file
with open('lists.pkl', 'rb') as f:
    data = pickle.load(f)
    outputs_list = data['outputs_list']
    label_list = data['label_list']

# Print the loaded lists
print('Outputs List:', outputs_list)
print('Label List:', label_list)

In [None]:
label_list[0]

In [None]:
outputs_list[0]