# Lab 2 (inference): Generate protein embeddings with Amplify on Sagemaker and filter by quality

This notebook will guide you through running inference against the Amplify model on Sagemaker to enerate protein embeddings, and using these embeddings to filter low-quality sequences generated by Progen2

## Step 1: Setup and Configuration

First, let's get our AWS account information and set up variables we'll use throughout the notebook.

In [4]:
import os
import boto3
import pandas as pd
import numpy as np
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from sagemaker.pytorch import PyTorchPredictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

##########################################################

# Get AWS account information
sts_client = boto3.client('sts')
account_id = sts_client.get_caller_identity()['Account']
region = boto3.Session().region_name

# Define S3 bucket and folder names
S3_BUCKET = f'workshop-data-{account_id}'
LAB1_FOLDER = 'lab1-progen'
LAB2_FOLDER = 'lab2-amplify'
LAB3_FOLDER = 'lab3-esmfold'

print(f"Account ID: {account_id}")
print(f"Region: {region}")
print(f"S3 Bucket: {S3_BUCKET}")

##########################################################

Account ID: 973884802842
Region: us-east-1
S3 Bucket: workshop-data-973884802842


## Step 2: Generate protein embeddings using the Amplify model

We will generate protein embeddings for the reference protein sequence and all protein sequences generate by the Progen2 model at the prvious step

### Step 2.1:  Load the reference sequence and sequences generated by the Progen2 model

In [5]:
# Load reference protein sequnce
record = next(SeqIO.parse('./data/reference.fasta', "fasta"))
ref_sequence = str(record.seq)
ref_embeddings = None


# Load protein sequences generated by Progen2 model
gen_sequences = []

for file in os.listdir(f'./data/{LAB1_FOLDER}'):
    if file.endswith(".fasta"):
        file_path = os.path.join(f'./data/{LAB1_FOLDER}', file)    
        for record in SeqIO.parse(file_path, "fasta"):
            gen_sequences.append({
                'prompt_id': record.id,
                'sequence': str(record.seq),
                'description': record.description,
                'embeddings': None,
                'distance': None,
                'distnace_type': None
            })


print(f'Reference sequence: {ref_sequence}')
print()
print('Generated sequences:')
gen_sequences

Reference sequence: 1MEVVIVTGMSGAGKTVALQSLEDLGYFCVDNLPPQLIPKFVELAAGKKGRKIAVALDVRDGVELEGLPEILEQLQSSGYSYQVLFLDASDEALVRRYKE

Generated sequences:


