In [14]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "False"
import sys
sys.path.append('..')

import jax
import jax.numpy as jnp
import pickle as pkl
import numpy as np
from tqdm import tqdm
import argparse

from functools import partial
from src.model.diffusion_transformer import DiffusionTransformer
from train.schedulers import GaussianDiffusion
import datetime
from flax.jax_utils import replicate
from functools import reduce

from configs.global_config import global_config
from configs.dit_config import dit_config
global_config.dropout_flag = False 

### Load embedding 
with open('../embeddings/protoken_emb.pkl', 'rb') as f:
    protoken_emb = jnp.array(pkl.load(f), dtype=jnp.float32)
with open('../embeddings/aatype_emb.pkl', 'rb') as f:
    aatype_emb = jnp.array(pkl.load(f), dtype=jnp.float32)

## Preparation

In [15]:
#### constants
NRES = 512
NSAMPLE_PER_DEVICE = 8
DIM_EMB_PTK = protoken_emb.shape[-1]
DIM_EMB_AA = aatype_emb.shape[-1]
DIM_EMB = DIM_EMB_PTK + DIM_EMB_AA # 40 # 32 + 8
NDEVICES = len(jax.devices())

BATCH_SIZE = NSAMPLE_PER_DEVICE * NDEVICES

In [16]:
#### function utils 

def split_multiple_rng_keys(rng_key, num_keys):
    rng_keys = jax.random.split(rng_key, num_keys + 1)
    return rng_keys[:-1], rng_keys[-1]

def flatten_list_of_dicts(list_of_dicts):
    ### [{a: [1,2,3,4]}] -> [{a:1}, {a:2}, {a:3}, {a:4}]
    flattened_lists = [[{k: v[i] for k, v in d.items()} 
                        for i in range(len(next(iter(d.values()))))] for d in list_of_dicts]
    return reduce(lambda x, y: x+y, flattened_lists, [])

def protoken_emb_distance_fn(x, y):
    x_ = x / (jnp.linalg.norm(x, axis=-1, keepdims=True) + 1e-6)
    y_ = y / (jnp.linalg.norm(y, axis=-1, keepdims=True) + 1e-6)
    
    return -jnp.sum(x_ * y_, axis=-1)

def aatype_emb_distance_fn(x, y):
    return jnp.sum((x - y) ** 2, axis=-1)

def aatype_index_to_resname(aatype_index):
    restypes = [
        'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
        'S', 'T', 'W', 'Y', 'V'
    ]
    
    return "".join([restypes[int(i)] for i in aatype_index])

def resname_to_aatype_index(resnames):
    restypes = [
        'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P',
        'S', 'T', 'W', 'Y', 'V'
    ]
    return np.array([restypes.index(a) for a in resnames], dtype=np.int32)

In [17]:
#### load model & params 
dit_model = DiffusionTransformer(
    config=dit_config, global_config=global_config
)
num_diffusion_timesteps = 500
scheduler = GaussianDiffusion(num_diffusion_timesteps=num_diffusion_timesteps)

#### rng keys
rng_key = jax.random.PRNGKey(8888)
np.random.seed(7777)

##### load params
ckpt_path = '../ckpts/PT_DiT_params_2000000.pkl'
with open(ckpt_path, "rb") as f:
    params = pkl.load(f)
    params = jax.tree_util.tree_map(lambda x: jnp.array(x), params)
    
##### replicate params
params = replicate(params)

In [18]:
##### main inference functions
jit_apply_fn = jax.jit(dit_model.apply)
infer_protuple = True

def clamp_x0_fn(x0):
    protoken_indexes = \
                jnp.argmin(protoken_emb_distance_fn(x0[..., None, :protoken_emb.shape[-1]], 
                                                  protoken_emb.reshape((1,)*(len(x0.shape)-1) + protoken_emb.shape)), axis=-1)
    if bool(infer_protuple):
        aatype_indexes = \
                jnp.argmin(aatype_emb_distance_fn(x0[..., None, protoken_emb.shape[-1]:], 
                                                  aatype_emb.reshape((1,)*(len(x0.shape)-1) + aatype_emb.shape)), axis=-1)
        return jnp.concatenate([protoken_emb[protoken_indexes], aatype_emb[aatype_indexes]], axis=-1)
    else:
        return protoken_emb[protoken_indexes]

