# Multimerize inputs for antibody chains

In [7]:
import torch
import numpy as np
from esm.utils.structure.protein_chain import ProteinChain  # defined in esm/utils/structure/protein_chain.py  [oai_citation:0‡code2prompt.txt](file-service://file-7M2YDUiM8jAcYqSemot19N)
from esm.tokenization.sequence_tokenizer import EsmSequenceTokenizer

In [8]:
from esm.utils.constants.esm3 import CHAIN_BREAK_STR  # chain break 
CHAIN_BREAK_STR

'|'

In [21]:
protein_chain_H = ProteinChain.from_pdb(path=fpath, chain_id='H')  # loads heavy chain
protein_chain_L = ProteinChain.from_pdb(path=fpath, chain_id='L')  # loads light chain
print(protein_chain_H[0])

ProteinChain(id='1BEY', sequence='Q', chain_id='H', entity_id=1, residue_index=array([1]), insertion_code=array([''], dtype='<U4'), atom37_positions=array([[[ 7.929, 79.104, 42.518],
        [ 6.897, 79.705, 43.426],
        [ 7.205, 79.372, 44.891],
        [ 5.522, 79.118, 43.088],
        [ 7.862, 80.138, 45.612],
        [ 4.342, 79.998, 43.419],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
        [ 4.18 , 81.089, 42.389],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
        [   nan,    nan,    nan],
 

In [13]:
# Extract atom37 representations (assumed to be available as an attribute)
# Each will be a numpy array of shape (num_residues, 37, 3)
atom37_H = protein_chain_H.atom37_positions  # e.g. shape (L_H, 37, 3)
atom37_L = protein_chain_L.atom37_positions  # e.g. shape (L_L, 37, 3)
print(atom37_H.shape, atom37_L.shape)

(210, 37, 3) (214, 37, 3)


----
Dan, we need to define a chain break row and then concat

In [22]:
chain_break_atom37 = np.full((1, 37, 3), np.inf, dtype=np.float32)  # shape (1, 37, 3) defining an inf vector

# Concatenate the heavy and light chain atom37 representations with a chain break in between.
atom37_multimer = np.concatenate([atom37_H, chain_break_atom37, atom37_L], axis=0)

In [57]:
# unit test
print(atom37_multimer.shape)
print(atom37_H.shape[0]+atom37_L.shape[0]+1)

(425, 37, 3)
425


425

In [35]:
atom37_multimer_tensor = torch.tensor(atom37_multimer, dtype=torch.float32, device='cpu')

In [58]:
# Next, for the sequence input, we want to concatenate the sequences from both chains.
# We assume that each ProteinChain object has a 'sequence' attribute (a string).
seq_H = protein_chain_H.sequence  # heavy chain sequence (e.g., "EVQLVESGGGLVQPGGSLRLSCAAS...")
seq_L = protein_chain_L.sequence  # light chain sequence (e.g., "DIQMTQSPSSLSASVGDRVTITCRAS...")
# Insert the chain break marker between the sequences.
multimer_sequence = seq_H + '|' + seq_L  # chain break inserted between sequences

print(len(list(protein_chain_H.sequence))+len(list(protein_chain_L.sequence))+1) # test

425


In [65]:
len(multimer_sequence)

425

In [70]:
print("Multimer sequence:", multimer_sequence)

# Now load the ESM sequence tokenizer and encode the multimer sequence.
seq_tokenizer = EsmSequenceTokenizer()  # ensure this is imported from esm.tokenization.sequence_tokenizer
seq_tokens = seq_tokenizer.encode(multimer_sequence)  # encode the concatenated sequence
print(len(seq_tokens))
print(seq_tokenizer.decode(seq_tokens))
print(seq_tokenizer.get_vocab())
# + 2 extra tokens for the special tokens

Multimer sequence: QVQLQESGPGLVRPSQTLSLTCTVSGFTFTDFYMNWVRQPPGRGLEWIGFIRDKAKGYTTEYNPSVKGRVTMLVDTSKNQFSLRLSSVTAADTAVYYCAREGHTAAPFDYWGQGSLVTVSSASTKGPSVFPLAPAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKV|DIQMTQSPSSLSASVGDRVTITCKASQNIDKYLNWYQQKPGKAPKLLIYNTNNLQTGVPSRFSGSGSGTDFTFTISSLQPEDIATYYCLQHISRPRTFGQGTKVEIKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRGEC
427
<cls> Q V Q L Q E S G P G L V R P S Q T L S L T C T V S G F T F T D F Y M N W V R Q P P G R G L E W I G F I R D K A K G Y T T E Y N P S V K G R V T M L V D T S K N Q F S L R L S S V T A A D T A V Y Y C A R E G H T A A P F D Y W G Q G S L V T V S S A S T K G P S V F P L A P A A L G C L V K D Y F P E P V T V S W N S G A L T S G V H T F P A V L Q S S G L Y S L S S V V T V P S S S L G T Q T Y I C N V N H K P S N T K V D K K V | D I Q M T Q S P S S L S A S V G D R V T I T C K A S Q N I D K Y L N W Y Q Q K P G K A P K L L I Y N T N N L Q T G V P S R F

In [64]:
seq_tokens_tensor = torch.tensor(seq_tokens, dtype=torch.int64).unsqueeze(0).to('cpu')  # shape (1, seq_len)
seq_tokens_tensor.shape

Multimer sequence: QVQLQESGPGLVRPSQTLSLTCTVSGFTFTDFYMNWVRQPPGRGLEWIGFIRDKAKGYTTEYNPSVKGRVTMLVDTSKNQFSLRLSSVTAADTAVYYCAREGHTAAPFDYWGQGSLVTVSSASTKGPSVFPLAPAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKV|DIQMTQSPSSLSASVGDRVTITCKASQNIDKYLNWYQQKPGKAPKLLIYNTNNLQTGVPSRFSGSGSGTDFTFTISSLQPEDIATYYCLQHISRPRTFGQGTKVEIKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRGEC


torch.Size([1, 427])

In [36]:
# Print the shapes of the final representations
print("Atom37 multimer tensor shape:", atom37_multimer_tensor.shape)
print("Sequence tokens tensor shape:", seq_tokens_tensor.shape)

# Dan, the next step would be to feed both the atom37_multimer_tensor and the seq_tokens_tensor into ESM3,
# which will then combine these modalities for further processing.

Multimer sequence: QVQLQESGPGLVRPSQTLSLTCTVSGFTFTDFYMNWVRQPPGRGLEWIGFIRDKAKGYTTEYNPSVKGRVTMLVDTSKNQFSLRLSSVTAADTAVYYCAREGHTAAPFDYWGQGSLVTVSSASTKGPSVFPLAPAALGCLVKDYFPEPVTVSWNSGALTSGVHTFPAVLQSSGLYSLSSVVTVPSSSLGTQTYICNVNHKPSNTKVDKKV|DIQMTQSPSSLSASVGDRVTITCKASQNIDKYLNWYQQKPGKAPKLLIYNTNNLQTGVPSRFSGSGSGTDFTFTISSLQPEDIATYYCLQHISRPRTFGQGTKVEIKRTVAAPSVFIFPPSDEQLKSGTASVVCLLNNFYPREAKVQWKVDNALQSGNSQESVTEQDSKDSTYSLSSTLTLSKADYEKHKVYACEVTHQGLSSPVTKSFNRGEC
Atom37 multimer tensor shape: torch.Size([425, 37, 3])
Sequence tokens tensor shape: torch.Size([1, 427])


In [3]:
type(protein_chain_H)

esm.utils.structure.protein_chain.ProteinChain