<a href="https://colab.research.google.com/github/luquelab/lab-bioinformatics-workshops/blob/main/lab-bioinformatics-workshops/tree/main/protein_language_models/LAB_WORKSHOP_llm_embedding_plus_prediction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction to language models in bioinformatics

In this hands-on, we will:
- Use a **protein language model (PLM)** (ProtBert-BFD) to create **embeddings** of viral proteins.
- Run a small **classifier** (multilayer perceptron, MLP) to predict functional categories.
- Inspect the actual data objects (embeddings, predictions).

**Key concepts**
- **Neural network:** a stack of linear + non-linear transformations learned from data.
- **Transformer:** an architecture that uses **self-attention** to relate all tokens (amino acids) to each other.
- **Embedding:** a vector of numbers that encodes semantic/functional/structural information in a continuous space.
- **Classifier (MLP):** a few dense layers that map embeddings to probabilities over categories.

> We’ll introduce concepts as we go, right next to the code that uses them.


## 0) Housekeeping & file panel
We’ll work in `/content`. The next cell removes Colab’s demo folder and prints where we are.  
Use the left **Files** panel to upload your FASTA (drag-and-drop).


In [None]:
%%bash
set -euo pipefail
cd /content
rm -rf sample_data
echo "Working dir:"; pwd; ls -la


## 0.1) Get the auxiliary code (GitHub clone)