def denoise_step(params, x, seq_mask, t, residue_index, rng_key,
                 clamp_x0_fn=None):
    t = jnp.full((x.shape[0],), t)
    indicator = params['params']['protoken_indicator']
    if bool(infer_protuple):
        indicator = jnp.concatenate([indicator, params['params']['aatype_indicator']], 
                                    axis=-1)
    eps_prime = jit_apply_fn({'params': params['params']['model']}, x + indicator[None, ...], 
                             seq_mask, t, tokens_rope_index=residue_index)
    mean, variance, log_variance = scheduler.p_mean_variance(x, t, eps_prime, clip=False, clamp_x0_fn=clamp_x0_fn)
    rng_key, normal_key = jax.random.split(rng_key)
    x = mean + jnp.exp(0.5 * log_variance) * jax.random.normal(normal_key, x.shape)
    return x, rng_key

def q_sample(x, t, rng_key):
    t = jnp.full((x.shape[0], ), t)
    rng_key, normal_key = jax.random.split(rng_key)
    eps = jax.random.normal(normal_key, x.shape, dtype=jnp.float32)
    x_t = scheduler.q_sample(x, t, eps)
    return x_t, rng_key

def noise_step(x, t, rng_key):
    t = jnp.full((x.shape[0], ), t)
    rng_key, normal_key = jax.random.split(rng_key)
    x = scheduler.q_sample_step(x, t, jax.random.normal(normal_key, x.shape))
    return x, rng_key

def index_from_embedding(x):
    # x: (B, Nres, Nemb)
    protoken_indexes = \
        jnp.argmin(protoken_emb_distance_fn(x[..., None, :protoken_emb.shape[-1]], 
                                            protoken_emb[None, None, ...]), axis=-1)
    ret = {'protoken_indexes': protoken_indexes}
    if bool(infer_protuple):
        aatype_indexes = \
            jnp.argmin(aatype_emb_distance_fn(x[..., None, protoken_emb.shape[-1]:], 
                                                aatype_emb[None, None, ...]), axis=-1)
        ret.update({'aatype_indexes': aatype_indexes})
        
    return ret            
    
pjit_denoise_step = jax.pmap(jax.jit(partial(denoise_step, clamp_x0_fn=None)), axis_name="i", 
                            in_axes=(0, 0, 0, None, 0, 0))
pjit_denoise_step_clamped = jax.pmap(jax.jit(partial(denoise_step, clamp_x0_fn=clamp_x0_fn)), axis_name="i", 
                            in_axes=(0, 0, 0, None, 0, 0))
pjit_q_sample = jax.pmap(jax.jit(noise_step), axis_name="i",
                            in_axes=(0, None, 0))
pjit_noise_step = jax.pmap(jax.jit(noise_step), axis_name="i",
                            in_axes=(0, None, 0))
pjit_index_from_embedding = jax.pmap(jax.jit(index_from_embedding), axis_name="i")

### RePaint Algorithm

In [19]:
def make_repaint_info(aatypes, protokens, aatype_context_ids, protoken_context_ids):
    protoken_context = protoken_emb[protokens]
    aatype_context = aatype_emb[aatypes]
    assert len(protoken_context) == len(aatype_context), 'seq_len mismatch: {} != {}'.format(len(protoken_context), len(aatype_context))
    seq_len = len(protoken_context)
    
    repaint_context = np.concatenate([protoken_context, aatype_context], axis=-1)
    repaint_mask_aa = np.array([[0,]*DIM_EMB_PTK+[1,]*DIM_EMB_AA if i in aatype_context_ids \
                                  else [0,]*DIM_EMB for i in range(seq_len)], 
                                dtype=np.bool_)
    repaint_mask_ptk = np.array([[1,]*DIM_EMB_PTK+[0,]*DIM_EMB_AA if i in protoken_context_ids \
                                  else [0,]*DIM_EMB for i in range(seq_len)], 
                                dtype=np.bool_)
            
    return repaint_context, np.logical_or(repaint_mask_aa, repaint_mask_ptk)

