This notebook will measure the OpenMM energies of both the crystal and minimized structures
of the validation and test sets.

Because there are fewer minimized proteins, we will start with those protein identities 
and move from there. The notebook should create a CSV that reports the protein id,
evaluation set name, minimized energy, and unminimized energy.

In [1]:
from glob import glob
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

from sidechainnet import SCNProtein

%cd ~/openfold

/net/pulsar/home/koes/jok120/openfold


# Data paths

In [2]:
data = {
    'test': {
        'unmin': {
            "dirpath": 'data/test/cameo/20230103/data_dir',
            "scnproteins": []
        },
        'min': {
            "dirpath": 'data/test/cameo/20230103/minimized/data_dir',
            "scnproteins": []
        }
    },
    'val': {
        'unmin': {
            "dirpath": 'data/validation/cameo/20220116/data_dir',
            "scnproteins": []
        },
        'min': {
            "dirpath": 'data/validation/cameo/20220116/minimized/data_dir',
            "scnproteins": []
        }
    },
}


# Functions

In [3]:
def get_scnproteins_from_paths(paths, chids=None):
    """Use SCNProtein.from_{pdb,cif} to get list of SCNProtein objs from list of paths."""

    iterator = tqdm(paths) if chids is None else tqdm(zip(paths, chids), total=len(paths))

    scnproteins = []
    for path_chid in iterator:
        if chids is None:
            path = path_chid
            pdbid_chain = os.path.basename(path).split('.')[0]
            pdbid, chain_id = pdbid_chain.split('_')
        else:
            path, chain_id = path_chid
            pdbid = os.path.basename(path).split('.')[0]
        
        if path.endswith('.pdb'):
            scnproteins.append(SCNProtein.from_pdb(path, pdbid=pdbid_chain))
        elif path.endswith('.cif'):
            scnproteins.append(SCNProtein.from_cif(path, chid=chain_id, pdbid=f"{pdbid}_{chain_id}"))
        else:
            raise ValueError(f'Path {path} does not end with .pdb or .cif')
    return scnproteins


In [4]:
def get_protein_paths_from_dir(data_dir):
    """Return list of protein files (.cif or .pdb) from data_dir."""
    return glob(os.path.join(data_dir, '*'))

In [89]:
def do_fastbuild_and_get_energy(scnprotein):
    scnprotein.reset_hydrogens_and_openmm()

    results = {}
    # This is the energy when we build the hydrogens with fastbuild
    scnprotein.fastbuild(add_hydrogens=True, inplace=True)
    results['fastbuild_energy'] = scnprotein.get_energy(add_hydrogens_via_openmm=False,
                                                     return_unitless_kjmol=True,
                                                     add_missing=False,
                                                     add_hydrogens_via_scnprotein=False)
    # This is the energy when we build the hydrogens with scnprotein/hydrogenBuilder
    scnprotein.reset_hydrogens_and_openmm()
    results['scnprotein_energy'] = scnprotein.get_energy(add_hydrogens_via_openmm=False,
                                                     return_unitless_kjmol=True,
                                                     add_missing=False,
                                                     add_hydrogens_via_scnprotein=True)
    # This is the energy when we build the hydrogens with openmm
    scnprotein.reset_hydrogens_and_openmm()
    results['openmm_pure_energy'] = scnprotein.get_energy(add_hydrogens_via_openmm=True,
                                                     return_unitless_kjmol=True,
                                                     add_missing=False,
                                                     add_hydrogens_via_scnprotein=False)
    return results

In [5]:
def get_energy_data(scnprotein, add_h=True):
    results = {}
    results['openmm_energy'] = scnprotein.get_energy(add_hydrogens_via_openmm=add_h,
                                                     return_unitless_kjmol=True,
                                                     add_missing=add_h,
                                                     add_hydrogens_via_scnprotein=False)
    scnprotein.reset_hydrogens_and_openmm()
    if not add_h:
        scnprotein.fastbuild(add_hydrogens=True, inplace=True)
    results['scn_energy'] = scnprotein.get_energy(add_hydrogens_via_openmm=False,
                                                  return_unitless_kjmol=True,
                                                  add_missing=False,
                                                  add_hydrogens_via_scnprotein=add_h)
    return results

