In [165]:
import sidechainnet as scn
import numpy as np
import pandas as pd
import os
from geometricus import get_invariants_for_structures, Geometricus, SplitInfo, SplitType

def load_sidechainnet_data(casp_version=12, casp_thinning=30):
    np.random.seed(0)
    return scn.load(casp_version=casp_version, casp_thinning=casp_thinning)

def process_pdb_content(lines):
    # Extract residue numbers and coordinates
    residues = [int(line.split()[5]) for line in lines]
    
    # Find continuous ranges
    ranges = []
    start = residues[0]
    current_range = [start]
    
    for i in range(1, len(residues)):
        if residues[i] == current_range[-1] + 1:
            current_range.append(residues[i])
        else:
            ranges.append(current_range)
            current_range = [residues[i]]
    
    ranges.append(current_range)
    
    # Find the first range with at least 8 consecutive residues
    valid_range = next((r for r in ranges if len(r) >= 8), None)
    
    if valid_range:
        # Filter lines to keep only residues in the valid range
        return [line for line in lines if int(line.split()[5]) in valid_range]
    
    return None

def export_protein_structures(dataset, limit=20, output_dir='PDB_files'):
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    testing_ids = []
    index = []
    for i in range(limit):
        # Create PDB with modified filename (replacing underscore)
        pdb_filename = os.path.join(output_dir, f"{dataset[i].id.replace('_', '-')}.pdb")
        ca_only_filename = os.path.join(output_dir, f"{dataset[i].id.replace('_', '-')}_ca_only.pdb")
        
        # Export full PDB
        dataset[i].to_pdb(pdb_filename)
        
        # Extract CA-only atoms
        with open(pdb_filename, "r") as file:
            lines = file.readlines()
        
        ca_rows = [line for line in lines if line.startswith("ATOM") and line[13:15] == "CA"]

        # Write CA-only PDB
        with open(ca_only_filename, "w") as output_file:
            output_file.writelines(ca_rows)

        if process_pdb_content(ca_rows):
            testing_ids.append(ca_only_filename)
            index.append(i)

    print(f"{index=}")
    return testing_ids

def compute_moment_invariants(pdb_files, n_threads=4, split_type=SplitType.KMER, split_size=8):
    invariants, _ = get_invariants_for_structures(
        pdb_files, 
        n_threads=n_threads,
        split_infos=[SplitInfo(split_type, split_size)],
        moment_types=["O_3", "O_4", "O_5", "F"]
    )
    return invariants

def log_moment_invariants(multiple_moment_invariants):
    # Take natural log of moments, handling potential zero or negative values
    log_moments = np.log(np.abs(multiple_moment_invariants.moments) + 1e-10)
    
    # Round to nearest integer
    return np.round(log_moments).astype(int)

def remove_duplicate_shapmers(log_moments):
    combined_array = np.vstack(log_moments)
    
    _, unique_indices = np.unique(combined_array, axis=0, return_index=True)
    unique_rows = combined_array[np.sort(unique_indices)]
    
    print("\nNumber of unique rows:", len(unique_rows))
    return unique_rows

def process_protein_structures(limit=20):
    # Load dataset
    dataset = load_sidechainnet_data()
    
    # Export structures
    pdb_files = export_protein_structures(dataset, limit)
    
    # Compute moment invariants
    invariants = compute_moment_invariants(pdb_files)
    
    # Log-transform moments
    log_transformed_moments = [log_moment_invariants(inv) for inv in invariants]
    
    return {
        'dataset': dataset,
        'pdb_files': pdb_files,
        'invariants': invariants,
        'log_moments': log_transformed_moments
    }

results = process_protein_structures()
print("Processed", len(results['pdb_files']), "protein structures")
print("Log-transformed moments for first structure:", results['log_moments'][0:2])
print("Log-transformed moments for all structures (no duplicates):", remove_duplicate_shapmers(results['log_moments']))

SidechainNet was loaded from ./sidechainnet_data/sidechainnet_casp12_30.pkl.
index=[1, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18, 19]
Found 16 protein structures


100%|███████████████████████████████████████████| 16/16 [00:02<00:00,  6.30it/s]


Computed invariants in 2.62 seconds
Processed 16 protein structures
Log-transformed moments for first structure: [array([[ 3,  6,  6,  8],
       [ 4,  8,  7,  9],
       [ 4,  9,  7,  8],
       [ 5, 10,  8,  8],
       [ 5, 11,  8, 10],
       [ 5, 11,  8, 12],
       [ 5, 11,  9, 13],
       [ 6, 11,  9, 14],
       [ 6, 11,  9, 13],
       [ 6, 11,  9, 14],
       [ 6, 10,  9, 14],
       [ 6, 10,  9, 13],
       [ 6, 11,  9, 13],
       [ 6, 11, 10, 14],
       [ 5, 11,  9, 13],
       [ 5, 10,  9, 11],
       [ 5,  9,  8, 10],
       [ 4,  8,  7, 10]]), array([[ 4,  4,  5,  7],
       [ 5,  5,  6, 10],
       [ 5,  6,  7,  9],
       [ 6,  8,  8, 12],
       [ 6,  9,  9, 13],
       [ 6, 10,  9, 14],
       [ 6,  9,  9, 13],
       [ 5,  8,  8, 12],
       [ 5,  6,  7, 11]])]

