# üß¨ ESMC Protein Embedding Generator

Generate protein sequence embeddings using ESM-C models from EvolutionaryScale.

## Prerequisites

**Before running this notebook:**
1. You need a **HuggingFace account** with access to ESM models
2. Get your access token from: https://huggingface.co/settings/tokens
3. Have a `sequences.csv` file ready (output from the FASTA Cleaner notebook)

## How to use:
1. **Run all cells** in order (Runtime ‚Üí Run all)
2. **Enter your HuggingFace token** when prompted
3. **Upload your sequences.csv** file
4. **Configure embedding options** (layers, logits)
5. **Click "Generate Embeddings"** and wait for processing
6. **Download** the resulting embeddings file

---

In [None]:
# ============================================================
# STEP 1: SETUP - Run this cell first!
# ============================================================

print("üîß Setting up environment...\n")

# Check if running in Google Colab
try:
    from google.colab import files as colab_files
    IN_COLAB = True
    print("‚úÖ Running in Google Colab")
except ImportError:
    IN_COLAB = False
    print("‚úÖ Running in local Jupyter environment")

# Install required packages
print("\nüì¶ Installing required packages...")
print("   This may take a few minutes on first run.\n")

!pip install -q esm huggingface_hub ipywidgets pandas torch

# Import libraries
import re
from datetime import datetime

import pandas as pd
import torch

try:
    import ipywidgets as widgets
    from IPython.display import display, HTML, clear_output
except ImportError:
    print("‚ö†Ô∏è ipywidgets not found. Installing...")
    !pip install -q ipywidgets
    import ipywidgets as widgets
    from IPython.display import display, HTML, clear_output

# Check GPU availability
if torch.cuda.is_available():
    DEVICE = "cuda"
    gpu_name = torch.cuda.get_device_name(0)
    print(f"‚úÖ GPU detected: {gpu_name}")
else:
    DEVICE = "cpu"
    print("‚ö†Ô∏è No GPU detected. Running on CPU (will be slower).")

# Model layer counts
MODEL_LAYERS = {"esmc_300m": 36, "esmc_600m": 36}

print("\nüéâ Setup complete! Proceed to Step 2.")

In [None]:
# ============================================================
# STEP 2: HUGGINGFACE LOGIN
# ============================================================

from huggingface_hub import login

# Storage
model = None
login_status = {"success": False}

# Create widgets
token_input = widgets.Password(
    placeholder="Paste your HuggingFace token here",
    description="HF Token:",
    layout=widgets.Layout(width="400px")
)

model_dropdown = widgets.Dropdown(
    options=[("ESMC 600M (recommended)", "esmc_600m"), ("ESMC 300M (faster)", "esmc_300m")],
    value="esmc_600m",
    description="Model:",
    layout=widgets.Layout(width="300px")
)

login_btn = widgets.Button(
    description="üîê Login & Load Model",
    button_style="primary",
    layout=widgets.Layout(width="200px", height="40px")
)

login_output = widgets.Output()

def on_login_click(btn):
    global model, login_status
    with login_output:
        clear_output()
        token = token_input.value.strip()
        
        if not token:
            print("‚ö†Ô∏è Please enter your HuggingFace token.")
            print("\n   Get your token at: https://huggingface.co/settings/tokens")
            return
        
        print("üîÑ Logging in to HuggingFace...")
        try:
            login(token=token)
            print("‚úÖ Login successful!\n")
        except Exception as e:
            print(f"‚ùå Login failed: {e}")
            print("\n   Make sure your token is correct and has read access.")
            return
        
        print(f"üîÑ Loading {model_dropdown.value} model...")
        print("   This may take 1-2 minutes on first run.\n")
        
        try:
            from esm.models.esmc import ESMC
            model = ESMC.from_pretrained(model_dropdown.value).to(DEVICE)
            login_status["success"] = True
            print(f"‚úÖ Model loaded successfully on {DEVICE.upper()}!")
            print("\nüëá Proceed to Step 3 to upload your sequences.")
        except Exception as e:
            print(f"‚ùå Failed to load model: {e}")
            print("\n   Make sure you have accepted the ESM model license on HuggingFace.")

login_btn.on_click(on_login_click)

# Display
display(HTML("<h3>üîê Step 2: Login to HuggingFace</h3>"))
display(HTML("<p>Enter your HuggingFace access token to download the ESM model:</p>"))
display(HTML("<p><small>Get your token at: <a href='https://huggingface.co/settings/tokens' target='_blank'>huggingface.co/settings/tokens</a></small></p>"))
display(token_input)
display(model_dropdown)
display(login_btn)
display(login_output)