In [7]:
def get_energy_from_path(path, do_fastbuild_first=False):
    paths = [path]
    if path.endswith('.pkl'):
        scnprotein=SCNProtein.from_pkl(path)
    else:
        scnprotein = get_scnproteins_from_paths(paths)[0]
    if do_fastbuild_first:
        scnprotein.fastbuild(add_hydrogens=True, inplace=True)
        return get_energy_data(scnprotein, add_h=False)
    else:
        return get_energy_data(scnprotein, add_h=True)

# Looking at a problematic protein

## We see some unexpected diffs

1. If we do fastbuild before computing energies, this means we have a consistent way of adding hydrogens.
   - This explains why openmm_energy and scn_energy are the same when we do fastbuild first.
2. If we do use openmm or scnprotein's add hydrogens methods, we get different results.

In [8]:
get_energy_from_path("/net/pulsar/home/koes/jok120/scnmin_evaltest230412/unmin/7eym_A.pkl", do_fastbuild_first=False)

{'openmm_energy': -16069.593476920754, 'scn_energy': -12463.914517097342}

In [9]:
get_energy_from_path("/net/pulsar/home/koes/jok120/scnmin_evaltest230412/unmin/7eym_A.pkl", do_fastbuild_first=True)

{'openmm_energy': -14523.598900653416, 'scn_energy': -14523.59839740905}

In [10]:
get_energy_from_path("/net/pulsar/home/koes/jok120/scnmin_evaltest230412/min/7eym_A.pkl", do_fastbuild_first=False)

{'openmm_energy': -15993.200105323785, 'scn_energy': -11067.809308101103}

In [11]:
get_energy_from_path("/net/pulsar/home/koes/jok120/scnmin_evaltest230412/min/7eym_A.pkl", do_fastbuild_first=True)

{'openmm_energy': -15395.8611104439, 'scn_energy': -15395.861069793598}

### Some Questions

1. Proteins should have been minimized by optimizing the internal angles and rebuilding with fastbuild. Why then is the energy different when we add hydrogens via scnprotein?
    - perhaps this is because the hydrogens are not added the same way as via fastbuild.

Let's confirm that if we construct the all-atom-with-H coordinate set via fastbuild,
that the constructed heavy-atom coordinates are identical to the existing heavy atom coordinates.

The below analysis confirms that the difference in energy is due to the difference with
adding hydrogens via fastbuild vs. scnprotein. The protein is optimized for the fastbuild
method, but we are adding hydrogens via scnprotein, which is slightly less optimal.

In [14]:
def tensors_equal_with_nans(tensor1, tensor2, exact=False, rtol=1e-05, atol=1e-08):
    import torch
    # Check if both tensors have the same shape
    if tensor1.shape != tensor2.shape:
        return False

    # Create boolean masks for NaNs in both tensors
    nan_mask1 = torch.isnan(tensor1)
    nan_mask2 = torch.isnan(tensor2)

    # Check if NaN masks are equal
    nan_masks_equal = torch.all(nan_mask1 == nan_mask2)

    # Replace NaNs with a placeholder value in both tensors
    placeholder = torch.zeros_like(tensor1)
    tensor1_no_nans = torch.where(nan_mask1, placeholder, tensor1)
    tensor2_no_nans = torch.where(nan_mask2, placeholder, tensor2)

    # Check if tensors without NaNs are element-wise equal within a tolerance
    if not exact:
        tensors_close = torch.allclose(tensor1_no_nans, tensor2_no_nans, rtol=rtol, atol=atol)
    else:
        tensors_close = torch.all(tensor1_no_nans == tensor2_no_nans)

    return nan_masks_equal and tensors_close

In [54]:
p = SCNProtein.from_pkl("/net/pulsar/home/koes/jok120/scnmin_evaltest230412/min/7eym_A.pkl")

In [55]:
b = p.copy()

In [56]:
b.get_energy(add_hydrogens_via_scnprotein=True)

Quantity(value=-11067.809421062582, unit=kilojoule/mole)

In [57]:
b.coords = b.hcoords

In [58]:
built_hcoords = p.fastbuild(add_hydrogens=True, inplace=True)