In [20]:
n_eq_steps = 50 ### more n eq steps -> higher quality, in RePaint, we recommand more n eq steps
phasing_time = 250 ### controls balance between diversity & quality, larger phasing time -> higer quality, lower diversity
def run_infer(x, seq_mask, residue_index, rng_keys, 
              repaint_context=None, repaint_mask=None, repaint_time_steps=np.arange(num_diffusion_timesteps)):
    for ti in tqdm(range(num_diffusion_timesteps)):
        t = num_diffusion_timesteps - ti
        for eq_step in range(n_eq_steps):
            denoise_fn = pjit_denoise_step if t > phasing_time else pjit_denoise_step_clamped
            x, rng_keys = denoise_fn(params, x, seq_mask, t, residue_index, rng_keys)
            x, rng_keys = pjit_noise_step(x, t, rng_keys)
            
            if repaint_context is not None and t in repaint_time_steps:
                repaint_context_ = repaint_context[..., t-1] if len(repaint_context.shape) > len(x.shape) \
                                    else repaint_context
                repaint_mask_ = repaint_mask[..., t-1] if len(repaint_mask.shape) > len(x.shape) \
                                    else repaint_mask
                
                repaint_context_t, rng_keys = pjit_q_sample(repaint_context_, t, rng_keys)
                x = repaint_mask_ * repaint_context_t + (1 - repaint_mask_) * x
            
        x, rng_keys = denoise_fn(params, x, seq_mask, t, residue_index, rng_keys)

    ret = {'embedding': x, 'seq_mask': seq_mask, 'residue_index': residue_index}
    ret.update(pjit_index_from_embedding(x))
    
    return ret

## (Contextual) Inverse Folding

In [33]:
### example: 8CYK

### obatin ProTokens
working_dir = '../'
os.system(f'''export PYTHONPATH={working_dir}/PROTOKEN
          python {working_dir}/PROTOKEN/scripts/infer_batch.py\
            --encoder_config {working_dir}/PROTOKEN/config/encoder.yaml\
            --decoder_config {working_dir}/PROTOKEN/config/decoder.yaml\
            --vq_config {working_dir}/PROTOKEN/config/vq.yaml\
            --pdb_dir_path {working_dir}/example_scripts/results/inverse_folding\
            --save_dir_path {working_dir}/example_scripts/results/inverse_folding\
            --load_ckpt_path {working_dir}/ckpts/protoken_params_100000.pkl''')

2025-01-14 09:49:01.690842: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.



Now inference pdbs at: ..//example_scripts/results/inverse_folding, 
Example: ..//example_scripts/results/inverse_folding/8CYK_B.pdb
Total validation pdb number: 1
Saving pdb at: ..//example_scripts/results/inverse_folding/stage_1/inference_pdbs, 
Example: ..//example_scripts/results/inverse_folding/stage_1/inference_pdbs/8CYK_B.pdb
	BATCH_SIZE: 16
pdb names in every device:  8CYK_B
pdb names in every device:  
pdb names in every device:  
pdb names in every device:  
pdb names in every device:  
pdb names in every device:  
pdb names in every device:  
pdb names in every device:  
Not enough data for sharding, extend the last data.
VALID_DATA_NUM: 1


