In [1]:
import sys
sys.path.append('../../carbonmatrix_public')

In [2]:
import os
import argparse
import logging
from logging.handlers import QueueHandler, QueueListener
import resource
import json
from random import shuffle

import ml_collections
import numpy as np
import torch
import torch.multiprocessing as mp
from einops import rearrange
from collections import OrderedDict
from carbondesign.data.pdbio import save_pdb

from carbondesign.model.carbondesign import CarbonDesign
from carbondesign.testloader import dataset_test
from carbondesign.common.utils import index_to_str_seq

In [3]:
def worker_device(rank, args):
    if args.device == 'gpu':
        return torch.device(f'cuda:{args.gpu_idx[rank]}')
    else:
        return torch.device('cpu')

def worker_load(rank, args):
    def _feats_gen(feats, device):
        for fn, opts in feats:
            if 'device' in opts:
                opts['device'] = device
            yield fn, opts
    
    device = worker_device(rank, args)
    
    # model
    with open(args.model_config, 'r', encoding='utf-8') as f:
        config = json.loads(f.read())
        config = ml_collections.ConfigDict(config)

    checkpoint = torch.load(args.model, map_location=device)
    model = CarbonDesign(config=config.model)
    model_state_dict = OrderedDict()
    for k, v in checkpoint['model'].items():
        if k.startswith('module.'):
            k = k[len('module.'):]
        model_state_dict[k] = v
    model.load_state_dict(model_state_dict,strict = False)
    
    with open(args.model_features, 'r', encoding='utf-8') as f:
        feats = json.loads(f.read())

        for i in range(len(feats)):
            feat_name, feat_args = feats[i]
            if 'device' in feat_args and feat_args['device'] == '%(device)s':
                feat_args['device'] = device

    model = model.to(device=device)
    model.eval()
    return list(_feats_gen(feats, device)), model


def softmax(x):
    x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return x / np.sum(x, axis=-1, keepdims=True)

def one_hot(a, num_classes=21):
  return np.squeeze(np.eye(num_classes)[a.reshape(-1)])

def mrf_score(label, site_repr, pair_repr, site_mask, pair_mask):
    N, C = site_repr.shape
    pair_repr = np.reshape(pair_repr, [N, N, C, C])
    label = one_hot(label)
    
    score = np.sum(site_repr * label * site_mask[:,None]) + np.sum(label[None,:,None,:] * label[:,None,:,None] * pair_repr * pair_mask[:,:,None,None]) / 2.0
    return score

def temp_softmax(z, T):
    exp_z = np.exp(z/T)
    sum_exp_z = np.sum(exp_z)
    return exp_z / sum_exp_z

def infer_mrf(init_label, site_repr, pair_repr, site_mask, pair_mask, T):
    N, C = site_repr.shape
    pair_repr = np.reshape(pair_repr, [N, N, C, C])
    deg = np.sum(pair_mask, axis=-1)

    prev_label = np.array(init_label)
    pos = np.argsort(deg)
    
    for cycle in range(5):
        #shuffle(pos)
        updated_count = 0
        for i in pos:
            if not site_mask[i]:
                continue
            obj = -10000.0
            obj_c = -1
            t_lis = np.zeros(C-1)
            for c in range(C - 1):
                t = site_repr[i, c]
                for k in range(N):
                    if pair_mask[i, k]:
                        t += pair_repr[i, k, c, prev_label[k]]
                        t_lis[c] = t
            if T >= 0.1:
                probs = temp_softmax(t_lis, T)
                obj_c = np.argmax(np.random.multinomial(1, probs))
            else:
                obj_c = np.argmax(t_lis)
            if prev_label[i] != obj_c:
                prev_label[i] = obj_c
                updated_count += 1
        if updated_count == 0:
            break
    return prev_label

def replace_with_x(seq_str, mask):
    ret_str = []
    for a, m in zip(seq_str, mask):
        if m:
            ret_str.append(a)
        else:
            ret_str.append('X')
    return ''.join(ret_str)