This notebook is based on the paper [**Large language models improve annotation of prokaryotic viral proteins**](https://www.nature.com/articles/s41564-023-01584-8) (Flamholz et al 2024).

We’ll pull the project repository that contains the utility scripts we’ll use:

- `scripts/embed_faa.py` — runs the **protein language model** (ProtBert-BFD) to produce embeddings  
- `scripts/predict_function.py` — runs the **classifier (MLP)** to predict functional categories  
- `plm_vpf_embed.yml` — the exact dependencies for the **embedding** environment (older stack)  

> The python environments that we create next needs the repo’s YAML (`plm_vpf_embed.yml`). Cloning first ensures we can build that env.


In [None]:
%%bash
set -euo pipefail
cd /content
if [[ ! -d viral-protein-function-plm ]]; then
  git clone https://github.com/kellylab/viral-protein-function-plm.git
else
  echo "Repo already present."
fi


## 1) Environments for reproducibility
We’ll use **two conda environments**:
- **`plm_vpf_embed` (Python 3.7):** matches the older stack required by the embedding code (BioTransformers + Transformers 4.8.x).
- **`plm_vpf_predict` (Python 3.10):** modern enough to install TensorFlow 2.12 for the classifier.

> Pinning versions keeps today’s workshop stable even if upstream packages change.


In [None]:
%%bash
set -euo pipefail
cd /content

# Install Miniconda (clean reinstall)
rm -rf /opt/conda
wget -q -P /tmp https://repo.anaconda.com/miniconda/Miniconda3-py311_23.11.0-2-Linux-x86_64.sh
bash /tmp/Miniconda3-py311_23.11.0-2-Linux-x86_64.sh -b -p /opt/conda
rm -f /tmp/Miniconda3-py311_23.11.0-2-Linux-x86_64.sh

# Create env from repo YAML (Py3.7 + bio-transformers stack)
source /opt/conda/etc/profile.d/conda.sh
conda env create -f /content/viral-protein-function-plm/plm_vpf_embed.yml

# Lock versions that we know work for the old stack
conda activate plm_vpf_embed
python -m pip -q install --force-reinstall \
  "numpy==1.18.5" "transformers==4.8.2" "tokenizers==0.10.3" \
  "huggingface_hub==0.0.12" "requests==2.31.0"

# Persist both HF endpoint vars so even ancient libs behave
conda env config vars set HUGGINGFACE_CO_URL="https://huggingface.co" HF_ENDPOINT="https://huggingface.co"
conda deactivate
echo "plm_vpf_embed created."


## 2) Prediction environment (TensorFlow 2.12)
The classifier we’ll run later was written against TF 2.12. Newer Colab bases use Python 3.12, where TF 2.12 dependencies are not provided, so we create a clean **Py3.10** env specifically for predictions.


In [None]:
%%bash
set -euo pipefail
source /opt/conda/etc/profile.d/conda.sh

# Create the prediction env if it doesn't exist
if ! conda env list | grep -q "^plm_vpf_predict"; then
  conda create -y -n plm_vpf_predict python=3.10
fi

conda activate plm_vpf_predict
python -m pip -q install --no-cache-dir \
  "tensorflow==2.12.0" \
  "seaborn" \
  "scikit-learn==1.1.3" \
  "pandas" \
  "matplotlib" \
  "ipykernel"

python - <<'PY'
import sys, tensorflow as tf, sklearn, pandas as pd, matplotlib
print("Python:", sys.version.split()[0])
print("TF:", tf.__version__, "| sklearn:", sklearn.__version__)
PY


## 3) Transformers & model caching (ProtBert-BFD)
To make the code robust, we download the model to the local filesystem so the old libraries can load it reliably and offline (going online will try to update them).

We’ll use **ProtBert-BFD**, an artificial intelligence transformer model designed to read and understand protein sequences, similar to how language models handle words. It was trained using masked-language modeling, a method where parts of a protein sequence (like certain amino acids) are hidden, and the model learns to predict these missing parts.

A transformer is a type of neural network that takes in a sequence and processes the whole sequence at once, instead of step by step. Transformers use attention mechanisms to decide which parts matter most for predicting the best output

Specifically, ProtBert uses **self-attention**. The model looks at the entire input sequence and figures out how each part (say, each word) relates to every other part, not just the one before or after it. For example, it helps the model know that "sky" and "blue" are connected, even if they're far apart in the sentence.


In [None]:
%%bash
set -euo pipefail
# Use system Python for a modern huggingface_hub to snapshot the model
/usr/bin/python3 -m pip -q install "huggingface_hub>=0.16"

# Download to the shared cache
SNAP=$(/usr/bin/python3 - <<'PY'
from huggingface_hub import snapshot_download
p = snapshot_download("Rostlab/prot_bert_bfd")
print(p)
PY
)

echo "Snapshot at: $SNAP"

# Copy snapshot files into a folder named exactly like the model ID inside the repo
mkdir -p /content/viral-protein-function-plm/Rostlab/prot_bert_bfd
cp -f "$SNAP"/* /content/viral-protein-function-plm/Rostlab/prot_bert_bfd/
ls -l /content/viral-protein-function-plm/Rostlab/prot_bert_bfd


## 4) Inputs and file layout
- **FASTA:** one or more protein sequences.
- **OUT_DIR:** where results go during embedding, then we move it up to `/content` after prediction.

Notes:
- Avoid parentheses in filenames.
- GPU helps during embedding, but CPU works too (just slower).
- For speed, consider a small subset.


In [None]:
# @title Set your inputs (edit here)
import os, sys

FASTA_FILE = "viral_proteins_sample.faa"  # @param {type:"string"}
OUT_DIR    = "viral_proteins_sample_out"  # @param {type:"string"}

# Resolve paths exactly like the original
if FASTA_FILE == "test.faa":
    input_path = "/content/viral-protein-function-plm/test/test.faa"
else:
    input_path = f"/content/{FASTA_FILE}"
    if not os.path.isfile(input_path):
        raise FileNotFoundError(f"{input_path} not found in /content — upload it via the left Files panel.")

out_path = OUT_DIR if OUT_DIR.strip() else "output"

print("input_path:", input_path)
print("out_path  :", out_path)

# Export for bash cells (persist to the notebook process env)
%env INPUT_PATH=$input_path
%env OUT_PATH=$out_path


## 5) Embeddings (what we’re about to compute)
**Embeddings** turn each protein into a vector of numbers so that “similar” proteins (by sequence context and learned biology) end up **nearby** in a multidimensional space.

Pipeline we run next:
1. Tokenize proteins
     - The sequence of amino acids (letters) in a protein is converted into a list of tokens, so the model can read and analyze it.
2. Run ProtBert-BFD to produce **per-residue** representations.
     - Each amino acid token is passed through the ProtBert-BFD model, which produces a set of numbers (an embedding) that represents the unique information about that amino acid in its context within the whole protein. This gives a vector for every residue (amino acid) in the sequence).
3. Pool to a **per-protein** vector (e.g., mean across residues).
     - To summarize the whole protein, all the per-residue vectors are combined (for example, by averaging) to get one single vector representing the entire protein's properties. This is called "pooling"
4. Save a Python pickle: `{basename}_embeddings_dict.pkl` mapping sequence IDs to embedding vectors.
    - The resulting vectors (called embeddings) are saved in a Python .pkl file

> We can even pool per-protein vectors to get one single vector representing a whole genome (genome-language-models).

In [None]:
%%bash
set -euo pipefail
source /opt/conda/etc/profile.d/conda.sh
conda activate plm_vpf_embed

# Force offline + legacy endpoints (belt & suspenders)
export TRANSFORMERS_OFFLINE=1
export HUGGINGFACE_CO_URL="https://huggingface.co"
export HF_ENDPOINT="https://huggingface.co"

# Ensure util sets env vars inside Python (idempotent)
UTIL="/content/viral-protein-function-plm/scripts/protbert_bfd_embed_utils.py"
grep -q "HUGGINGFACE_CO_URL" "$UTIL" || python - <<'PY'
import re
p="/content/viral-protein-function-plm/scripts/protbert_bfd_embed_utils.py"
s=open(p,"r",encoding="utf-8").read()
inject="import os\nos.environ.setdefault('HUGGINGFACE_CO_URL','https://huggingface.co')\nos.environ.setdefault('HF_ENDPOINT','https://huggingface.co')\n"
m=re.search(r"^(from|import).*(?:\n(?:from|import).*)*", s, flags=re.MULTILINE)
s = (s[:m.end()] + "\n" + inject + s[m.end():]) if m else inject + s
open(p,"w",encoding="utf-8").write(s)
print("Patched:", p)
PY

# Make sure the local model folder exists (from Step 4)
test -f /content/viral-protein-function-plm/Rostlab/prot_bert_bfd/config.json

# Auto-detect GPU and pick num_gpus accordingly
NUM_GPUS=$(python - <<'PY'
try:
    import torch
    print(1 if torch.cuda.is_available() else 0)
except Exception:
    print(0)
PY
)
echo "Using num_gpus = $NUM_GPUS"

cd /content/viral-protein-function-plm
python scripts/embed_faa.py -faa "${INPUT_PATH:?Missing INPUT_PATH}" -out "${OUT_PATH:?Missing OUT_PATH}" --num_gpus "$NUM_GPUS"
echo "Embedding done → /content/viral-protein-function-plm/${OUT_PATH}"


In [None]:
# Peek into the embeddings we just created
import os, pickle, numpy as np

out_dir = os.environ.get("OUT_PATH", "environmental_genomes_subset_out")
repo_dir = "/content/viral-protein-function-plm"
pkl_path = os.path.join(repo_dir, out_dir, os.path.basename(os.environ.get("INPUT_PATH","")).split(".faa")[0] + "_embeddings_dict.pkl")

print("Embeddings file:", pkl_path)
assert os.path.exists(pkl_path), "Embeddings pickle not found. Did the embedding step finish?"

emb = pickle.load(open(pkl_path, "rb"))
print(f"Number of proteins embedded: {len(emb):,}")

# Take one example
first_id = next(iter(emb))
vec = emb[first_id]
arr = np.asarray(vec)

print("\nExample ID:", first_id[:80], "...")
print("Vector dtype/shape:", arr.dtype, arr.shape)
print("First 8 numbers:", np.round(arr.ravel()[:8], 4))

# Simple quality checks across all embeddings
lengths = np.array([np.asarray(v).shape[-1] for v in emb.values()])
print("\nAll vectors have the same length?", np.all(lengths == lengths[0]))
print("Embedding dimension (D):", int(lengths[0]))


### What you’re seeing
Each protein ID maps to a **D-dimensional vector** (often D≈**1024** for ProtBert-BFD). These numbers are not arbitrary, during pre-training, the model learned to place **functionally/contextually related** proteins closer in this space.

Because we pooled per-residue representations to one vector per protein, every vector has the **same length**, independent of sequence length, perfect for feeding into a classifier (and comparable to what happens when we align sequences).


## 6) From embeddings to functions (Classifier Multilayer Perceptron).
An **MLP (multilayer perceptron)** is a simple neural network classifier. It takes the embedding vector as input and passes it through 1–3 dense layers with non-linearities, ending in a **softmax** that outputs **probabilities** over categories.

What you’ll get next:
- `<basename>_predictions.csv` —> the **top prediction** for each protein.
- `<basename>_probabilities.csv` —> full probability vector for each protein (one column per category).
- `prediction_heatmap.png` — a quick visual summary.
- `<basename>_protbert_bfd.pkl` — serialized embeddings

> **Interpretation:** High top-1 probability + large margin to the second best class is generally more reliable than many near-ties.


In [None]:
%%bash
set -euo pipefail
source /opt/conda/etc/profile.d/conda.sh
conda activate plm_vpf_predict

# Reuse the variables from your inputs step; fall back to the common names if missing
INPUT_PATH="${INPUT_PATH:-/content/environmental_genomes.faa}"
OUT_PATH="${OUT_PATH:-environmental_genomes_out}"

cd /content/viral-protein-function-plm
python scripts/predict_function.py -faa "$INPUT_PATH" -out "$OUT_PATH" \
  --output_predictions --prediction_heatmap --output_embeddings

# Match the original notebook: move the results dir up one level
mv "$OUT_PATH" ../
cd /content

echo "Results at /content/$OUT_PATH:"
ls -l "/content/$OUT_PATH" | sed -n '1,200p'


In [None]:
# Inspect classifier outputs
import os, pandas as pd

out_dir = os.environ.get("OUT_PATH", "environmental_genomes_subset_out")
final_dir = f"/content/{out_dir}"

# The notebook moved OUT_DIR up one level after prediction
pred_csv = os.path.join(final_dir, os.path.basename(os.environ.get("INPUT_PATH","")).split(".faa")[0] + "_predictions.csv")
probs_csv = os.path.join(final_dir, os.path.basename(os.environ.get("INPUT_PATH","")).split(".faa")[0] + "_probabilities.csv")

print("Predictions file:", pred_csv, "| exists:", os.path.exists(pred_csv))
print("Probabilities file:", probs_csv, "| exists:", os.path.exists(probs_csv))

if os.path.exists(pred_csv):
    preds = pd.read_csv(pred_csv)
    display(preds.head(10))

if os.path.exists(probs_csv):
    probs = pd.read_csv(probs_csv, index_col=0)
    print("\nProbabilities shape:", probs.shape)
    # Show the top-5 most confident calls
    top_conf = probs.max(axis=1).sort_values(ascending=False).head(5)
    print("\nTop-5 most confident proteins:")
    display(pd.DataFrame({"max_prob": top_conf}).join(preds.set_index(preds.columns[0]), how="left"))