In [59]:
p.sb = None
p.has_hydrogens

True

In [60]:
p.to_3Dmol(other_protein=b)

<py3Dmol.view at 0x7f2114790460>

In [66]:
p.reset_hydrogens_and_openmm()
built_hcoords = p.fastbuild(add_hydrogens=True, inplace=False)
built_coords = p.hydrogenrep_to_heavyatomrep(built_hcoords, inplace=False)

In [67]:
built_coords.shape,p.coords.shape

(torch.Size([96, 15, 3]), torch.Size([96, 15, 3]))

In [68]:
tensors_equal_with_nans(p.coords, built_coords, exact=True)

tensor(True)

# Analysis and Save DataFrame

In [69]:
# Load the minimized proteins first
for dataset, dataset_dict in data.items():
    data_dir = dataset_dict['min']['dirpath']
    protein_paths = get_protein_paths_from_dir(data_dir)
    scnproteins = get_scnproteins_from_paths(protein_paths)
    dataset_dict['min']['scnproteins'] = scnproteins
    print(f"Loaded {len(scnproteins)} proteins from {data_dir}")

100%|██████████| 93/93 [00:38<00:00,  2.43it/s]


Loaded 93 proteins from data/test/cameo/20230103/minimized/data_dir


100%|██████████| 22/22 [00:12<00:00,  1.82it/s]

Loaded 22 proteins from data/validation/cameo/20220116/minimized/data_dir





In [70]:
# Now, observe the minimized protein pdbids and load the cooresponding unminimized proteins
for dataset, dataset_dict in data.items():
    min_pdbids_with_chains = [p.id for p in dataset_dict['min']['scnproteins']]
    min_pdbids_with_chains = {p.split('_')[0]: p.split('_')[1] for p in min_pdbids_with_chains}
    data_dir = dataset_dict['unmin']['dirpath']
    raw_protein_paths = get_protein_paths_from_dir(data_dir)
    filtered_protein_paths = []
    filtered_chids = []
    for rpp in raw_protein_paths:
        pdbid = os.path.basename(rpp).split('.')[0]
        if pdbid in min_pdbids_with_chains:
            filtered_protein_paths.append(rpp)
            filtered_chids.append(min_pdbids_with_chains[pdbid])


    scnproteins = get_scnproteins_from_paths(filtered_protein_paths, filtered_chids)
    dataset_dict['unmin']['scnproteins'] = scnproteins
    print(f"Loaded {len(scnproteins)} proteins from {data_dir}")

100%|██████████| 92/92 [01:15<00:00,  1.23it/s]


Loaded 92 proteins from data/test/cameo/20230103/data_dir


100%|██████████| 22/22 [00:20<00:00,  1.09it/s]

Loaded 22 proteins from data/validation/cameo/20220116/data_dir





In [94]:
# Get the energy data for each protein and save it to a dataframe
for dataset, dataset_dict in data.items():
    for min_or_unmin, min_or_unmin_dict in dataset_dict.items():
        print(dataset, min_or_unmin)
        scnproteins = min_or_unmin_dict['scnproteins']
        energy_data = []
        for scnprotein in tqdm(scnproteins):
            # energy_data.append(get_energy_data(scnprotein))
            energy_data.append(do_fastbuild_and_get_energy(scnprotein))
        min_or_unmin_dict['energy_data'] = energy_data
        print(f"Got energy data for {len(energy_data)} proteins in {dataset} {min_or_unmin}")


test unmin


100%|██████████| 92/92 [07:05<00:00,  4.62s/it]


Got energy data for 92 proteins in test unmin
test min


100%|██████████| 93/93 [04:32<00:00,  2.93s/it]


Got energy data for 93 proteins in test min
val unmin


100%|██████████| 22/22 [01:54<00:00,  5.22s/it]


Got energy data for 22 proteins in val unmin
val min


100%|██████████| 22/22 [01:24<00:00,  3.83s/it]

Got energy data for 22 proteins in val min





In [95]:
# Save the data to a single unified dataframe for all proteins, with the columns being:
# pdbid, chain_id, dataset, min_or_unmin, openmm_energy, scn_energy

