# Transformers for Biological Sequence Analysis

## Understanding Protein Sequences, Families, and the Usage of ESM-2

---

## 1. What Are Protein Sequences?

Proteins are biological macromolecules composed of amino acids arranged in a linear sequence.
Each protein sequence can be represented as a string of characters, where each character corresponds to one of the 20 standard amino acids.

**Example sequence:** `MKTLLILTCLVAVALARPKA...`

This linear sequence determines:

* The three-dimensional structure of the protein
* Its biochemical function
* Its evolutionary relationships

Understanding how sequence relates to structure and function is a central problem in molecular biology.


## 2. What Are Protein Families?

Protein families group together proteins that:

* Share evolutionary ancestry
* Have similar structural domains
* Perform similar biological functions

Databases such as **Pfam** classify proteins into families based on conserved sequence motifs and Hidden Markov Models (HMMs).

**Examples of protein families**

* **Kinases** — enzymes that transfer phosphate groups
* **ABC transporters** — membrane transport proteins
* **Response regulators** — signaling proteins

Traditionally, identifying family membership required:

* Sequence alignment (e.g., BLAST)
* Multiple sequence alignment
* Profile HMMs

These approaches depend heavily on sequence similarity.


## 3. What Is ESM-2?

**ESM-2 (Evolutionary Scale Modeling 2)** is a transformer-based protein language model trained on millions of protein sequences.

Instead of using alignments, ESM-2 learns directly from raw amino acid sequences using a self-supervised objective known as **masked language modeling**.

Through this training process, the model learns to:

* Capture long-range dependencies in sequences
* Encode structural information
* Represent functional similarity
* Embed evolutionary relationships

The output of ESM-2 is a dense numerical representation (**embedding**) of each protein sequence.

## There are several ESM-2 checkpoints with differing model sizes. 
Larger models will generally have better accuracy, but   they require more GPU memory and will take much longer to train. The available ESM-2 checkpoints are:

| Checkpoint name        | Num layers | Num parameters |
|------------------------|-----------:|---------------:|
| esm2_t48_15B_UR50D     | 48         | 15B            |
| esm2_t36_3B_UR50D      | 36         | 3B             |
| esm2_t33_650M_UR50D    | 33         | 650M           |
| esm2_t30_150M_UR50D    | 30         | 150M           |
| esm2_t12_35M_UR50D     | 12         | 35M            |
| esm2_t6_8M_UR50D       | 6          | 8M             |

Note that the larger checkpoints may be very difficult to train without a large cloud GPU like an A100 or H100, and the largest 15B parameter checkpoint will probably be impossible to train on any single GPU! Also, note that memory usage for attention during training will scale as O(batch_size * num_layers * seq_len^2), so larger models on long sequences will use quite a lot of memory! We will use the esm2_t30_150M_UR50D checkpoint for this notebook, which should train on any notebook or modern GPU.

## 4. How Will We Use ESM-2 in This Notebook?

In this study notebook, we will:

1. Retrieve raw protein sequences from **UniProt**
2. Convert each sequence into a fixed-length embedding using ESM-2
3. Train a simple linear classifier to predict protein family labels
4. Evaluate how well the embeddings capture biologically meaningful structure

Importantly, we will **not fine-tune ESM-2**.
We will use it strictly as a feature extractor and test whether family-level biological information is already encoded in its representations.


## 5. Why Is This Important?

If a simple linear classifier can accurately predict protein families from ESM-2 embeddings, it suggests that:

* The model has learned biologically meaningful structure
* No alignment or handcrafted biological features are required
* Protein function and evolutionary information are encoded directly in the embedding space

This approach demonstrates how transformer-based models can serve as powerful tools for biological discovery.


In the following sections, we transition from biological background to practical implementation and experimentation.


## Verify Installation + GPU Availability

We verify that libraries import correctly and check whether CUDA is available.


In [2]:
import torch
import transformers
import requests
from dataclasses import dataclass
import numpy as np
import pandas as pd
import os
import time
import sklearn
import matplotlib

