This notebook goes through the basic process of encoding backbone structures into the latent space learned by the model. 

Before running the notebook, please check the below prerequisites:

1. The environment has been set up following the instructions in the README file. `TopoDiff` has been installed as a package.

2. The model weight and the structure dataset have been downloaded and placed into the correct path.

```
project(repo)
├── data
│   ├── dataset
│   ├── weights
├── notebook
│   ├── 0_compute_embedding.ipynb (We are here)
│   ├── 1_ ...
│   ├── 2_ ...
├── TopoDiff
```

The whole notebook takes ~5 min to run. If you find it time-consuming, we actually have placed a precomputed embedding in the `data/dataset` directory, so you can skip this notebook and directly go to the next one. The following blocks of code will reproduce the precomputed result.

In [1]:
import os
import sys
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from multiprocessing import Pool
import pickle
from tqdm import tqdm

In [2]:
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import torch
import numpy as np
import openTSNE

In [3]:
import TopoDiff

# model
from TopoDiff.experiment.sampler import Sampler

# data
from TopoDiff.data.representation_data_modules import StructureRepresentationMinimalDataset, StructureRepresentationCollator

# utils
from myopenfold.utils.tensor_utils import tensor_tree_map

# np
from myopenfold.np import protein



attn_core_inplace_cuda not found, will use normal attention implementation


In [4]:
project_dir = os.path.dirname(os.path.dirname(TopoDiff.__path__[0]))
data_dir = os.path.join(project_dir, 'data', 'dataset')

# data

We showcase embedding computation with this CATH-60 dataset. All structures have been stored in pdb format in `data/dataset/CATH_60`. And the corresponding annotation information is stored in `data/dataset/CATH_60.csv`.

In [6]:
label_df_path = os.path.join(data_dir, 'CATH_60.csv')
data_par_dir = os.path.join(data_dir, 'CATH_60')

def load_ca(key):
    pdb_path = os.path.join(data_par_dir, key[1:3], key + '.pdb')
    pdb_string = open(pdb_path, 'r').read()
    pdb = protein.from_pdb_string(pdb_string)
    return key, pdb.atom_positions[:, 1][pdb.atom_mask[:, 1].astype(bool)].astype(np.float32)

In [7]:
label_df = pd.read_csv(label_df_path)
label_df.head()

Unnamed: 0,key,domain_id,CATH_ID,class_1,class_2,class_3,class_4,length,alpha_ratio,beta_ratio,coil_ratio,gyration_radius
0,139lA00,139lA00,1.10.530.40,Mainly Alpha,Orthogonal Bundle,Lysozyme,,162,0.654321,0.092593,0.253086,1.647574
1,16pkA01,16pkA01,3.40.50.1260,Alpha Beta,3-Layer(aba) Sandwich,Rossmann fold,"Phosphoglycerate kinase, N-terminal domain",188,0.414894,0.170213,0.414894,1.555073
2,16pkA02,16pkA02,3.40.50.1260,Alpha Beta,3-Layer(aba) Sandwich,Rossmann fold,"Phosphoglycerate kinase, N-terminal domain",208,0.4375,0.163462,0.399038,1.611004
3,1914A00,1914A00,3.30.720.10,Alpha Beta,2-Layer Sandwich,Signal recognition particle alu RNA binding he...,Signal recognition particle alu RNA binding he...,208,0.362573,0.298246,0.339181,1.64924
4,1a04A01,1a04A01,3.40.50.2300,Alpha Beta,3-Layer(aba) Sandwich,Rossmann fold,,124,0.435484,0.225806,0.33871,1.344929


In [8]:
n_workers = 30
info_keys = label_df['key'].tolist()
info_key_to_idx = {row.key: idx for idx, row in label_df.iterrows()}

# process
with Pool(n_workers) as p:
    results = list(tqdm(p.imap(load_ca, info_keys), total=len(info_keys)))

# merge results
merged_ca_data = {}
for k, ca in results:
    merged_ca_data[k] = ca

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30074/30074 [01:06<00:00, 452.70it/s]


# model

In [9]:
sampler = Sampler(
    model_version='v1_1_2',
)