In [None]:
# ============================================================
# STEP 3: UPLOAD SEQUENCES CSV
# ============================================================

# Storage
sequences_df = None

# Create widgets
upload_widget = widgets.FileUpload(
    accept=".csv",
    multiple=False,
    description="Upload CSV",
    button_style="primary",
    layout=widgets.Layout(width="200px")
)

upload_output = widgets.Output()

def get_file_content(data):
    """Extract file content as string."""
    if isinstance(data, bytes):
        return data.decode("utf-8")
    elif hasattr(data, "tobytes"):
        return data.tobytes().decode("utf-8")
    elif isinstance(data, str):
        return data
    else:
        return str(data)

def on_upload_change(change):
    global sequences_df
    with upload_output:
        clear_output()
        new_value = change["new"]
        
        if not new_value:
            return
        
        # Handle different ipywidgets versions
        if isinstance(new_value, dict):
            for filename, file_data in new_value.items():
                if isinstance(file_data, dict) and "content" in file_data:
                    content = get_file_content(file_data["content"])
                else:
                    content = get_file_content(file_data)
        elif isinstance(new_value, (list, tuple)) and len(new_value) > 0:
            file_info = new_value[0]
            if isinstance(file_info, dict):
                filename = file_info.get("name", "unknown.csv")
                content = get_file_content(file_info.get("content", b""))
            else:
                print("‚ö†Ô∏è Unexpected file format.")
                return
        else:
            print("‚ö†Ô∏è No file detected.")
            return
        
        # Parse CSV
        from io import StringIO
        try:
            sequences_df = pd.read_csv(StringIO(content), keep_default_na=False)
        except Exception as e:
            print(f"‚ùå Failed to parse CSV: {e}")
            return
        
        # Validate columns
        required = {"sequence_id", "sequence"}
        if not required.issubset(sequences_df.columns):
            missing = required - set(sequences_df.columns)
            print(f"‚ùå Missing required columns: {missing}")
            print("\n   Your CSV should have 'sequence_id' and 'sequence' columns.")
            print("   This is the format output by the FASTA Cleaner notebook.")
            sequences_df = None
            return
        
        print(f"‚úÖ Uploaded: {filename}")
        print(f"   {len(sequences_df)} sequences found\n")
        print("üìã Preview:")
        display(sequences_df.head())
        print("\nüëá Proceed to Step 4 to configure and generate embeddings.")

upload_widget.observe(on_upload_change, names="value")

# Display
display(HTML("<h3>üì§ Step 3: Upload Your Sequences</h3>"))
display(HTML("<p>Upload the <code>sequences.csv</code> file from the FASTA Cleaner notebook:</p>"))
display(upload_widget)
display(upload_output)

In [None]:
# ============================================================
# STEP 4: GENERATE EMBEDDINGS
# ============================================================

from esm.models.esmc import LogitsConfig
from esm.sdk.api import ESMProtein

# Storage for results
embedding_results = None

# ===== OUTPUT OPTIONS =====
embed_embeddings = widgets.Checkbox(
    value=True,
    description="Return embeddings (last layer)",
    layout=widgets.Layout(width="250px")
)

embed_logits = widgets.Checkbox(
    value=True,
    description="Return logits",
    layout=widgets.Layout(width="200px")
)

# ===== HIDDEN LAYER OPTIONS =====
layer_mode = widgets.Dropdown(
    options=[
        ("None (default)", "none"),
        ("Last layer only", "last"),
        ("Specific layers", "specific"),
        ("All layers (memory intensive!)", "all")
    ],
    value="none",
    description="Hidden layers:",
    layout=widgets.Layout(width="300px")
)

layer_input = widgets.Text(
    value="12, 24, 36",
    placeholder="e.g., 12, 24, 36",
    description="Layer indices:",
    layout=widgets.Layout(width="300px"),
    disabled=True
)

def on_layer_mode_change(change):
    layer_input.disabled = (change["new"] != "specific")

layer_mode.observe(on_layer_mode_change, names="value")

# ===== PROGRESS WIDGETS =====
progress_bar = widgets.IntProgress(
    value=0,
    min=0,
    max=100,
    description="Progress:",
    bar_style="info",
    layout=widgets.Layout(width="400px")
)

progress_label = widgets.Label(value="")

embed_btn = widgets.Button(
    description="üöÄ Generate Embeddings",
    button_style="success",
    layout=widgets.Layout(width="200px", height="40px")
)

embed_output = widgets.Output()

def clean_sequence(seq):
    """Remove non-alphabetic characters from sequence."""
    return re.sub(r"[^A-Z]", "", seq.upper())