2025-01-14 09:51:36.620354: E external/xla/xla/service/rendezvous.cc:38] This thread has been waiting for `initialize clique for rank 1; clique=devices=[0,1,2,3,4,5,6,7]; stream=0; run_id=1259322816` for 10 seconds and may be stuck. Expected 8 threads to join the rendezvous, but not all of them arrived on time.
2025-01-14 09:51:36.620425: E external/xla/xla/service/rendezvous.cc:38] This thread has been waiting for `initialize clique for rank 5; clique=devices=[0,1,2,3,4,5,6,7]; stream=0; run_id=1259322816` for 10 seconds and may be stuck. Expected 8 threads to join the rendezvous, but not all of them arrived on time.
2025-01-14 09:51:36.620464: E external/xla/xla/service/rendezvous.cc:38] This thread has been waiting for `initialize clique for rank 7; clique=devices=[0,1,2,3,4,5,6,7]; stream=0; run_id=1259322816` for 10 seconds and may be stuck. Expected 8 threads to join the rendezvous, but not all of them arrived on time.
2025-01-14 09:51:36.620496: E external/xla/xla/service/rendez

pdb saved at: ..//example_scripts/results/inverse_folding/stage_1/inference_pdbs/8CYK_B.pdb 
seq_len: 129
preprocessing time: 27.764188s
inference time: 80.605927s
code_usage: 0.0078125
total time: 108.370115s
seq_len_max: 129, seq_len_min: 129
stage_1 finished.
pdb, loss_aux, seq_len_dict, vq_code_indexes_dict saved at ..//example_scripts/results/inverse_folding/stage_1
All finished.


0

In [34]:
### load context
with open('./results/inverse_folding/stage_1/generator_inputs/8CYK_B.pkl', 'rb') as f:
    data_dict = pkl.load(f)

seq_len = len(data_dict['protokens'])
seq_mask = np.ones(seq_len, dtype=np.bool_)
residue_index = np.arange(seq_len, dtype=np.int32)
protoken_context = data_dict['protokens'].astype(np.int32)
aatype_context = data_dict['aatype'].astype(np.int32)

input_dict = {
    'seq_mask': seq_mask, 'residue_index': residue_index,
}

for k, v in input_dict.items(): print(k, v.shape, v.dtype)

seq_mask (129,) bool
residue_index (129,) int32


In [35]:
protoken_context_resids = np.arange(seq_len)
#### contextual inverse folding: put sequence context here
aatype_context_resids = [] 

repaint_context, repaint_mask = make_repaint_info(aatype_context, protoken_context, aatype_context_resids, protoken_context_resids)

repaint_dict = {
    'repaint_context': repaint_context, 'repaint_mask': repaint_mask.astype(np.float32)
}
for k, v in repaint_dict.items(): print(k, v.shape, v.dtype)

repaint_context (129, 40) float32
repaint_mask (129, 40) float32


### Run Infer

In [36]:
### preprocessing inputs

def reshape_tile_pad_x(x):
    x_shape = x.shape
    x = np.pad(x, ((0, NRES - x_shape[0]), ) + ((0,0),) * (len(x_shape) - 1))
    x_shape = x.shape
    
    x = np.tile(x[None, ...], (BATCH_SIZE, ) + (1, ) * len(x_shape))
    x = x.reshape(NDEVICES, NSAMPLE_PER_DEVICE, *x_shape)
    return x

input_dict = jax.tree.map(
    lambda x: jnp.array(reshape_tile_pad_x(x)), input_dict
)
repaint_dict = jax.tree.map(
    lambda x: jnp.array(reshape_tile_pad_x(x)), repaint_dict
)

init_key, rng_key = jax.random.split(rng_key)
x = jax.random.normal(init_key, (NDEVICES, NSAMPLE_PER_DEVICE, NRES, DIM_EMB))
input_dict['x'] = x

rng_keys, rng_key = split_multiple_rng_keys(rng_key, NDEVICES)
rng_keys = jnp.reshape(rng_keys, (NDEVICES, -1))

for k, v in input_dict.items(): print(k, v.shape, v.dtype)
for k, v in repaint_dict.items(): print(k, v.shape, v.dtype)

