<a href="https://colab.research.google.com/github/hadwin-357/ProteinMPNN_breakdown/blob/main/model_utils_function_5_featurize.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#test model utils function
# featurize
'''
inputs:  batch: containing # protein (seq, chain_seq, cooridinates, info about whether chains are masked)

Return:
X: [Batch, L_max, 4, 3] # L_max (longest seqence of all proteins in the batch, 4: N CA C O, 3: x, y, z coordinates)
S: ground truth, sequence displayed as int for example A as 0, C as 1
mask: mask for padding, padding as 0
lengths: np array of seq_length for protein in the batch
chain_M: [Batch, L_max] 1 for masked 0 for not
residue_idx: [Batch, l] l: length of seqence, encode in hopping way for differnt chain
mask_self: [Batch, L_max, L_max]: residue pairwise info: within chain as 0, inter chain as 1
chain_encoding_all: [Batch, L_max] chain code as 1 1 1 1 1 ...2 2 2... 3 3 3...
'''

In [None]:
def featurize(batch, device):
    alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
    B = len(batch)
    lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) #sum of chain seq lengths
    L_max = max([len(b['seq']) for b in batch]) # longest seq
    X = np.zeros([B, L_max, 4, 3]) # initize with zero
    residue_idx = -100*np.ones([B, L_max], dtype=np.int32) #residue idx with jumps across chains
    chain_M = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted, 0.0 for the bits that are given
    mask_self = np.ones([B, L_max, L_max], dtype=np.int32) #for interface loss calculation - 0.0 for self interaction, 1.0 for other
    chain_encoding_all = np.zeros([B, L_max], dtype=np.int32) #integer encoding for chains 0, 0, 0,...0, 1, 1,..., 1, 2, 2, 2...
    S = np.zeros([B, L_max], dtype=np.int32) #sequence AAs integers
    init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z']
    extra_alphabet = [str(item) for item in list(np.arange(300))]
    chain_letters = init_alphabet + extra_alphabet
    for i, b in enumerate(batch):
        masked_chains = b['masked_list']
        visible_chains = b['visible_list']
        all_chains = masked_chains + visible_chains
        visible_temp_dict = {}
        masked_temp_dict = {}
        for step, letter in enumerate(all_chains):
            chain_seq = b[f'seq_chain_{letter}']
            if letter in visible_chains:
                visible_temp_dict[letter] = chain_seq
            elif letter in masked_chains:
                masked_temp_dict[letter] = chain_seq
        for km, vm in masked_temp_dict.items():
            for kv, vv in visible_temp_dict.items():
                if vm == vv:
                    if kv not in masked_chains:
                        masked_chains.append(kv)
                    if kv in visible_chains:
                        visible_chains.remove(kv)
        all_chains = masked_chains + visible_chains
        random.shuffle(all_chains) #randomly shuffle chain order
        num_chains = b['num_of_chains']
        mask_dict = {}
        x_chain_list = []
        chain_mask_list = []
        chain_seq_list = []
        chain_encoding_list = []
        c = 1
        l0 = 0
        l1 = 0
        for step, letter in enumerate(all_chains):
            if letter in visible_chains:
                chain_seq = b[f'seq_chain_{letter}']
                chain_length = len(chain_seq)
                chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
                chain_mask = np.zeros(chain_length) #0.0 for visible chains
                x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_length,4,3]
                x_chain_list.append(x_chain)
                chain_mask_list.append(chain_mask)
                chain_seq_list.append(chain_seq)
                chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
                l1 += chain_length
                mask_self[i, l0:l1, l0:l1] = np.zeros([chain_length, chain_length])
                residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
                l0 += chain_length
                c+=1
            elif letter in masked_chains:
                chain_seq = b[f'seq_chain_{letter}']
                chain_length = len(chain_seq)
                chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
                chain_mask = np.ones(chain_length) #0.0 for visible chains
                x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
                x_chain_list.append(x_chain)
                chain_mask_list.append(chain_mask)
                chain_seq_list.append(chain_seq)
                chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
                l1 += chain_length
                mask_self[i, l0:l1, l0:l1] = np.zeros([chain_length, chain_length])
                residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
                l0 += chain_length
                c+=1
        x = np.concatenate(x_chain_list,0) #[L, 4, 3]
        all_sequence = "".join(chain_seq_list)
        m = np.concatenate(chain_mask_list,0) #[L,], 1.0 for places that need to be predicted
        chain_encoding = np.concatenate(chain_encoding_list,0)

        l = len(all_sequence)
        x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, ))
        X[i,:,:,:] = x_pad

        m_pad = np.pad(m, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
        chain_M[i,:] = m_pad

        chain_encoding_pad = np.pad(chain_encoding, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
        chain_encoding_all[i,:] = chain_encoding_pad

        # Convert to labels
        indices = np.asarray([alphabet.index(a) for a in all_sequence], dtype=np.int32)
        S[i, :l] = indices

    isnan = np.isnan(X)
    mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32)
    X[isnan] = 0.

    # Conversion
    residue_idx = torch.from_numpy(residue_idx).to(dtype=torch.long,device=device)
    S = torch.from_numpy(S).to(dtype=torch.long,device=device)
    X = torch.from_numpy(X).to(dtype=torch.float32, device=device)
    mask = torch.from_numpy(mask).to(dtype=torch.float32, device=device)
    mask_self = torch.from_numpy(mask_self).to(dtype=torch.float32, device=device)
    chain_M = torch.from_numpy(chain_M).to(dtype=torch.float32, device=device)
    chain_encoding_all = torch.from_numpy(chain_encoding_all).to(dtype=torch.long, device=device)
    return X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all

In [71]:
import numpy as np

# Sample input data
batch = [
    {
        'seq_chain_A': 'MKLVFLVLLVFVQGF',
        'coords_chain_A': {'N_chain_A': np.random.rand(16, 3), 'CA_chain_A': np.random.rand(16, 3), 'C_chain_A': np.random.rand(16, 3), 'O_chain_A': np.random.rand(16, 3)},
        'seq_chain_B': 'MSVKVEEVG',
        'coords_chain_B': {'N_chain_B': np.random.rand(9, 3), 'CA_chain_B': np.random.rand(9, 3), 'C_chain_B': np.random.rand(9, 3), 'O_chain_B': np.random.rand(9, 3)},
        'seq_chain_C': 'ATCGATCGATCGATCG',
        'coords_chain_C': {'N_chain_C': np.random.rand(16, 3), 'CA_chain_C': np.random.rand(16, 3), 'C_chain_C': np.random.rand(16, 3), 'O_chain_C': np.random.rand(16, 3)},
        'masked_list': ['A', 'B'],
        'visible_list': ['C'],
        'num_of_chains': 3,
        'seq': 'MKLVFLVLLVFVQGF'+ 'MSVKVEEVG' + 'ATCGATCGATCGATCG'
    },
      {
        'seq_chain_X': 'ACDEFGHIKLMNPQRSTVWY',
        'coords_chain_X': {'N_chain_X': np.random.rand(20, 3), 'CA_chain_X': np.random.rand(20, 3), 'C_chain_X': np.random.rand(20, 3), 'O_chain_X': np.random.rand(20, 3)},
        'seq_chain_Y': 'ABCDEFGHIJKLM',
        'coords_chain_Y': {'N_chain_Y': np.random.rand(13, 3), 'CA_chain_Y': np.random.rand(13, 3), 'C_chain_Y': np.random.rand(13, 3), 'O_chain_Y': np.random.rand(13, 3)},
        'seq_chain_Z': 'JKLMNOPQRST',
        'coords_chain_Z': {'N_chain_Z': np.random.rand(11, 3), 'CA_chain_Z': np.random.rand(11, 3), 'C_chain_Z': np.random.rand(11, 3), 'O_chain_Z': np.random.rand(11, 3)},
        'masked_list': ['X', 'Y'],
        'visible_list': ['Z'],
        'num_of_chains': 3,
        'seq': 'ACDEFGHIKLMNPQRSTVWY'+'ABCDEFGHIJKLM'+'JKLMNOPQRST'

    }
]


In [33]:
batch[0]['seq']

'MKLVFLVLLVFVQGFMSVKVEEVGATCGATCGATCGATCG'

In [72]:
alphabet = 'ACDEFGHIKLMNPQRSTVWYX'
B = len(batch)
lengths = np.array([len(b['seq']) for b in batch], dtype=np.int32) #sum of chain seq lengths
L_max = max([len(b['seq']) for b in batch]) # longest seq
X = np.zeros([B, L_max, 4, 3]) # initize with zero
residue_idx = -100*np.ones([B, L_max], dtype=np.int32) #residue idx with jumps across chains
chain_M = np.zeros([B, L_max], dtype=np.int32) #1.0 for the bits that need to be predicted, 0.0 for the bits that are given
mask_self = np.ones([B, L_max, L_max], dtype=np.int32) #for interface loss calculation - 0.0 for self interaction, 1.0 for other
chain_encoding_all = np.zeros([B, L_max], dtype=np.int32) #integer encoding for chains 0, 0, 0,...0, 1, 1,..., 1, 2, 2, 2...
S = np.zeros([B, L_max], dtype=np.int32) #sequence AAs integers
init_alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G','H', 'I', 'J','K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T','U', 'V','W','X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g','h', 'i', 'j','k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't','u', 'v','w','x', 'y', 'z']
extra_alphabet = [str(item) for item in list(np.arange(300))]
chain_letters = init_alphabet + extra_alphabet

