In [1]:
# !pip install biopython
# !pip install nglview
# !conda install -c conda-forge nglview

In [2]:
import warnings
warnings.filterwarnings('ignore')

### The CASP12 dataset is just a huge (12gb) ASCII text file with each protein's data in sequence. Open the `test` file if you want to see the structure as it is only a few mb. Here is a parser that loads it into an initial df.

In [3]:
import pandas as pd
from tqdm import tqdm

def parse_protein_data(file_path, n):
    # Initialize lists to store data
    ids = []
    sequences = []
    evolutionary = []
    tertiary = []
    masks = []
    
    # Initialize counters and temporary storage for current protein data
    current_protein = {}
    count = 0
    
    with open(file_path, 'r') as file:
        for line in tqdm(file):
            line = line.strip()
            if line.startswith('[ID]'):
                # Save previous protein data if exists
                if current_protein:
                    ids.append(current_protein.get('ID', ''))
                    sequences.append(current_protein.get('PRIMARY', ''))
                    evolutionary.append(current_protein.get('EVOLUTIONARY', []))
                    tertiary.append(current_protein.get('TERTIARY', []))
                    masks.append(current_protein.get('MASK', ''))
                    count += 1
                    # Break the loop if the required number of proteins have been parsed
                    if count >= n:
                        break
                    current_protein = {}
            
            # Identify the section and append the data to the current protein
            if line.startswith('['):
                key = line[1:line.find(']')]
                current_protein[key] = []
            elif current_protein:
                # Split TSV data into lists of floats or keep as strings depending on the section
                if key in ['EVOLUTIONARY', 'TERTIARY']:
                    # Split the line by spaces, convert each item to float
                    current_protein[key].append([float(x) for x in line.split()])
                else:
                    current_protein[key].append(line)
                
    # Check if last processed protein needs to be added
    if count < n and current_protein:
        ids.append(current_protein.get('ID', ''))
        sequences.append(current_protein.get('PRIMARY', ''))
        evolutionary.append(current_protein.get('EVOLUTIONARY', []))
        tertiary.append(current_protein.get('TERTIARY', []))
        masks.append(current_protein.get('MASK', ''))
    
    # Create DataFrame
    df = pd.DataFrame({
        'ID': ids,
        'Sequence': sequences,
        'Evolutionary': evolutionary,
        'Tertiary': tertiary,
        'Mask': masks
    })
    
    return df

In [4]:
# Usage example
file_path = './data/training_95'
number_of_proteins = 90000  # Set how many proteins you want to load
protein_df = parse_protein_data(file_path, number_of_proteins)
protein_df.head(2)

1680162it [01:07, 24870.41it/s]


