# Lab 2 (inference): Generate protein embeddings with AMPLIFY on Amazon SageMaker and filter by quality

This notebook will guide you through running inference against the AMPLIFY model on Amazon SageMaker to generate protein embeddings, and using these embeddings to filter low-quality sequences generated by the Progen2 model in Lab1

## Step 1: Setup and Configuration

First, let's get our AWS account information and set up variables we'll use throughout the notebook.

In [None]:
import os
import boto3
import pandas as pd
import numpy as np
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from sagemaker.pytorch import PyTorchPredictor
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer

##########################################################

# Get AWS account information
sts_client = boto3.client('sts')
account_id = sts_client.get_caller_identity()['Account']
region = boto3.Session().region_name

# Define S3 bucket and folder names
S3_BUCKET = f'workshop-data-{account_id}'
LAB1_FOLDER = 'lab1-progen'
LAB2_FOLDER = 'lab2-amplify'
LAB3_FOLDER = 'lab3-esmfold'

print(f"Account ID: {account_id}")
print(f"Region: {region}")
print(f"S3 Bucket: {S3_BUCKET}")

##########################################################

## Step 2: Generate protein embeddings using the Amplify model

We will generate protein embeddings for the reference protein sequence and all protein sequences generated by the Progen2 model in Lab1

### Step 2.1:  Load the reference sequence and sequences generated by the Progen2 model

In [None]:
# Load reference protein sequence
record = next(SeqIO.parse('./data/reference.fasta', "fasta"))
ref_sequence = str(record.seq)
ref_embeddings = None


# Load protein sequences generated by Progen2 model
gen_sequences = []

for file in os.listdir(f'./data/{LAB1_FOLDER}'):
    if file.endswith(".fasta"):
        file_path = os.path.join(f'./data/{LAB1_FOLDER}', file)    
        for record in SeqIO.parse(file_path, "fasta"):
            gen_sequences.append({
                'prompt_id': record.id,
                'sequence': str(record.seq),
                'description': record.description,
                'embeddings': None,
                'distance': None,
                'distance_type': None
            })


print(f'Reference sequence: {ref_sequence}')
print()
print('Generated sequences:')
gen_sequences

### Step 2.2: Initialize Amazon SageMaker predictor

The predictor is a client interface that connects to the deployed AMPLIFY model endpoint, enabling real-time inference calls with JSON serialization for input data and automatic deserialization of model outputs.

In [None]:
# Define the SageMaker endpoint name for the Amplify model with embeddings
endpoint_name_embeddings = 'amplify-120m-endpoint-embeddings'

# Initialize predictor using the endpoint name
predictor = PyTorchPredictor(
    endpoint_name=endpoint_name_embeddings,
    serializer=JSONSerializer(),
    deserializer=JSONDeserializer()
)

### Step 2.3: Generate embeddings using predictor

In [None]:
# Generate embeddings for the reference sequence
output = predictor.predict({
        "sequence": ref_sequence, 
        'mode': 'embeddings'
    })
ref_embeddings = np.array(output['embeddings'])


# Generate embeddings for novel sequences
for gen_seq in gen_sequences:
    output = predictor.predict({
        "sequence": gen_seq['sequence'], 
        'mode': 'embeddings'
    })
    gen_seq['embeddings'] = np.array(output['embeddings'])
    
    print(f'Generated sequence : embeddings shape = {gen_seq["embeddings"].shape} ')


In [None]:
# Each generated protein sequence is now associated with its corresponding embeddings
gen_sequences

## Step 3: Filter out protein sequences using generated embeddings

### Step 3.1: Calculate cosine distance 
Cosine distance is calculated between embeddings of generated and reference protein sequences

In [None]:
import numpy as np
from scipy.spatial import distance

def cosine_distance(embeddings1, embeddings2, mean_pooling=False):
    if mean_pooling:
        embeddings1 = embeddings1[:, 1:-1, :].mean(axis=1)
        embeddings2 = embeddings2[:, 1:-1, :].mean(axis=1)

    return distance.cosine(embeddings1.ravel(), embeddings2.ravel())

# Calculate cosine distances for each generated sequences
for gen_seq in gen_sequences:
    gen_seq['distance'] = cosine_distance(gen_seq['embeddings'], ref_embeddings, mean_pooling=True)
    gen_seq['distance_type'] = 'cosine' 



### Step 3.2: Sort the generated protein sequences by cosine distance

In [None]:
df = pd.DataFrame(gen_sequences)
df.set_index('prompt_id', inplace=True)
df = df.sort_values(by='distance',ascending=True)
df[['distance', 'sequence', 'description']]

### Step 3.3: Select five top sequences for the downstream analysis and save them in a FASTA file

In [None]:
# Build a list of sequence records
records = []
for prompt_id, row in df.head(5).iterrows():
    record = SeqRecord(
        Seq(row.sequence),
        id=prompt_id,
        description=f'{row.description},distance={row.distance}'
    )
    records.append(record)

# Save the sequences in a  FASTA file
with open(f'./data/{LAB2_FOLDER}/top_sequence_candidates.fasta', 'w') as f:
    SeqIO.write(records, f, "fasta")


### Step 3.4: [Optional] Delete the unused endpoint

In [None]:
# Remove endpoint
sagemaker = boto3.client('sagemaker')
try:
    sagemaker.delete_endpoint(EndpointName=endpoint_name_embeddings)
    print(f"Successfully deleted endpoint: {endpoint_name_embeddings}")
except Exception as e:
    print(f"Error deleting endpoint: {str(e)}")