def parse_layer_indices(text):
    """Parse comma-separated layer indices."""
    indices = []
    for part in text.split(","):
        part = part.strip()
        if part.isdigit() or (part.startswith("-") and part[1:].isdigit()):
            indices.append(int(part))
    return indices

def get_layers_to_extract():
    """Get list of layer indices based on user selection."""
    mode = layer_mode.value
    if mode == "none":
        return []
    elif mode == "last":
        return [-1]
    elif mode == "specific":
        return parse_layer_indices(layer_input.value)
    elif mode == "all":
        total = MODEL_LAYERS.get(model_dropdown.value, 36)
        return list(range(1, total + 1))
    return []

def on_embed_click(btn):
    global embedding_results
    
    with embed_output:
        clear_output()
        
        # Validation
        if model is None:
            print("‚ö†Ô∏è Model not loaded! Complete Step 2 first.")
            return
        
        if sequences_df is None or len(sequences_df) == 0:
            print("‚ö†Ô∏è No sequences uploaded! Complete Step 3 first.")
            return
        
        # Get configuration
        return_embeddings = embed_embeddings.value
        return_logits = embed_logits.value
        layers_to_extract = get_layers_to_extract()
        
        print("üîÑ Generating embeddings...")
        print(f"   ‚Ä¢ Embeddings: {'Yes' if return_embeddings else 'No'}")
        print(f"   ‚Ä¢ Logits: {'Yes' if return_logits else 'No'}")
        if layers_to_extract:
            if len(layers_to_extract) > 5:
                print(f"   ‚Ä¢ Hidden layers: {len(layers_to_extract)} layers")
            else:
                print(f"   ‚Ä¢ Hidden layers: {layers_to_extract}")
        print()
        
        # Setup
        sequence_ids = sequences_df["sequence_id"].tolist()
        sequences = sequences_df["sequence"].tolist()
        total = len(sequences)
        
        progress_bar.max = total
        progress_bar.value = 0
        
        results = {
            "sequence_id": [],
            "embeddings": [],
            "logits": [],
            "hidden_states": [],
            "hidden_layers_extracted": layers_to_extract,
            "model_name": model_dropdown.value,
            "created_at": datetime.now().isoformat(),
            "errors": [],
            "config": {
                "return_embeddings": return_embeddings,
                "return_logits": return_logits,
                "hidden_layers": layers_to_extract if layers_to_extract else None
            }
        }
        
        # Process each sequence
        for i, (seq_id, seq) in enumerate(zip(sequence_ids, sequences)):
            progress_bar.value = i + 1
            progress_label.value = f"{i+1}/{total} sequences"
            
            try:
                # Clean and convert
                cleaned = clean_sequence(seq)
                protein = ESMProtein(
                    sequence=cleaned,
                    potential_sequence_of_concern=True
                )
                protein_tensor = model.encode(protein)
                
                # Main forward pass
                main_config = LogitsConfig(
                    sequence=return_logits,
                    return_embeddings=return_embeddings,
                    return_hidden_states=len(layers_to_extract) == 1,
                    ith_hidden_layer=layers_to_extract[0] if len(layers_to_extract) == 1 else -1
                )
                output = model.logits(protein_tensor, main_config)
                
                # Build result for this sequence
                seq_hidden = {}
                
                # Store results
                results["sequence_id"].append(seq_id)
                
                # Logits
                if return_logits and output.logits is not None:
                    results["logits"].append(
                        output.logits.sequence.squeeze(0).detach().cpu()
                    )
                else:
                    results["logits"].append(None)
                
                # Embeddings
                if return_embeddings and output.embeddings is not None:
                    results["embeddings"].append(
                        output.embeddings.squeeze(0).detach().cpu()
                    )
                else:
                    results["embeddings"].append(None)
                
                # Hidden states
                if len(layers_to_extract) == 1:
                    hs = getattr(output, "hidden_states", None)
                    if isinstance(hs, torch.Tensor):
                        seq_hidden[layers_to_extract[0]] = hs.squeeze().detach().cpu()
                elif len(layers_to_extract) > 1:
                    # Multiple layers need multiple passes
                    for layer_idx in layers_to_extract:
                        layer_config = LogitsConfig(
                            sequence=False,
                            return_embeddings=False,
                            return_hidden_states=True,
                            ith_hidden_layer=layer_idx
                        )
                        layer_output = model.logits(protein_tensor, layer_config)
                        hs = getattr(layer_output, "hidden_states", None)
                        if isinstance(hs, torch.Tensor):
                            seq_hidden[layer_idx] = hs.squeeze().detach().cpu()
                
                results["hidden_states"].append(seq_hidden)
                    
            except Exception as e:
                results["errors"].append((seq_id, str(e)))
                results["sequence_id"].append(seq_id)
                results["logits"].append(None)
                results["embeddings"].append(None)
                results["hidden_states"].append({})
                print(f"‚ö†Ô∏è Error on {seq_id}: {e}")
        
        embedding_results = results
        
        # Summary
        print("\n" + "="*50)
        print("‚úÖ EMBEDDING COMPLETE!")
        print("="*50)
        print("\nüìä Results:")
        print(f"   ‚Ä¢ Sequences processed: {len(results['sequence_id'])}")
        print(f"   ‚Ä¢ Errors: {len(results['errors'])}")
        
        if results["embeddings"] and results["embeddings"][0] is not None:
            print(f"   ‚Ä¢ Embedding shape: {results['embeddings'][0].shape}")
        
        if layers_to_extract:
            print(f"   ‚Ä¢ Hidden layers extracted: {len(layers_to_extract)}")
        
        print("\nüëá Proceed to Step 5 to download your embeddings.")

