# Lab1 (inference): Generate protein sequences at scale with Progen2 on AWS Batch

This notebook will guide you through defining Progen2 prompts with parameters, submitting Batch jobs, monitoring job status, and finally 
viewing generated sequences.

#### Prerequisites
- Progen2 docker image pushed to ECR
- Batch resources provisioned (Compute Environment, Job Queue, Job Definition)

## Step 1: Setup and Configuration

First, let's get our AWS account information and set up variables we'll use throughout the notebook.

In [None]:
%pip install biopython

In [None]:
import boto3
import time
import datetime

##########################################################

# Get AWS account information
sts_client = boto3.client('sts')
account_id = sts_client.get_caller_identity()['Account']
region = boto3.Session().region_name

# Define S3 bucket and folder names
S3_BUCKET = f'workshop-data-{account_id}'
LAB1_FOLDER = 'lab1-progen'
LAB2_FOLDER = 'lab2-amplify'
LAB3_FOLDER = 'lab3-esmfold'

print(f"Account ID: {account_id}")
print(f"Region: {region}")
print(f"S3 Bucket: {S3_BUCKET}")

##########################################################

# Define model 
model_version = 'progen2-small'
model_id = f'hugohrban/{model_version}' 

# Create batch client
batch_client = boto3.client('batch')

# Define Batch job queue and job definition 
job_queue_name = 'progen2-batch-job-queue'
job_definition_name = 'progen2-job-definition'

## Step 2: Define prompts with parameters to generate protein sequences


The inference parameters control how Progen2 generates protein sequences. Each configuration includes:

**Key Parameters:**
- *`prompt`*: Starting token sequence that serves as the seed for generation - the model continues from this initial sequence to generate longer protein sequences
- *`max_length`*: Target sequence length
- *`temperature`*: Controls randomness in token selection, where low values produce conservative sequences and high values generate more creative, diverse sequences
- *`top_p`*: Considers only tokens whose cumulative probability reaches the specified threshold, adapting to the model's confidence level
- *`top_k`*: Limits selection to a specified number of the most probable tokens at each position

**Strategy:** 10 configurations with different parameters to generate sequences ranging from conservative (similar to training data) to creative (novel sequences)

The parameters file will be stored on S3 for distributed batch processing.

In [None]:
%%writefile data/$LAB1_FOLDER/inference-params.json

{
    "inference-params": [
        {"prompt_id": "prompt-001", "prompt": "MEVVIVTGMSGAGK", "max_length":100, 
        "temperature": 0.001, "top_p":0.9, "top_k":50},

        {"prompt_id": "prompt-002", "prompt": "MEVVIVTGMSGAGK", "max_length":100, 
        "temperature": 0.001, "top_p":0.9, "top_k":50},

        {"prompt_id": "prompt-003", "prompt": "MEVVIVTGMSGAGK", "max_length":100, 
        "temperature": 0.2, "top_p":0.9, "top_k":50},

        {"prompt_id": "prompt-004", "prompt": "MEVVIVTGMSGAGK", "max_length":100, 
        "temperature": 0.2, "top_p":0.9, "top_k":50},

        {"prompt_id": "prompt-005", "prompt": "MEVVIVTGMSGAGK", "max_length":100, 
        "temperature": 0.4, "top_p":0.9, "top_k":50},

        {"prompt_id": "prompt-006", "prompt": "MEVVIVTGMSGAGK", "max_length":100, 
        "temperature": 0.4, "top_p":0.9, "top_k":50},

        {"prompt_id": "prompt-007", "prompt": "MEVVIVTGMSGAGK", "max_length":100, 
        "temperature": 0.7, "top_p":0.9, "top_k":50},

        {"prompt_id": "prompt-008", "prompt": "MEVVIVTGMSGAGK", "max_length":100, 
        "temperature": 0.7, "top_p":0.9, "top_k":50},

        {"prompt_id": "prompt-009", "prompt": "MEVVIVTGMSGAGK", "max_length":100, 
        "temperature": 0.9, "top_p":0.9, "top_k":50},

        {"prompt_id": "prompt-010", "prompt": "MEVVIVTGMSGAGK", "max_length":100, 
        "temperature": 0.9, "top_p":0.9, "top_k":50}
    ]
}