# Configuration class to store experiment settings in one place
@dataclass
class Config:
    cache_dir: str = "./data/cache_uniprot"                      # folder to store downloaded/cached UniProt data
    seed: int = 42                                               # random seed for reproducibility
    device: str = "cuda" if torch.cuda.is_available() else "cpu" # automatically use GPU if available, otherwise fallback to CPU

cfg = Config()

os.makedirs(cfg.cache_dir, exist_ok=True)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)

print("\nCUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))



CUDA available: True
GPU: NVIDIA L4


In [2]:
# ---------- USER CONTROLS ----------
TARGET_FAMILIES = 200          # <-- set desired number of families (e.g., 100, 200, 500)
PER_CLASS = 150                # sequences per family (balanced)
MIN_AFTER_CLEAN = 80           # drop families that become too small after cleaning
SCAN_PFAM_MAX = 20000          # how far to scan PFxxxxx IDs (increase if you want more choices)

REVIEWED_ONLY = True           # True = Swiss-Prot only (cleaner, smaller). False = bigger, noisier.
LENGTH_MIN = 50
LENGTH_MAX = 1200

# Model choice
ESM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"   # <-- your requested model
MAX_LEN = 512                 # 512 is fast; 1024 is heavier
BATCH_SIZE = 4                # t30_150M needs smaller batch on many GPUs (try 4, 2, or 1)



## Retrieving Protein Sequences from UniProt (Programmatic Access)

To build our dataset, we need to retrieve protein sequences directly from UniProt.

UniProt provides a REST API that allows us to:
- Query proteins using structured search filters
- Retrieve results in different formats (TSV, JSON, FASTA)
- Handle pagination when results exceed a single page

The code below implements a small utility for downloading protein data in TSV format.

### Key Components

**1. UNIPROT_SEARCH_URL**

This is the base REST endpoint for searching UniProtKB:
https://rest.uniprot.org/uniprotkb/search



In [3]:
from typing import Dict, Optional

UNIPROT_SEARCH_URL = "https://rest.uniprot.org/uniprotkb/search"

def _next_link_from_headers(headers: Dict[str, str]) -> Optional[str]:
    link = headers.get("Link")
    if not link:
        return None
    m = re.search(r'<([^>]+)>;\s*rel="next"', link)
    return m.group(1) if m else None

def fetch_uniprot_tsv(query: str, fields: str, size: int = 500, max_rows: int = 1000) -> pd.DataFrame:
    params = {"query": query, "format": "tsv", "fields": fields, "size": size}
    rows = []
    url = UNIPROT_SEARCH_URL

    header = None
    while url and len(rows) < max_rows:
        r = requests.get(url, params=params if url == UNIPROT_SEARCH_URL else None, timeout=60)
        r.raise_for_status()

        text = r.text.strip()
        if not text:
            break

        lines = text.splitlines()
        header = lines[0].split("\t")
        for line in lines[1:]:
            rows.append(line.split("\t"))
            if len(rows) >= max_rows:
                break

        url = _next_link_from_headers(r.headers)
        params = None
        time.sleep(0.1)

    if not rows:
        return pd.DataFrame(columns=fields.split(","))
    return pd.DataFrame(rows, columns=header)


In [4]:
def uniprot_count(query: str) -> int:
    """
    Returns total number of UniProt entries matching a query
    without downloading full records.
    """
    r = requests.get(
        UNIPROT_SEARCH_URL,
        params={"query": query, "format": "json", "size": 0},
        timeout=30
    )
    if r.status_code != 200:
        return 0

    return int(r.headers.get("x-total-results", "0"))


## Selecting Protein Families for the Benchmark

In this step, we identify Pfam families that contain enough protein sequences
to support a balanced multi-class classification task.

Rather than manually choosing families, we automatically scan Pfam identifiers
(PF00001, PF00002, PF00003, ...) and check how many sequences are available
in UniProt under our filtering criteria.

---

### 1. Building a Structured UniProt Query

The function `build_query_for_pfam()` constructs a structured search query
for a given Pfam ID.

Each query enforces the following constraints:

- `xref:pfam-PFxxxxx` → Protein must belong to a specific Pfam family
- `fragment:false` → Exclude incomplete protein fragments
- `reviewed:true` (optional) → Include only Swiss-Prot curated entries
- `length:[min TO max]` → Restrict protein length range

These filters ensure:
- High-quality sequences
- Reasonable length distributions
- Reduced noise in downstream analysis

---

### 2. Discovering Candidate Families

We then iterate over Pfam IDs and:

1. Construct a query for each family
2. Count how many sequences satisfy our constraints
3. Keep only families with at least `PER_CLASS` sequences

This guarantees that each selected family has sufficient data
to form a balanced classification dataset.

---

### 3. Why We Select More Than Needed

We intentionally collect more families than our final target:

This buffer accounts for later cleaning steps, such as:

- Removing duplicate sequences
- Removing multi-label proteins
- Dropping families that shrink below a minimum size

This ensures that after cleaning, we still retain
the desired number of families.

---

### 4. Safety Check

The final assertion ensures that we discovered enough families.
If not, the user may need to:

- Disable the reviewed-only filter
- Reduce the required sequences per family
- Increase the Pfam search range

---

### Why This Step Is Important

Protein families in UniProt are highly imbalanced:
- Some families contain thousands of sequences
- Many families contain only a few

This automated selection process allows us to construct
a controlled and reproducible benchmark rather than relying
on arbitrary manual choices.




In [5]:
def build_query_for_pfam(pfam_id: str) -> str:
    reviewed = " AND (reviewed:true)" if REVIEWED_ONLY else ""
    pfam_clause = f"(xref:pfam-{pfam_id})"
    return (
        f"{pfam_clause}"
        f" AND (fragment:false)"
        f"{reviewed}"
        f" AND (length:[{LENGTH_MIN} TO {LENGTH_MAX}])"
    )

# Discover candidates with enough sequences
selected = []
for i in range(1, SCAN_PFAM_MAX + 1):
    pf = f"PF{i:05d}"
    c = uniprot_count(build_query_for_pfam(pf))
    if c >= PER_CLASS:
        selected.append(pf)

    # buffer so we don't lose too many families during cleaning
    if len(selected) >= int(TARGET_FAMILIES * 1.5):
        break

    if i % 200 == 0:
        print(f"Scanned PF00001..PF{i:05d} | candidates so far: {len(selected)}")
        time.sleep(0.05)

print("Candidate families found:", len(selected))
print("First 10:", selected[:10])

assert len(selected) >= TARGET_FAMILIES, (
    f"Not enough families found. Found {len(selected)} candidates, need {TARGET_FAMILIES}. "
    f"Try: set REVIEWED_ONLY=False, lower PER_CLASS, increase SCAN_PFAM_MAX."
)


Scanned PF00001..PF00200 | candidates so far: 122
Scanned PF00001..PF00400 | candidates so far: 213
Scanned PF00001..PF00600 | candidates so far: 297
Candidate families found: 300
First 10: ['PF00001', 'PF00004', 'PF00005', 'PF00006', 'PF00008', 'PF00009', 'PF00010', 'PF00011', 'PF00012', 'PF00013']


## Build Dataset (Part A) : Standardize UniProt Column Names

UniProt can return slightly different column headers depending on the endpoint or fields.
For example, the accession column may appear as `Entry` or `Accession`.

To keep the downstream code consistent, we normalize the DataFrame so that it always has:

- `accession`
- `sequence`
- `protein_name`
- `organism_name`
- `length`

This prevents bugs when we scale up to many families.


In [6]:
def _normalize_uniprot_columns(df: pd.DataFrame) -> pd.DataFrame:
    # Map lowercase column names -> original column names (case-insensitive matching)
    cmap = {c.lower(): c for c in df.columns}

    # Try multiple possible names UniProt might use
    acc = cmap.get("accession") or cmap.get("entry")
    seq = cmap.get("sequence")
    prot = cmap.get("protein_name") or cmap.get("protein names")
    org  = cmap.get("organism_name") or cmap.get("organism")
    leng = cmap.get("length")

    # accession + sequence are mandatory; if missing, stop early with a clear error
    if acc is None or seq is None:
        raise ValueError(f"Missing accession/sequence. Columns: {list(df.columns)}")

    # Rename to our standard schema
    return df.rename(columns={
        acc: "accession",
        seq: "sequence",
        (prot or "protein_name"): "protein_name",
        (org  or "organism_name"): "organism_name",
        (leng or "length"): "length",
    }).copy()


## Build Dataset (Part B) : Fetch One Family (with Caching)

For each Pfam family:
1. Build the UniProt query
2. Download up to `n` sequences
3. Clean and attach the family label
4. Cache the result locally (TSV file)

Caching is important because:
- It avoids repeating API requests
- It speeds up reruns
- It makes the notebook reproducible


In [7]:
def get_family_df(label: str, pfam_id: str, n: int) -> pd.DataFrame:
    # Cache path so we don't re-download the same data every time
    cache_path = os.path.join(
        cfg.cache_dir, f"{label}_{pfam_id}_n{n}_rev{int(REVIEWED_ONLY)}.tsv"
    )

    # If cached file exists, load it immediately
    if os.path.exists(cache_path):
        return pd.read_csv(cache_path, sep="\t")

    # Build query and fetch from UniProt
    query = build_query_for_pfam(pfam_id)
    fields = "accession,protein_name,organism_name,length,sequence"
    df = fetch_uniprot_tsv(query=query, fields=fields, size=500, max_rows=n)

    # If UniProt returns nothing, return empty DataFrame
    if df.empty:
        return pd.DataFrame()

    # Normalize columns and basic cleaning
    df = _normalize_uniprot_columns(df)
    df = df.dropna(subset=["sequence"]).copy()

    # Add label column (ground-truth family)
    df["label"] = label

    # Save to cache
    df.to_csv(cache_path, sep="\t", index=False)
    return df


## Build Dataset (Part C): Collect Many Families into One Dataset

We now:
- Select the first `TARGET_FAMILIES` Pfams from our candidate list
- Create a dictionary mapping each family label to its Pfam ID
- Fetch sequences for each family and combine them into one dataset (`data`)

At the end, `data` contains:
- sequences
- metadata
- a family label for supervised learning


In [10]:
pfams = selected[:TARGET_FAMILIES]
FAMILIES = {f"PFAM_{pf}": pf for pf in pfams}

dfs = []
for label, pf in FAMILIES.items():
    df_i = get_family_df(label, pf, PER_CLASS)
    if not df_i.empty:
        dfs.append(df_i)
    print(label, pf, "rows:", len(df_i))

data = pd.concat(dfs, ignore_index=True)
print("\nRaw rows:", len(data), "Raw classes:", data["label"].nunique())


PFAM_PF00001 PF00001 rows: 150
PFAM_PF00004 PF00004 rows: 150
PFAM_PF00005 PF00005 rows: 150
PFAM_PF00006 PF00006 rows: 150
PFAM_PF00008 PF00008 rows: 150
PFAM_PF00009 PF00009 rows: 150
PFAM_PF00010 PF00010 rows: 150
PFAM_PF00011 PF00011 rows: 150
PFAM_PF00012 PF00012 rows: 150
PFAM_PF00013 PF00013 rows: 150
PFAM_PF00014 PF00014 rows: 150
PFAM_PF00016 PF00016 rows: 150
PFAM_PF00017 PF00017 rows: 150
PFAM_PF00018 PF00018 rows: 150
PFAM_PF00019 PF00019 rows: 150
PFAM_PF00022 PF00022 rows: 150
PFAM_PF00023 PF00023 rows: 150
PFAM_PF00025 PF00025 rows: 150
PFAM_PF00026 PF00026 rows: 150
PFAM_PF00027 PF00027 rows: 150
PFAM_PF00028 PF00028 rows: 150
PFAM_PF00032 PF00032 rows: 150
PFAM_PF00033 PF00033 rows: 150
PFAM_PF00034 PF00034 rows: 150
PFAM_PF00035 PF00035 rows: 150
PFAM_PF00038 PF00038 rows: 150
PFAM_PF00041 PF00041 rows: 150
PFAM_PF00042 PF00042 rows: 150
PFAM_PF00043 PF00043 rows: 150
PFAM_PF00044 PF00044 rows: 150
PFAM_PF00046 PF00046 rows: 150
PFAM_PF00047 PF00047 rows: 150
PFAM_PF0

## Build Dataset (Part D): Remove Duplicate Sequences

The same protein sequence can sometimes appear multiple times
(e.g., redundant entries across organisms or isoforms).

We remove duplicate sequences to:
- avoid leakage
- prevent the classifier from memorizing identical sequences


In [11]:
data = data.drop_duplicates(subset=["sequence"]).reset_index(drop=True)
print("After de-dup sequences:", len(data))


After de-dup sequences: 27065


## Build Dataset (Part E): Remove Multi-Label Accessions

A protein can sometimes contain multiple domains and therefore match multiple Pfam families.
If we allow those proteins, the same sequence could appear with different labels,
which breaks the assumption of single-label classification.

To keep the task clean, we remove accessions that map to more than one selected family.


In [12]:
acc_n = data.groupby("accession")["label"].nunique()
data = data[data["accession"].isin(acc_n[acc_n == 1].index)].reset_index(drop=True)
print("After removing multi-label accessions:", len(data))


After removing multi-label accessions: 27065


## Build Dataset (Part F): Drop Families That Became Too Small

After cleaning, some families may lose many proteins.
To keep the benchmark stable, we drop families with fewer than `MIN_AFTER_CLEAN` sequences.

This ensures each class has enough training examples.


In [13]:
vc = data["label"].value_counts()
keep = vc[vc >= MIN_AFTER_CLEAN].index
data = data[data["label"].isin(keep)].reset_index(drop=True)

print("Final rows:", len(data))
print("Final classes:", data["label"].nunique())
print("Per-class min/median/max:",
      data["label"].value_counts().min(),
      int(data["label"].value_counts().median()),
      data["label"].value_counts().max())

assert data["label"].nunique() >= int(0.9 * TARGET_FAMILIES), (
    "Too many families dropped after cleaning. Try lowering MIN_AFTER_CLEAN or PER_CLASS, "
    "or set REVIEWED_ONLY=False to get more sequences."
)


Final rows: 27026
Final classes: 197
Per-class min/median/max: 82 141 150


## Generating Protein Embeddings Using ESM-2

In this step, we convert raw protein sequences into numerical representations
using a pretrained ESM-2 transformer model.

---

### 1. Loading the Pretrained Model

We load two components:

- **Tokenizer** → Converts amino acid sequences into token IDs
- **Model** → Transformer architecture that produces contextual embeddings

The model is moved to either:
- GPU (if available), or
- CPU

We set the model to evaluation mode (`model.eval()`) since we are not fine-tuning it.

---

### 2. How Embedding Works

Each protein sequence is:

1. Truncated to a maximum length (`MAX_LEN`) to control memory usage.
2. Tokenized into integer IDs.
3. Passed through the transformer model.
4. Converted into contextual token embeddings.

The output of ESM-2 is:

Each amino acid receives a contextual embedding that depends on the entire sequence.

---

### 3. Mean Pooling

Since proteins have variable lengths, we convert token-level embeddings
into a fixed-length vector using mean pooling:

- Multiply embeddings by attention mask (to ignore padding)
- Sum across sequence dimension
- Divide by number of valid tokens

This produces one vector per protein:
This vector is the **protein embedding**.

---

### 4. Embedding the Full Dataset

We process the dataset in mini-batches to:

- Reduce memory usage
- Improve GPU efficiency
- Enable scaling to large datasets

The final result:

- `X` → matrix of protein embeddings (one row per protein)
- `y` → corresponding Pfam family labels

---

### 5. What Does `X` Represent?

If the embedding dimension is `H`, then:


Each row represents a protein in a high-dimensional embedding space
learned from millions of protein sequences.

If ESM-2 has captured biological structure correctly,
proteins from the same family should cluster together in this space.

In [14]:
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

assert "data" in globals() and not data.empty

DEVICE = cfg.device
MODEL_NAME = ESM_MODEL_NAME

print("Embedding with:", MODEL_NAME)
print("Device:", DEVICE)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()

@torch.no_grad()
def embed_batch(seqs):
    seqs = [str(s)[:MAX_LEN] for s in seqs]
    toks = tokenizer(seqs, return_tensors="pt", padding=True, truncation=True, max_length=MAX_LEN)
    toks = {k: v.to(DEVICE) for k, v in toks.items()}
    out = model(**toks).last_hidden_state          # [B, T, H]
    mask = toks["attention_mask"].unsqueeze(-1)    # [B, T, 1]
    pooled = (out * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1)
    return pooled.float().cpu().numpy()

def embed_dataset(df, batch_size=BATCH_SIZE):
    seqs = df["sequence"].astype(str).tolist()
    embs = []
    for i in tqdm(range(0, len(seqs), batch_size), desc="Embedding"):
        embs.append(embed_batch(seqs[i:i+batch_size]))
    return np.vstack(embs)

X = embed_dataset(data)
y = data["label"].values

print("X shape:", X.shape)
print("Classes:", len(np.unique(y)))


Embedding with: facebook/esm2_t30_150M_UR50D
Device: cuda


Skipping import of cpp extensions due to incompatible torch version 2.9.1+cu128 for torchao version 0.14.1             Please see https://github.com/pytorch/ao/issues/2919 for more info
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t30_150M_UR50D and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Embedding: 100%|██████████| 6757/6757 [10:17<00:00, 10.94it/s]


X shape: (27026, 640)
Classes: 197


## Training a Linear Classifier on ESM-2 Embeddings

In this section, we evaluate whether ESM-2 embeddings encode
protein family information in a linearly separable way.

If a simple linear model achieves high accuracy,
it suggests that the embedding space already captures
functional structure.

---

### 1. Encoding Labels

Protein family labels are categorical (strings).
Machine learning models require numeric labels.

We use `LabelEncoder` to convert:
"PFAM_PF00001" → 0
"PFAM_PF00004" → 1
...

This creates integer class indices for training.

---

### 2. Train-Test Split

We split the dataset into:

- 75% training data
- 25% test data

We use **stratified sampling** to preserve class balance
across both sets.

This ensures fair evaluation.

---

### 3. Standardizing Embeddings (Very Important)

Even though embeddings are learned representations,
their dimensions may have different scales.

We standardize features using: X_standardized = (X - mean) / std

Important:
- We compute mean and standard deviation **only from training data**
- We apply the same transformation to test data

This avoids data leakage and improves optimization stability.

---

### 4. Moving Data to GPU

We convert NumPy arrays into PyTorch tensors and move them to:

- GPU (if available), or
- CPU

This allows us to train efficiently at scale.

---

### 5. Mini-Batch Training

Instead of training on the entire dataset at once,
we use mini-batches (size 256).

Advantages:
- More stable optimization
- Faster convergence
- Better GPU utilization

---

### 6. Linear Classifier

We define a simple linear layer: Linear(in_dim → num_classes)

This performs: logits = W·x + b

No hidden layers.

This is intentionally simple —
we are testing the quality of embeddings,
not building a deep classifier.

---

### 7. Loss Function

We use **CrossEntropyLoss**, which:

- Combines softmax + log likelihood
- Is standard for multi-class classification

---

### 8. Optimization

We use SGD with:
- Learning rate = 0.2
- Momentum = 0.9
- Weight decay for regularization

We train for 30 epochs.

---

### 9. Evaluation

After training, we:

- Switch to evaluation mode
- Compute predictions on test data
- Report accuracy and class-level precision/recall/F1

---

## What Are We Testing?

If the classifier achieves high accuracy,
it means:

- Protein families are linearly separable
- ESM-2 embeddings encode functional similarity
- No fine-tuning was required

This validates the representational power of the pretrained model.










In [15]:
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder

# --- Encode labels ---
le = LabelEncoder()
y_int = le.fit_transform(y)

X_train, X_test, y_train, y_test = train_test_split(
    X, y_int, test_size=0.25, random_state=42, stratify=y_int
)

# --- Standardize using train stats only (VERY IMPORTANT) ---
mu = X_train.mean(axis=0, keepdims=True)
std = X_train.std(axis=0, keepdims=True) + 1e-6
X_train_s = (X_train - mu) / std
X_test_s  = (X_test  - mu) / std

device = cfg.device
num_classes = len(le.classes_)
in_dim = X_train_s.shape[1]

# --- Tensors on GPU ---
X_train_t = torch.tensor(X_train_s, dtype=torch.float32, device=device)
y_train_t = torch.tensor(y_train,   dtype=torch.long,   device=device)
X_test_t  = torch.tensor(X_test_s,  dtype=torch.float32, device=device)
y_test_t  = torch.tensor(y_test,    dtype=torch.long,   device=device)

# --- DataLoader (mini-batches) ---
ds = TensorDataset(X_train_t, y_train_t)
loader = DataLoader(ds, batch_size=256, shuffle=True)

# --- Linear classifier ---
clf = torch.nn.Linear(in_dim, num_classes).to(device)

# Try SGD for linear models (often better behaved than Adam here)
opt = torch.optim.SGD(clf.parameters(), lr=0.2, momentum=0.9, weight_decay=1e-4)
loss_fn = torch.nn.CrossEntropyLoss()

# --- Train ---
clf.train()
for epoch in range(30):
    total_loss = 0.0
    for xb, yb in loader:
        opt.zero_grad()
        logits = clf(xb)
        loss = loss_fn(logits, yb)
        loss.backward()
        opt.step()
        total_loss += loss.item() * xb.size(0)

    avg_loss = total_loss / len(loader.dataset)
    if (epoch+1) % 5 == 0 or epoch == 0:
        print(f"epoch {epoch+1:02d} | loss {avg_loss:.4f}")

# --- Eval ---
clf.eval()
with torch.no_grad():
    pred = clf(X_test_t).argmax(dim=1).cpu().numpy()
    y_true = y_test_t.cpu().numpy()

from sklearn.metrics import accuracy_score, classification_report
print("Accuracy:", accuracy_score(y_true, pred))
print(classification_report(y_true, pred, digits=3, target_names=le.classes_))


epoch 01 | loss 0.6798
epoch 05 | loss 0.0399
epoch 10 | loss 0.0232
epoch 15 | loss 0.0217
epoch 20 | loss 0.0179
epoch 25 | loss 0.0162
epoch 30 | loss 0.0155
Accuracy: 0.9758768684327365
              precision    recall  f1-score   support

PFAM_PF00001      1.000     1.000     1.000        38
PFAM_PF00004      0.921     0.946     0.933        37
PFAM_PF00005      1.000     0.973     0.986        37
PFAM_PF00006      0.714     0.946     0.814        37
PFAM_PF00008      0.921     0.921     0.921        38
PFAM_PF00009      1.000     0.970     0.985        33
PFAM_PF00010      0.925     0.974     0.949        38
PFAM_PF00011      1.000     0.972     0.986        36
PFAM_PF00012      1.000     1.000     1.000        34
PFAM_PF00013      1.000     0.943     0.971        35
PFAM_PF00014      1.000     1.000     1.000        37
PFAM_PF00016      1.000     1.000     1.000        36
PFAM_PF00017      0.854     0.946     0.897        37
PFAM_PF00018      0.818     0.692     0.750        26

# Conclusion

In this notebook, we explored how pretrained transformer models can be used
for protein sequence analysis.

We started with raw protein sequences from UniProt and constructed
a labeled dataset based on Pfam protein families.

We then:

1. Used ESM-2 to convert each protein sequence into a fixed-length embedding.
2. Trained a simple linear classifier on these embeddings.
3. Evaluated the model against ground-truth family labels.

---

## What Did We Learn?

- ESM-2 embeddings contain meaningful biological information.
- Protein families are largely separable using a simple linear model.
- No sequence alignment or handcrafted features were required.
- We did not fine-tune ESM-2 — we used it purely as a feature extractor.

This shows that pretrained protein language models
can serve as powerful general-purpose representations for downstream tasks.

---

## Next Steps

To further explore model performance, we can:

- Compare different classifiers (Logistic Regression, SVM, Neural Networks).
- Compare different ESM-2 model sizes.
- Evaluate performance across more protein families.
- Analyze confusion between biologically related families.

---

This notebook demonstrates a practical workflow:

Raw Sequences → ESM-2 Embeddings → Linear Classifier → Family Prediction

This pipeline can be extended to many other biological prediction tasks.
