## ProToken 1.0 Single Chain Example

In [1]:
# ==============================================================================
# Copyright 2024 Changping Laboratory & Peking University. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and limitations under the License.
# ==============================================================================

### Load basic libraries

In [3]:
import os, jax
import pickle as pkl
import numpy as np
import tensorflow as tf
from data_process.preprocess import save_pdb_from_aux, protoken_encoder_preprocess, protoken_decoder_preprocess, init_protoken_model
from data_process.preprocess import protoken_encoder_input_features, protoken_decoder_input_features

### A. Single Chain Protein Structure Encoding and Decoding.

#### 1. Prepare the task information

In [4]:
# single chain example

task_mode = 'single' # 'single' or 'multi'
task_name = 'single_example'
pdb_input_dir = './examples/single'

saving_dir = f'./results/{task_name}'
os.makedirs(saving_dir, exist_ok=True)

pdb_saving_path = os.path.join(saving_dir, 'reconstructed_protein.pdb')
code_saving_path = os.path.join(saving_dir, 'protoken_index.pkl')

# Notes:
# We have 3 models for different sequence lengths range from 0-512, 512-1024, 1024-2048
# You can choose the model based on the sequence length of your protein,
# Once the sequence length is beyond the current model's range, you need to reinitialize the model.
# Have fun!

#### 2. Prepare the encoder inputs

In [5]:
encoder_inputs, encoder_aux, seq_len = protoken_encoder_preprocess(pdb_input_dir, task_mode=task_mode)
for k, v in zip(protoken_encoder_input_features, encoder_inputs):
    print(k, v.shape)

Found 1 pdb files in the input directory.
seq_mask (512,)
residue_index (512,)
backbone_atom_masks (512, 37)
backbone_atom_positions (512, 37, 3)
ca_pos (512, 3)
backbone_affine_tensor (512, 7)
torsion_angles_sin_cos (512, 6)
torsion_angles_mask (512, 3)


#### 3. Warmup the encoder and decoder

In [6]:
model = init_protoken_model(seq_len)

Did not find GPU, will use CPU for prediction


2024-04-23 20:48:00.354811: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


#### 4. Encode the protein structure and get the ProToken Index

In [7]:
encoder_results = model.encoder(*encoder_inputs)
protoken_index = np.asarray([encoder_results["protoken_index"][p] for p in range(encoder_aux['seq_mask'].shape[0]) \
                                if encoder_aux['seq_mask'][p]])
print(f'ProToken Index: {protoken_index.shape}\n{protoken_index}')
with open(code_saving_path, 'wb') as f:
    pkl.dump(protoken_index, f)

ProToken Index: (193,)
[258 384 416 294 324 454 324 227 127 104 342 100 373 381  92 215 487 403
  92 250 509 324 240 177 256 472 384  74 228 471  24 241 329 202 369 132
 458 487  47 333 151 267 231 483 133  51  28 132  32   0 362  78 493 220
  24  12 196 364 337 210 358 439 367 161 293 216 450 110 106 266 257 473
 495 291  46 503  92 328 214  48 384 360 146 266 476  50 297 185 241  50
  34 362 241 485 163 237 304  27 419 299  72  42 293 329 430  76 315 152
 481 268 315 123 361  59 194 262 372 248 130 268 425 109 256 118 386 264
 393 305 347 190 411 403 106 407 446  14  38 487 161 342 190 254  42 334
  49 125 187 466 143 457 324 439 109 161 456 163  30 161 415 440 151 170
 291 395 274  42 457 246  25  42 224 315 442 471 349 303 442 202 451 261
  38 272 165 230 466 168 434 247 450 411  52  95 264]


#### 5. Prepare the decoder inputs

In [8]:
# Single chain ProToken decoder's inputs should be a array of ProToken indexes in the np.ndarray format.
decoder_inputs = protoken_decoder_preprocess(protoken_index, task_mode=task_mode)
for k, v in zip(protoken_decoder_input_features, decoder_inputs):
    print(f'{k}: {v.shape}')

protoken_index: (512,)
protoken_mask: (512,)
residue_index: (512,)


#### 6. Decode the ProToken Index and get the reconstructed protein structure

In [9]:
decoder_results = model.decoder(*decoder_inputs)
reconstructed_atom_positions = np.asarray(decoder_results['reconstructed_atom_positions'])

#### 7. Compare the original and reconstructed protein structures

In [28]:
from data_process.preprocess import lddt
lDDT = lddt(reconstructed_atom_positions[None, ...][:,:,1,:], 
            encoder_aux['backbone_atom_positions'][None, ...][:,:,1,:],
            encoder_aux['seq_mask'][None,...,None], per_residue=True)[0]
print(f"Average lDDT: {np.mean(lDDT[:np.sum(encoder_aux['seq_mask'])])}")

Average lDDT: 0.9680577494332036


#### 8. Save the reconstructed protein structure

In [50]:
partial_aux = {"aatype": encoder_aux["aatype"].astype(np.int32),
               "residue_index": decoder_inputs[-1].astype(np.int32)+1,
               "atom_positions": reconstructed_atom_positions.astype(np.float32),
               "atom_mask": encoder_aux["backbone_atom_masks"].astype(np.float32),
               "plddt": lDDT.astype(np.float32)}
save_pdb_from_aux(partial_aux, pdb_saving_path)

# if you want to save the protein without encoder_aux, 
# use the following code to save the protein
# aatype_all_gly = np.asarray(decoder_inputs[1]).astype(np.int32)*7
# backbone_atom_mask = np.repeat(np.asarray([1,1,1,0,1]+[0]*32)[None,...], aatype_all_gly.shape[0], axis=0).astype(np.float32)*decoder_inputs[1][..., None]
# plddt = np.ones_like(aatype_all_gly).astype(np.float32)*99.99
# partial_aux = {"aatype": aatype_all_gly,
#                "residue_index": decoder_inputs[-1].astype(np.int32)+1,
#                "atom_positions": reconstructed_atom_positions.astype(np.float32),
#                "atom_mask": backbone_atom_mask,
#                "plddt": plddt}
# save_pdb_from_aux(partial_aux, pdb_saving_path)

In [32]:
print(f'PDB saved at: {pdb_saving_path}')
print(f'ProTokens saved at: {code_saving_path}')
print('Average lDDT:', round(np.mean(lDDT[:np.sum(encoder_aux['seq_mask'])]), 3), 'Seq_Len:', seq_len)
print(f'Job finished!\n')

PDB saved at: ./results/single_example/reconstructed_protein.pdb
ProTokens saved at: ./results/single_example/protoken_index.pkl
Average lDDT: 0.968 Seq_Len: 193
Job finished!