def evaluate_mrf_one(name, gt_str_seq, site_repr, pair_repr, site_mask, pair_mask, args, sidechain_all):
    site_prob = softmax(site_repr)
    label = np.argmax(site_prob, axis=-1)
    
    pred_str_seq1 = index_to_str_seq(label)
    pred_str_seq1 = replace_with_x(pred_str_seq1, site_mask)    
    total_len = len(gt_str_seq)
    valid_len = np.sum(site_mask)
    T = args.temperature

    label = infer_mrf(label, site_repr, pair_repr, site_mask, pair_mask, T)
    pred_str_seq2 = index_to_str_seq(label)
    
    label[~site_mask] = 20
    L = len(sidechain_all[0])
    sidechain_single = []
    for i in range(L):
        sidechain_single.append(sidechain_all[label[0],i,:])
    pred_str_seq2 = replace_with_x(pred_str_seq2, site_mask)
    sidechain_single = np.array(sidechain_single)
    sidechain_all = np.array(sidechain_all)
    print(f'>{name}\tCarbonDesign\n{pred_str_seq2}\n')
    with open(os.path.join(args.output_dir, name + '.fasta'), 'w') as fw:
        fw.write(f'>{name}\tCarbonDesign\n{pred_str_seq2}\n')

    if args.save_mrf:
        np.savez(os.path.join(args.output_dir, name + '.mrf.npz'), 
                site_repr=site_repr,
                pair_repr=pair_repr,
                site_mask=site_mask,
                pair_mask=pair_mask) 
    if args.save_sidechain:
        pdb_file = os.path.join(args.output_dir, f'{name}_sidechain.pdb')
        sequence = []
        sequence.append(pred_str_seq2)
        plddt = None
        chains = 'A'
        save_pdb(sequence, sidechain_single, pdb_file, chains)

def evaluate_mrf_batch(batch, ret, args):
    
    gt_str_seq, site_repr, pair_repr, sidechain_all = batch['str_seq'], ret['heads']['seqhead']['logits'].to('cpu').numpy(), ret['heads']['pairhead']['logits'].to('cpu').numpy(), ret['heads']['folding']['sidechains']['atom_pos'].cpu().numpy()
    site_mask, pair_mask = batch['mask'].to('cpu').numpy(), batch['pair_mask'].to('cpu').numpy()
    names = batch['name']
    L = sidechain_all.shape[1]
    sidechain_all=sidechain_all.reshape(-1,21,L,14,3)
    for name, site_mask_, pair_mask_, gt_str_seq_, pair_repr_, site_repr_, sidechain_all_ in zip(names, site_mask, pair_mask, gt_str_seq, pair_repr, site_repr, sidechain_all):
        evaluate_mrf_one(name, gt_str_seq_, site_repr_, pair_repr_, site_mask_, pair_mask_, args, sidechain_all_)

def evaluate(rank, log_queue, args):
    #worker_setup(rank, log_queue, args)

    feats, model = worker_load(rank, args)
    # logging.info('feats: %s', feats)

    
    device = worker_device(rank, args)
    name_idx = []
    with open(args.name_idx) as f:
        name_idx = [x.strip() for x in f]

    test_loader = dataset_test.load(
        data_dir=args.data_dir,
        name_idx=name_idx,
        feats=feats,
        is_cluster_idx=False,
        rank=None,
        world_size=1,
        batch_size=args.batch_size)
    
    for i, batch in enumerate(test_loader):
        if batch is None:
            continue
        try:
            logging.debug('name: %s', ','.join(batch['name']))
            logging.debug('len : %s', batch['seq'].shape[1])
            logging.debug('seq : %s', batch['str_seq'][0])
            if batch['seq'].shape[1] > 600:
                continue
            with torch.no_grad():
                ret = model(batch=batch, compute_loss=True)

            #print(ret['heads']['folding']['final_atom14_positions'].shape)
            sidechain_21 = ret['heads']['folding']['sidechains']['atom_pos'].shape

            evaluate_mrf_batch(batch, ret, args)
        except:
            logging.error('fails in predicting', batch['name'])


