In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# cd /content/drive/MyDrive/CS5284\ Project/protein-coordinates/notebooks

In [None]:
# !pip install pytorch matplotlib pandas seaborn scikit-learn torchvision numpy scipy dgl imageio biopython ipykernel ipywidgets wandb

In [None]:
# pip install biopython

In [4]:
import sys

sys.path.append('../')

In [5]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from Bio.PDB import PDBList, PDBParser
import os
import torch
import torch.nn as nn
from data_read import *
from tqdm import tqdm
import warnings
from torch.utils.data import DataLoader
%load_ext autoreload
%autoreload 2
warnings.filterwarnings('ignore')

### Data Retrieval

In [6]:
def download_pdb_files(sample_size=100):
    """
    Retrieves all PDB IDs available in the PDB.

    Returns:
        list: List of all PDB IDs.
    """
    np.random.seed(42)
    pdbl = PDBList()
    pdb_ids = pdbl.get_all_entries()
    sampled_pdb_ids = np.random.choice(pdb_ids, sample_size, replace=False)
    for pdb_id in sampled_pdb_ids:
        pdbl.retrieve_pdb_file(pdb_id, pdir='pdb_files', file_format='pdb')
    print(f"Downloaded {len(sampled_pdb_ids)} PDB files.")
    return pdb_ids

# download_pdb_files(sample_size=100)

In [7]:
# import shutil
# from sklearn.model_selection import train_test_split

# # Define the paths
# data_path = '../data/pdb_files/'
# train_path = os.path.join(data_path, 'train')
# test_path = os.path.join(data_path, 'test')

# # Create train and test directories if they don't exist
# os.makedirs(train_path, exist_ok=True)
# os.makedirs(test_path, exist_ok=True)

# # Get list of all files in the data_path
# all_files = [f for f in os.listdir(data_path) if os.path.isfile(os.path.join(data_path, f))]

# # Split the files into train and test sets
# train_files, test_files = train_test_split(all_files, test_size=0.2, random_state=42)

# # Move the files to the respective directories
# for file in train_files:
#     shutil.move(os.path.join(data_path, file), os.path.join(train_path, file))

# for file in test_files:
#     shutil.move(os.path.join(data_path, file), os.path.join(test_path, file))

# print(f"Moved {len(train_files)} files to {train_path}")
# print(f"Moved {len(test_files)} files to {test_path}")

### Testing of Datasets and DataLoaders

In [8]:
from src.dataset.datasets import *
from src.dataset.transforms import *
from torchvision import transforms

In [9]:
# pdb_transforms = transforms.Compose([NormalizeCoordinates(), PadDatasetTransform(1000)])

In [47]:
train_data_path = '../data/pdb_files/train'
test_data_path = '../data/pdb_files/test'
train_dataset = PDBDataset(train_data_path)
if len(train_dataset)%2 != 0:
    train_dataset = train_dataset[:-1]
test_dataset = PDBDataset(test_data_path)

100%|██████████| 77/77 [00:04<00:00, 15.93it/s]
100%|██████████| 20/20 [00:00<00:00, 21.88it/s]


### Data Statistics

In [48]:
from collections import defaultdict

# Initialize a dictionary to store the count of samples with each number of nodes
node_count_dict = defaultdict(int)

# Enumerate through the train_dataset and count the number of nodes in each sample
max_nodes = 0
for sample in train_dataset:
    num_nodes = sample['positions'].shape[0]
    node_count_dict[num_nodes] += 1
    if num_nodes > max_nodes:
        max_nodes = num_nodes

# Initialize a dictionary to store the count of occurrences for each unique value in ["i_seq"]
i_seq_count_dict = defaultdict(int)

# Enumerate through the train_dataset and count the occurrences of each unique value in ["i_seq"]
for sample in train_dataset:
    for i_seq_value in sample['i_seq']:
        if i_seq_value != 0:
            i_seq_count_dict[i_seq_value.item()] += 1

# Sort the keys of i_seq_count_dict
i_seq_count_dict = dict(sorted(i_seq_count_dict.items()))

