In [151]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from Bio.PDB import PDBList, PDBParser
import os
import torch
import torch.nn as nn
from data_read import *
import dgl
from tqdm import tqdm
import warnings
from torch.utils.data import DataLoader
%load_ext autoreload
%autoreload 2
warnings.filterwarnings('ignore')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Data Retrieval

In [None]:
def download_pdb_files(sample_size=100):
    """
    Retrieves all PDB IDs available in the PDB.

    Returns:
        list: List of all PDB IDs.
    """
    np.random.seed(42)
    pdbl = PDBList()
    pdb_ids = pdbl.get_all_entries()
    sampled_pdb_ids = np.random.choice(pdb_ids, sample_size, replace=False)
    for pdb_id in sampled_pdb_ids:
        pdbl.retrieve_pdb_file(pdb_id, pdir='pdb_files', file_format='pdb')
    print(f"Downloaded {len(sampled_pdb_ids)} PDB files.")
    return pdb_ids

download_pdb_files(sample_size=100)

In [148]:
import shutil
from sklearn.model_selection import train_test_split

# Define the paths
data_path = '../data/pdb_files/'
train_path = os.path.join(data_path, 'train')
test_path = os.path.join(data_path, 'test')

# Create train and test directories if they don't exist
os.makedirs(train_path, exist_ok=True)
os.makedirs(test_path, exist_ok=True)

# Get list of all files in the data_path
all_files = [f for f in os.listdir(data_path) if os.path.isfile(os.path.join(data_path, f))]

# Split the files into train and test sets
train_files, test_files = train_test_split(all_files, test_size=0.2, random_state=42)

# Move the files to the respective directories
for file in train_files:
    shutil.move(os.path.join(data_path, file), os.path.join(train_path, file))

for file in test_files:
    shutil.move(os.path.join(data_path, file), os.path.join(test_path, file))

print(f"Moved {len(train_files)} files to {train_path}")
print(f"Moved {len(test_files)} files to {test_path}")

Moved 77 files to ../data/pdb_files/train
Moved 20 files to ../data/pdb_files/test


### Testing of Datasets and DataLoaders

In [229]:
from src.dataset.datasets import *
from src.dataset.transforms import *
from torchvision import transforms

In [311]:
# pdb_transforms = transforms.Compose([NormalizeCoordinates(), PadDatasetTransform(1000)])

In [376]:
train_data_path = '../data/pdb_files/train'
test_data_path = '../data/pdb_files/test'
train_dataset = PDBDataset(train_data_path)
test_dataset = PDBDataset(test_data_path)

100%|██████████| 77/77 [00:05<00:00, 14.66it/s]
100%|██████████| 20/20 [00:00<00:00, 22.01it/s]


### Data Statistics

In [377]:
from collections import defaultdict

# Initialize a dictionary to store the count of samples with each number of nodes
node_count_dict = defaultdict(int)

# Enumerate through the train_dataset and count the number of nodes in each sample
max_nodes = 0
for sample in train_dataset:
    num_nodes = sample['positions'].shape[0]
    node_count_dict[num_nodes] += 1
    if num_nodes > max_nodes:
        max_nodes = num_nodes

# Initialize a dictionary to store the count of occurrences for each unique value in ["i_seq"]
i_seq_count_dict = defaultdict(int)

# Enumerate through the train_dataset and count the occurrences of each unique value in ["i_seq"]
for sample in train_dataset:
    for i_seq_value in sample['i_seq']:
        if i_seq_value != 0:
            i_seq_count_dict[i_seq_value.item()] += 1

# Sort the keys of i_seq_count_dict
i_seq_count_dict = dict(sorted(i_seq_count_dict.items()))

print("i_seq count dictionary:", dict(i_seq_count_dict))

print(f"Maximum number of nodes in a sample: {max_nodes}")
print("Node count dictionary:", dict(node_count_dict))

i_seq count dictionary: {1: 1314, 2: 268, 3: 1036, 4: 1017, 5: 695, 6: 1241, 7: 359, 8: 1058, 9: 1029, 10: 1534, 11: 328, 12: 744, 13: 808, 14: 623, 15: 791, 16: 981, 17: 1031, 18: 1244, 19: 207, 20: 603}
Maximum number of nodes in a sample: 773
Node count dictionary: {364: 1, 452: 1, 114: 2, 152: 1, 294: 1, 16: 1, 119: 1, 207: 1, 113: 1, 88: 1, 124: 1, 303: 1, 78: 3, 121: 1, 306: 1, 296: 2, 183: 1, 321: 1, 319: 1, 396: 1, 62: 1, 74: 1, 368: 1, 391: 1, 107: 1, 153: 1, 400: 1, 369: 1, 394: 1, 224: 1, 58: 2, 130: 1, 309: 1, 297: 1, 73: 1, 128: 1, 216: 1, 330: 1, 215: 1, 338: 1, 438: 1, 198: 1, 553: 1, 325: 1, 268: 1, 410: 1, 284: 1, 175: 1, 277: 1, 127: 1, 354: 1, 168: 1, 12: 1, 607: 1, 63: 1, 773: 1, 300: 1, 169: 1, 254: 1, 527: 1, 305: 1, 343: 1, 355: 1}