In [73]:
print(f'X shape: {X.shape}')
print(f'lengths: {lengths}')
print(f'residue_idx:{residue_idx.shape}')
print(f'chain_M:{chain_M.shape}')
print(f'mask_self:{mask_self.shape}')
print(f'chain_encoding_all:{chain_encoding_all.shape}')
print(chain_letters)

X shape: (2, 44, 4, 3)
lengths: [40 44]
residue_idx:(2, 44)
chain_M:(2, 44)
mask_self:(2, 44, 44)
chain_encoding_all:(2, 44)
['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101', '102', '10

In [74]:
import random
for i, b in enumerate(batch):
        masked_chains = b['masked_list']
        visible_chains = b['visible_list']
        all_chains = masked_chains + visible_chains
        visible_temp_dict = {}
        masked_temp_dict = {}
        for step, letter in enumerate(all_chains):
            chain_seq = b[f'seq_chain_{letter}']
            if letter in visible_chains:
                visible_temp_dict[letter] = chain_seq
            elif letter in masked_chains:
                masked_temp_dict[letter] = chain_seq
        for km, vm in masked_temp_dict.items():
            for kv, vv in visible_temp_dict.items():
                if vm == vv:
                    if kv not in masked_chains:
                        masked_chains.append(kv)
                    if kv in visible_chains:
                        visible_chains.remove(kv)
        all_chains = masked_chains + visible_chains
        random.shuffle(all_chains) #randomly shuffle chain order

In [68]:
masked_temp_dict

{'X': 'ACDEFGHIKLMNPQRSTVWY', 'Y': 'ABCDEFGHIJKLM'}

In [69]:
all_chains = masked_chains + visible_chains
random.shuffle(all_chains) #randomly shuffle chain order
print(all_chains)

['Y', 'Z', 'X']


In [75]:
num_chains = b['num_of_chains']
mask_dict = {}
x_chain_list = []
chain_mask_list = []
chain_seq_list = []
chain_encoding_list = []
c = 1  # chain code
l0 = 0
l1 = 0

for step, letter in enumerate(all_chains):
  if letter in visible_chains:
      chain_seq = b[f'seq_chain_{letter}']
      chain_length = len(chain_seq)
      chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
      chain_mask = np.zeros(chain_length) #0.0 for visible chains
      x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_length,4,3]
      x_chain_list.append(x_chain)
      chain_mask_list.append(chain_mask)
      chain_seq_list.append(chain_seq)
      chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
      l1 += chain_length
      mask_self[i, l0:l1, l0:l1] = np.zeros([chain_length, chain_length])
      residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
      l0 += chain_length
      c+=1
  elif letter in masked_chains:
      chain_seq = b[f'seq_chain_{letter}']
      chain_length = len(chain_seq)
      chain_coords = b[f'coords_chain_{letter}'] #this is a dictionary
      chain_mask = np.ones(chain_length) #0.0 for visible chains
      x_chain = np.stack([chain_coords[c] for c in [f'N_chain_{letter}', f'CA_chain_{letter}', f'C_chain_{letter}', f'O_chain_{letter}']], 1) #[chain_lenght,4,3]
      x_chain_list.append(x_chain)
      chain_mask_list.append(chain_mask)
      chain_seq_list.append(chain_seq)
      chain_encoding_list.append(c*np.ones(np.array(chain_mask).shape[0]))
      l1 += chain_length
      mask_self[i, l0:l1, l0:l1] = np.zeros([chain_length, chain_length])
      residue_idx[i, l0:l1] = 100*(c-1)+np.arange(l0, l1)
      l0 += chain_length
      c+=1