residue_index (8, 8, 512) int32
seq_mask (8, 8, 512) bool
x (8, 8, 512, 40) float32
repaint_context (8, 8, 512, 40) float32
repaint_mask (8, 8, 512, 40) float32


In [37]:
ret = run_infer(input_dict['x'], input_dict['seq_mask'], input_dict['residue_index'], rng_keys,
                repaint_dict['repaint_context'], repaint_dict['repaint_mask'])

ret = jax.tree_util.tree_map(lambda x: np.array(x.reshape(BATCH_SIZE, *x.shape[2:])).tolist(), ret)
with open('results/inverse_folding/result.pkl', 'wb') as f:
    pkl.dump(ret, f)
    
ret_ = flatten_list_of_dicts([ret])
with open('results/inverse_folding/result_flatten.pkl', 'wb') as f:
    pkl.dump(ret_, f)

100%|██████████| 500/500 [16:54<00:00,  2.03s/it]


In [42]:
for i, r in enumerate(ret_):
    protoken_idx = np.array(r['protoken_indexes'])[:seq_len]
    aatype_idx = np.array(r['aatype_indexes'])[:seq_len]
    print('seq{}: {}'.format(i, ''.join(aatype_index_to_resname(aatype_idx))))

seq0: METVVLKIKKGEELLAVLDYLYRSPDEDAIIELEIVDPKVKNLIAETIIMAAKELGADEEKIEALHNMINAKGVRDIYVKLKVKEEEYTIELVAELIDGRSAPITFTFNDKNGFLHILDLIRKQLKKLE
seq1: MAKREVMIDTPDEVIESLDHYRRSKSDVDEINFRFSDKQHKNLFIREVVEVAAKNDLPENQVKKMDEELKGPHTVDVDIEVYVKKDTFLITVRAVGKDGVVGSASFVASNEQHVERLAAAVGAAMDRAS
seq2: DRTIEMNVETANQLLEGVDRVYDLTEEDIIVKFDIEDIKGKLSLVEGLFRLVEELGADPLLVETLNGYLKSQDSIAVEITIRYEDGMFHLNVETHLDTNGVGGMLLTLENFDEARKFIRKIEKALSNLF
seq3: SKITKTEVKNTKELIFKLDSVCNSKSEIVTGIYLIDNNEIHISLGYTIIKNLNELENSNQKLSYMQERINSDDTNYLVLVAQRYNNNFCLQISINFKSGRTESVCYEFDNLEAFFQQMSAIAKSVEELG
seq4: MHTENKTVNAQQDLIILIRQIYKDSSPKDELSLEIDNDHVADLLANACAEVAKSQGMAKGTLDSLREGIESESTKHIKLDFRRSMQYYSLEFEFEYENDTKQKLVLTAPEEEKLRNLLTKLENLISKLK
seq5: SRFYIITVTTIEEVQETLLRLVSSTDEHLYIKFRIRDPELVRLIALCVRQALLDLRADEKLLGRLQEMLESFHYRDCCIELRKGIALYELDFLVNLRDGTKQTYMLHSADFSKVRGLIDKIAGKFRTLQ
seq6: SILFFREIHTVQEVTETLLRIFNQKSLAIDAEFRVSDKQLAVVLAVAIVETFKDMGFREVSLNHVRMFLDSVNFIDLGLEFKHTSTKLFIDIIAKAPDGETSKLSYSADDIDQFKSILLSIKSVIKGLE
seq7: MRNMRWHVRSRADLRDELAGLKASDADIQRVEVYIDHPTIAN

## Contextual Structure Design

In [25]:
### example: 5jxe_G_VH 

### obatin ProTokens
working_dir = '../'
os.system(f'''export PYTHONPATH={working_dir}/PROTOKEN
          python {working_dir}/PROTOKEN/scripts/infer_batch.py\
            --encoder_config {working_dir}/PROTOKEN/config/encoder.yaml\
            --decoder_config {working_dir}/PROTOKEN/config/decoder.yaml\
            --vq_config {working_dir}/PROTOKEN/config/vq.yaml\
            --pdb_dir_path {working_dir}/example_scripts/results/contextual_scaffolding\
            --save_dir_path {working_dir}/example_scripts/results/contextual_scaffolding\
            --load_ckpt_path {working_dir}/ckpts/protoken_params_100000.pkl''')

