## ProToken 1.0 Multimer Example


In [1]:
# ==============================================================================
# Copyright 2024 Changping Laboratory & Peking University. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
# ==============================================================================

### Load basic libraries

In [1]:
import os, jax
import pickle as pkl
import numpy as np
import tensorflow as tf
from data_process.preprocess import save_pdb_from_aux, protoken_encoder_preprocess, protoken_decoder_preprocess, init_protoken_model
from data_process.preprocess import protoken_encoder_input_features, protoken_decoder_input_features



### B. Multimer Structures Encoding and Decoding.

#### 1. Prepare the task information

In [2]:
# single chain example

task_mode = 'multi' # 'single' or 'multi'
task_name = 'multimer_example'
pdb_input_dir = './examples/multimer'

saving_dir = f'./results/{task_name}'
os.makedirs(saving_dir, exist_ok=True)

pdb_saving_path = os.path.join(saving_dir, 'reconstructed_protein.pdb')
code_saving_path = os.path.join(saving_dir, 'protoken_index.pkl')

# Notes:
# We have 3 models for different sequence lengths range from 0-512, 512-1024, 1024-2048
# You can choose the model based on the sequence length of your protein,
# Once the sequence length is beyond the current model's range, you need to reinitialize the model.
# Have fun!

#### 2. Prepare the encoder inputs

In [3]:
encoder_inputs, encoder_aux, seq_len = protoken_encoder_preprocess(pdb_input_dir, task_mode=task_mode)
for k, v in zip(protoken_encoder_input_features, encoder_inputs):
    print(k, v.shape)

Found 2 pdb files in the input directory.
seq_mask (1024,)
residue_index (1024,)
backbone_atom_masks (1024, 37)
backbone_atom_positions (1024, 37, 3)
ca_pos (1024, 3)
backbone_affine_tensor (1024, 7)
torsion_angles_sin_cos (1024, 6)
torsion_angles_mask (1024, 3)


In [5]:
# multimer auxiliary information
encoder_aux['chain_length_info']

{0: {'pdb_name': '7W51_B.pdb', 'seq_len': 154, 'start_idx': 0},
 1: {'pdb_name': '7W51_A.pdb', 'seq_len': 361, 'start_idx': 154}}

#### 3. Warmup the encoder and decoder

In [6]:
model = init_protoken_model(seq_len)

2024-04-23 21:19:30.098989: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


Did not find GPU, will use CPU for prediction


2024-04-23 21:19:53.781014: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 4294967296 exceeds 10% of free system memory.
2024-04-23 21:19:54.894360: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 4294967296 exceeds 10% of free system memory.
2024-04-23 21:19:57.575232: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 4294967296 exceeds 10% of free system memory.
2024-04-23 21:19:58.284182: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 2147483648 exceeds 10% of free system memory.
2024-04-23 21:19:58.284224: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 2147483648 exceeds 10% of free system memory.


#### 4. Encode the protein structure and get the ProToken Index

In [7]:
encoder_results = model.encoder(*encoder_inputs)

In [10]:
protoken_index_ = np.asarray([encoder_results["protoken_index"][p] for p in range(encoder_aux['seq_mask'].shape[0]) \
                                if encoder_aux['seq_mask'][p]])
protoken_index_multimer = [protoken_index_[v['start_idx']:v['start_idx']+v['seq_len']] for k, v in encoder_aux['chain_length_info'].items()]

for k in encoder_aux['chain_length_info'].keys():
    print('PDB ID: ', encoder_aux['chain_length_info'][k]['pdb_name'])
    print(f'Chain Length: ', encoder_aux['chain_length_info'][k]['seq_len'])
    print(f'ProToken Index: {protoken_index_multimer[k].shape}\n{protoken_index_multimer[k]}')
    encoder_aux['chain_length_info'][k]['protoken_index'] = protoken_index_multimer[k]
with open(code_saving_path, 'wb') as f:
    pkl.dump(encoder_aux['chain_length_info'], f)