print("i_seq count dictionary:", dict(i_seq_count_dict))

print(f"Maximum number of nodes in a sample: {max_nodes}")
print("Node count dictionary:", dict(node_count_dict))

i_seq count dictionary: {1: 1314, 2: 268, 3: 1036, 4: 1017, 5: 695, 6: 1241, 7: 359, 8: 1058, 9: 1029, 10: 1534, 11: 328, 12: 744, 13: 808, 14: 623, 15: 791, 16: 981, 17: 1031, 18: 1244, 19: 207, 20: 603}
Maximum number of nodes in a sample: 773
Node count dictionary: {364: 1, 452: 1, 114: 2, 152: 1, 294: 1, 16: 1, 119: 1, 207: 1, 113: 1, 88: 1, 124: 1, 303: 1, 78: 3, 121: 1, 306: 1, 296: 2, 183: 1, 321: 1, 319: 1, 396: 1, 62: 1, 74: 1, 368: 1, 391: 1, 107: 1, 153: 1, 400: 1, 369: 1, 394: 1, 224: 1, 58: 2, 130: 1, 309: 1, 297: 1, 73: 1, 128: 1, 216: 1, 330: 1, 215: 1, 338: 1, 438: 1, 198: 1, 553: 1, 325: 1, 268: 1, 410: 1, 284: 1, 175: 1, 277: 1, 127: 1, 354: 1, 168: 1, 12: 1, 607: 1, 63: 1, 773: 1, 300: 1, 169: 1, 254: 1, 527: 1, 305: 1, 343: 1, 355: 1}


In [49]:
# Augment Node Dict
node_count_dict_augmented = dict()
for i in range(1, max_nodes+1):
    if i in node_count_dict:
        node_count_dict_augmented[i] = node_count_dict[i]
    else:
        node_count_dict_augmented[i] = 0.1

In [50]:
# Get dataset info required to generate node distribution
aa_rep = "ACDEFGHIKLMNPQRSTVWYBXZJUO"
aa_to_int = {c: i for i, c in enumerate(aa_rep)}

dataset_info = {
    'name': 'pdb',
    'max_n_nodes': max_nodes,
    'n_nodes': node_count_dict_augmented,
    'atom_types': i_seq_count_dict,
    'atom_encoder': aa_to_int,
    'atom_decoder': [i for i in aa_rep],
    'colors_dic': ['C'+str(i) for i in range(len(aa_rep))],
    'radius_dic': [0.3]*len(aa_rep)
}

In [52]:
from src.dataset.collate import PreprocessPDB

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=PreprocessPDB().collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=PreprocessPDB().collate_fn)

### Models

In [53]:
import yaml

# Read the jer_config.yml file
with open('../edm_config.yml', 'r') as file:
    config = yaml.safe_load(file)

class Config:
    def __init__(self, config_dict):
        for key, value in config_dict.items():
            if isinstance(value, str):
                try:
                    # Try to convert strings to their appropriate types
                    value = eval(value)
                except:
                    pass
            setattr(self, key, value)

args = Config(config)

In [54]:
from src.models.prepare_models import *

In [55]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

In [56]:
model, nodes_dist, prop_dist = get_model(args, device, dataset_info, train_dataloader)
optim = get_optim(args, model)

model = model.to(device)