In [4]:
args_dict ={
    'save_mrf': False,
    'save_sidechain': False,
    'model': '/data/users/kgeorge/workspace/CarbonDesign/params/carbondesign_default.ckpt',
    'model_features': '/nethome/kgeorge/workspace/DomainPrediction/carbonmatrix_public/config/config_data_mrf2.json',
    'model_config': '/nethome/kgeorge/workspace/DomainPrediction/carbonmatrix_public/config/config_model_mrf_pair_enable_esm_sc.json',
    'name_idx': '/nethome/kgeorge/workspace/DomainPrediction/carbonmatrix_public/data/pdbs/name.idx', # pdb file names
    'data_dir': '/nethome/kgeorge/workspace/DomainPrediction/carbonmatrix_public/data/pdbs', # pdb dir
    'output_dir': '/nethome/kgeorge/workspace/DomainPrediction/src/CarbonDesign', # results
    'temperature': 0.01,
    'batch_size': 3,
    'verbose': True,
    'gpu_idx': [0],
    'map_location': None,
    'device': 'cpu',
    'ipc_file': 'test.ipc'
}

In [5]:
args = argparse.Namespace(**args_dict)

In [6]:
mp.set_start_method('spawn', force=True)
os.makedirs(os.path.abspath(args.output_dir), exist_ok=True)

# check file name in basename
__file__ = 'run'
handlers = [
    logging.StreamHandler(),
    logging.FileHandler(
        os.path.join(
            args.output_dir,
            f'{os.path.splitext(os.path.basename(__file__))[0]}.log'))]

def handler_apply(h, f, *arg):
    f(*arg)
    return h

level = logging.DEBUG if args.verbose else logging.INFO
handlers = list(map(lambda x: handler_apply(
        x, x.setLevel, level), handlers))
fmt = '%(asctime)-15s [%(levelname)s] (%(filename)s:%(lineno)d) %(message)s'
handlers = list(map(lambda x: handler_apply(
        x, x.setFormatter, logging.Formatter(fmt)), handlers))

logging.basicConfig(
        format=fmt,
        level=level,
        handlers=handlers)

log_queue = mp.Queue(-1)
listener = QueueListener(log_queue, *handlers, respect_handler_level=True)
listener.start()

# evaluate(0, log_queue, args)

In [7]:
## evaluate function
feats, model = worker_load(0, args)
device = worker_device(0, args)
name_idx = []
with open(args.name_idx) as f:
    name_idx = [x.strip() for x in f]

In [8]:
name_idx

['T1187']

In [9]:
test_loader = dataset_test.load(
        data_dir=args.data_dir,
        name_idx=name_idx,
        feats=feats,
        is_cluster_idx=False,
        rank=None,
        world_size=1,
        batch_size=args.batch_size)

2024-06-28 15:47:07,683 [INFO] (dataset_test.py:86) dataset size= 1 max_seq_len= None reduce_num= None is_cluster_idx= False


In [10]:
for batch in test_loader:
    break

2024-06-28 15:47:07,689 [INFO] (dataset_test.py:172) processing /nethome/kgeorge/workspace/DomainPrediction/carbonmatrix_public/data/pdbs/T1187.pdb


/nethome/kgeorge/workspace/DomainPrediction/carbonmatrix_public/data/pdbs
T1187


In [11]:
batch.keys()

dict_keys(['name', 'str_seq', 'seq', 'mask', 'aatype_unk_mask', 'atom14_gt_positions', 'atom14_gt_exists', 'chain_id', 'geo_global', 'data_type', 'atom14_atom_exists', 'atom14_atom_is_ambiguous', 'residx_atom37_to_atom14', 'atom37_atom_exists', 'atom14_alt_gt_positions', 'atom14_alt_gt_exists', 'atom37_gt_positions', 'atom37_gt_exists', 'rigidgroups_gt_frames', 'rigidgroups_gt_exists', 'rigidgroups_group_exists', 'rigidgroups_group_is_ambiguous', 'rigidgroups_alt_gt_frames', 'torsion_angles_sin_cos', 'alt_torsion_angles_sin_cos', 'torsion_angles_mask', 'pseudo_beta', 'pseudo_beta_mask', 'pair_mask', 'dist_one_hot', 'left_gt_calpha3_frame_positions', 'right_gt_calpha3_frame_positions', 'left_gt_calpha3_frame_position_exists', 'right_gt_calpha3_frame_position_exists', 'left_forth_atom_rel_pos'])