2025-01-13 21:14:25.706101: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.



Now inference pdbs at: ..//example_scripts/results/contextual_scaffolding, 
Example: ..//example_scripts/results/contextual_scaffolding/5jxe_G_VH.pdb
Total validation pdb number: 1
Saving pdb at: ..//example_scripts/results/contextual_scaffolding/stage_1/inference_pdbs, 
Example: ..//example_scripts/results/contextual_scaffolding/stage_1/inference_pdbs/5jxe_G_VH.pdb
	BATCH_SIZE: 16
pdb names in every device:  5jxe_G_VH
pdb names in every device:  
pdb names in every device:  
pdb names in every device:  
pdb names in every device:  
pdb names in every device:  
pdb names in every device:  
pdb names in every device:  
Not enough data for sharding, extend the last data.
VALID_DATA_NUM: 1


2025-01-13 21:16:46.899637: E external/xla/xla/service/rendezvous.cc:38] This thread has been waiting for `initialize clique for rank 6; clique=devices=[0,1,2,3,4,5,6,7]; stream=0; run_id=1259322816` for 10 seconds and may be stuck. Expected 8 threads to join the rendezvous, but not all of them arrived on time.
2025-01-13 21:16:46.899680: E external/xla/xla/service/rendezvous.cc:38] This thread has been waiting for `initialize clique for rank 3; clique=devices=[0,1,2,3,4,5,6,7]; stream=0; run_id=1259322816` for 10 seconds and may be stuck. Expected 8 threads to join the rendezvous, but not all of them arrived on time.
2025-01-13 21:16:46.899707: E external/xla/xla/service/rendezvous.cc:38] This thread has been waiting for `initialize clique for rank 4; clique=devices=[0,1,2,3,4,5,6,7]; stream=0; run_id=1259322816` for 10 seconds and may be stuck. Expected 8 threads to join the rendezvous, but not all of them arrived on time.
2025-01-13 21:16:46.899728: E external/xla/xla/service/rendez

pdb saved at: ..//example_scripts/results/contextual_scaffolding/stage_1/inference_pdbs/5jxe_G_VH.pdb 
seq_len: 218
preprocessing time: 24.317128s
inference time: 83.330078s
code_usage: 0.033203125
total time: 107.647206s
seq_len_max: 218, seq_len_min: 218
stage_1 finished.
pdb, loss_aux, seq_len_dict, vq_code_indexes_dict saved at ..//example_scripts/results/contextual_scaffolding/stage_1
All finished.


0

In [43]:
### load context
with open('./results/contextual_scaffolding/stage_1/generator_inputs/5jxe_G_VH.pkl', 'rb') as f:
    data_dict = pkl.load(f)

seq_len = len(data_dict['protokens'])
seq_mask = np.ones(seq_len, dtype=np.bool_)
residue_index = np.arange(seq_len, dtype=np.int32)
protoken_context = data_dict['protokens'].astype(np.int32)
aatype_context = data_dict['aatype'].astype(np.int32)

input_dict = {
    'seq_mask': seq_mask, 'residue_index': residue_index,
}

for k, v in input_dict.items(): print(k, v.shape, v.dtype)

seq_mask (218,) bool
residue_index (218,) int32


### Select CDR3 as context

In [44]:
import ast

def parse_annotation(annotation_file, aatype_indexes):
    seq_str = aatype_index_to_resname(aatype_indexes)
    with open(annotation_file, 'r') as f:
        contents = f.readlines()
        annotation_dict = ast.literal_eval(contents[0])
    annotation_resid_dict = {}
    for k, v in annotation_dict.items():
        start_id = seq_str.find(v)
        if start_id == -1:
            raise ValueError('can not find {} in {}'.format(v, seq_str))
        end_id = start_id + len(v)
        annotation_resid_dict[k] = np.arange(start_id, end_id)
        
    return annotation_resid_dict