PDB ID:  7W51_B.pdb
Chain Length:  154
ProToken Index: (154,)
[ 65 471 223 368  25 167  29 336  87 155  49 109 125 268 486  30  97 324
 303 212 393 415 411 369  38  38  50 385 487 383 346 463  47 337 346 466
 254 415  12 227 454  18 161 346 268 333  80 277  26 174 111 301 354  82
 495 249 297  46 291 313 227 134  72 307 132  64 235  71 422 184 485 226
  34 272 507  30 509 507 361 261  60 257 240 210 466 293 354 459 391 230
 256 267 227  51 315 389 495  25 405 343 314 219 337 305 268  97 212   5
 254 367 449 215 131 294 106 319 249 461 262 347 368 369 354 471 315  47
 440  65 232 113 177  87 159  58 118 254 167 327 439 191 440 215 469  33
 206 379  25 507 504 216  15 140 202  91]
PDB ID:  7W51_A.pdb
Chain Length:  361
ProToken Index: (361,)
[307 154 463 345 496 481 146 361  54 176  16  72 328 259 483 395 479 346
 385 342 504 224   7 215 348 113 478 470 346 414 346  50 449 200 300 141
 367 139 274 492 265  58  26 240 257  95 369 472 216 241 350 459 427 293
  51 389 442  38 404 109 393 39

#### 5. Prepare the decoder inputs

In [11]:
# Multimer ProToken decoder's inputs should be a list of ProToken index and ProToken index should be in np.ndarray format.
decoder_inputs = protoken_decoder_preprocess(protoken_index_multimer, task_mode=task_mode)
for k, v in zip(protoken_decoder_input_features, decoder_inputs):
    print(f'{k}: {v.shape}')

protoken_index: (1024,)
protoken_mask: (1024,)
residue_index: (1024,)


#### 6. Decode the ProToken Index and get the reconstructed protein structure

In [12]:
decoder_results = model.decoder(*decoder_inputs)
reconstructed_atom_positions = np.asarray(decoder_results['reconstructed_atom_positions'])

#### 7. Compare the original and reconstructed protein structures

In [13]:
from data_process.preprocess import lddt
lDDT = lddt(reconstructed_atom_positions[None, ...][:,:,1,:], 
            encoder_aux['backbone_atom_positions'][None, ...][:,:,1,:],
            encoder_aux['seq_mask'][None,...,None], per_residue=True)[0]
print(f"Average lDDT: {np.mean(lDDT[:np.sum(encoder_aux['seq_mask'])])}")

Average lDDT: 0.9068252570445748


#### 8. Save the reconstructed protein structure

In [18]:
partial_aux = {"aatype": encoder_aux["aatype"].astype(np.int32),
               "residue_index": decoder_inputs[-1].astype(np.int32)+1,
               "atom_positions": reconstructed_atom_positions.astype(np.float32),
               "atom_mask": encoder_aux["backbone_atom_masks"].astype(np.float32),
               "plddt": lDDT.astype(np.float32)}
save_pdb_from_aux(partial_aux, pdb_saving_path)

# if you want to save the protein without encoder_aux, 
# use the following code to save the protein
# aatype_all_gly = np.asarray(decoder_inputs[1]).astype(np.int32)*7
# backbone_atom_mask = np.repeat(np.asarray([1,1,1,0,1]+[0]*32)[None,...], aatype_all_gly.shape[0], axis=0).astype(np.float32)*decoder_inputs[1][..., None]
# plddt = np.ones_like(aatype_all_gly).astype(np.float32)*99.99
# partial_aux = {"aatype": aatype_all_gly,
#                "residue_index": decoder_inputs[-1].astype(np.int32)+1,
#                "atom_positions": reconstructed_atom_positions.astype(np.float32),
#                "atom_mask": backbone_atom_mask,
#                "plddt": plddt}
# save_pdb_from_aux(partial_aux, pdb_saving_path)

In [19]:
print(f'PDB saved at: {pdb_saving_path}')
print(f'ProTokens saved at: {code_saving_path}')
print('Average lDDT:', round(np.mean(lDDT[:np.sum(encoder_aux['seq_mask'])]), 3), 'Seq_Len:', seq_len)
print(f'Job finished!\n')

PDB saved at: ./results/multimer_example/reconstructed_protein.pdb
ProTokens saved at: ./results/multimer_example/protoken_index.pkl
Average lDDT: 0.907 Seq_Len: 515
Job finished!

