In this notebook, we will walk through the basic steps of model sampling and metric computation. 

Currently we mainly focus on the diversity and coverage metrics. The detailed guidline for computing designability and novelty were described in the Method section of the paper, and we are working on the integration of these metrics into this environment.

In [1]:
import os
import sys
import pickle
from tqdm import tqdm

import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import torch

import py3Dmol

import IPython
from IPython.display import display, Markdown

In [3]:
import TopoDiff

# sampler
from TopoDiff.experiment.sampler import Sampler

# data
from TopoDiff.data.structure import StructureBuilder

# pdb
from myopenfold.np import protein

# evaluation
from TopoDiff.evaluation.diversity import compute_tm_matrix, compute_unique_cluster
from TopoDiff.evaluation.coverage import compute_progres_embedding, compute_coverage

In [4]:
project_dir = os.path.dirname(os.path.dirname(TopoDiff.__path__[0]))
data_dir = os.path.join(project_dir, 'data', 'dataset')

Intermediate results might be saved during the process. Please specifiy a directory to save the results. (or use the default directory `./results/`)

In [34]:
notebook_path = "/".join(IPython.extract_module_locals()[1]["__vsc_ipynb_file__"].split("/"))

# overwrite this if you want to save the results somewhere else
par_dir = os.path.join(os.path.dirname(notebook_path), 'results')
os.makedirs(par_dir, exist_ok=True)

print(f"Results will be saved in {par_dir}")

Results will be saved in /tmp/results


In [7]:
pdb_path_list = [os.path.join(par_dir, 'sample', 'length_125', 'sample_%d.pdb' % i) for i in range(100)]

# Sampling

We begin by sampling a minimally sufficient set of proteins for metric computation. Here we choose 100 proteins of length 125.

Although this number is significantly smaller than in our benchmark experiment, it still takes approximately 10 minutes to run on a single GPU. Therefore, we highly recommend running the sampling using the command-line interface (using the command below)

In [35]:
markdown_content = f"""```bash\npython {project_dir}/TopoDiff/run_sampling.py \\\n\t-m all_round \\\n\t-s 125 \\\n\t-e 125 \\\n\t-n 100 \\\n\t--seed 42 \\\n\t-o {par_dir}/sample/ \n```"""
display(Markdown(markdown_content))

```bash
python /home/zhangyuy/workspace/dl/TopoDiff/submit_240923/github/TopoDiff_public/TopoDiff/run_sampling.py \
	-m all_round \
	-s 125 \
	-e 125 \
	-n 100 \
	--seed 42 \
	-o /tmp/results/sample/ 
```

After running the above command, we should be able to see the sampling results in the directory `{par_dir}/sample/` in pdb format, and we can proceed to the computation of the metrics.

In [10]:
n_sample = 100
size = 50
n_col = 20
n_row = int(np.ceil(n_sample / n_col))

view = py3Dmol.view(width=size * n_col, height=size * n_row, viewergrid=(n_row, n_col), linked=False)

for i in range(n_sample):
    row_idx, col_idx = i // n_col, i % n_col
    row_idx = i // n_col
    col_idx = i % n_col
    with open(pdb_path_list[i], 'r') as f:
        pdb_content = f.read()  
    view.addModelsAsFrames(pdb_content, 'pdb', viewer=(row_idx, col_idx))
    view.setStyle({'model': 0}, {'cartoon': {'color': 'spectrum'}}, viewer=(row_idx, col_idx))

view.zoomTo()

<py3Dmol.view at 0x7f3b99f79e20>

# Diversity

Diversity is defined as the fraction of unique clusters relative to the total number of sampled PDBs. The samples are clustered based on pairwise TM-scores, using Agglomerative Hierarchical Clustering implemented in SciPy to determine the clusters.

$$
\text{Diversity} = \frac{\text{Number of unique clusters}}{\text{Total number of sampled PDBs}}
$$

When computing diversity for the subset of designable samples, we first filter out the non-designable samples based on the designability threshold and then count the number of unique clusters. The number of unique clusters is still divided by the total number of sampled PDBs (including non-designable samples).

