1. Download PDB files:

In [None]:
from Bio.PDB import PDBList
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import sys
import contextlib

def load_pdb_ids(file_path):
    """
    Load PDB IDs from a text file.

    Args:
        file_path (str): Path to the text file containing PDB IDs.

    Returns:
        list: A list of PDB IDs.
    """
    with open(file_path, 'r') as file:
        pdb_ids = file.read().splitlines()
    return pdb_ids

@contextlib.contextmanager
def filter_stdout(filter_words):
    """
    Context manager to filter specific stdout messages.

    Args:
        filter_words (list): List of words to filter out from stdout.

    Yields:
        None
    """
    class FilteredStream:
        def __init__(self, stream):
            self.stream = stream

        def write(self, message):
            if not any(word in message for word in filter_words):
                self.stream.write(message)

        def flush(self):
            self.stream.flush()

    old_stdout = sys.stdout
    sys.stdout = FilteredStream(sys.stdout)
    try:
        yield
    finally:
        sys.stdout = old_stdout

def download_pdb_file(pdb_id, save_dir):
    """
    Download a single PDB file.

    Args:
        pdb_id (str): The PDB ID of the file to download.
        save_dir (str): The directory to save the downloaded PDB file.

    Returns:
        str: A message indicating the result of the download attempt.

    Raises:
        Exception: If there is an error during the download.
    """
    pdbl = PDBList()  # using Biopython's PDBList class
    file_path = os.path.join(save_dir, f"pdb{pdb_id}.ent")

    if not os.path.exists(file_path):
        try:
            with filter_stdout(["Downloading PDB structure", "Desired structure doesn't exist"]):
                pdbl.retrieve_pdb_file(pdb_id, pdir=save_dir, file_format='pdb', overwrite=False)
            return f"Downloaded {pdb_id}"
        except Exception as e:
            return f"Error downloading {pdb_id}: {str(e)}"
    else:
        return f"Skipped {pdb_id}, already exists"

def download_pdb_files(pdb_ids, save_dir='pdb_files', num_threads=16):
    """
    Download PDB files using multiple threads and a progress bar.

    Args:
        pdb_ids (list): List of PDB IDs to download.
        save_dir (str): The directory to save the downloaded PDB files.
        num_threads (int): The number of threads to use for downloading.

    Returns:
        None
    """
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    
    results = []
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        futures = {executor.submit(download_pdb_file, pdb_id, save_dir): pdb_id for pdb_id in pdb_ids}
        for future in tqdm(as_completed(futures), total=len(futures), desc="Downloading PDB files", unit="file", leave=True):
            results.append(future.result())
    
    for result in results:
        print(result)

# Load PDB IDs from supplement-provided text files
train_pdb_ids = load_pdb_ids('train_ids.txt')
test_pdb_ids = load_pdb_ids('test_ids.txt')

# Download PDB files
download_pdb_files(train_pdb_ids, save_dir='pdb_files/train')
download_pdb_files(test_pdb_ids, save_dir='pdb_files/test')

2. Get pairwise distance matrices:

In [None]:
# "To create our datasets,
# we extract non-overlapping fragments of lengths 16, 64, and 128 from chain ‘A’ for each protein
# structure starting at the first residue and calculate the pairwise distance matrices from the alpha-carbon
# coordinate positions"

import os
import numpy as np
from Bio import PDB
from tqdm import tqdm

def load_structure(pdb_file):
    """Load a PDB structure and return the first model and chain A."""
    parser = PDB.PDBParser(QUIET=True)
    structure = parser.get_structure("protein", pdb_file)
    model = structure[0]
    chain_a = model["A"]
    return chain_a

def extract_fragments(chain, fragment_length):
    """Extract non-overlapping fragments of specified length from the chain."""
    fragments = []
    residues = list(chain.get_residues())
    
    for i in range(0, len(residues) - fragment_length + 1, fragment_length):
        fragment = residues[i:i+fragment_length]
        if len(fragment) == fragment_length:
            fragments.append(fragment)
    
    return fragments

def calculate_distance_matrix(fragment):
    """Calculate pairwise distance matrix for alpha carbons in the fragment."""
    coords = []
    for residue in fragment:
        if "CA" in residue:
            coords.append(residue["CA"].coord)
    
    coords = np.array(coords)
    dist_matrix = np.linalg.norm(coords[:, np.newaxis] - coords, axis=2)
    return dist_matrix

def process_pdb_file(pdb_file, fragment_lengths=[16, 64, 128]):
    """Process a single PDB file and return distance matrices for each fragment length."""
    chain = load_structure(pdb_file)
    results = {length: [] for length in fragment_lengths}
    
    for length in fragment_lengths:
        fragments = extract_fragments(chain, length)
        for fragment in fragments:
            dist_matrix = calculate_distance_matrix(fragment)
            if dist_matrix.shape[0] == length:
                results[length].append(dist_matrix)
    
    return results

def create_datasets(pdb_dir, output_dir, fragment_lengths=[16, 64, 128]):
    """Create datasets from PDB files in the specified directory."""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    pdb_files = [f for f in os.listdir(pdb_dir) if f.endswith('.ent') or f.endswith('.pdb')]
    
    for length in fragment_lengths:
        all_matrices = []
        
        for pdb_file in tqdm(pdb_files, desc=f"Processing {length}-residue fragments"):
            results = process_pdb_file(os.path.join(pdb_dir, pdb_file), [length])
            all_matrices.extend(results[length])
        
        dataset = np.array(all_matrices)
        np.save(os.path.join(output_dir, f'distance_matrices_{length}.npy'), dataset)
        print(f"Saved {len(dataset)} matrices of size {length}x{length}")

create_datasets('pdb_files/train', 'datasets/train')
create_datasets('pdb_files/test', 'datasets/test')