In [1]:
import argparse
import logging
import os
import sys

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import pathlib

import matplotlib.pyplot as plt
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from pytorch_lightning.strategies import DDPStrategy
from torch_geometric.data import Data


from vae_decoder import EGNNDecoder
from vae_model import MolecularVAE, vae_loss_function

from visnet_vae_encoder import ViSNetEncoder

from torch_geometric.loader import DataLoader
import torch.nn.functional as F
from pathlib import Path
import sys
import wandb
from aib9_lib import aib9_tools as aib9
from torch.cuda.amp import autocast, GradScaler
from vae_utils import validate_and_sample, visualize_molecule_3d, compute_bond_lengths

In [2]:
from aib9_lib import aib9_tools as aib9
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'Using device: {device}')


#Training parameters
ATOM_COUNT = 58
COORD_DIM = 3
ORIGINAL_DIM = ATOM_COUNT * COORD_DIM  
LATENT_DIM = 64 
EPOCHS = 50
VISNET_HIDDEN_CHANNELS = 64
ENCODER_NUM_LAYERS = 6
DECODER_HIDDEN_DIM = 64
DECODER_NUM_LAYERS = 6
BATCH_SIZE = 512  # Increased from 128 (V100 can handle much more!)
LEARNING_RATE = 5e-5  # Reduced to prevent gradient explosion
NUM_WORKERS = 2  # Parallel data loading

train_data_np = np.load(aib9.FULL_DATA)
train_data_np = train_data_np.reshape(-1, 58, 3)

TOPO_FILE = (
   "/Users/chriszhang/Documents/aib9_vanillavae/visnet_aib9/aib9_lib/aib9_atom_info.npy"
)  

ATOMICNUMBER_MAPPING = {
"H": 1,
"C": 6,
"N": 7,
"O": 8,
"F": 9,
"P": 15,
"S": 16,
"Cl": 17,
"Br": 35,
"I": 53,
}

ATOMIC_NUMBERS = []
topo = np.load(TOPO_FILE)
#print(f"Loaded topology with shape: {topo.shape}", topo)
for i in range(topo.shape[0]):
    atom_name = topo[i, 0][0]
    if atom_name in ATOMICNUMBER_MAPPING:
        ATOMIC_NUMBERS.append(ATOMICNUMBER_MAPPING[atom_name])
    else:
        raise ValueError(f"Unknown atom name: {atom_name}")

# Create all tensors directly on the selected device to avoid device mismatch
z = torch.tensor(ATOMIC_NUMBERS, dtype=torch.long, device=device)

# Create a list of Data objects, one for each molecule
edges = aib9.identify_all_covalent_edges(topo)
# edges is already in shape [2, num_edges], no need to transpose
edge_index = torch.tensor(edges, dtype=torch.long, device=device).contiguous()

train_data_list = []
for i in range(train_data_np.shape[0]):
    pos = torch.from_numpy(train_data_np[i]).float().to(device)
    data = Data(z=z, pos=pos, edge_index=edge_index).to(device)
    train_data_list.append(data)
