In [None]:
import argparse
import logging
import os
import sys

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import pathlib

import matplotlib.pyplot as plt
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from pytorch_lightning.strategies import DDPStrategy
from torch_geometric.data import Data


from vae_decoder import EGNNDecoder
from vae_model import MolecularVAE, vae_loss_function

from visnet_vae_encoder import ViSNetEncoder

from torch_geometric.loader import DataLoader
import torch.nn.functional as F
from pathlib import Path
import sys
import wandb
from aib9_lib import aib9_tools as aib9
from torch.cuda.amp import autocast, GradScaler
from vae_utils import validate_and_sample, visualize_molecule_3d, compute_bond_lengths

In [21]:
from aib9_lib import aib9_tools as aib9
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'Using device: {device}')


#Training parameters
ATOM_COUNT = 58
COORD_DIM = 3
ORIGINAL_DIM = ATOM_COUNT * COORD_DIM  
LATENT_DIM = 64 
EPOCHS = 50
VISNET_HIDDEN_CHANNELS = 64
ENCODER_NUM_LAYERS = 6
DECODER_HIDDEN_DIM = 64
DECODER_NUM_LAYERS = 6
BATCH_SIZE = 512  # Increased from 128 (V100 can handle much more!)
LEARNING_RATE = 5e-5  # Reduced to prevent gradient explosion
NUM_WORKERS = 2  # Parallel data loading

train_data_np = np.load(aib9.FULL_DATA)
train_data_np = train_data_np.reshape(-1, 58, 3)

TOPO_FILE = (
   "/Users/chriszhang/Documents/aib9_vanillavae/visnet_aib9/aib9_lib/aib9_atom_info.npy"
)  

ATOMICNUMBER_MAPPING = {
"H": 1,
"C": 6,
"N": 7,
"O": 8,
"F": 9,
"P": 15,
"S": 16,
"Cl": 17,
"Br": 35,
"I": 53,
}

ATOMIC_NUMBERS = []
topo = np.load(TOPO_FILE)
#print(f"Loaded topology with shape: {topo.shape}", topo)
for i in range(topo.shape[0]):
    atom_name = topo[i, 0][0]
    if atom_name in ATOMICNUMBER_MAPPING:
        ATOMIC_NUMBERS.append(ATOMICNUMBER_MAPPING[atom_name])
    else:
        raise ValueError(f"Unknown atom name: {atom_name}")

# Create all tensors directly on the selected device to avoid device mismatch
z = torch.tensor(ATOMIC_NUMBERS, dtype=torch.long, device=device)

# Create a list of Data objects, one for each molecule
edges = aib9.identify_all_covalent_edges(topo)
# edges is already in shape [2, num_edges], no need to transpose
edge_index = torch.tensor(edges, dtype=torch.long, device=device).contiguous()

train_data_list = []
for i in range(train_data_np.shape[0]):
    pos = torch.from_numpy(train_data_np[i]).float().to(device)
    data = Data(z=z, pos=pos, edge_index=edge_index).to(device)
    train_data_list.append(data)
train_loader = DataLoader(
    train_data_list,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

Using device: cpu


In [22]:
edges.shape


(2, 112)