In [76]:
len(x_chain_list)
print(x_chain_list[0].shape)
print(x_chain_list[1].shape)
print(x_chain_list[2].shape)

(11, 4, 3)
(13, 4, 3)
(20, 4, 3)


In [78]:
print(chain_mask_list[0])
print(chain_mask_list[1])
print(chain_mask_list[2])

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]


In [79]:
print(chain_seq_list[0])
print(chain_seq_list[1])
print(chain_seq_list[2])

JKLMNOPQRST
ABCDEFGHIJKLM
ACDEFGHIKLMNPQRSTVWY


In [80]:
print(chain_encoding_list[0])
print(chain_encoding_list[1])
print(chain_encoding_list[2])

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.]
[3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3. 3.]


In [83]:
print(mask_self[1,:13,:13])

[[0 0 0 0 0 0 0 0 0 0 0 1 1]
 [0 0 0 0 0 0 0 0 0 0 0 1 1]
 [0 0 0 0 0 0 0 0 0 0 0 1 1]
 [0 0 0 0 0 0 0 0 0 0 0 1 1]
 [0 0 0 0 0 0 0 0 0 0 0 1 1]
 [0 0 0 0 0 0 0 0 0 0 0 1 1]
 [0 0 0 0 0 0 0 0 0 0 0 1 1]
 [0 0 0 0 0 0 0 0 0 0 0 1 1]
 [0 0 0 0 0 0 0 0 0 0 0 1 1]
 [0 0 0 0 0 0 0 0 0 0 0 1 1]
 [0 0 0 0 0 0 0 0 0 0 0 1 1]
 [1 1 1 1 1 1 1 1 1 1 1 0 0]
 [1 1 1 1 1 1 1 1 1 1 1 0 0]]