Entropy of n_nodes: H[N] -6.056983947753906
alphas2 [9.99990000e-01 9.99940002e-01 9.99790014e-01 9.99540060e-01
 9.99190176e-01 9.98740416e-01 9.98190846e-01 9.97541550e-01
 9.96792624e-01 9.95944182e-01 9.94996350e-01 9.93949271e-01
 9.92803104e-01 9.91558019e-01 9.90214206e-01 9.88771865e-01
 9.87231215e-01 9.85592489e-01 9.83855933e-01 9.82021810e-01
 9.80090398e-01 9.78061989e-01 9.75936891e-01 9.73715426e-01
 9.71397932e-01 9.68984761e-01 9.66476280e-01 9.63872873e-01
 9.61174936e-01 9.58382883e-01 9.55497140e-01 9.52518150e-01
 9.49446371e-01 9.46282275e-01 9.43026349e-01 9.39679097e-01
 9.36241035e-01 9.32712696e-01 9.29094628e-01 9.25387393e-01
 9.21591568e-01 9.17707746e-01 9.13736535e-01 9.09678557e-01
 9.05534449e-01 9.01304864e-01 8.96990470e-01 8.92591949e-01
 8.88109998e-01 8.83545330e-01 8.78898672e-01 8.74170767e-01
 8.69362373e-01 8.64474261e-01 8.59507220e-01 8.54462051e-01
 8.49339573e-01 8.44140618e-01 8.38866033e-01 8.33516680e-01
 8.28093438e-01 8.22597199e-01 8.

In [57]:
import wandb

In [58]:
from src.train_test import train_epoch, test

In [59]:
if args.no_wandb:
    mode = 'disabled'
else:
    mode = 'online' if args.online else 'offline'
    kwargs = {'entity': args.wandb_usr, 'name': args.exp_name, 'project': 'e3_diffusion', 'config': args, 'reinit': True, 'mode': 'disabled'}
wandb.init(**kwargs)

In [60]:
from src.model_utils import Queue

In [61]:
gradnorm_queue = Queue()
gradnorm_queue.add(3000)

In [62]:
# prompt: Get the size of object 'model'

# print(f"Size of model: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

# print(f"Allocated memory: {torch.cuda.memory_allocated()} bytes")
# print(f"Reserved memory: {torch.cuda.memory_reserved()} bytes")
# print(f"Max memory: {torch.cuda.max_memory_allocated()} bytes")

In [63]:
train_loss = {}
test_loss = {}
for epoch in range(args.n_epochs):
    nll_train = train_epoch(args=args, loader=train_dataloader, epoch=epoch, model=model, model_dp=model,
                    model_ema=model, ema=None, device=device, dtype=torch.float32, property_norms=None,
                    nodes_dist=nodes_dist, dataset_info=dataset_info,
                    gradnorm_queue=gradnorm_queue, optim=optim, prop_dist=prop_dist)
    train_loss[epoch] = nll_train
    if epoch % args.test_epochs == 0:
        nll_test = test(args=args, loader=test_dataloader, epoch=epoch, eval_model=model,
                        partition='Test', device=device, dtype=torch.float32,
                        nodes_dist=nodes_dist, property_norms=None)
        test_loss[epoch] = nll_test

Clipped gradient with value 41344220.0 while allowed 4500.0
Epoch: 0, iter: 0/17, Loss 63891.36, NLL: 63891.36, RegTerm: 0.0, GradNorm: 41344220.0
Clipped gradient with value 1782382208.0 while allowed 7125.0
Epoch: 0, iter: 1/17, Loss 10632.24, NLL: 10632.24, RegTerm: 0.0, GradNorm: 1782382208.0
Clipped gradient with value 4025837.8 while allowed 10722.0
Epoch: 0, iter: 2/17, Loss 24154.44, NLL: 24154.44, RegTerm: 0.0, GradNorm: 4025837.8
Clipped gradient with value 97273464.0 while allowed 15366.9
Epoch: 0, iter: 3/17, Loss 115038.95, NLL: 115038.95, RegTerm: 0.0, GradNorm: 97273464.0
Clipped gradient with value 34886156.0 while allowed 21140.3
Epoch: 0, iter: 4/17, Loss 92726.78, NLL: 92726.78, RegTerm: 0.0, GradNorm: 34886156.0
Clipped gradient with value 2463364608.0 while allowed 28122.5
Epoch: 0, iter: 5/17, Loss 138191.89, NLL: 138191.89, RegTerm: 0.0, GradNorm: 2463364608.0
Clipped gradient with value 228634560.0 while allowed 36391.4
Epoch: 0, iter: 6/17, Loss 17873.92, NLL: 

: 