df = pd.DataFrame(columns=['pdbid', 'chain_id', 'dataset', 'min_or_unmin', 'fastbuild_energy', 'scnprotein_energy', 'openmm_pure_energy'])
for dataset, dataset_dict in data.items():
    for min_or_unmin, min_or_unmin_dict in dataset_dict.items():
        energy_data = min_or_unmin_dict['energy_data']
        for i, ed in enumerate(energy_data):
            df = df.append({
                'pdbid': min_or_unmin_dict['scnproteins'][i].id.split('_')[0],
                'chain_id': min_or_unmin_dict['scnproteins'][i].id.split('_')[1],
                'dataset': dataset,
                'min_or_unmin': min_or_unmin,
                'openmm_pure_energy': ed['openmm_pure_energy'],
                'scnprotein_energy': ed['scnprotein_energy'],
                'fastbuild_energy': ed['fastbuild_energy'],
                
            }, ignore_index=True)


In [96]:
df['pdbid_chainid'] = df['pdbid'] + '_' + df['chain_id']
df

Unnamed: 0,pdbid,chain_id,dataset,min_or_unmin,fastbuild_energy,scnprotein_energy,openmm_pure_energy,pdbid_chainid
0,7sci,A,test,unmin,63228.210382,147600.522266,-40380.694061,7sci_A
1,7wj0,A,test,unmin,-8989.842754,-3221.577137,-10765.592162,7wj0_A
2,7pvm,A,test,unmin,-6043.046692,-3805.779034,-6102.468049,7pvm_A
3,7ead,A,test,unmin,-11872.394795,-5592.074002,-17872.597238,7ead_A
4,7w5c,A,test,unmin,-22635.333248,10897.647461,-67183.992030,7w5c_A
...,...,...,...,...,...,...,...,...
224,7dfe,A,val,min,-18848.357209,-12879.166634,-19736.704250,7dfe_A
225,7v5y,B,val,min,-19921.668867,-15771.962638,-20229.181690,7v5y_B
226,7ee3,C,val,min,-21685.338132,-15908.619913,-22364.761034,7ee3_C
227,7wgk,A,val,min,-44245.370169,-32777.791127,-44964.571463,7wgk_A


## Identify duplicates in test set

In [97]:
# find rows where pdbid_chainid is duplicated
# df[df.duplicated(subset=['pdbid_chainid'], keep=False)]
df[df.duplicated(subset=['pdbid_chainid', 'min_or_unmin'], keep=False) & (df['min_or_unmin'] == 'min')]

Unnamed: 0,pdbid,chain_id,dataset,min_or_unmin,fastbuild_energy,scnprotein_energy,openmm_pure_energy,pdbid_chainid


In [98]:
# Save the dataframe to a csv in the same directory as this notebook
%cd ~/openfold/jk_research/evaluations/230420
df.to_csv('230428_evaluationset_energies5.csv', index=False)

/net/pulsar/home/koes/jok120/openfold/jk_research/evaluations/230420


In [23]:
# Move proteins that are duplicated in the test set from their directories current to a
# new subdirectory 'duplicates'
%cd ~/openfold

duplicate_test_proteins = df[df.duplicated(subset=['pdbid_chainid', 'min_or_unmin'], keep=False) & (df['min_or_unmin'] == 'min')]['pdbid_chainid'].unique()
os.makedirs(f"data/test/cameo/20230103/minimized/data_dir/duplicates", exist_ok=True)
os.makedirs(f"data/test/cameo/20230103/data_dir/duplicates", exist_ok=True)
os.makedirs(f"data/test/cameo/20230103/minimized/alignments/duplicates", exist_ok=True)
os.makedirs(f"data/test/cameo/20230103/alignments/duplicates", exist_ok=True)

for dtp in duplicate_test_proteins:
    # move structs
    os.rename(f"data/test/cameo/20230103/minimized/data_dir/{dtp}.pdb", f"data/test/cameo/20230103/minimized/data_dir/duplicates/{dtp}.pdb")
    os.rename(f"data/test/cameo/20230103/data_dir/{dtp.split('_')[0]}.cif", f"data/test/cameo/20230103/data_dir/duplicates/{dtp.split('_')[0]}.cif")
    # move alignments
    os.rename(f"data/test/cameo/20230103/minimized/alignments/{dtp}", f"data/test/cameo/20230103/minimized/alignments/duplicates/{dtp}")
    os.rename(f"data/test/cameo/20230103/alignments/{dtp}", f"data/test/cameo/20230103/alignments/duplicates/{dtp}")