Unnamed: 0,ID,Sequence,Evolutionary,Tertiary,Mask
0,[3LAF_1_A],[DRWGSELESSHHHHHHGGRRSLHFVSEPSDAVTMRGGNVLLNCSA...,"[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10...","[[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,...",[------------------+++++++++++++++++++++++++++...
1,[2D5F_d2d5fb1],[NECQLNNLNALEPDHRVESEGGLIETWNSQHPELQCAGVTVSKRT...,"[[0.0, 0.03975101015616468, 0.0, 0.00715892820...","[[472.3, 574.6, 714.3, 726.5, 853.6, 963.7, 92...",[+++++++++++++++++++++++++++++++++++++++++++++...


In [5]:
len(protein_df)

50914

### We need to do some further extraction to get the right format:

In [6]:
def process_dataframe(df):
    # Convert ID, Sequence, and Mask to string directly
    df['ID'] = df['ID'].apply(lambda x: ''.join(x))
    df['Sequence'] = df['Sequence'].apply(lambda x: ''.join(x))
    df['Mask'] = df['Mask'].apply(lambda x: ''.join(x).replace('+', '1').replace('-', '0'))

    # Initialize columns for coordinates and amino acids
    coord_columns = ['x', 'y', 'z']
    aa_columns = ['A','C','D','E','F','G','H','I','K','L','M','N','P','Q','R','S','T','V','W','Y','Info']

    # Expand tertiary coordinates
    for i, col in enumerate(coord_columns):
        df[col] = df['Tertiary'].apply(lambda x: x[i] if len(x) > i else None)

    # Expand evolutionary data into separate columns for each amino acid position
    for i, col in enumerate(aa_columns):
        df[col] = df['Evolutionary'].apply(lambda x: x[i] if len(x) > i else None)

    # Drop the original Tertiary and Evolutionary columns
    df.drop(['Tertiary', 'Evolutionary'], axis=1, inplace=True)

    return df

In [7]:
# Assume 'df' is your DataFrame loaded with data
# Example usage:
processed_df = process_dataframe(protein_df)
processed_df.head(2)

Unnamed: 0,ID,Sequence,Mask,x,y,z,A,C,D,E,...,N,P,Q,R,S,T,V,W,Y,Info
0,3LAF_1_A,DRWGSELESSHHHHHHGGRRSLHFVSEPSDAVTMRGGNVLLNCSAE...,0000000000000000001111111111111111111111111111...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.105...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.037...",...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.126...","[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.729...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.9999999988691709, 0.9999999988691709, 0.999..."
1,2D5F_d2d5fb1,NECQLNNLNALEPDHRVESEGGLIETWNSQHPELQCAGVTVSKRTL...,1111111111111111111111111111111111111111111111...,"[472.3, 574.6, 714.3, 726.5, 853.6, 963.7, 925...","[7239.7, 7348.0, 7287.6, 7174.9, 7104.6, 7189....","[-5819.6, -5811.4, -5816.0, -5884.9, -5896.6, ...","[0.0, 0.03975101015616468, 0.0, 0.007158928206...","[0.0, 0.0, 0.9742406353976603, 0.0008181632235...","[0.0, 0.021185977940373482, 0.0, 0.06463489466...","[0.0, 0.6226930217320084, 0.0, 0.0075680098179...",...,"[0.9999999999999999, 0.0, 0.0, 0.0928615258744...","[0.0, 0.0, 0.0, 0.0, 0.001115392415331576, 0.0...","[0.0, 0.20181282079283608, 0.0, 0.416342810390...","[0.0, 0.014961231844490552, 0.0, 0.27091429740...","[0.0, 0.06202904881511412, 0.0, 0.022090407036...","[0.0, 0.0, 0.0, 0.008488443444467173, 0.006489...","[0.0, 0.01550726220377853, 0.0, 0.0, 0.0484688...","[0.0, 0.0, 0.0074058173231726955, 0.0, 0.0, 0....","[0.0, 0.0, 0.010089084469249759, 0.0, 0.0, 0.0...","[0.9999999988691709, 0.5927681381443177, 0.949..."


## Data description:
- ID is the protein name
- Sequence is the amino acid sequence using the single letter code
- Mask (same length as sequence) is a binary representation of how importatnt each amino acid is to the structure (that is, amino acids marked `0` can be changed without affecting the structure much, whereas amino acids marked `1` would significantly alter structure/function).
- XYZ are the coordinates of each atom (note there are 3 atoms per amino acid, Nitrogen, alpha-Carbon, and Carbon, so there are 3 times as many coordinates as amino acids in the sequence)
- The 20 columns A-Y are normalized PSSM information for each amino acid. This is essentially a sparse matrix where there is a row for each position in the protein sequence, and a column for each of the 20 amino acids.
- Info contains the information content for that residue (unclear what this actually is)
## Note:
Proteins are chains of amino acids, and each amino acid has a backbone of three atoms, N-C-C, which are always linked in the N to C direction: 

`Beginning of protein sequence amino acids:       1       2       3    ... etc`

`Beginning of protein sequence backbone atoms: (N-C-C)-(N-C-C)-(N-C-C)-... etc`

So in our final output, we need to have xyz coordinates for each of the three atoms, in order, for every amino acid.

## NGL - load a sample protein to get the parser working 

https://files.rcsb.org/download/5YHX.pdb

In [8]:
import nglview as nv
from Bio.PDB import PDBParser

# Load your PDB file
parser = PDBParser()
structure = parser.get_structure('Sample', './data/5yhx.pdb')

# Create NGLView widget for the structure
view = nv.show_biopython(structure)
view.clear_representations()

# Add molecular graphics: e.g., cartoon for secondary structures and ball+stick for atoms
view.add_representation('cartoon', selection='protein', color='blue')
view.add_representation('ball+stick', selection='ligand')

view



NGLWidget()

## Here is a parser to turn a row of our dataframe into a PDB file that can be viewed with NGL. Note that our dataset doesn't have info on the secondary structure, so these structures will not appear bolded in the output. The PDB specification is available here:

https://www.biostat.jhsph.edu/~iruczins/teaching/260.655/links/pdbformat.pdf

In [9]:
import pandas as pd

def format_pdb_from_df(row):
    pdb_str = ""
    atom_count = 1  # Starting index for ATOM records
    chain_id = 'A'  # Assuming a single chain for simplicity
    res_num = 1     # Residue number
    occupancy = 1.00
    t_factor = 50.00

    atom_types = ['N', 'CA', 'C']  # Backbone atoms
    element_types = {'N': 'N', 'CA': 'C', 'C': 'C'}  # Elements for atoms
    residue_mapping = {'A': 'ALA','R': 'ARG','N': 'ASN','D': 'ASP','C': 'CYS','E': 'GLU','Q': 'GLN','G': 'GLY',
                       'H': 'HIS','I': 'ILE','L': 'LEU','K': 'LYS','M': 'MET','F': 'PHE','P': 'PRO','S': 'SER',
                       'T': 'THR','W': 'TRP','Y': 'TYR','V': 'VAL' } # PDB requires the 3 letter codes for amino acids

    for i in range(len(row['Sequence'])):
        residue = residue_mapping[row['Sequence'][i]]
        # Loop over each backbone atom type
        for j, atom_type in enumerate(atom_types):
            # Index to pull the correct coordinates from flattened list
            idx = 3 * i + j
            x = row['x'][idx]/100
            y = row['y'][idx]/100
            z = row['z'][idx]/100
            element = element_types[atom_type]

            pdb_str += f"ATOM  {atom_count:>5}  {atom_type:<2}  {residue:>3} {chain_id}{res_num:>4}    {x:>8.3f}{y:>8.3f}{z:>8.3f}{occupancy:>6.2f}{t_factor:>6.2f}          {element:>2}  \n"
            atom_count += 1

        res_num += 1  # Increment residue number for each new amino acid

    return pdb_str

In [10]:
protein_row = processed_df.iloc[10]

pdb_content = format_pdb_from_df(protein_row)
print(pdb_content)

# Optionally, write to a file
with open("./data/sample_protein.pdb", "w") as file:
    file.write(pdb_content)

ATOM      1  N   ASN A   1      15.878  29.275  76.731  1.00 50.00           N  
ATOM      2  CA  ASN A   1      15.137  28.332  75.915  1.00 50.00           C  
ATOM      3  C   ASN A   1      15.539  26.931  76.360  1.00 50.00           C  
ATOM      4  N   LYS A   2      14.647  26.274  77.101  1.00 50.00           N  
ATOM      5  CA  LYS A   2      14.972  25.005  77.746  1.00 50.00           C  
ATOM      6  C   LYS A   2      15.109  23.930  76.692  1.00 50.00           C  
ATOM      7  N   THR A   3      16.130  23.093  76.818  1.00 50.00           N  
ATOM      8  CA  THR A   3      16.246  21.915  75.966  1.00 50.00           C  
ATOM      9  C   THR A   3      15.103  20.938  76.293  1.00 50.00           C  
ATOM     10  N   SER A   4      14.511  20.344  75.257  1.00 50.00           N  
ATOM     11  CA  SER A   4      13.358  19.451  75.405  1.00 50.00           C  
ATOM     12  C   SER A   4      13.070  18.756  74.091  1.00 50.00           C  
ATOM     13  N   GLU A   5  

In [11]:
import nglview as nv
from Bio.PDB import PDBParser

# Load your PDB file
parser = PDBParser()
structure = parser.get_structure('Sample', './data/sample_protein.pdb')

# Create NGLView widget for the structure
view = nv.show_biopython(structure)
view.clear_representations()

# Add molecular graphics: e.g., cartoon for secondary structures and ball+stick for atoms
view.add_representation('cartoon', selection='protein', color='blue')
view.add_representation('ball+stick', selection='ligand')

view

NGLWidget()

## add features for GNN
### For nodes (amino acids) we will add some of their chemical properties)

In [12]:
def amino_acid_properties():
    properties = {
        'Node_Hydrophobic': "AYCMVLWIF",
        'Node_Hydrophilic': "RKNEPD",
        'Node_Polar': "YWHKREQDNST",
        'Node_Small': "VCAGDNSTP",
        # 'Node_Proline': "P",
        'Node_Tiny': "AGS",
        'Node_Aliphatic': "ILV",
        'Node_Aromatic': "FYWH",
        'Node_Positive': "HKR",
        'Node_Negative': "DE",
        'Node_Charged': "HKRED"
    }
    return properties

def encode_properties(seq, properties):
    # Initialize the dictionary for property encodings
    property_encodings = {key: "" for key in properties}
    
    # Iterate through each amino acid in the sequence
    for aa in seq:
        for prop, aas in properties.items():
            if aa in aas:
                property_encodings[prop] += '1'
            else:
                property_encodings[prop] += '0'
    
    return property_encodings

In [13]:
from tqdm import tqdm

df_features = processed_df#.copy()

# Get amino acid properties
props = amino_acid_properties()

# Encode properties for each sequence
for i, prop in enumerate(props):
    tqdm.pandas(desc=f"Applying properties {i+1}")
    df_features[prop] = df_features['Sequence'].progress_apply(lambda x: encode_properties(x, props)[prop])

df_features.rename(columns={#"Sequence": "Node_Sequence",
                            "Mask": "Node_Mask",
                            "ID": "Protein_Name",}, inplace=True)
df_features.head(1)

Applying properties 1: 100%|█████████████| 50914/50914 [00:08<00:00, 5757.79it/s]
Applying properties 2: 100%|█████████████| 50914/50914 [00:08<00:00, 5745.54it/s]
Applying properties 3: 100%|█████████████| 50914/50914 [00:08<00:00, 5801.93it/s]
Applying properties 4: 100%|█████████████| 50914/50914 [00:08<00:00, 5781.08it/s]
Applying properties 5: 100%|█████████████| 50914/50914 [00:08<00:00, 5796.80it/s]
Applying properties 6: 100%|█████████████| 50914/50914 [00:08<00:00, 5783.17it/s]
Applying properties 7: 100%|█████████████| 50914/50914 [00:08<00:00, 5769.86it/s]
Applying properties 8: 100%|█████████████| 50914/50914 [00:08<00:00, 5745.16it/s]
Applying properties 9: 100%|█████████████| 50914/50914 [00:08<00:00, 5749.65it/s]
Applying properties 10: 100%|████████████| 50914/50914 [00:08<00:00, 5750.77it/s]


Unnamed: 0,Protein_Name,Sequence,Node_Mask,x,y,z,A,C,D,E,...,Node_Hydrophobic,Node_Hydrophilic,Node_Polar,Node_Small,Node_Tiny,Node_Aliphatic,Node_Aromatic,Node_Positive,Node_Negative,Node_Charged
0,3LAF_1_A,DRWGSELESSHHHHHHGGRRSLHFVSEPSDAVTMRGGNVLLNCSAE...,0000000000000000001111111111111111111111111111...,"[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.105...","[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 1.0, 0.037...",...,0010001000000000000001011000001101000011101010...,1100010100000000001100000011010000100100010001...,1110110111111111001110100110110010100100010101...,1001100011000000110010001101111110011110011110...,0001100011000000110010000100101000011000000110...,0000001000000000000001001000000100000011100000...,0010000000111111000000110000000000000000000000...,0100000000111111001100100000000000100000000000...,1000010100000000000000000010010000000000000001...,1100010100111111001100100010010000100000000001...


## Functions to compute global edge interactions (between all nodes)

In [14]:
import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist, squareform
from tqdm import tqdm

def calculate_distances(coords):
    """Calculate the Euclidean distance matrix for a set of coordinates."""
    return squareform(pdist(coords, 'euclidean'))

def calculate_centroid(coordinates):
    """Calculate the centroid from a list of 3D coordinates."""
    return np.mean(coordinates, axis=0)

def calculate_residue_distance_matrix(centroid_list):
    """Calculate the Euclidean distance matrix for a list of centroids."""
    return squareform(pdist(centroid_list, 'euclidean'))

def calculate_sequence_distance(seq_length):
    """Calculate the sequence distance matrix for a protein sequence."""
    indices = np.arange(seq_length)
    seq_dist = np.abs(indices[:, None] - indices[None, :])
    return seq_dist

def calculate_contact_map(dist_matrix, threshold=16.0):
    """Calculate the contact map for a distance matrix given a threshold, ignoring distances of 0."""
    contact_map = (dist_matrix <= threshold).astype(int)
    contact_map[dist_matrix == 0] = 0
    return contact_map

## Functions to compute local edge interactions (between neighboring nodes)

In [15]:
def conservation_interaction_matrix(mask):
    """Generate a matrix where entries are 1 if both residues are conserved, 0 otherwise."""
    n = len(mask)
    conservation_matrix = np.zeros((n, n), dtype=int)
    for i in range(1, n):  # Start from 1 to be able to access i-1
        if mask[i] == '1' and mask[i-1] == '1':
            conservation_matrix[i, i-1] = 1  # Interaction with previous residue
        if i < n - 1 and mask[i] == '1' and mask[i+1] == '1':
            conservation_matrix[i, i+1] = 1  # Interaction with next residue
    return conservation_matrix

def calculate_charge_interactions(sequence):
    """Calculate electrostatic interaction matrix for a given amino acid sequence."""
    n = len(sequence)
    charge_matrix = np.zeros((n, n), dtype=int)
    charge_map = {'D': -1, 'E': -1, 'K': 1, 'R': 1, 'H': 1}  # Example simplified charges

    for i in range(1, n):
        if sequence[i] in charge_map and sequence[i-1] in charge_map:
            if charge_map[sequence[i]] * charge_map[sequence[i-1]] == -1:
                charge_matrix[i, i-1] = 1
        if i < n - 1 and sequence[i] in charge_map and sequence[i+1] in charge_map:
            if charge_map[sequence[i]] * charge_map[sequence[i+1]] == -1:
                charge_matrix[i, i+1] = 1

    return charge_matrix

import numpy as np

def assign_interaction_group(aa):
    """Assign interaction group to amino acids."""
    hydrophobic = {'L', 'I', 'M', 'P', 'V', 'A', 'G', 'F', 'W'}
    hydrophilic = {'D', 'E', 'H', 'K', 'R', 'Q', 'N', 'C', 'T', 'S'}
    if aa in hydrophobic: return -1  # Hydrophobic
    elif aa in hydrophilic: return 1  # Hydrophilic
    return 0  # Neither

def calculate_water_interactions(sequence):
    """Calculate the interaction matrix for immediate neighbors in the sequence."""
    n = len(sequence)
    interaction_scores = np.array([assign_interaction_group(aa) for aa in sequence])
    
    # Create an empty matrix with zeros
    interaction_matrix = np.zeros((n, n), dtype=int)

    # Only set values for immediate neighbors
    for i in range(1, n):
        if interaction_scores[i] == interaction_scores[i-1] and interaction_scores[i] != 0:
            interaction_matrix[i, i-1] = interaction_scores[i]
            interaction_matrix[i-1, i] = interaction_scores[i]  # Ensure symmetry

    return interaction_matrix

## apply all functions

In [14]:
# rescale distances to Angstroms
tqdm.pandas(desc="Calculating x coords")
df_features['x'] = df_features['x'].progress_apply(lambda x: [round(val / 100, 3) for val in x])
tqdm.pandas(desc="Calculating y coords")
df_features['y'] = df_features['y'].progress_apply(lambda x: [round(val / 100, 3) for val in x])
tqdm.pandas(desc="Calculating z coords")
df_features['z'] = df_features['z'].progress_apply(lambda x: [round(val / 100, 3) for val in x])
tqdm.pandas(desc="Calculating zipped coords")
df_features['Atom_Coordinates'] = df_features.progress_apply(lambda row: list(zip(row['x'], row['y'], row['z'])), axis=1)

Calculating x coords: 100%|██████████████| 50914/50914 [00:07<00:00, 6744.29it/s]
Calculating y coords: 100%|██████████████| 50914/50914 [00:15<00:00, 3222.88it/s]
Calculating z coords: 100%|██████████████| 50914/50914 [00:08<00:00, 6233.52it/s]
Calculating zipped coords: 100%|████████| 50914/50914 [00:02<00:00, 22101.76it/s]


In [None]:
# Calculate atom distance matrix and add to dataframe
tqdm.pandas(desc="Calculating 3D atom distance matrices")
df_features['Atom_3D_dists'] = df_features['Atom_Coordinates']\
                                 .progress_apply(lambda coords: calculate_distances(np.array(coords)))

# Calculate centroid for each residue
tqdm.pandas(desc="Calculating residue centroids")
df_features['Node_centroids'] = df_features['Atom_Coordinates']\
        .progress_apply(lambda coords: [calculate_centroid(coords[i:i+3]) for i in range(0, len(coords), 3)])
    
# Calculate distance matrix between centroids
tqdm.pandas(desc="Calculating residue 3D distances")
df_features['Edge_3D_dists_global'] = df_features['Node_centroids']\
        .progress_apply(lambda centroids: calculate_residue_distance_matrix(np.array(centroids)))

# Calculate sequence distance matrix and add to dataframe
tqdm.pandas(desc="Calculating sequence distance matrices")
df_features['Edge_1D_dists_global'] = df_features['Sequence']\
                                    .progress_apply(lambda seq: calculate_sequence_distance(len(seq)))

# Calculate which other atoms are within 16A, which is close enough to make a secondary structure
tqdm.pandas(desc="Calculating atom contact map matrices")
df_features['Atom_proximity'] = df_features['Atom_3D_dists']\
                                  .progress_apply(lambda dist_matrix: calculate_contact_map(dist_matrix))

# Calculate which other residues are within 16A, which is close enough to make a secondary structure
tqdm.pandas(desc="Calculating residue contact map matrices")
df_features['Edge_proximity_global'] = df_features['Edge_3D_dists_global']\
        .progress_apply(lambda dist_matrix: calculate_contact_map(dist_matrix))

# Calcluate which edges are highly conserved
tqdm.pandas(desc="Calculating conservation matrices")
df_features['Edge_conserved_local'] = df_features['Node_Mask']\
                                .progress_apply(lambda mask: conservation_interaction_matrix(mask))

# Calcluate charge interactions
tqdm.pandas(desc="Calculating charge matrices")
df_features['Edge_charge_local'] = df_features['Sequence']\
                                .progress_apply(lambda seq: calculate_charge_interactions(seq))

# Calcluate hydrophobicity interactions
tqdm.pandas(desc="Calculating solubility matrices")
df_features['Edge_solubility_local'] = df_features['Sequence']\
                                .progress_apply(lambda seq: calculate_water_interactions(seq))

Calculating x coords: 100%|██████████████| 25299/25299 [00:03<00:00, 7201.35it/s]
Calculating y coords: 100%|██████████████| 25299/25299 [00:03<00:00, 7205.87it/s]
Calculating z coords: 100%|██████████████| 25299/25299 [00:03<00:00, 7173.25it/s]
Calculating zipped coords: 100%|████████| 25299/25299 [00:00<00:00, 29881.12it/s]
Calculating 3D atom distance matrices:  41%|▍| 10301/25299 [00:36<00:51, 289.01it

### export

In [15]:
import json
# Assuming df is your DataFrame
def protein_to_jsonl(row):
    # Split atom coordinates into groups by atom type: N, CA, C
    n_coords = row['Atom_Coordinates'][0::3]  # Assuming every third starting from 0 is N
    ca_coords = row['Atom_Coordinates'][1::3]  # Assuming every third starting from 1 is CA
    c_coords = row['Atom_Coordinates'][2::3]  # Assuming every third starting from 2 is C

    # Create a dictionary to hold the protein data
    protein_data = {
        "name": row['Protein_Name'],
        "conserved": row['Node_Mask'],
        "num_chains": 1,  # Assuming all entries have one chain; adjust as necessary
        "seq": row['Sequence'],
        "coords": {
            "N": n_coords,
            "CA": ca_coords,
            "C": c_coords
        }
    }
    return json.dumps(protein_data)  # Convert dictionary to JSON string

# Apply the function to each row and save the output as a JSONL file
jsonl_data = df_features.apply(protein_to_jsonl, axis=1)

# Write to a JSONL file
with open('training_95_processed.jsonl', 'w') as outfile:
    for entry in jsonl_data:
        outfile.write(entry + '\n')

In [16]:
# Shuffle the DataFrame
df_shuffled = df_features.sample(frac=1).reset_index(drop=True)

# Calculate split indices
train_end = int(len(df_shuffled) * 0.7)
val_end = train_end + int(len(df_shuffled) * 0.15)

# Split the DataFrame into train, validation, and test
train_proteins = df_shuffled.iloc[:train_end]['Protein_Name'].tolist()
val_proteins = df_shuffled.iloc[train_end:val_end]['Protein_Name'].tolist()
test_proteins = df_shuffled.iloc[val_end:]['Protein_Name'].tolist()

# Create dictionary to be written as JSON
split_dict = {
    'train': train_proteins,
    'validation': val_proteins,
    'test': test_proteins
}

# Write to JSON file
with open('training_95_split.json', 'w') as json_file:
    json.dump(split_dict, json_file, indent=4)

### convert strings to arrays and map amino acids to numbers

In [16]:
def convert_strings_to_char_arrays(df, columns):
    aa_to_int = {'A': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, 'I': 8, 'K': 9,
                 'L': 10, 'M': 11, 'N': 12, 'P': 13, 'Q': 14, 'R': 15, 'S': 16, 'T': 17,
                 'V': 18, 'W': 19, 'Y': 20}
    # Apply conversion to the 'Sequence' column specifically
    df['Sequence'] = df['Sequence'].apply(lambda seq: [aa_to_int[aa] for aa in seq])
    
    # Define a vectorized function to split strings into arrays of characters
    vectorized_split = np.vectorize(list)
    # Apply the vectorized function to the specified columns
    for column in columns:
        df[column] = df[column].apply(vectorized_split)
    return df
    
# Specify the columns you want to convert
columns_to_convert = ['Node_Mask', 'Node_Hydrophobic', 'Node_Hydrophilic',
                      'Node_Polar', 'Node_Small', 'Node_Tiny', 'Node_Aliphatic',          
                      'Node_Aromatic', 'Node_Positive', 'Node_Negative', 'Node_Charged']          

# Convert the specified columns
df_features = convert_strings_to_char_arrays(df_features, columns_to_convert)
df_features.drop(columns=['x','y','z'],inplace=True)

## Current features

In [17]:
df_features.iloc[0].T

Protein_Name                                                   90#2WXZ_2_C
Sequence                 [3, 15, 18, 20, 8, 7, 13, 5, 7, 10, 10, 20, 20...
Node_Mask                [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
A                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.012...
C                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.002...
D                        [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
E                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
F                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.016...
G                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
H                        [0.0, 0.0, 0.0, 0.021224489795918365, 0.0, 1.0...
I                        [0.0, 0.0, 0.012965050732807213, 0.0, 0.512043...
K                        [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.011...
L                        [0.0, 0.0, 0.02367531003382187, 0.0, 0.0, 0.0,...
M                        

In [18]:
df_features.columns

Index(['Protein_Name', 'Sequence', 'Node_Mask', 'A', 'C', 'D', 'E', 'F', 'G',
       'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y',
       'Info', 'Node_Hydrophobic', 'Node_Hydrophilic', 'Node_Polar',
       'Node_Small', 'Node_Tiny', 'Node_Aliphatic', 'Node_Aromatic',
       'Node_Positive', 'Node_Negative', 'Node_Charged', 'Atom_Coordinates',
       'Atom_3D_dists', 'Node_centroids', 'Edge_3D_dists_global',
       'Edge_1D_dists_global', 'Atom_proximity', 'Edge_proximity_global',
       'Edge_conserved_local', 'Edge_charge_local', 'Edge_solubility_local'],
      dtype='object')

## model

In [19]:
# !pip install torch torch-geometric ogb

In [20]:
import torch
from torch_geometric.data import Data, DataLoader
from ogb.graphproppred import PygGraphPropPredDataset, Evaluator
from torch_geometric.nn import GCNConv, global_mean_pool
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler

In [21]:
df = df_features.copy(deep=True)

In [22]:
def df_to_graph(row):
    # node features (assuming they are already properly scaled and do not include the raw sequence anymore)
    features = np.vstack([ row['Sequence'],
        row['A'], row['C'], row['D'], row['E'], row['F'],
        row['G'], row['H'], row['I'], row['K'], row['L'],
        row['M'], row['N'], row['P'], row['Q'], row['R'],
        row['S'], row['T'], row['V'], row['W'], row['Y'],
        row['Node_Hydrophobic'], row['Node_Hydrophilic'],
        row['Node_Polar'], row['Node_Small'], row['Node_Tiny'],
        row['Node_Aliphatic'], row['Node_Aromatic'], row['Node_Positive'],
        row['Node_Negative'], row['Node_Charged'], row['Info'],
    ]).astype(float).T  # Ensure all data is float

    # Combine sequence encoding with additional features
    node_features = torch.tensor(features, dtype=torch.float)

    # Edge indices for consecutive residues
    edge_index = torch.tensor(np.vstack([np.arange(len(row['Sequence']) - 1), np.arange(1, len(row['Sequence']))]), dtype=torch.long)

    # Edge attributes
    edge_attributes = torch.tensor(np.vstack([
        row['Edge_1D_dists_global'],
        row['Edge_conserved_local'],
        row['Edge_charge_local'],
        row['Edge_solubility_local']
    ]).T, dtype=torch.float)  # Transpose to align features with edges

    # Target (residue centroids)
    target = torch.tensor(row['Node_centroids'], dtype=torch.float)

    return Data(x=node_features, edge_index=edge_index, edge_attr=edge_attr, y=target)

# Apply this function to a DataFrame row by row
# graphs = df.apply(df_to_graph, axis=1)  # This might still cause an error due to inconsistent feature lengths across rows.

In [23]:
# Example of applying this function to your DataFrame
graphs = df.apply(df_to_graph, axis=1)

NameError: name 'edge_attr' is not defined

In [None]:
df.iloc[20]

### plotting functions from features

In [None]:
# import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns

# def plot_distance_distribution_specific_atoms(dist_matrix, start_index=630, end_index=666):
#     """
#     Plot the distribution of distances between a specific range of atoms in a protein.
    
#     Args:
#     - dist_matrix (numpy.ndarray): The distance matrix of a protein.
#     - start_index (int): Start index of the atom range.
#     - end_index (int): End index of the atom range (inclusive).
#     """
#     # Extract the specific submatrix for the desired range
#     sub_matrix = dist_matrix[start_index:end_index+1, start_index:end_index+1]
    
#     # Flatten the upper triangle of the submatrix, ignoring diagonal (distance 0)
#     distances = sub_matrix[np.triu_indices_from(sub_matrix, k=1)]
    
#     # Plot the distribution of distances
#     plt.figure(figsize=(10, 6))
#     sns.histplot(distances, bins=50, kde=True)
#     plt.title(f'Distribution of Distances Between Atoms {start_index} to {end_index}')
#     plt.xlabel('Distance (Å)')
#     plt.ylabel('Frequency')
#     plt.grid(True)
#     plt.show()

# # Example usage:
# # Assuming 'dist_matrix' is already defined and contains the distances for all atoms
# plot_distance_distribution_specific_atoms(dist_matrix)


In [None]:
# import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns

# def plot_distance_distribution_specific_atoms(dist_matrix, start_index=0, end_index=5000):
#     """
#     Plot the distribution of distances between a specific range of atoms in a protein.
    
#     Args:
#     - dist_matrix (numpy.ndarray): The distance matrix of a protein.
#     - start_index (int): Start index of the atom range.
#     - end_index (int): End index of the atom range (inclusive).
#     """
#     # Extract the specific submatrix for the desired range
#     sub_matrix = dist_matrix[start_index:end_index+1, start_index:end_index+1]
    
#     # Flatten the upper triangle of the submatrix, ignoring diagonal (distance 0)
#     distances = sub_matrix[np.triu_indices_from(sub_matrix, k=1)]
    
#     # Plot the distribution of distances
#     plt.figure(figsize=(4, 3))
#     sns.histplot(distances, bins=2,)
#     plt.title(f'Distribution of Distances Between Atoms {start_index} to {end_index}')
#     plt.xlabel('Distance (Å)')
#     plt.ylabel('Frequency')
#     plt.grid(True)
#     plt.show()

# # Example usage:
# # Assuming 'dist_matrix' is already defined and contains the distances for all atoms
# plot_distance_distribution_specific_atoms(df_features['Edge_contact_map'].iloc[0])


In [None]:
# contact matrix visualizer
import matplotlib.pyplot as plt
import numpy as np

def plot_section_of_array(array, start_row, start_col, max):
    """
    Plot a 20x20 section of a numpy 2D array starting from specified row and column indices.
    
    Args:
    - array (numpy.ndarray): The 2D array to visualize.
    - start_row (int): The starting row index for the 20x20 section.
    - start_col (int): The starting column index for the 20x20 section.
    """
    if array.shape[0] < start_row + max or array.shape[1] < start_col + max:
        raise ValueError("Array does not contain enough rows or columns for a 20x20 section from the starting index.")
    
    # Extract the 20x20 section from the array
    section = array[start_row:start_row + max, start_col:start_col + max]
    
    # Plot the section using imshow
    plt.figure(figsize=(6, 6))
    plt.imshow(section, cmap='viridis', interpolation='nearest')
    plt.colorbar()
    plt.title('20x20 Section of Array')
    plt.show()

# Example usage:
# Assume 'large_array' is a numpy 2D array larger than 20x20
mat = df_features['Edge_proximity_global'].iloc[0]
plot_section_of_array(mat, 0, 0, len(mat))  # Adjust indices based on your array dimensions

In [None]:
# import pandas as pd

# def format_pdb_from_df1(row):
#     pdb_str = ""
#     atom_count = 1  # Starting index for ATOM records
#     chain_id = 'A'  # Assuming a single chain for simplicity
#     res_num = 1     # Residue number
#     occupancy = 1.00
#     t_factor = 50.00

#     # Only keep the alpha carbon (CA)
#     atom_types = ['CA']  # Only backbone alpha carbon atoms
#     element_types = {'CA': 'C'}  # Element for alpha carbon
    
#     residue_mapping = {'A': 'ALA', 'R': 'ARG', 'N': 'ASN', 'D': 'ASP', 'C': 'CYS', 
#                        'E': 'GLU', 'Q': 'GLN', 'G': 'GLY', 'H': 'HIS', 'I': 'ILE', 
#                        'L': 'LEU', 'K': 'LYS', 'M': 'MET', 'F': 'PHE', 'P': 'PRO', 
#                        'S': 'SER', 'T': 'THR', 'W': 'TRP', 'Y': 'TYR', 'V': 'VAL'}

#     for i in range(len(row['Sequence'])):
#         residue = residue_mapping[row['Sequence'][i]]
#         # We know the alpha carbon is the second atom in the list for each residue (1,4,7,...)
#         idx = 3 * i + 1  # Update index to only point to alpha carbon
#         x = row['x'][idx]/100
#         y = row['y'][idx]/100
#         z = row['z'][idx]/100
#         element = element_types['CA']

#         pdb_str += f"ATOM  {atom_count:>5}  {atom_types[0]:<2}  {residue:>3} {chain_id}{res_num:>4}    {x:>8.3f}{y:>8.3f}{z:>8.3f}{occupancy:>6.2f}{t_factor:>6.2f}          {element:>2}  \n"
#         atom_count += 1
#         res_num += 1  # Increment residue number for each new amino acid

#     return pdb_str

# test = format_pdb_from_df1(protein_row)

In [33]:
import json

def load_json_to_dict(filepath):
    """
    Load JSON data from a file into a dictionary.

    Args:
    filepath (str): The path to the JSON file.

    Returns:
    dict: The dictionary containing the JSON data.
    """
    try:
        with open(filepath, 'r') as file:
            data_dict = json.load(file)
        print("JSON data successfully loaded.")
        return data_dict
    except FileNotFoundError:
        print("The file was not found.")
        return None
    except json.JSONDecodeError:
        print("Error decoding JSON from the file.")
        return None
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        return None

In [36]:
# Example usage
# Assuming the JSON structure is a list of objects
json_filepath = '../neurips19-graph-protein-design/data/cath/chain_set_splits.json'
sp = load_json_to_dict(json_filepath)
sp.keys()
# df_from_string = json_string_to_df(json_string)
# print("DataFrame from JSON string:")
# print(df_from_string)

JSON data successfully loaded.


dict_keys(['test', 'cath_nodes', 'train', 'validation'])

In [67]:
import json

def remove_key_from_jsonl(input_file_path, output_file_path, key_to_remove):
    """
    Reads a JSONL file, removes a specified key from the 'coords' dictionary within each JSON object, and writes the result to a new JSONL file.

    Args:
    input_file_path (str): Path to the input JSONL file.
    output_file_path (str): Path to the output JSONL file where the modified data will be saved.
    key_to_remove (str): The key to remove from the 'coords' dictionary of each JSON object.
    """
    with open(input_file_path, 'r') as input_file, open(output_file_path, 'w') as output_file:
        for line in input_file:
            data = json.loads(line)  # Convert JSON string to dictionary
            if key_to_remove in data['coords']:
                del data['coords'][key_to_remove]  # Remove the key from the 'coords' dictionary
            output_file.write(json.dumps(data) + '\n')  # Convert dictionary back to JSON string and write to output file


# Usage example
input_file_path = '../neurips19-graph-protein-design/data/cath/chain_set.jsonl'
output_file_path = '../neurips19-graph-protein-design/data/cath/chain_set_updated.jsonl'
remove_key_from_jsonl(input_file_path, output_file_path, 'O')

In [1]:
remove_set = {'2D6F_3_C', '2I7X_1_A', '4UDX_1_X', '3JBX_2_B', '4R4X_1_A', '3NG9_1_A', '3OHN_1_A', '5HCC_1_B', '1TMO_1_A', '1C96_1_A', '2QO3_1_A', '3GZS_1_A', '3EN9_1_A', '4OM9_1_A', '5C6G_1_A', '4DDU_1_A', '4KSI_1_A', '3FIM_1_B', '2D2O_1_A', '5DFZ_5_D', '4OUR_2_B', '3SAM_1_A', '4ZAJ_1_A', '3BGA_1_A', '3G25_1_A', '3WEO_1_A', '4RH3_1_A', '4G26_1_A', '3JZY_1_A', '4K7H_1_A', '2QYU_1_A', '4QNL_1_A', '4QBD_1_A', '4KF7_1_A', '5E5N_1_A', '5B2Q_1_A', '4CT0_1_A', '4GAA_1_A', '2QSF_1_A', '3U4J_1_A', '1KDG_1_A', '4HAW_3_C', '5DZT_1_A', '2AHV_1_A', '4CRM_1_P', '4K0J_1_A', '5AIJ_1_A', '2XFN_1_A', '3L0Q_1_A', '2XFR_1_A', '4BHW_1_A', '2VSE_1_A', '5IXQ_1_A', '4BOT_4_E', '4B2T_8_Z', '4TXI_1_A', '3UA4_1_A', '4U1C_1_C', '3UMV_1_A', '1KL7_1_A', '3IYL_1_A', '2C5G_1_A', '2GKS_1_A', '4WVM_1_A', '3QVP_1_A', '3WPU_1_A', '3FOT_1_A', '5GAS_8_N', '1EDQ_1_A', '4ZCF_2_C', '4OXI_1_A', '5FIR_1_A', '1ON9_1_A', '5EAD_1_A', '3AQP_1_A', '4DWS_4_D', '2QE7_1_A', '5ISU_1_A', '4UAQ_1_A', '3JB9_22_c', '4A82_1_A', '4C3X_1_A', '5A01_1_A', '3UOX_1_A', '3DJL_1_A', '1R6V_1_A', '5BN3_1_A', '4YBQ_1_A', '3J3R_2_A', '4V1T_1_A', '4QYR_1_A', '5EJB_1_B', '1W7P_3_D', '1TR2_1_A', '4ZIU_1_A', '4OK4_1_A', '5FJA_1_A', '3JCN_33_b', '1L9N_1_A', '1ZPU_1_A', '2WSD_1_A', '2DE0_1_X', '4H3S_1_A', '3PUK_1_A', '4EPL_1_A', '3AYJ_1_A', '4Y28_2_B', '5EPC_1_A', '5ELP_1_C', '1QVR_1_A', '1HN0_1_A', '4XMN_2_B', '4CMR_1_A', '3W0L_2_B', '5I6I_1_A', '3WXL_1_A', '2ZXR_1_A', '5C78_1_A', '3QGK_2_C', '4K17_1_A', '2VNU_2_D', '3J3I_1_A', '3FN9_1_A', '4BXZ_13_X', '5AM2_1_A', '2YEV_1_A', '4UPH_1_A', '3J26_1_A', '1GPE_1_A', '4B7O_1_A', '3OKY_2_B', '3BMX_1_A', '3JCR_7_B', '1ITZ_1_A', '2WQD_1_A', '2BVL_1_A', '4JDC_1_A', '4UFC_1_A', '2ZTG_1_A', '3IHG_1_A', '1YGY_1_A', '3IFQ_1_A', '2IUF_1_A', '5BV2_1_P', '4V5T_1_AA', '3CUZ_1_A', '3PUI_1_A', '4H5Y_1_A', '4ZW9_1_A', '1MTY_1_D', '5A1Y_2_C', '1Q57_1_A', '1Y4U_1_A', '4AT2_1_A', '2FOK_1_A', '4R0Z_1_A', '2E6K_1_A', '4TLM_2_B', '3J9V_1_b', '4J2D_1_A', '3AYX_1_A', '4UM2_1_A', '2A91_1_A', '3VIU_1_A', '1PIX_1_A', '3EDY_1_A', '3VA7_1_A', '3WE0_1_A', '5I6H_1_A', '2ZKM_1_X', '4UM8_2_B', '2MYS_1_A', '1H41_1_A', '2Z0F_1_A', '1X1N_1_A', '5DFM_1_A', '3B5X_1_A', '2YHE_1_A', '1TG7_1_A', '2VZS_1_A', '3PVM_2_B', '1Z1N_1_X', '4IV9_1_A', '4J2A_d4j2aa2', '4GUA_1_A', '3V2A_1_R', '3EPO_1_A', '3TCH_1_A', '1OHF_1_A', '2V9K_1_A', '1NTL_1_A', '1D0N_1_A', '3AK5_1_A', '3G4D_1_A', '4D04_1_A', '1IOK_1_A', '2WYO_1_A', '1YRZ_1_A', '2XZO_1_A', '4YDD_1_A', '5E9T_1_A', '2VXO_1_A', '4GZU_1_A', '1YQT_1_A', '4AM6_1_A', '4ZLF_1_A', '4NUF_1_A', '4AIE_1_A', '5A9Q_6_5', '4OO1_10_J', '2FGY_1_A', '2VEQ_1_A', '5B04_5_I', '4JE5_2_D', '4AD8_1_A', '4GY7_1_A', '2O1U_1_A', '1BF2_1_A', '4DXC_1_A', '4RF6_1_A', '4FYT_1_A', '3IVE_1_A', '1U6G_3_C', '3GAW_1_A', '1RH1_1_A', '4APN_1_A', '3IZO_2_F', '1R7A_1_A', '4FWG_1_A', '1D8Y_2_A', '2WVX_2_B', '3FHN_1_A', '5A9Q_4_3', '3X0V_1_A', '4DG6_1_A', '4MH1_1_A', '4FF3_1_A', '3NWA_1_B', '1JW1_1_A', '4IGL_2_B', '3LPP_1_A', '4QQW_1_A', '3BVU_1_A', '3BB0_1_A', '5A2R_1_A', '4WJ4_1_A', '4YPJ_1_A', '4R8A_1_A', '2XQ0_1_A', '4PQG_1_A', '1S58_1_A', '2DKH_1_A', '1W0P_1_A', '5D51_1_L', '1A4S_1_A', '3WOO_1_A', '1IV8_1_A', '4FUS_1_A', '3JAI_87_jj', '3D45_1_A', '3IZY_2_P', '3FHH_1_A', '4RJK_1_A', '4G9I_1_A', '2Z5U_1_A', '4OR0_1_A', '1UXV_1_A', '3MMP_2_G', '4FDD_1_A', '3WA2_1_X', '4FJQ_1_A', '5FL8_43_s', '3RCN_1_A', '4Z64_1_A', '2H1N_1_A', '5D0O_1_A', '2CXN_1_A', '3ILV_1_A', '3BQ6_1_A', '1YIQ_1_A', '3RA4_1_A', '1S0E_d1s0ea3', '2CKW_1_A', '2XQF_1_A', '4RGW_1_A', '4RPE_1_A', '4AP3_1_A', '4E57_1_A', '5HDT_1_A', '5DHK_1_A', '3CMX_3_A', '5I0N_1_A', '5CPS_1_A', '3EFM_1_A', '1V02_2_E', '4CJ0_1_A', '4QAW_1_A', '4QIW_2_B', '3ZUS_1_A', '1EA9_1_C', '1MW9_1_X', '3FY4_1_A', '3HX6_1_A', '4XHC_1_A', '1Q5C_1_A', '4AEF_1_A', '2RDY_1_A', '5CLW_1_A', '4RBN_1_A', '1E07_1_A', '4F4C_1_A', '4B2T_7_Q', '4I3G_1_A', '4WWX_1_B', '3GF3_1_A', '4WHJ_1_A', '4AUR_1_A', '2GP4_1_A', '4B9Q_1_A', '2NVO_1_A', '3VQT_1_A', '3U7Q_2_B', '2XE4_1_A', '1W07_1_A', '4HEA_13_L', '5FP2_1_A', '3OPB_1_A', '4BET_1_A', '2JGP_1_A', '4QFL_1_A', '5HB3_1_A', '4XWM_1_A', '1CHU_1_A', '3TTY_1_A', '2FHF_1_A', '4AUM_1_A', '3I4G_1_A', '5GAM_3_C', '4B2T_1_A', '4K6M_1_A', '3J9D_1_A', '3L4G_2_B', '3QYA_1_A', '2WTB_1_A', '4O5H_1_A', '2NVU_2_B', '2GJ4_1_A', '4PDX_d4pdxb1', '3ZXK_1_A', '1QI9_1_A', '1N62_2_B', '4PJ1_1_A', '3NAF_1_A', '3JSZ_1_A', '1EWQ_3_A', '3N2O_1_A', '1X9D_1_A', '1OAO_1_A', '4PF1_1_A', '2XWU_2_B', '1YQ2_1_A', '2WU2_1_A', '1DMS_1_A', '3ODU_1_A', '3HRZ_4_D', '2PFD_1_A', '4YD9_1_A', '4B3J_1_A', '3HWC_1_A', '2C42_1_A', '4BDV_1_A', '1S49_1_A', '5BP1_1_A', '4M5D_1_A', '2DU3_2_A', '2VQI_1_A', '5EY9_1_A', '5IRE_1_A', '1WDK_1_A', '1QO8_1_A', '5F1Q_1_A', '3J92_44_u', '2Q7Z_1_A', '3RDE_1_A', '1W7C_1_A', '4YE6_1_A', '1QH8_2_B', '2BX2_1_L', '4C00_1_A', '2OBE_1_A', '2EAB_1_A', '2WQ7_1_A', '3MT5_1_A', '4CWU_3_O', '1XDP_1_A', '3QA8_1_A', '1I7D_2_A', '3KIC_d3kict-', '4KRE_1_A', '5FGU_1_A', '4BQE_1_A', '5CXF_1_A', '1N7V_1_A', '2Z5L_1_A', '4FNM_1_A', '1G0D_1_A', '3C5E_1_A', '5IL7_1_A', '3N6R_1_A', '4B63_1_A', '4NBQ_1_A', '3UJZ_1_A', '4MC5_1_A', '4YFB_3_C', '4B7H_1_A', '3A1K_1_A', '3ZUK_1_A', '2FSH_1_A', '3MKQ_1_A', '3SDR_1_A', '3QDK_1_A', '2BPO_1_A', '1G8M_1_A', '3ITE_1_A', '2XN2_1_A', '3V83_1_A', '4RUL_1_A', '4RHB_1_A', '4EPA_1_A', '3H4Z_1_A', '3NSJ_1_A', '1RQB_1_A', '5FSG_1_A', '3OSR_1_A', '3BF0_1_A', '1SY7_1_A', '3UFB_1_A', '2WEU_1_A', '4MH8_1_A', '3LPF_1_A', '1FLG_1_A', '4QTO_1_A', '4LC9_1_A', '5AB0_1_A', '1FCB_1_A', '3MMP_1_A', '3IXW_1_A', '5DNC_1_A', '1ULV_1_A', '1RT8_1_A', '3KK7_1_A', '1M1Z_1_A', '4FL3_1_A', '4YE5_1_A', '1K4Y_1_A', '2CN3_1_A', '3P4P_1_A', '5AN8_1_A', '4RWF_1_A', '4OUA_2_B', '4KVL_1_A', '4QIW_1_A', '2W54_2_B', '5IKL_1_B', '5HAX_1_A', '4UHV_1_A', '3F6K_1_A', '1LLA_1_A', '5EAN_1_A', '5IJO_7_J', '5DFZ_2_C', '1K1X_1_A', '4QME_1_A', '2NLK_1_A', '1PGU_1_A', '2I0K_1_A', '4BY6_1_A'}

In [2]:
import json

# Path to your JSON file
json_file_path = '/Users/daniel/Dropbox/USF/generative/neurips19-graph-protein-design/data/cath/chain_set_splits.json'

# Read the JSON data from the file
with open(json_file_path, 'r') as file:
    data = json.load(file)

# Filter out the entries in the remove_set from each key
for key in ['train', 'test', 'validation']:
    data[key] = [item for item in data[key] if item not in remove_set]

# Write the modified data back to the file
with open(json_file_path, 'w') as file:
    json.dump(data, file, indent=4)

print("Updated JSON file saved.")

Updated JSON file saved.