In [84]:
residue_idx[1]  # for different chain, idx is hopped with 100.

array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10, 111, 112,
       113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 224, 225,
       226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238,
       239, 240, 241, 242, 243], dtype=int32)

In [85]:
#concat list
x = np.concatenate(x_chain_list,0) #[L, 4, 3]
all_sequence = "".join(chain_seq_list)
m = np.concatenate(chain_mask_list,0) #[L,], 1.0 for places that need to be predicted
chain_encoding = np.concatenate(chain_encoding_list,0)

print(x.shape)
print(len(all_sequence))
print(m.shape)
print(chain_encoding.shape)


(44, 4, 3)
44
(44,)
(44,)


In [86]:
l = len(all_sequence)
x_pad = np.pad(x, [[0,L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, ))
X[i,:,:,:] = x_pad  # [bath, L_max, 4, 3]

m_pad = np.pad(m, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
chain_M[i,:] = m_pad #[bath, L_max]

chain_encoding_pad = np.pad(chain_encoding, [[0,L_max-l]], 'constant', constant_values=(0.0, ))
chain_encoding_all[i,:] = chain_encoding_pad

In [99]:
#x_pad

In [90]:
m_pad

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
       1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [91]:
chain_encoding_pad

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 2.,
       2., 2., 2., 2., 2., 2., 2., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
       3., 3., 3., 3., 3., 3., 3., 3., 3., 3.])

In [97]:
all_sequence

'JKLMNOPQRSTABCDEFGHIJKLMACDEFGHIKLMNPQRSTVWY'

In [100]:
isnan = np.isnan(X)
mask = np.isfinite(np.sum(X,(2,3))).astype(np.float32)
X[isnan] = 0.

In [101]:
mask  # here no padding is added as it is the longest sequence, so no zero

array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)