We always used a distance threshold of 0.4 for clustering in our experiments, which means samples with TM-score > 0.6 are considered as the same cluster.

In [11]:
res = compute_tm_matrix(pdb_path_list)
tm_mat = res['tm_matrix']

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:01<00:00, 86.13it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5050/5050 [00:08<00:00, 588.86it/s]


In [12]:
n_uniq_cluster = compute_unique_cluster(1 - tm_mat)
print(f"Number of unique clusters: {n_uniq_cluster}, sample diversity: {n_uniq_cluster / 100:.2f}")

Number of unique clusters: 94, sample diversity: 0.94


# Coverage

The definition of coverage is provided as follows:

$$
\text{NND}_k(X_i) := D(X_i, X_{NN}(X_i, k))
$$

$$
B(x, r) := \{ y \mid D(x, y) < r \}
$$

$$
\text{coverage} := \frac{1}{N} \sum_{i=1}^{N} 1_{\exists j: Y_j \in B(X_i, \text{NND}_k(X_i))}
$$

Intuitively, a reference structure is considered to be covered if it is within a certain distance to at least one structure in the generated samples.

Since the distribution density of reference samples is not necessarily uniform, this distance is not set as constant. Instead, it is dynamically chosen based on the distance from the current reference to other reference structures (specifically, the distance to the k-th closest reference). Therefore, before calculating coverage, we need to obtain: 
1) The distance matrix between reference structures.
2) The pairwise distance matrix between generated samples and reference structures.

We will now demonstrate the implementation based on two different distance definitions.


Since the avaiable fold types are dependent are the length of the protein, for each sampled length, we only considered natural proteins with similar length as reference. For all experiments, we used 25 residues as the cutoff (for sampled length 125 we only considered natural proteins with length 100-150).

In [13]:
sample_length = 125
n_sample = 100

# scope for reference to compare with (+- 25)
scope = 25

# hyperparameter K for KNN distance
K = 100

## Coverage - $D_\mathrm{Progres}$

In [14]:
from progres import progres

### 1. distance matrix within reference

For Progres, we used the provided CATH-40 embedding downloaded from the official [zenodo repository](https://zenodo.org/record/7782088). 

In [15]:
cath_emb_path = os.path.join(project_dir, 'TopoDiff', 'progres', 'progres','databases', 'v_0_2_0', 'cath40.pt')
cath_emb_dict = torch.load(cath_emb_path)

In [16]:
cath_ref_length = torch.tensor(cath_emb_dict['nres'] )
filter_mask = (cath_ref_length >= sample_length - scope) & (cath_ref_length <= sample_length + scope)
cath_emb_filtered = cath_emb_dict['embeddings'][filter_mask].numpy()
# 1 - cosine similarity
ref_dis_mat = 1 - np.matmul(cath_emb_filtered, cath_emb_filtered.T)

In [17]:
cath_emb_filtered.shape, ref_dis_mat.shape

((9515, 128), (9515, 9515))

### 2. pairwise distance between generated samples and reference structures

In [18]:
emb_sample = compute_progres_embedding(pdb_path_list)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 12.20it/s]


In [19]:
ref_sample_dis_mat = 1 - np.matmul(cath_emb_filtered, emb_sample.T)

In [20]:
ref_sample_dis_mat.shape

torch.Size([9515, 100])

### 3. compute coverage

First, decide the receptive field of each reference structure.

In [21]:
# start from 1 to exclude itself
ref_dis_sorted = np.sort(ref_dis_mat, axis=1)[:, 1:]

Second, find out the closest distance for each reference structure to any of the sampled structures.

In [22]:
ref_closest_dis = np.sort(ref_sample_dis_mat, axis=1)[:, 0]

Finally, count the fraction of reference structures that there is at least one sampled structure within its receptive field, and here we get the coverage!

In [23]:
np.sum(ref_closest_dis <= ref_dis_sorted[:, K-1]) / ref_closest_dis.shape[0]

0.4867052023121387

### 4. using the wrapped-up API

Or use the wrappered function to compute the coverage directly