In [45]:
annotation_resid_dict = parse_annotation('./results/contextual_scaffolding/5jxe_G_VH_annotation.txt', 
                                         aatype_context)

### select H-CDR3 as context
protoken_context_resids = annotation_resid_dict['H-CDR3']
aatype_context_resids = annotation_resid_dict['H-CDR3']

repaint_context, repaint_mask = make_repaint_info(aatype_context, protoken_context, aatype_context_resids, protoken_context_resids)

repaint_dict = {
    'repaint_context': repaint_context, 'repaint_mask': repaint_mask.astype(np.float32)
}
for k, v in repaint_dict.items(): print(k, v.shape, v.dtype)

repaint_context (218, 40) float32
repaint_mask (218, 40) float32


### Run Infer

In [47]:
### preprocessing inputs

def reshape_tile_pad_x(x):
    x_shape = x.shape
    x = np.pad(x, ((0, NRES - x_shape[0]), ) + ((0,0),) * (len(x_shape) - 1))
    x_shape = x.shape
    
    x = np.tile(x[None, ...], (BATCH_SIZE, ) + (1, ) * len(x_shape))
    x = x.reshape(NDEVICES, NSAMPLE_PER_DEVICE, *x_shape)
    return x

input_dict = jax.tree.map(
    lambda x: jnp.array(reshape_tile_pad_x(x)), input_dict
)
repaint_dict = jax.tree.map(
    lambda x: jnp.array(reshape_tile_pad_x(x)), repaint_dict
)

init_key, rng_key = jax.random.split(rng_key)
x = jax.random.normal(init_key, (NDEVICES, NSAMPLE_PER_DEVICE, NRES, DIM_EMB))
input_dict['x'] = x

rng_keys, rng_key = split_multiple_rng_keys(rng_key, NDEVICES)
rng_keys = jnp.reshape(rng_keys, (NDEVICES, -1))

for k, v in input_dict.items(): print(k, v.shape, v.dtype)
for k, v in repaint_dict.items(): print(k, v.shape, v.dtype)

residue_index (8, 8, 512) int32
seq_mask (8, 8, 512) bool
x (8, 8, 512, 40) float32
repaint_context (8, 8, 512, 40) float32
repaint_mask (8, 8, 512, 40) float32


In [48]:
ret = run_infer(input_dict['x'], input_dict['seq_mask'], input_dict['residue_index'], rng_keys,
                repaint_dict['repaint_context'], repaint_dict['repaint_mask'])

ret = jax.tree_util.tree_map(lambda x: np.array(x.reshape(BATCH_SIZE, *x.shape[2:])).tolist(), ret)
with open('results/contextual_scaffolding/result.pkl', 'wb') as f:
    pkl.dump(ret, f)
    
ret_ = flatten_list_of_dicts([ret])
with open('results/contextual_scaffolding/result_flatten.pkl', 'wb') as f:
    pkl.dump(ret_, f)

100%|██████████| 500/500 [16:47<00:00,  2.01s/it]


### Decode Structures

In [49]:
working_dir = '../'

os.system(f'''export PYTHONPATH={working_dir}/PROTOKEN
          python {working_dir}/PROTOKEN/scripts/decode_structure.py\
                --decoder_config {working_dir}/PROTOKEN/config/decoder.yaml\
                --vq_config {working_dir}/PROTOKEN/config/vq.yaml\
                --input_path results/contextual_scaffolding/result_flatten.pkl\
                --output_dir results/contextual_scaffolding/pdb\
                --load_ckpt_path {working_dir}/ckpts/protoken_params_100000.pkl\
                --padding_len {NRES}''')

2025-01-14 10:39:30.279075: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


decoding structures...


100%|██████████| 1/1 [01:05<00:00, 65.75s/it]


saving structures to .pdbs...
done


0