[{'prompt_id': 'prompt-001',
  'sequence': 'MEVVIVTGMSGAGKSTAVKALEDLGYYCVDNLPPALIPKFVELMEKSGKKYSRVALVMDIRGGEFFEDLKEALDELKKKGIDYKILFLDADDETLVKRYK',
  'description': 'prompt-001 max_len:100,temp:0.001,top_k:50,top_p:0.9,prompt:MEVVIVTGMSGAGK',
  'embeddings': None,
  'distance': None,
  'distnace_type': None},
 {'prompt_id': 'prompt-002',
  'sequence': 'MEVVIVTGMSGAGKTTAVQALEDLGYYCVDNLPPRLLVRFVELAASASETLTRVAVVMDLRGREFFAGIREVLAALEARGVTPQVLFLDASDEVLVKRYS',
  'description': 'prompt-002 max_len:100,temp:0.7,top_k:50,top_p:0.9,prompt:MEVVIVTGMSGAGK',
  'embeddings': None,
  'distance': None,
  'distnace_type': None},
 {'prompt_id': 'prompt-003',
  'sequence': 'MEVVIVTGMSGAGKSTAVKALEDLGYYCVDNLPPALIPKFVELMEKSGKKYSRVALVMDIRGGEFFEDLKEALDELKKKGIDYKILFLDADDETLVKRYK',
  'description': 'prompt-003 max_len:100,temp:0.001,top_k:50,top_p:0.9,prompt:MEVVIVTGMSGAGK',
  'embeddings': None,
  'distance': None,
  'distnace_type': None},
 {'prompt_id': 'prompt-004',
  'sequence': 'MEVVIVTGMSGAGKSTAVKCLERMGYFC

### Step 2.2: Initialize SageMaker predictor

In [6]:
# Define the SageMaker endpoint name for the Amplify model with embeddings
endpoint_name_embeddings = 'amplify-120m-endpoint-embeddings'

# Initialize predictor using the endpoint name
predictor = PyTorchPredictor(
    endpoint_name=endpoint_name_embeddings,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer()
)

### Step 2.3: Generate embeddings using predictor

In [7]:
# Generate embeddings for the reference sequence
output = predictor.predict({
        "sequence": ref_sequence, 
        'mode': 'embeddings'
    })
ref_embeddings = np.array(output['embeddings'])


# Generate embeddings for novel sequnces
for gen_seq in gen_sequences:
    output = predictor.predict({
        "sequence": gen_seq['sequence'], 
        'mode': 'embeddings'
    })
    gen_seq['embeddings'] = np.array(output['embeddings'])
    
    print(f'Generated sequence : embeddings shape = {gen_seq["embeddings"].shape} ')


Generated sequence : embeddings shape = (1, 102, 640) 
Generated sequence : embeddings shape = (1, 102, 640) 
Generated sequence : embeddings shape = (1, 102, 640) 
Generated sequence : embeddings shape = (1, 102, 640) 
Generated sequence : embeddings shape = (1, 102, 640) 
Generated sequence : embeddings shape = (1, 102, 640) 
Generated sequence : embeddings shape = (1, 102, 640) 
Generated sequence : embeddings shape = (1, 102, 640) 
Generated sequence : embeddings shape = (1, 102, 640) 
Generated sequence : embeddings shape = (1, 102, 640) 


In [8]:
# Each generated protein sequence now has assigned embeddings
gen_sequences

[{'prompt_id': 'prompt-001',
  'sequence': 'MEVVIVTGMSGAGKSTAVKALEDLGYYCVDNLPPALIPKFVELMEKSGKKYSRVALVMDIRGGEFFEDLKEALDELKKKGIDYKILFLDADDETLVKRYK',
  'description': 'prompt-001 max_len:100,temp:0.001,top_k:50,top_p:0.9,prompt:MEVVIVTGMSGAGK',
  'embeddings': array([[[  6.49712467, -17.97194481,  -5.13160419, ...,  59.672966  ,
            -6.24426508,  27.74408913],
          [-26.77676392, -25.38581657, -33.85313797, ..., -14.13401031,
           -14.19096184,  73.51075745],
          [  3.88515806,  -3.97606659,   1.72403896, ...,  77.95613861,
           -31.63186264,  34.09099197],
          ...,
          [ -9.06565666,  -6.63508606,   9.69138336, ..., 105.21736908,
            16.14282417,  -1.14048767],
          [  0.11898682, -12.97450447,  12.97942829, ...,  88.52701569,
            29.68675613,  13.03888702],
          [ 11.08183098,  -2.35176849,   2.79319191, ...,  77.63578033,
           -23.2152977 ,   1.28294849]]]),
  'distance': None,
  'distnace_type': None},
 {'promp

## Step 3: Filter out protein sequences using generated embeddings

### Step 3.1: Calculate cosine distance 
Cosine distance is calcualted between embeddings of generated and reference protein sequences

In [9]:
import numpy as np
from scipy.spatial import distance

def cosine_distance(embeddings1, embeddings2, mean_pooling=False):
    if mean_pooling:
        embeddings1 = embeddings1[:, 1:-1, :].mean(axis=1)
        embeddings2 = embeddings2[:, 1:-1, :].mean(axis=1)

    return distance.cosine(embeddings1.ravel(), embeddings2.ravel())

# Calculate cosine distances for each generated sequences
for gen_seq in gen_sequences:
    gen_seq['distance'] = cosine_distance(gen_seq['embeddings'], ref_embeddings, mean_pooling=True)
    gen_seq['distance_type'] = 'cosine' 



### Step 3.2: Sort the generated proteine sequences by cosine distance 

In [10]:
df = pd.DataFrame(gen_sequences)
df.set_index('prompt_id', inplace=True)
df = df.sort_values(by='distance',ascending=True)
df[['distance', 'sequence', 'description']]

Unnamed: 0_level_0,distance,sequence,description
prompt_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
prompt-010,0.002038,MEVVIVTGMSGAGKSTAIKALENLGFFCVDNLPPELLVQFAALMGD...,"prompt-010 max_len:100,temp:0.7,top_k:50,top_p..."
prompt-008,0.002587,MEVVIVTGMSGAGKTTAIQSLELLGYTCVDNLPPALIPRLFAMLEE...,"prompt-008 max_len:100,temp:0.7,top_k:50,top_p..."
prompt-004,0.003616,MEVVIVTGMSGAGKSTAVKCLERMGYFCVDNLPPVLIRELVDLVKQ...,"prompt-004 max_len:100,temp:0.7,top_k:50,top_p..."
prompt-002,0.004852,MEVVIVTGMSGAGKTTAVQALEDLGYYCVDNLPPRLLVRFVELAAS...,"prompt-002 max_len:100,temp:0.7,top_k:50,top_p..."
prompt-003,0.005243,MEVVIVTGMSGAGKSTAVKALEDLGYYCVDNLPPALIPKFVELMEK...,"prompt-003 max_len:100,temp:0.001,top_k:50,top..."
prompt-005,0.005243,MEVVIVTGMSGAGKSTAVKALEDLGYYCVDNLPPALIPKFVELMEK...,"prompt-005 max_len:100,temp:0.001,top_k:50,top..."
prompt-007,0.005243,MEVVIVTGMSGAGKSTAVKALEDLGYYCVDNLPPALIPKFVELMEK...,"prompt-007 max_len:100,temp:0.001,top_k:50,top..."
prompt-001,0.005243,MEVVIVTGMSGAGKSTAVKALEDLGYYCVDNLPPALIPKFVELMEK...,"prompt-001 max_len:100,temp:0.001,top_k:50,top..."
prompt-009,0.005243,MEVVIVTGMSGAGKSTAVKALEDLGYYCVDNLPPALIPKFVELMEK...,"prompt-009 max_len:100,temp:0.001,top_k:50,top..."
prompt-006,0.005319,MEVVIVTGMSGAGKSTAIRCFERLGYYCVDNLPPQLLASMVDMALA...,"prompt-006 max_len:100,temp:0.7,top_k:50,top_p..."


### Step 3.3: Select 5 top sequences for the downstream analysis and save them in a FASTA file

In [11]:
# Build a list of sequnece records
records = []
for prompt_id, row in df.head(5).iterrows():
    record = SeqRecord(
        Seq(row.sequence),
        id=prompt_id,
        description=f'{row.description},distance={row.distance}'
    )
    records.append(record)

# Save the sequences in a  FASTA file
with open(f'./data/{LAB2_FOLDER}/top_sequence_candidates.fasta', 'w') as f:
    SeqIO.write(records, f, "fasta")


### [Optional] Delete the unused endpoint

In [None]:
# # Remove endpoint
# sagemaker = boto3.client('sagemaker')
# try:
#     sagemaker.delete_endpoint(EndpointName=endpoint_name_embeddings)
#     print(f"Successfully deleted endpoint: {endpoint_name_embeddings}")
# except Exception as e:
#     print(f"Error deleting endpoint: {str(e)}")