In [378]:
# Get dataset info required to generate node distribution
aa_rep = "ACDEFGHIKLMNPQRSTVWYBXZJUO"
aa_to_int = {c: i for i, c in enumerate(aa_rep)}

dataset_info = {
    'name': 'pdb',
    'max_n_nodes': max_nodes,
    'n_nodes': node_count_dict,
    'atom_types': i_seq_count_dict,
    'atom_encoder': aa_to_int,
    'atom_decoder': [i for i in aa_rep],
    'colors_dic': ['C'+str(i) for i in range(len(aa_rep))],
    'radius_dic': [0.3]*len(aa_rep)
}

In [390]:
from src.dataset.collate import PreprocessPDB

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=PreprocessPDB().collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=PreprocessPDB().collate_fn)

### Models

In [394]:
import yaml

# Read the jer_config.yml file
with open('../edm_config.yml', 'r') as file:
    config = yaml.safe_load(file)

class Config:
    def __init__(self, config_dict):
        for key, value in config_dict.items():
            if isinstance(value, str):
                try:
                    # Try to convert strings to their appropriate types
                    value = eval(value)
                except:
                    pass
            setattr(self, key, value)

args = Config(config)

In [395]:
from src.models.prepare_models import *

In [403]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [404]:
model, nodes_dist, prop_dist = get_model(args, device, dataset_info, train_dataloader)
optim = get_optim(args, model)

model = model.to(device)

Entropy of n_nodes: H[N] -4.109879493713379
alphas2 [9.99990000e-01 9.99982000e-01 9.99958001e-01 9.99918003e-01
 9.99862007e-01 9.99790014e-01 9.99702026e-01 9.99598046e-01
 9.99478076e-01 9.99342118e-01 9.99190176e-01 9.99022254e-01
 9.98838355e-01 9.98638484e-01 9.98422646e-01 9.98190846e-01
 9.97943090e-01 9.97679383e-01 9.97399731e-01 9.97104143e-01
 9.96792624e-01 9.96465182e-01 9.96121825e-01 9.95762562e-01
 9.95387400e-01 9.94996350e-01 9.94589420e-01 9.94166620e-01
 9.93727960e-01 9.93273451e-01 9.92803104e-01 9.92316930e-01
 9.91814941e-01 9.91297149e-01 9.90763566e-01 9.90214206e-01
 9.89649081e-01 9.89068205e-01 9.88471593e-01 9.87859258e-01
 9.87231215e-01 9.86587480e-01 9.85928068e-01 9.85252996e-01
 9.84562278e-01 9.83855933e-01 9.83133976e-01 9.82396427e-01
 9.81643302e-01 9.80874619e-01 9.80090398e-01 9.79290657e-01
 9.78475416e-01 9.77644695e-01 9.76798513e-01 9.75936891e-01
 9.75059851e-01 9.74167412e-01 9.73259599e-01 9.72336431e-01
 9.71397932e-01 9.70444124e-01 9.

In [356]:
import wandb

In [372]:
from src.train_test import train_epoch

In [357]:
if args.no_wandb:
    mode = 'disabled'
else:
    mode = 'online' if args.online else 'offline'
kwargs = {'entity': args.wandb_usr, 'name': args.exp_name, 'project': 'e3_diffusion', 'config': args,
          'settings': wandb.Settings(_disable_stats=True), 'reinit': True, 'mode': mode}
wandb.init(**kwargs)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mjerhanwl[0m ([33mjerhanwl-national-university-of-singapore[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [399]:
from src.model_utils import Queue

In [400]:
gradnorm_queue = Queue()
gradnorm_queue.add(3000)

In [408]:
for epoch in range(1,2):
    train_epoch(args=args, loader=train_dataloader, epoch=epoch, model=model, model_dp=model,
                    model_ema=model, ema=None, device=device, dtype=torch.float32, property_norms=None,
                    nodes_dist=nodes_dist, dataset_info=dataset_info,
                    gradnorm_queue=gradnorm_queue, optim=optim, prop_dist=prop_dist)

torch.Size([4, 343])
tensor([[13, 12, 20,  ...,  0,  0,  0],
        [20, 14,  3,  ..., 20, 13, 19],
        [18,  4, 17,  ...,  0,  0,  0],
        [16,  1, 14,  ...,  0,  0,  0]], dtype=torch.int32)
x torch.Size([4, 343, 3])
h[categorical] torch.Size([4, 343, 26])
h[integer] torch.Size([4, 343, 1])


: 