This notebook will measure the OpenMM energies of both the crystal and minimized structures
of the validation and test sets.

Because there are fewer minimized proteins, we will start with those protein identities 
and move from there. The notebook should create a CSV that reports the protein id,
evaluation set name, minimized energy, and unminimized energy.

In [1]:
from glob import glob
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

from sidechainnet import SCNProtein

%cd ~/openfold

/net/pulsar/home/koes/jok120/openfold


# Data paths

In [2]:
data = {
    'test': {
        'unmin': {
            "dirpath": 'data/test/cameo/20230103/data_dir',
            "scnproteins": []
        },
        'min': {
            "dirpath": 'data/test/cameo/20230103/minimized/data_dir',
            "scnproteins": []
        }
    },
    'val': {
        'unmin': {
            "dirpath": 'data/validation/cameo/20220116/data_dir',
            "scnproteins": []
        },
        'min': {
            "dirpath": 'data/validation/cameo/20220116/minimized/data_dir',
            "scnproteins": []
        }
    },
}


# Functions

In [2]:
def get_scnproteins_from_paths(paths, chids=None):
    """Use SCNProtein.from_{pdb,cif} to get list of SCNProtein objs from list of paths."""

    iterator = tqdm(paths) if chids is None else tqdm(zip(paths, chids), total=len(paths))

    scnproteins = []
    for path_chid in iterator:
        if chids is None:
            path = path_chid
            pdbid_chain = os.path.basename(path).split('.')[0]
            pdbid, chain_id = pdbid_chain.split('_')
        else:
            path, chain_id = path_chid
            pdbid = os.path.basename(path).split('.')[0]
        
        if path.endswith('.pdb'):
            scnproteins.append(SCNProtein.from_pdb(path, pdbid=pdbid_chain))
        elif path.endswith('.cif'):
            scnproteins.append(SCNProtein.from_cif(path, chid=chain_id, pdbid=f"{pdbid}_{chain_id}"))
        else:
            raise ValueError(f'Path {path} does not end with .pdb or .cif')
    return scnproteins


In [3]:
def get_protein_paths_from_dir(data_dir):
    """Return list of protein files (.cif or .pdb) from data_dir."""
    return glob(os.path.join(data_dir, '*'))

In [4]:
def get_energy_data(scnprotein):
    results = {}
    results['openmm_energy'] = scnprotein.get_energy(add_hydrogens_via_openmm=True,
                                                     return_unitless_kjmol=True,
                                                     add_missing=True,
                                                     add_hydrogens_via_scnprotein=False)
    scnprotein.reset_hydrogens_and_openmm()
    results['scn_energy'] = scnprotein.get_energy(add_hydrogens_via_openmm=False,
                                                  return_unitless_kjmol=True,
                                                  add_missing=False,
                                                  add_hydrogens_via_scnprotein=True)
    return results

# Analysis and Save DataFrame

In [6]:
# Load the minimized proteins first
for dataset, dataset_dict in data.items():
    data_dir = dataset_dict['min']['dirpath']
    protein_paths = get_protein_paths_from_dir(data_dir)
    scnproteins = get_scnproteins_from_paths(protein_paths)
    dataset_dict['min']['scnproteins'] = scnproteins
    print(f"Loaded {len(scnproteins)} proteins from {data_dir}")

100%|██████████| 98/98 [00:43<00:00,  2.27it/s]


Loaded 98 proteins from data/test/cameo/20230103/minimized/data_dir


100%|██████████| 22/22 [00:12<00:00,  1.73it/s]

Loaded 22 proteins from data/validation/cameo/20220116/minimized/data_dir





In [7]:
# Now, observe the minimized protein pdbids and load the cooresponding unminimized proteins
for dataset, dataset_dict in data.items():
    min_pdbids_with_chains = [p.id for p in dataset_dict['min']['scnproteins']]
    min_pdbids_with_chains = {p.split('_')[0]: p.split('_')[1] for p in min_pdbids_with_chains}
    data_dir = dataset_dict['unmin']['dirpath']
    raw_protein_paths = get_protein_paths_from_dir(data_dir)
    filtered_protein_paths = []
    filtered_chids = []
    for rpp in raw_protein_paths:
        pdbid = os.path.basename(rpp).split('.')[0]
        if pdbid in min_pdbids_with_chains:
            filtered_protein_paths.append(rpp)
            filtered_chids.append(min_pdbids_with_chains[pdbid])


    scnproteins = get_scnproteins_from_paths(filtered_protein_paths, filtered_chids)
    dataset_dict['unmin']['scnproteins'] = scnproteins
    print(f"Loaded {len(scnproteins)} proteins from {data_dir}")