Number of unique rows: 67
Log-transformed moments for all structures (no duplicates): [[ 3  6  6  8]
 [ 4  8  7  9]
 [ 4  9  7  8]
 [ 5 10  8  8]
 [ 5 11  8 10]
 [ 5 11  8 12]
 [ 5 11  9 13]
 [ 6 11  9 14]


In [161]:
for i in range(9, 10):
    d[i].to_pdb(f"{d[i].id}.pdb")

    with open(f"PDB_files{d[i].id}.pdb", "r") as file:
        lines = file.readlines()
    
    ca_rows = [line for line in lines if line.startswith("ATOM") and line[13:15] == "CA"]
    
    with open(f"{d[i].id}_ca_only.pdb", "w") as output_file:
        output_file.writelines(ca_rows)
    
    with open(f"{d[i].id}_ca_only.pdb", "r") as file:
        for line in file:
            print(line.strip())
    
    print("New PDB file " + f"{d[i].id}_ca_only.pdb" + " created successfully.")

ATOM      2  CA  LYS A  11      -7.723  13.053  26.536  1.00  0.00           C
ATOM      7  CA  SER A  12      -4.106  12.520  27.642  1.00  0.00           C
ATOM     13  CA  SER A  13      -1.587  15.386  27.724  1.00  0.00           C
ATOM     19  CA  PHE A  14       2.107  16.108  28.244  1.00  0.00           C
ATOM     30  CA  PHE A  15       1.143  19.169  30.268  1.00  0.00           C
ATOM     41  CA  SER A  16       0.785  18.860  34.034  1.00  0.00           C
ATOM     47  CA  ASP A  17      -2.506  20.711  34.271  1.00  0.00           C
ATOM     55  CA  ARG A  18      -4.278  18.786  31.508  1.00  0.00           C
ATOM     66  CA  GLY A  19      -5.581  15.269  30.915  1.00  0.00           C
New PDB file 3JRV_2_C_ca_only.pdb created successfully.


In [167]:
output_dir='PDB_files'
os.makedirs(output_dir, exist_ok=True)

for i in range(9, 10):
    pdb_filename = os.path.join(output_dir, f"{dataset[i].id}.pdb")
    ca_only_filename = os.path.join(output_dir, f"{dataset[i].id}_ca_only.pdb")
    
    dataset[i].to_pdb(pdb_filename)
    
    with open(pdb_filename, "r") as file:
        lines = file.readlines()
    
    ca_rows = [line for line in lines if line.startswith("ATOM") and line[13:15] == "CA"]
    
    with open(ca_only_filename, "w") as output_file:
        output_file.writelines(ca_rows)
    
    with open(ca_only_filename, "r") as file:
        for line in file:
            print(line.strip())
    
    print("New PDB file " + ca_only_filename + " created successfully.")

ATOM      2  CA  LYS A  11      -7.723  13.053  26.536  1.00  0.00           C
ATOM      7  CA  SER A  12      -4.106  12.520  27.642  1.00  0.00           C
ATOM     13  CA  SER A  13      -1.587  15.386  27.724  1.00  0.00           C
ATOM     19  CA  PHE A  14       2.107  16.108  28.244  1.00  0.00           C
ATOM     30  CA  PHE A  15       1.143  19.169  30.268  1.00  0.00           C
ATOM     41  CA  SER A  16       0.785  18.860  34.034  1.00  0.00           C
ATOM     47  CA  ASP A  17      -2.506  20.711  34.271  1.00  0.00           C
ATOM     55  CA  ARG A  18      -4.278  18.786  31.508  1.00  0.00           C
ATOM     66  CA  GLY A  19      -5.581  15.269  30.915  1.00  0.00           C
New PDB file PDB_files/3JRV_2_C_ca_only.pdb created successfully.


In [169]:
ca_rows

['ATOM      2  CA  LYS A  11      -7.723  13.053  26.536  1.00  0.00           C  \n',
 'ATOM      7  CA  SER A  12      -4.106  12.520  27.642  1.00  0.00           C  \n',
 'ATOM     13  CA  SER A  13      -1.587  15.386  27.724  1.00  0.00           C  \n',
 'ATOM     19  CA  PHE A  14       2.107  16.108  28.244  1.00  0.00           C  \n',
 'ATOM     30  CA  PHE A  15       1.143  19.169  30.268  1.00  0.00           C  \n',
 'ATOM     41  CA  SER A  16       0.785  18.860  34.034  1.00  0.00           C  \n',
 'ATOM     47  CA  ASP A  17      -2.506  20.711  34.271  1.00  0.00           C  \n',
 'ATOM     55  CA  ARG A  18      -4.278  18.786  31.508  1.00  0.00           C  \n',
 'ATOM     66  CA  GLY A  19      -5.581  15.269  30.915  1.00  0.00           C  \n']

In [171]:
result = process_pdb_content(ca_rows)
result

['ATOM      2  CA  LYS A  11      -7.723  13.053  26.536  1.00  0.00           C  \n',
 'ATOM      7  CA  SER A  12      -4.106  12.520  27.642  1.00  0.00           C  \n',
 'ATOM     13  CA  SER A  13      -1.587  15.386  27.724  1.00  0.00           C  \n',
 'ATOM     19  CA  PHE A  14       2.107  16.108  28.244  1.00  0.00           C  \n',
 'ATOM     30  CA  PHE A  15       1.143  19.169  30.268  1.00  0.00           C  \n',
 'ATOM     41  CA  SER A  16       0.785  18.860  34.034  1.00  0.00           C  \n',
 'ATOM     47  CA  ASP A  17      -2.506  20.711  34.271  1.00  0.00           C  \n',
 'ATOM     55  CA  ARG A  18      -4.278  18.786  31.508  1.00  0.00           C  \n',
 'ATOM     66  CA  GLY A  19      -5.581  15.269  30.915  1.00  0.00           C  \n']