In [151]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from Bio.PDB import PDBList, PDBParser
import os
import torch
import torch.nn as nn
from data_read import *
import dgl
from tqdm import tqdm
import warnings
from torch.utils.data import DataLoader
%load_ext autoreload
%autoreload 2
warnings.filterwarnings('ignore')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Data Retrieval

In [None]:
def download_pdb_files(sample_size=100):
    """
    Retrieves all PDB IDs available in the PDB.

    Returns:
        list: List of all PDB IDs.
    """
    np.random.seed(42)
    pdbl = PDBList()
    pdb_ids = pdbl.get_all_entries()
    sampled_pdb_ids = np.random.choice(pdb_ids, sample_size, replace=False)
    for pdb_id in sampled_pdb_ids:
        pdbl.retrieve_pdb_file(pdb_id, pdir='pdb_files', file_format='pdb')
    print(f"Downloaded {len(sampled_pdb_ids)} PDB files.")
    return pdb_ids

download_pdb_files(sample_size=100)

In [148]:
import shutil
from sklearn.model_selection import train_test_split

# Define the paths
data_path = '../data/pdb_files/'
train_path = os.path.join(data_path, 'train')
test_path = os.path.join(data_path, 'test')

# Create train and test directories if they don't exist
os.makedirs(train_path, exist_ok=True)
os.makedirs(test_path, exist_ok=True)

# Get list of all files in the data_path
all_files = [f for f in os.listdir(data_path) if os.path.isfile(os.path.join(data_path, f))]

# Split the files into train and test sets
train_files, test_files = train_test_split(all_files, test_size=0.2, random_state=42)

# Move the files to the respective directories
for file in train_files:
    shutil.move(os.path.join(data_path, file), os.path.join(train_path, file))

for file in test_files:
    shutil.move(os.path.join(data_path, file), os.path.join(test_path, file))

print(f"Moved {len(train_files)} files to {train_path}")
print(f"Moved {len(test_files)} files to {test_path}")

Moved 77 files to ../data/pdb_files/train
Moved 20 files to ../data/pdb_files/test


### Testing of Datasets and DataLoaders

In [229]:
from src.dataset.datasets import *
from src.dataset.transforms import *

In [228]:
from torchvision import transforms

In [311]:
pdb_transforms = transforms.Compose([NormalizeCoordinates(), PadDatasetTransform(1000)])

In [318]:
train_data_path = '../data/pdb_files/train'
test_data_path = '../data/pdb_files/test'
train_dataset = PDBDataset(train_data_path)
test_dataset = PDBDataset(test_data_path)

100%|██████████| 77/77 [00:04<00:00, 15.62it/s]
100%|██████████| 20/20 [00:00<00:00, 23.33it/s]


In [319]:
from src.dataset.collate import PreprocessPDB

train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=PreprocessPDB().collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=PreprocessPDB().collate_fn)

### Models

In [330]:
from src.models.prepare_models import *

In [None]:
get_model(args, device, dataset_info, dataloaders['train'])