100%|██████████| 97/97 [01:20<00:00,  1.20it/s]


Loaded 97 proteins from data/test/cameo/20230103/data_dir


100%|██████████| 22/22 [00:20<00:00,  1.06it/s]

Loaded 22 proteins from data/validation/cameo/20220116/data_dir





In [8]:
# Get the energy data for each protein and save it to a dataframe
for dataset, dataset_dict in data.items():
    for min_or_unmin, min_or_unmin_dict in dataset_dict.items():
        print(dataset, min_or_unmin)
        scnproteins = min_or_unmin_dict['scnproteins']
        energy_data = []
        for scnprotein in tqdm(scnproteins):
            energy_data.append(get_energy_data(scnprotein))
        min_or_unmin_dict['energy_data'] = energy_data
        print(f"Got energy data for {len(energy_data)} proteins in {dataset} {min_or_unmin}")


test unmin


100%|██████████| 97/97 [10:40<00:00,  6.61s/it]


Got energy data for 97 proteins in test unmin
test min


100%|██████████| 98/98 [08:01<00:00,  4.91s/it]


Got energy data for 98 proteins in test min
val unmin


100%|██████████| 22/22 [02:04<00:00,  5.67s/it]


Got energy data for 22 proteins in val unmin
val min


100%|██████████| 22/22 [01:58<00:00,  5.41s/it]

Got energy data for 22 proteins in val min





In [9]:

# Save the data to a single unified dataframe for all proteins, with the columns being:
# pdbid, chain_id, dataset, min_or_unmin, openmm_energy, scn_energy

df = pd.DataFrame(columns=['pdbid', 'chain_id', 'dataset', 'min_or_unmin', 'openmm_energy', 'scn_energy'])
for dataset, dataset_dict in data.items():
    for min_or_unmin, min_or_unmin_dict in dataset_dict.items():
        energy_data = min_or_unmin_dict['energy_data']
        for i, ed in enumerate(energy_data):
            df = df.append({
                'pdbid': min_or_unmin_dict['scnproteins'][i].id.split('_')[0],
                'chain_id': min_or_unmin_dict['scnproteins'][i].id.split('_')[1],
                'dataset': dataset,
                'min_or_unmin': min_or_unmin,
                'openmm_energy': ed['openmm_energy'],
                'scn_energy': ed['scn_energy']
            }, ignore_index=True)


In [12]:
df['pdbid_chainid'] = df['pdbid'] + '_' + df['chain_id']
df

Unnamed: 0,pdbid,chain_id,dataset,min_or_unmin,openmm_energy,scn_energy,pdbid_chainid
0,7sci,A,test,unmin,-99236.965301,-69827.339318,7sci_A
1,7wgk,A,test,unmin,-45585.986728,-32588.705792,7wgk_A
2,7wj0,A,test,unmin,-11886.696006,-8719.855140,7wj0_A
3,7pvm,A,test,unmin,-6221.771473,-4408.863527,7pvm_A
4,7ead,A,test,unmin,-21121.873942,-13155.142319,7ead_A
...,...,...,...,...,...,...,...
234,7dfe,A,val,min,-19783.327374,-12876.558158,7dfe_A
235,7v5y,B,val,min,-20284.087039,-15776.160864,7v5y_B
236,7ee3,C,val,min,-22390.912999,-15917.149032,7ee3_C
237,7wgk,A,val,min,-45021.939726,-32776.283638,7wgk_A


## Identify duplicates in test set

In [19]:
# find rows where pdbid_chainid is duplicated
# df[df.duplicated(subset=['pdbid_chainid'], keep=False)]
df[df.duplicated(subset=['pdbid_chainid', 'min_or_unmin'], keep=False) & (df['min_or_unmin'] == 'min')]

