# Sequence Representation Visualization with ProkBERT

This guide outlines the steps to visualize sequence embeddings using ProkBERT,  focusing on the genomic features of ESKAPE pathogens with ProkBERT-mini.
The workflow:
1. **Model Loading**
2. **Dataset Preparation**
3. **Model Evaluation**
4. **Results Visualization**


### Setup and Installation

Before we start, let's ensure that all necessary libraries are installed for our project. This notebook uses packages, including `umap-learn` for dimensionality reduction and `seaborn` for visualization.


In [None]:
# umap, seaborn, HF datasets
!pip install umap-learn seaborn datasets

# Imports

import pandas as pd
import os
from transformers import AutoTokenizer, AutoModel, TrainingArguments, Trainer
from datasets import load_dataset
import seaborn as sns
import matplotlib.pyplot as plt
import umap



## Enabling and testing the GPU (if you are using google colab)

First, you'll need to enable GPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select GPU from the Hardware Accelerator drop-down


### Loading the model
In this step, we'll utilize the MINI pretrained model of ProkBERT, focusing on the base model to extract sequence embeddings. It's important to match the model with the appropriate tokenizer, especially when loading directly from Hugging Face to ensure compatibility with tokenization parameters.

**Embeddings:**

Embeddings are dense vector representations of data, in this case, genomic sequences, where similar sequences are closer in the vector space. This representation allows the model to capture the context and semantic meanings of sequences, facilitating more effective analysis and comparison. By extracting embeddings from the ProkBERT model, we can leverage these rich, contextually informed representations for various bioinformatics applications, such as clustering, similarity searches, or as features for downstream machine learning models.


In [None]:
# using the prokbert-mini model
model_name_path = 'neuralbioinfo/prokbert-mini-long'
tokenizer = AutoTokenizer.from_pretrained(model_name_path, trust_remote_code=True)
# We are going to use base, encoder model
model = AutoModel.from_pretrained(model_name_path, trust_remote_code=True)




### Dataset Preparation and Tokenization

This section demonstrates preparing a dataset for tokenization and model training. A subset of 1000 samples is selected from the Hugging Face dataset for quick prototyping. The dataset is shuffled to ensure randomness, and tokenization is applied using the ProkBERT tokenizer. 

Sequences are padded, truncated to a maximum length of 512 tokens, and processed efficiently using multiprocessing. This setup ensures the data is ready for use with the ProkBERT model.


In [None]:
dataset = load_dataset("neuralbioinfo/ESKAPE-genomic-features", split='ESKAPE')
dataset.shuffle()
dataset_sample = dataset.select(range(1000))

num_cores = os.cpu_count()


def tokenize_function(examples):
    return tokenizer(
        examples["segment"],  # Replace 'sequence' with the actual column name if different
        padding=True,
        truncation=True,
        max_length=512,  # Set the maximum sequence length if needed
        return_tensors="pt"
    )

# Apply tokenization
tokenized_dataset = dataset_sample.map(tokenize_function, batched=True, num_proc=num_cores)



### Generating Sequence Representations

In this section, we use the Trainer API to compute sequence embeddings efficiently. The dataset is processed through the model to extract the last hidden states of the final layer. These hidden states are aggregated using a mean pooling operation across the sequence length dimension, resulting in a single vector representation for each sequence.

The `TrainingArguments` define the evaluation settings, including batch size and output directories, while the `Trainer` simplifies the prediction process. This streamlined approach replaces manual batching and ensures compatibility with the dataset and model. The resulting representations can be used for downstream tasks like classification or visualization.



In [None]:
training_args = TrainingArguments(
    output_dir="./results",  # Output directory
    per_device_eval_batch_size=16,  # Batch size for evaluation
    remove_unused_columns=True,  # Ensure compatibility with input format
    logging_dir="./logs",  # Logging directory
    report_to="none",  # No reporting needed
)

# Set up the Trainer for prediction and evaluation
trainer = Trainer(
    model=model,  # Dummy model
    args=training_args,  # Evaluation arguments
)
predictions = trainer.predict(tokenized_dataset)
last_hidden_states = predictions.predictions[0]
representations = last_hidden_states.mean(axis=1)


### Visualizing Sequence Embeddings with UMAP

This section demonstrates visualizing high-dimensional sequence embeddings using Uniform Manifold Approximation and Projection (UMAP). UMAP reduces the dimensionality of the embeddings to 2D while preserving their structural relationships, making it easier to interpret patterns and clusters in the data.

**UMAP Parameters:**
- **`n_neighbors`**: Determines the balance between local and global data structure. Higher values prioritize global structure.
- **`min_dist`**: Controls the minimum spacing between points in the 2D space. Smaller values emphasize local details.
- **`random_state`**: Ensures reproducibility of the visualization.

After dimensionality reduction, the UMAP embeddings are added to a DataFrame for visualization. We use Seaborn's `FacetGrid` to create scatterplots categorized by features such as `strand` and `class_label`. This allows us to explore how the embeddings cluster based on these features, revealing potential patterns and relationships within the dataset.

The visualization process provides an intuitive understanding of the model's learned representations and their alignment with biological features.


In [None]:
#predictions.last_hidden_state
umap_random_state = 42
n_neighbors=20
min_dist = 0.4
reducer = umap.UMAP(random_state=umap_random_state, n_neighbors=n_neighbors, min_dist=min_dist)
print('Running UMAP ....')
umap_embeddings = reducer.fit_transform(representations)

dataset_df = dataset_sample.to_pandas()
dataset_df['umap_1'] = umap_embeddings[:, 0]
dataset_df['umap_2'] = umap_embeddings[:, 1]

g = sns.FacetGrid(dataset_df, col="strand", hue="class_label", palette="Set1", height=6)
# Apply a scatterplot to each subplot
g.map(sns.scatterplot, "umap_1", "umap_2")
# Add a legend
g.add_legend()