train_loader = DataLoader(
    train_data_list,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

Using device: cpu


In [4]:
 max(1.0 + 0.1 * 3, 2.0)

2.0

In [5]:
# Check if the graph is fully connected
import networkx as nx
from collections import defaultdict

def check_graph_connectivity(edges, num_atoms):
    """
    Check if the graph formed by edges is fully connected.
    
    Args:
        edges: List of [node1, node2] pairs or tensor of shape [2, num_edges]
        num_atoms: Total number of atoms/nodes
    
    Returns:
        dict with connectivity information
    """
    # Convert to numpy if tensor
    if hasattr(edges, 'cpu'):
        edges_np = edges.cpu().numpy()
    else:
        edges_np = np.array(edges)
    
    # Handle different edge formats
    if edges_np.shape[0] == 2:  # [2, num_edges] format
        edge_list = [(edges_np[0][i], edges_np[1][i]) for i in range(edges_np.shape[1])]
    else:  # [num_edges, 2] format
        edge_list = [(edges_np[i][0], edges_np[i][1]) for i in range(edges_np.shape[0])]
    
    # Create NetworkX graph
    G = nx.Graph()
    G.add_nodes_from(range(num_atoms))
    G.add_edges_from(edge_list)
    
    # Check connectivity
    is_connected = nx.is_connected(G)
    num_components = nx.number_connected_components(G)
    components = list(nx.connected_components(G))
    
    # Find isolated nodes (degree 0)
    isolated_nodes = [node for node in G.nodes() if G.degree(node) == 0]
    
    # Node degrees
    degrees = dict(G.degree())
    
    results = {
        'is_connected': is_connected,
        'num_components': num_components,
        'components': components,
        'isolated_nodes': isolated_nodes,
        'num_edges': len(edge_list),
        'num_nodes': num_atoms,
        'degrees': degrees,
        'avg_degree': sum(degrees.values()) / len(degrees),
        'min_degree': min(degrees.values()),
        'max_degree': max(degrees.values())
    }
    
    return results

# Check connectivity of our molecular graph
print("Checking graph connectivity...")
print(f"Number of atoms: {ATOM_COUNT}")
print(f"Number of edges: {len(edges)}")
print(f"Edge format: {np.array(edges).shape}")

connectivity = check_graph_connectivity(edges, ATOM_COUNT)

print(f"\n🔍 CONNECTIVITY ANALYSIS:")
print(f"✅ Fully connected: {connectivity['is_connected']}")
print(f"📊 Number of components: {connectivity['num_components']}")
print(f"🔗 Total edges: {connectivity['num_edges']}")
print(f"📈 Average degree: {connectivity['avg_degree']:.2f}")
print(f"📉 Min degree: {connectivity['min_degree']}")
print(f"📈 Max degree: {connectivity['max_degree']}")

if connectivity['isolated_nodes']:
    print(f"🚨 Isolated nodes (degree 0): {connectivity['isolated_nodes']}")
    
if not connectivity['is_connected']:
    print(f"\n🔍 CONNECTED COMPONENTS:")
    for i, component in enumerate(connectivity['components']):
        print(f"  Component {i+1}: {sorted(list(component))} (size: {len(component)})")
        
# Show degree distribution
print(f"\n📊 DEGREE DISTRIBUTION:")
degree_counts = defaultdict(int)
for node, degree in connectivity['degrees'].items():
    degree_counts[degree] += 1

for degree in sorted(degree_counts.keys()):
    count = degree_counts[degree]
    print(f"  Degree {degree}: {count} nodes")
    
# Check specific atoms that might be problematic
print(f"\n🔍 SPECIFIC ATOM ANALYSIS:")
for i in range(min(10, ATOM_COUNT)):  # Show first 10 atoms
    atom_name = topo[i, 0][0]
    degree = connectivity['degrees'][i]
    print(f"  Atom {i} ({atom_name}): degree {degree}")

# Look for CA2 specifically (the one we fixed)
ca2_indices = [i for i in range(ATOM_COUNT) if topo[i, 0][0] == 'CA2']
if ca2_indices:
    for ca2_idx in ca2_indices:
        degree = connectivity['degrees'][ca2_idx]
        print(f"  🎯 CA2 atom {ca2_idx}: degree {degree}")
        if degree == 0:
            print(f"    ❌ CA2 is isolated!")
        else:
            print(f"    ✅ CA2 is connected")


Checking graph connectivity...
Number of atoms: 58
Number of edges: 2
Edge format: (2, 114)

🔍 CONNECTIVITY ANALYSIS:
✅ Fully connected: True
📊 Number of components: 1
🔗 Total edges: 114
📈 Average degree: 1.97
📉 Min degree: 1
📈 Max degree: 5

📊 DEGREE DISTRIBUTION:
  Degree 1: 30 nodes
  Degree 2: 10 nodes
  Degree 3: 9 nodes
  Degree 4: 8 nodes
  Degree 5: 1 nodes

🔍 SPECIFIC ATOM ANALYSIS:
  Atom 0 (C): degree 1
  Atom 1 (C): degree 3
  Atom 2 (O): degree 1
  Atom 3 (N): degree 2
  Atom 4 (C): degree 4
  Atom 5 (C): degree 1
  Atom 6 (C): degree 1
  Atom 7 (C): degree 3
  Atom 8 (O): degree 1
  Atom 9 (N): degree 2