Unnamed: 0,pdbid,chain_id,dataset,min_or_unmin,openmm_energy,scn_energy,pdbid_chainid
105,7mwr,A,test,min,-48092.537243,-37155.064159,7mwr_A
110,7puo,A,test,min,-25422.463357,-17635.847585,7puo_A
128,7l8n,A,test,min,-25161.063715,-17164.252443,7l8n_A
154,7mcc,A,test,min,-34384.561107,-23019.641517,7mcc_A
183,7wgk,A,test,min,-45002.697783,-32768.836113,7wgk_A
217,7mwr,A,val,min,-47465.000557,-37139.103229,7mwr_A
219,7puo,A,val,min,-25401.226364,-17633.071464,7puo_A
222,7l8n,A,val,min,-25142.619541,-17168.849909,7l8n_A
229,7mcc,A,val,min,-34378.33359,-23030.312678,7mcc_A
237,7wgk,A,val,min,-45021.939726,-32776.283638,7wgk_A


In [20]:
# Save the dataframe to a csv in the same directory as this notebook
%cd ~/openfold/jk_research/evaluations/230420
df.to_csv('230428_evaluationset_energies.csv', index=False)

/net/pulsar/home/koes/jok120/openfold/jk_research/evaluations/230420


In [23]:
# Move proteins that are duplicated in the test set from their directories current to a
# new subdirectory 'duplicates'
%cd ~/openfold

duplicate_test_proteins = df[df.duplicated(subset=['pdbid_chainid', 'min_or_unmin'], keep=False) & (df['min_or_unmin'] == 'min')]['pdbid_chainid'].unique()
os.makedirs(f"data/test/cameo/20230103/minimized/data_dir/duplicates", exist_ok=True)
os.makedirs(f"data/test/cameo/20230103/data_dir/duplicates", exist_ok=True)
os.makedirs(f"data/test/cameo/20230103/minimized/alignments/duplicates", exist_ok=True)
os.makedirs(f"data/test/cameo/20230103/alignments/duplicates", exist_ok=True)

for dtp in duplicate_test_proteins:
    # move structs
    os.rename(f"data/test/cameo/20230103/minimized/data_dir/{dtp}.pdb", f"data/test/cameo/20230103/minimized/data_dir/duplicates/{dtp}.pdb")
    os.rename(f"data/test/cameo/20230103/data_dir/{dtp.split('_')[0]}.cif", f"data/test/cameo/20230103/data_dir/duplicates/{dtp.split('_')[0]}.cif")
    # move alignments
    os.rename(f"data/test/cameo/20230103/minimized/alignments/{dtp}", f"data/test/cameo/20230103/minimized/alignments/duplicates/{dtp}")
    os.rename(f"data/test/cameo/20230103/alignments/{dtp}", f"data/test/cameo/20230103/alignments/duplicates/{dtp}")

/net/pulsar/home/koes/jok120/openfold


# Load dataframe

In [5]:
# Load the above csv
df = pd.read_csv('~/openfold/jk_research/evaluations/230420/230428_evaluationset_energies.csv')
df

Unnamed: 0,pdbid,chain_id,dataset,min_or_unmin,openmm_energy,scn_energy,pdbid_chainid
0,7sci,A,test,unmin,-99236.965301,-69827.339318,7sci_A
1,7wgk,A,test,unmin,-45585.986728,-32588.705792,7wgk_A
2,7wj0,A,test,unmin,-11886.696006,-8719.855140,7wj0_A
3,7pvm,A,test,unmin,-6221.771473,-4408.863527,7pvm_A
4,7ead,A,test,unmin,-21121.873942,-13155.142319,7ead_A
...,...,...,...,...,...,...,...
234,7dfe,A,val,min,-19783.327374,-12876.558158,7dfe_A
235,7v5y,B,val,min,-20284.087039,-15776.160864,7v5y_B
236,7ee3,C,val,min,-22390.912999,-15917.149032,7ee3_C
237,7wgk,A,val,min,-45021.939726,-32776.283638,7wgk_A


In [7]:
df[df['pdbid'] == '7wgk']

Unnamed: 0,pdbid,chain_id,dataset,min_or_unmin,openmm_energy,scn_energy,pdbid_chainid
1,7wgk,A,test,unmin,-45585.986728,-32588.705792,7wgk_A
183,7wgk,A,test,min,-45002.697783,-32768.836113,7wgk_A
195,7wgk,A,val,unmin,-45585.986094,-32588.705975,7wgk_A
237,7wgk,A,val,min,-45021.939726,-32776.283638,7wgk_A