embed_btn.on_click(on_embed_click)

# Display
display(HTML("<h3>üöÄ Step 4: Generate Embeddings</h3>"))
display(HTML("<p><b>Output options:</b></p>"))
display(widgets.HBox([embed_embeddings, embed_logits]))
display(HTML("<p><b>Hidden layer extraction:</b> (optional, for advanced analysis)</p>"))
display(layer_mode)
display(layer_input)
display(HTML("<br>"))
display(embed_btn)
display(widgets.HBox([progress_bar, progress_label]))
display(embed_output)

In [None]:
# ============================================================
# STEP 5: DOWNLOAD RESULTS
# ============================================================

download_output = widgets.Output()

def download_embeddings(btn):
    with download_output:
        clear_output()
        
        if embedding_results is None:
            print("‚ö†Ô∏è No embeddings generated! Complete Step 4 first.")
            return
        
        filename = "embeddings.pt"
        torch.save(embedding_results, filename)
        
        if IN_COLAB:
            colab_files.download(filename)
            print(f"‚úÖ Downloading {filename}...")
        else:
            print(f"‚úÖ Saved to: {filename}")
        
        print("\nüìñ To load this file later:")
        print("   import torch")
        print("   results = torch.load('embeddings.pt')")
        print("   embeddings = results['embeddings']")
        print("   sequence_ids = results['sequence_id']")
        print("\n   # Access hidden states (if extracted):")
        print("   hidden = results['hidden_states'][0]  # First sequence")
        print("   layer_12 = hidden[12]  # Layer 12")

# Create download button
download_btn = widgets.Button(
    description="üíæ Download Embeddings",
    button_style="warning",
    layout=widgets.Layout(width="200px", height="40px")
)
download_btn.on_click(download_embeddings)

# Display
display(HTML("<h3>üíæ Step 5: Download Your Results</h3>"))
display(HTML("<p>Click the button below to download your embeddings:</p>"))
display(download_btn)
display(download_output)

---

## üìñ Output File Description

The `embeddings.pt` file is a PyTorch file containing a dictionary with:

| Key | Description |
|-----|-------------|
| `sequence_id` | List of sequence IDs (links to your `metadata.csv`) |
| `embeddings` | List of last-layer embedding tensors (if enabled) |
| `logits` | List of logits tensors (if enabled) |
| `hidden_states` | List of dicts `{layer_idx: tensor}` for each sequence |
| `hidden_layers_extracted` | List of layer indices that were extracted |
| `model_name` | Name of the ESM-C model used |
| `created_at` | Timestamp of when embeddings were generated |
| `config` | Configuration used for this run |
| `errors` | List of (sequence_id, error_message) for any failed sequences |

### Loading the file:

```python
import torch

# Load results
results = torch.load("embeddings.pt")

# Get embeddings for first sequence
first_embedding = results["embeddings"][0]  # Shape: (seq_len, embedding_dim)

# Get mean embedding (useful for classification)
mean_embedding = first_embedding.mean(dim=0)  # Shape: (embedding_dim,)

# Access hidden states (if extracted)
hidden = results["hidden_states"][0]  # First sequence
layer_12 = hidden[12]  # Get layer 12

# Check which layers were extracted
print(results["hidden_layers_extracted"])  # e.g., [12, 24, 36]

# Find embedding by sequence ID
target_id = "99603f8fb1e9"
idx = results["sequence_id"].index(target_id)
embedding = results["embeddings"][idx]
```

> **Tip:** The `sequence_id` values match those in your `metadata.csv`, so you can easily link embeddings back to protein names, dates, and other metadata.