In [12]:
batch['mask'][0][:20] = False

In [13]:
batch['mask']

tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  

In [14]:
batch['str_seq']

('GGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG',)

In [15]:
batch['seq']

tensor([[7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]])

In [16]:
### recycle and esm models are in the model pred

In [18]:
for i, batch in enumerate(test_loader):
    if batch is None:
        continue
    try:
        logging.debug('name: %s', ','.join(batch['name']))
        logging.debug('len : %s', batch['seq'].shape[1])
        logging.debug('seq : %s', batch['str_seq'][0])
        if batch['seq'].shape[1] > 600:
            continue
        with torch.no_grad():
            batch['mask'][0][:20] = False
            ret = model(batch=batch, compute_loss=True)

        # #print(ret['heads']['folding']['final_atom14_positions'].shape)
        # sidechain_21 = ret['heads']['folding']['sidechains']['atom_pos'].shape

        # evaluate_mrf_batch(batch, ret, args)
    except:
        logging.error('fails in predicting', batch['name'])

    break

2024-06-28 15:48:07,034 [INFO] (dataset_test.py:172) processing /nethome/kgeorge/workspace/DomainPrediction/carbonmatrix_public/data/pdbs/T1187.pdb
2024-06-28 15:48:07,122 [DEBUG] (1993038441.py:5) name: T1187
2024-06-28 15:48:07,123 [DEBUG] (1993038441.py:6) len : 164
2024-06-28 15:48:07,123 [DEBUG] (1993038441.py:7) seq : GGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGGG


/nethome/kgeorge/workspace/DomainPrediction/carbonmatrix_public/data/pdbs
T1187
dict_keys(['name', 'str_seq', 'seq', 'mask', 'aatype_unk_mask', 'atom14_gt_positions', 'atom14_gt_exists', 'chain_id', 'geo_global', 'data_type', 'atom14_atom_exists', 'atom14_atom_is_ambiguous', 'residx_atom37_to_atom14', 'atom37_atom_exists', 'atom14_alt_gt_positions', 'atom14_alt_gt_exists', 'atom37_gt_positions', 'atom37_gt_exists', 'rigidgroups_gt_frames', 'rigidgroups_gt_exists', 'rigidgroups_group_exists', 'rigidgroups_group_is_ambiguous', 'rigidgroups_alt_gt_frames', 'torsion_angles_sin_cos', 'alt_torsion_angles_sin_cos', 'torsion_angles_mask', 'pseudo_beta', 'pseudo_beta_mask', 'pair_mask', 'dist_one_hot', 'left_gt_calpha3_frame_positions', 'right_gt_calpha3_frame_positions', 'left_gt_calpha3_frame_position_exists', 'right_gt_calpha3_frame_position_exists', 'left_forth_atom_rel_pos', 'prev_seq', 'prev_pair'])
tensor([[False, False, False, False, False, False, False, False, False, False,
         Fa

In [19]:
ret.keys()

dict_keys(['representations', 'heads'])

In [20]:
ret['representations'].keys()

dict_keys(['pair', 'seq'])

In [21]:
ret['representations']['seq']

tensor([[[  0.2355,   2.2536,   0.3902,  ...,  -0.2701,  -3.4855,   0.8383],
         [  7.1885,  -2.7726,  -1.9866,  ...,   2.0849,  -0.0588,  -0.7167],
         [  0.1565,  -0.9102,  -1.9622,  ...,  -2.6089,   2.9208,  -0.8299],
         ...,
         [  5.4769,  -6.0417,   5.4953,  ...,  -1.0085,  -2.4838,  -2.6021],
         [ -2.6412,  10.9062,   9.0635,  ...,  -2.0565,  -9.3781,  15.2025],
         [  7.9768, -13.4142, -14.7368,  ...,  -7.0520,   7.8579, -32.3924]]])

In [22]:
ret['representations']['pair'].shape, ret['representations']['seq'].shape

(torch.Size([1, 164, 164, 128]), torch.Size([1, 164, 384]))

In [23]:
ret['heads'].keys()

dict_keys(['folding', 'seqhead', 'pairhead'])

In [24]:
print(ret['heads']['folding'].keys())
print(ret['heads']['folding']['sidechains'].keys())
print(ret['heads']['folding']['sidechains']['atom_pos'].shape)
print(ret['heads']['folding']['representations'].keys())
print(ret['heads']['folding']['final_atom14_positions'].shape)
print(ret['heads']['folding']['final_atom14_positions'].shape)

dict_keys(['sidechains', 'traj', 'representations', 'final_atom14_positions', 'final_atom_positions'])
dict_keys(['angles_sin_cos', 'unnormalized_angles_sin_cos', 'atom_pos'])
torch.Size([21, 164, 14, 3])
dict_keys(['structure_module'])
torch.Size([1, 164, 21, 14, 3])
torch.Size([1, 164, 21, 14, 3])


In [25]:
ret['heads']['seqhead']['logits'].shape

torch.Size([1, 164, 21])

In [26]:
batch.keys(), batch['prev_seq']

(dict_keys(['name', 'str_seq', 'seq', 'mask', 'aatype_unk_mask', 'atom14_gt_positions', 'atom14_gt_exists', 'chain_id', 'geo_global', 'data_type', 'atom14_atom_exists', 'atom14_atom_is_ambiguous', 'residx_atom37_to_atom14', 'atom37_atom_exists', 'atom14_alt_gt_positions', 'atom14_alt_gt_exists', 'atom37_gt_positions', 'atom37_gt_exists', 'rigidgroups_gt_frames', 'rigidgroups_gt_exists', 'rigidgroups_group_exists', 'rigidgroups_group_is_ambiguous', 'rigidgroups_alt_gt_frames', 'torsion_angles_sin_cos', 'alt_torsion_angles_sin_cos', 'torsion_angles_mask', 'pseudo_beta', 'pseudo_beta_mask', 'pair_mask', 'dist_one_hot', 'left_gt_calpha3_frame_positions', 'right_gt_calpha3_frame_positions', 'left_gt_calpha3_frame_position_exists', 'right_gt_calpha3_frame_position_exists', 'left_forth_atom_rel_pos', 'prev_seq', 'prev_pair', 'is_recycling', 'representations', 'heads']),
 tensor([[[  0.0447,   2.2610,   0.8008,  ...,   0.1857,  -3.3827,   0.7974],
          [  7.4329,  -3.1642,  -1.3361,  ...,

In [27]:
ret['heads']['seqhead']['logits']

tensor([[[ 1.7711,  0.1862, -0.1315,  ..., -0.6690,  0.2133, -0.9742],
         [ 0.8016, -0.0561, -0.1382,  ..., -0.4754, -0.3067, -0.8885],
         [ 0.7670,  0.5164, -0.0420,  ..., -0.3668,  0.2781, -1.0991],
         ...,
         [-0.7642,  3.7027, -0.3607,  ..., -1.0223,  0.4284, -0.7228],
         [ 1.0517, -0.0633, -0.5169,  ..., -0.2714,  0.1190, -0.5592],
         [ 0.0606,  0.1233, -0.5964,  ..., -0.6465,  2.3549, -0.8133]]])

In [28]:
ret['heads']['seqhead']['logits'][0]

tensor([[ 1.7711,  0.1862, -0.1315,  ..., -0.6690,  0.2133, -0.9742],
        [ 0.8016, -0.0561, -0.1382,  ..., -0.4754, -0.3067, -0.8885],
        [ 0.7670,  0.5164, -0.0420,  ..., -0.3668,  0.2781, -1.0991],
        ...,
        [-0.7642,  3.7027, -0.3607,  ..., -1.0223,  0.4284, -0.7228],
        [ 1.0517, -0.0633, -0.5169,  ..., -0.2714,  0.1190, -0.5592],
        [ 0.0606,  0.1233, -0.5964,  ..., -0.6465,  2.3549, -0.8133]])

In [29]:
ret['heads']['seqhead']['logits'][1]

IndexError: index 1 is out of bounds for dimension 0 with size 1

In [30]:
(ret['heads']['seqhead']['logits'][0] == ret['heads']['seqhead']['logits'][1])

IndexError: index 1 is out of bounds for dimension 0 with size 1

In [31]:
ret['heads']['pairhead']['logits'].shape

torch.Size([1, 164, 164, 441])

In [34]:
## how is pairmask calulated
gt_str_seq, site_repr, pair_repr, sidechain_all = batch['str_seq'], ret['heads']['seqhead']['logits'].to('cpu').numpy(), ret['heads']['pairhead']['logits'].to('cpu').numpy(), ret['heads']['folding']['sidechains']['atom_pos'].cpu().numpy()
site_mask, pair_mask = batch['mask'].to('cpu').numpy().copy(), batch['pair_mask'].to('cpu').numpy().copy()

# site_mask[0][:20] = False

names = batch['name']
L = sidechain_all.shape[1]
sidechain_all=sidechain_all.reshape(-1,21,L,14,3)
for (name, site_mask_, pair_mask_, 
     gt_str_seq_, pair_repr_, site_repr_, 
     sidechain_all_) in zip(names, site_mask, pair_mask, 
                            gt_str_seq, pair_repr, site_repr, 
                            sidechain_all):
    
    site_prob = softmax(site_repr_)
    print(site_repr_.shape)
    label = np.argmax(site_prob, axis=-1)

    print(label, np.sum(label))

    pred_str_seq1 = index_to_str_seq(label)
    pred_str_seq1 = replace_with_x(pred_str_seq1, site_mask_)    
    total_len = len(gt_str_seq_)
    valid_len = np.sum(site_mask_)
    T = args.temperature
    print(pred_str_seq1)

    print(np.argsort(np.sum(pair_mask_, axis=-1)))
    ## this chnage labels slightly
    label = infer_mrf(label, site_repr_, pair_repr_, site_mask_, pair_mask_, T)
    pred_str_seq2 = index_to_str_seq(label)
    label[~site_mask_] = 20
    L = len(sidechain_all_[0])
    sidechain_single = []
    for i in range(L):
        sidechain_single.append(sidechain_all_[label[0],i,:])
    pred_str_seq2 = replace_with_x(pred_str_seq2, site_mask_)
    print(pred_str_seq2)

(164, 21)
[ 0  7  0 15  7 14  0  0  7  0  7  0 15  7  7  7  7  0  0  7 17  1 17 10
 16 19  6 14 13 19  6 19  0  6 10 10  2 19  4 17 10  6  9  1  7 11  9  2
 16  1 10 10 15 14  2 16 16 18  0 19 18 10 19 13 11 10 16  6 11 14 18  2
 10  6 16 19 10  0 16 19  1 13 19 15  3 19 14 14  7 15 14 15  2  7  1 16
 19 18 10 15  1 11 11 11 18 14  3  3  6  7 19 13 14  1  6  1  6  3  7 17
 12  6  9  6 10  7  6 13 13 19  6  3  7  3  6  7  6 19  6 12 15 10 15  6
 10  3 14  7  5 17 11 15  7 10  9 19  5  7 13  6  9  1 14 19] 1600
XXXXXXXXXXXXXXXXXXXXWRWLTVEPFVEVAELLNVCWLEIRGKINTRLLSPNTTYAVYLVFKLTEKPYNLETVLATVRFVSDVPPGSPSNGRTVYLSRKKKYPDDEGVFPREREDGWMEIELGEFFVEDGDEGEVEMSLSELDPGQWKSGLIVQGFEIRPV
[106  88 104 105  28  87 116  89  27  90  67 103 107 147  26  86 108 146
  83 118 117  25   0  18  29 115  49  24  54  17  16 148 131  85 145  84
 130 102 114  68 133  92 163 101  15  50  66  53   8 149  23  93  82 109
 132 134  30  55   1   2  52 135  56  69 100 162  94  47  38 144  13   3
  95 136  70  19  22 113 161

In [None]:
site_mask

In [None]:
pred_str_seq1

In [None]:
pred_str_seq2

In [None]:
label

In [None]:
batch.keys()