In [None]:
!aws s3 cp data/$LAB1_FOLDER/inference-params.json s3://$S3_BUCKET/$LAB1_FOLDER/inference-params.json

## Step 3: Generate protein sequences with inference parameters

AWS Batch excels at parallel processing, allowing you to scale protein sequence generation by distributing the workload across multiple concurrent jobs.

### Step 3.1: Submit Batch jobs

<pre>
batch_count = 5  # defines how many jobs will be created
batch_size  = 2  # defines how many sequences will be generated by each job
</pre>


This configuration will:
- Create 5 concurrent Batch jobs 
- Generate 2 sequences per job (total of 10 sequences)
- Complete faster than a single job processing all 10 sequences
- Provide better fault tolerance and resource distribution

In [None]:
batch_count = 5
batch_size = 2

s3_input_params_path = f's3://{S3_BUCKET}/{LAB1_FOLDER}/inference-params.json'
s3_output_path = f's3://{S3_BUCKET}/{LAB1_FOLDER}/run-{batch_count}-{batch_size}'


jobs = []
for batchNumber in range(batch_count):

    # Generate unique job name
    timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    job_name = f'progen2-batch-job-{batchNumber}-{timestamp}'

    # Submit the job
    response = batch_client.submit_job(
        jobName=job_name,
        jobQueue=job_queue_name,
        jobDefinition=job_definition_name,
        parameters={
            'hfModelId': model_id,
            's3InputParamsPath': s3_input_params_path,
            'batchId' : f'batch-10{batchNumber}',
            'batchSize': f'{batch_size}',
            'batchNumber': f'{batchNumber}',
            's3OutputPath': s3_output_path
        }
    )
    jobs.append(response)

    job_id = response['jobId']
    print(f"Submitted job: {response['jobName']}")
    print(f"   Job ID: {job_id}")
    print(f"   Job ARN: {response['jobArn']}")

### Step 3.2: Wait for all jobs to complete

In [None]:
def wait_for_jobs_completion(job_ids, check_interval=60):
    """Wait for all jobs to complete (either SUCCEEDED or FAILED)"""
    
    print(f"Waiting for {len(job_ids)} job(s) to complete...")
    print("This may take several minutes depending on the workload.\n")
    
    completed_jobs = set()
    
    while len(completed_jobs) < len(job_ids):
        for job_id in job_ids:
            if job_id in completed_jobs:
                continue
                
            response = batch_client.describe_jobs(jobs=[job_id])
            job = response['jobs'][0]
            
            job_name = job['jobName']
            status = job['status']
            
            if status in ['SUCCEEDED', 'FAILED']:
                if job_id not in completed_jobs:
                    print(f"Job {job_name} completed with status: {status}")
                    completed_jobs.add(job_id)
                    
                    if status == 'FAILED' and 'statusReason' in job:
                        print(f"  Failure reason: {job['statusReason']}")
            else:
                print(f"Job {job_name} is {status}")
        
        if len(completed_jobs) < len(job_ids):
            print(f"[waiting {check_interval} seconds before next check...]")
            time.sleep(check_interval)
            print()
    
    print(f"\n All {len(job_ids)} job(s) have completed!")
    
    # Final status summary
    print("\nFinal Status Summary:")
    for job_id in job_ids:
        response = batch_client.describe_jobs(jobs=[job_id])
        job = response['jobs'][0]
        print(f"  {job['jobName']}: {job['status']}")

# Extract job IDs from submitted jobs
job_ids = [job['jobId'] for job in jobs]

# Wait for all jobs to complete
wait_for_jobs_completion(job_ids)

## Step 4: View generated sequences

### Step 4.1: Download FASTA files with generated sequences from S3

In [None]:
!aws s3 cp $s3_output_path ./data/$LAB1_FOLDER/ --recursive --exclude "*" --include "*.fasta"

### Step 4.2: Read FASTA file(s) and print generated sequences

In [None]:
from Bio import SeqIO
import os

path = f"data/{LAB1_FOLDER}"
for file in os.listdir(path):
    if file.endswith(".fasta"):

        file_path = os.path.join(path, file)    
        for record in SeqIO.parse(file_path, "fasta"):
            print(f"ID: {record.id}")
            print(f"Description: {record.description}")
            print(f"Sequence: {record.seq}")
            print("-" * 40)