# ESMFold Structure Prediction

This notebook demonstrates how to:
1. Get ESM embeddings from a sequence (using existing InterPLM code)
2. Predict 3D structure using ESMFold
3. Visualize the predicted structure

In [1]:
import torch
from transformers import AutoTokenizer, EsmForProteinFolding
from interplm.embedders.esm import ESM
import py3Dmol
import numpy as np
from pathlib import Path



  from .autonotebook import tqdm as notebook_tqdm


## 1. Get ESM Embeddings (Your Current Approach)

In [2]:
# Example protein sequence
sequence = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"

# Get ESM embeddings using existing InterPLM code
esm_embedder = ESM(model_name="facebook/esm2_t6_8M_UR50D")
embeddings = esm_embedder.embed_single_sequence(sequence, layer=4)

print(f"Sequence length: {len(sequence)}")
print(f"Embedding shape: {embeddings.shape}")  # (seq_len, 320)

Sequence length: 65
Embedding shape: (65, 320)


## 2. Predict 3D Structure with ESMFold

In [3]:
# Load ESMFold model (this will download ~15GB on first run)
tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", torch_dtype=torch.float32)
model.float()
# Move to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()
print(f"Model dtype: {next(model.parameters()).dtype}")  # Should say torch.float32

print(f"ESMFold loaded on {device}")

  return self.fget.__get__(instance, owner)()
Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model dtype: torch.float32
ESMFold loaded on cpu


In [4]:
# Tokenize and predict structure
tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)
tokenized_input = {k: v.to(device) for k, v in tokenized_input.items()}

with torch.no_grad():
    output = model(**tokenized_input)

# Extract predicted structure
# output contains:
# - positions: 3D coordinates of atoms
# - plddt: per-residue confidence scores (0-100)
# - ptm: predicted template modeling score

positions = output["positions"][-1, 0].cpu().numpy()  # (seq_len, 14, 3) - CA, CB, etc.
plddt = output["plddt"][0].cpu().numpy()  # (seq_len,) confidence scores

print(f"Predicted structure shape: {positions.shape}")
print(f"Mean pLDDT confidence: {plddt.mean():.2f}")

Predicted structure shape: (65, 14, 3)
Mean pLDDT confidence: 0.80


In [6]:
with torch.no_grad():
    output = model.infer_pdb(sequence)

with open("result.pdb", "w") as f:
    f.write(output)

## 6. Get Both Embeddings and Structure in One Pipeline

In [4]:
def process_sequence_full_pipeline(sequence, esm_layer=4):
    """Complete pipeline: ESM embeddings + ESMFold structure prediction."""
    
    # 1. Get ESM embeddings
    print("Getting ESM embeddings...")
    esm_embedder = ESM(model_name="facebook/esm2_t6_8M_UR50D")
    embeddings = esm_embedder.embed_single_sequence(sequence, layer=esm_layer)
    
    # 2. Predict structure with ESMFold
    print("Predicting structure with ESMFold...")
    tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
    model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1")
    model.float()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
        
    
    # tokenized_input = tokenizer([sequence], return_tensors="pt", add_special_tokens=False)
    # tokenized_input = {k: v.to(device) for k, v in tokenized_input.items()}
    
    with torch.no_grad():
        output = model.infer_pdb(sequence)
    
    # 3. Extract results
    # pdb_string = convert_outputs_to_pdb(output, sequence)
    # plddt = output["plddt"][0].cpu().numpy()
    
    return {
        'sequence': sequence,
        'embeddings': embeddings,
        'pdb_string': output,
        # 'plddt': plddt,
        # 'mean_plddt': plddt.mean()
    }

# Test it
results = process_sequence_full_pipeline(sequence)
print(f"\nResults:")
print(f"  Sequence length: {len(results['sequence'])}")
print(f"  Embeddings shape: {results['embeddings'].shape}")
# print(f"  Mean confidence (pLDDT): {results['mean_plddt']:.2f}")

Getting ESM embeddings...
Predicting structure with ESMFold...


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



Results:
  Sequence length: 65
  Embeddings shape: (65, 320)


In [6]:
results['pdb_string']

'PARENT N/A\nATOM      1  N   MET A   1       3.858  -6.128 -16.819  1.00  0.57           N  \nATOM      2  CA  MET A   1       3.590  -6.533 -15.442  1.00  0.60           C  \nATOM      3  C   MET A   1       4.452  -5.742 -14.465  1.00  0.59           C  \nATOM      4  CB  MET A   1       3.839  -8.032 -15.263  1.00  0.51           C  \nATOM      5  O   MET A   1       3.959  -5.263 -13.442  1.00  0.57           O  \nATOM      6  CG  MET A   1       2.757  -8.743 -14.467  1.00  0.48           C  \nATOM      7  SD  MET A   1       2.946 -10.568 -14.496  1.00  0.54           S  \nATOM      8  CE  MET A   1       4.247 -10.774 -13.248  1.00  0.44           C  \nATOM      9  N   LYS A   2       5.804  -5.698 -14.742  1.00  0.75           N  \nATOM     10  CA  LYS A   2       6.714  -4.947 -13.882  1.00  0.77           C  \nATOM     11  C   LYS A   2       6.332  -3.471 -13.834  1.00  0.78           C  \nATOM     12  CB  LYS A   2       8.157  -5.100 -14.363  1.00  0.69           C  \nATO

## 7. Batch Processing Multiple Sequences

In [None]:
# Example: Process multiple sequences
sequences = [
    "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG",
    "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"
]

for i, seq in enumerate(sequences):
    print(f"\nProcessing sequence {i+1}...")
    result = process_sequence_full_pipeline(seq)
    
    # Save structure
    pdb_file = output_dir / f"structure_{i+1}.pdb"
    with open(pdb_file, 'w') as f:
        f.write(result['pdb_string'])
    print(f"  Saved to: {pdb_file}")
    print(f"  Mean pLDDT: {result['mean_plddt']:.2f}")