/net/pulsar/home/koes/jok120/openfold


# Load dataframe

In [9]:
# Load the above csv
df = pd.read_csv('~/openfold/jk_research/evaluations/230420/230428_evaluationset_energies.csv')
df

Unnamed: 0,pdbid,chain_id,dataset,min_or_unmin,openmm_energy,scn_energy,pdbid_chainid
0,7sci,A,test,unmin,-99236.965301,-69827.339318,7sci_A
1,7wgk,A,test,unmin,-45585.986728,-32588.705792,7wgk_A
2,7wj0,A,test,unmin,-11886.696006,-8719.855140,7wj0_A
3,7pvm,A,test,unmin,-6221.771473,-4408.863527,7pvm_A
4,7ead,A,test,unmin,-21121.873942,-13155.142319,7ead_A
...,...,...,...,...,...,...,...
234,7dfe,A,val,min,-19783.327374,-12876.558158,7dfe_A
235,7v5y,B,val,min,-20284.087039,-15776.160864,7v5y_B
236,7ee3,C,val,min,-22390.912999,-15917.149032,7ee3_C
237,7wgk,A,val,min,-45021.939726,-32776.283638,7wgk_A


In [12]:
df[df['pdbid'] == '7wgk']

Unnamed: 0,pdbid,chain_id,dataset,min_or_unmin,openmm_energy,scn_energy,pdbid_chainid
195,7wgk,A,val,unmin,-45585.986094,-32588.705975,7wgk_A
237,7wgk,A,val,min,-45021.939726,-32776.283638,7wgk_A


In [14]:
df[df.duplicated(subset=['pdbid_chainid', 'min_or_unmin'], keep=False)]


Unnamed: 0,pdbid,chain_id,dataset,min_or_unmin,openmm_energy,scn_energy,pdbid_chainid


In [11]:
# find rows where pdbid_chainid is duplicated
# df[df.duplicated(subset=['pdbid_chainid'], keep=False)]
df[df.duplicated(subset=['pdbid_chainid', 'min_or_unmin'], keep=False) & (df['min_or_unmin'] == 'min')]

# remove the duplicated rows that have dataset == 'test'
df = df[~(df.duplicated(subset=['pdbid_chainid', 'min_or_unmin'], keep=False) & (df['dataset'] == 'test'))]
df


Unnamed: 0,pdbid,chain_id,dataset,min_or_unmin,openmm_energy,scn_energy,pdbid_chainid
0,7sci,A,test,unmin,-99236.965301,-69827.339318,7sci_A
2,7wj0,A,test,unmin,-11886.696006,-8719.855140,7wj0_A
3,7pvm,A,test,unmin,-6221.771473,-4408.863527,7pvm_A
4,7ead,A,test,unmin,-21121.873942,-13155.142319,7ead_A
5,7w5c,A,test,unmin,-78641.724298,-60290.734101,7w5c_A
...,...,...,...,...,...,...,...
234,7dfe,A,val,min,-19783.327374,-12876.558158,7dfe_A
235,7v5y,B,val,min,-20284.087039,-15776.160864,7v5y_B
236,7ee3,C,val,min,-22390.912999,-15917.149032,7ee3_C
237,7wgk,A,val,min,-45021.939726,-32776.283638,7wgk_A


In [15]:
df[df.duplicated(subset=['pdbid_chainid', 'min_or_unmin'], keep=False) & (df['min_or_unmin'] == 'min')]

Unnamed: 0,pdbid,chain_id,dataset,min_or_unmin,openmm_energy,scn_energy,pdbid_chainid


In [16]:
!pwd

/net/pulsar/home/koes/jok120/openfold


In [17]:
# save csv
df.to_csv('/net/pulsar/home/koes/jok120/openfold/jk_research/evaluations/230420/230428_evaluationset_energies2.csv', index=False)
