In [1]:
# Run once to install the ESM-1b model: https://github.com/facebookresearch/esm
!pip install git+https://github.com/facebookresearch/esm.git

Collecting git+https://github.com/facebookresearch/esm.git
  Cloning https://github.com/facebookresearch/esm.git to /private/var/folders/kp/nj27yjxx3n3dqz4hw_jw1vrh0000gr/T/pip-req-build-y_ubrsbd
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/esm.git /private/var/folders/kp/nj27yjxx3n3dqz4hw_jw1vrh0000gr/T/pip-req-build-y_ubrsbd
  Resolved https://github.com/facebookresearch/esm.git to commit 2b369911bb5b4b0dda914521b9475cad1656b2ac
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hBuilding wheels for collected packages: fair-esm
  Building wheel for fair-esm (pyproject.toml) ... [?25ldone
[?25h  Created wheel for fair-esm: filename=fair_esm-2.0.1-py3-none-any.whl size=105419 sha256=a758f42a21acb23b124250f86cdb0714d301b8ae6ceef2c9d4dd23e8dae7143c
  Stored in directory: /private/var/folders/kp/nj27yjxx3n3dqz4hw_jw1vrh0000g

In [2]:
import scipy.spatial.distance
import torch

In [3]:
# https://www.uniprot.org/uniprotkb/P04637/entry
wt_seq = 'MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD'
print(len(wt_seq), 'residues in sequence')

mutations = [
    # TP53 mutation known to destabilize the protein: https://www.pnas.org/doi/10.1073/pnas.0805326105
    'Y220C',
    # Two ClinVar mutations classified as 'benign': https://www.ncbi.nlm.nih.gov/clinvar/?term=Li-Fraumeni+syndrome
    'E298S',
    'Q354K',
]
print(len(mutations), 'mutations')

393 residues in sequence
3 mutations


In [4]:
with open('sequences.fasta', 'w') as fh:
    print(f'>wt', file=fh)
    print(wt_seq, file=fh)
    for mut in mutations:
        aa_pos = int(mut[1:-1])
        aa_ref = mut[0]
        aa_alt = mut[-1]
        #print(aa_pos, aa_ref, aa_alt)
        mut_seq = wt_seq[:aa_pos - 1] + aa_alt + wt_seq[aa_pos:]
        assert wt_seq[aa_pos - 1] == aa_ref
        assert mut_seq[aa_pos - 1] == aa_alt

        print(f'>{mut}', file=fh)
        print(mut_seq, file=fh)

In [5]:
#!python $CONDA_PREFIX/lib/python3.10/site-packages/esm/scripts/extract.py esm1_t6_43M_UR50S sequences.fasta embeddings --include mean
import esm, esm.scripts, esm.scripts.extract
parser = esm.scripts.extract.create_parser()
args = parser.parse_args(['esm1_t6_43M_UR50S', 'sequences.fasta', 'embeddings', '--include', 'mean'])
esm.scripts.extract.run(args)


Read sequences.fasta with 4 sequences
Processing 1 of 1 batches (4 sequences)


In [6]:
# Check shape of arbitrary embedding
wt_emb = torch.load('embeddings/wt.pt')['mean_representations'][6]
Y220C_emb = torch.load('embeddings/Y220C.pt')['mean_representations'][6]
E298S_emb = torch.load('embeddings/E298S.pt')['mean_representations'][6]
Q354K_emb = torch.load('embeddings/Q354K.pt')['mean_representations'][6]

In [7]:
print(scipy.spatial.distance.cosine(Y220C_emb, wt_emb)) # pathogenic
print(scipy.spatial.distance.cosine(E298S_emb, wt_emb)) # benign
print(scipy.spatial.distance.cosine(Q354K_emb, wt_emb)) # benign

0.0001891483466651689
4.11470402561509e-05
3.777076966193782e-05