2024-10-02 16:06:07.112 - INFO - Using single-GPU


In [10]:
device = sampler.model_diffusion.device
config = sampler.config_diffusion
model = sampler.model_diffusion

In [11]:
dataset = StructureRepresentationMinimalDataset(
    data_info = label_df,
    data_dict = merged_ca_data,
    config = config.Data.common,
    extra_config = {'encoder_no_noise': True},
)
collator = StructureRepresentationCollator(pad_in_collator=True)

In [12]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=16,
    shuffle=False,
    collate_fn=collator,
    num_workers=8,
    drop_last=False,
)

In [13]:
result_dict_list = []
for i, batch in enumerate(tqdm(dataloader)):
    with torch.no_grad():
        batch = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
        batch = tensor_tree_map(lambda x: x.to(device), batch)
        result = model.encode_topology(batch)
        result['length'] = batch['length']
        result['sample_idx'] = batch['sample_idx']
        result = tensor_tree_map(lambda x: x.cpu(), result)
        result_dict_list.append(result)

result_dict = {}
for key in result_dict_list[0].keys():
    result_dict[key] = torch.cat([x[key] for x in result_dict_list], dim=0).cpu().numpy()
result_dict['key'] = dataset.keys

  0%|                                                                                                                                                                                                                                                                | 0/1880 [00:00<?, ?it/s]

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1880/1880 [01:14<00:00, 25.32it/s]


This is the 32-dim latent codes of each structure. For better visualization, we will use t-SNE algorithm to reduce the dimension to 2.

In [14]:
tsne = openTSNE.TSNE(
    perplexity=50,
    metric="euclidean",
    n_jobs=30,
    random_state=477778,
    verbose=True,
    n_iter = 250,
    initialization = 'pca',
)
embed_2d = tsne.fit(result_dict['latent_mu'])

--------------------------------------------------------------------------------
TSNE(early_exaggeration=12, n_iter=250, n_jobs=30, perplexity=50,
     random_state=477778, verbose=True)
--------------------------------------------------------------------------------
===> Finding 150 nearest neighbors using Annoy approximate search using euclidean distance...
   --> Time elapsed: 7.17 seconds
===> Calculating affinity matrix...


2024-10-02 16:08:02.900 - INFO - Precomputed initialization provided. Ignoring initalization-related parameters.


   --> Time elapsed: 3.83 seconds
===> Calculating PCA-based initialization...


2024-10-02 16:08:03.321 - INFO - Automatically determined negative gradient method `fft`


   --> Time elapsed: 0.42 seconds
===> Running optimization with exaggeration=12.00, lr=2506.17 for 250 iterations...
Iteration   50, KL divergence 3.8384, 50 iterations in 3.3969 sec
Iteration  100, KL divergence 3.7359, 50 iterations in 3.3955 sec
Iteration  150, KL divergence 3.6273, 50 iterations in 3.6220 sec
Iteration  200, KL divergence 3.5841, 50 iterations in 3.4820 sec


2024-10-02 16:08:20.657 - INFO - Automatically determined negative gradient method `fft`


Iteration  250, KL divergence 3.5641, 50 iterations in 3.3632 sec
   --> Time elapsed: 17.26 seconds
===> Running optimization with exaggeration=1.00, lr=30074.00 for 250 iterations...
Iteration   50, KL divergence 1.7766, 50 iterations in 3.9921 sec
Iteration  100, KL divergence 1.5181, 50 iterations in 8.1890 sec
Iteration  150, KL divergence 1.4025, 50 iterations in 11.8351 sec
Iteration  200, KL divergence 1.3355, 50 iterations in 14.4091 sec
Iteration  250, KL divergence 1.2934, 50 iterations in 16.8835 sec
   --> Time elapsed: 55.31 seconds


In [15]:
save_path = os.path.join(data_dir, 'CATH_60_embedding_reproduced.pkl')
result_dict['embed_2d'] = embed_2d
with open(save_path, 'wb') as f:
    pickle.dump(result_dict, f)

Now we have computed the embedding and transformed it into 2D t-SNE space. Let's go to the next notebook and visualize it along with human (curators of CATH database) annotations as well as many intrinsic structural descriptors !