In [24]:
coverage = compute_coverage(pdb_path_list,
                            metric='Progres',
                            length=sample_length,
                            scope=scope,
                            K=K,
                            verbose=True)
print(f"Coverage: {coverage}")

2025-04-27 16:32:01.373 - INFO - Loading precomputed CATH embeddings


2025-04-27 16:32:06.547 - INFO - Computing embeddings for the sampled proteins
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 13/13 [00:01<00:00, 10.76it/s]
2025-04-27 16:32:07.921 - INFO - Computing coverage


Coverage: 0.4867052023121387


**Note** that the computed coverage is not that high in this case. This is because the KNN hyperparameter K (100) is chosen for a generally higher number of samples (e.g. 500 per length as in our benchmark experiment). You can choose to increase the number of samples to get a more accurate coverage estimation, or temporarily set K to a larger value (e.g. 200).

Please check our supplementary result for more information:

- `Section 3`: More results on the coverage metric, including the choice of K and the effect of different distance definitions.

- `Section 4`: More benchmark results, including curve of coverage as a function of the number of samples.

## Coverage - $D_\mathrm{TM}$

Coverage with $D_\mathrm{TM}$ as distance definition is computed in a similar way. However, due to the non-vectorizable nature of TM-score algorithm, the all v.s. all distance matrix computation can be extremely slow.

### 1. distance matrix within reference

For the pairwise distance matrix of the reference structures, we precomputed it using `TopoDiff.evaluation.diversity.compute_tm_matrix` to save time. However, it is only computed for sample length [50, 75, 100, 125, 150, 175, 200, 225, 250] with a scope of 25.

In [26]:
coverage_data_dir = os.path.join(project_dir, 'data', 'evaluation', 'coverage')
cath_tm_mat_dir = os.path.join(coverage_data_dir, 'cath_tm')
cath_pdb_dir = os.path.join(coverage_data_dir, 'cath_pdb')

In [27]:
ref_tm_path = os.path.join(cath_tm_mat_dir, 'length_%d.pkl' % sample_length)
with open(ref_tm_path, 'rb') as f:
    ref_tm_dict = pickle.load(f)
ref_dis_mat = 1 - (ref_tm_dict['tm_mat_norm_chain'] + ref_tm_dict['tm_mat_norm_chain'].T) / 2

### 2. pairwise distance between generated samples and reference structures

**NOTE** this step can be super time-consuming, you may want to use all available CPU cores to speed up the process.

Also if the notebook hangs for a long time without any output, you may want to run this with command line.

In [28]:
cath_pdb_path_list = [os.path.join(cath_pdb_dir, key) for key in ref_tm_dict['key_list']]

In [29]:
res = compute_tm_matrix(pdb_path_list,
                        path_list_2 = cath_pdb_path_list,
                        n_workers = 160,
)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 16.90it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9525/9525 [00:26<00:00, 358.26it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 952500/952500 [13:05<00:00, 1213.34it/s]


In [30]:
ref_sample_dis_mat = (1 - res['tm_matrix']).T

### 3. compute coverage

In [31]:
# start from 1 to exclude itself
ref_dis_sorted = np.sort(ref_dis_mat, axis=1)[:, 1:]

Second, find out the closest distance for each reference structure to any of the sampled structures.

In [32]:
ref_closest_dis = np.sort(ref_sample_dis_mat, axis=1)[:, 0]

In [33]:
np.sum(ref_closest_dis <= ref_dis_sorted[:, K-1]) / ref_closest_dis.shape[0]

0.6123884514435696

### 4. using the wrapped-up API

In [25]:
coverage = compute_coverage(pdb_path_list,
                            metric='TM',
                            length=sample_length,
                            scope=scope,
                            K=K,
                            n_workers=160,
                            verbose=True)
print(f"Coverage (DTM): {coverage}")

2025-04-27 16:32:18.232 - INFO - Loading precomputed reference TM score matrix
2025-04-27 16:32:21.663 - INFO - Computing reference v.s. sample TM score matrix
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 23.22it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9525/9525 [00:21<00:00, 435.88it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 952500/952500 [13:06

Coverage (DTM): 0.6123884514435696
