<a href="https://colab.research.google.com/github/ramanathanlab/genslm/blob/main/examples/embedding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# NOTE: You may need to run this twice due to a pip dependency conflict
!pip install git+https://github.com/ramanathanlab/genslm

In [None]:
from google.colab import drive

drive.mount("/content/gdrive")

In [None]:
!ls gdrive/MyDrive/patric_25m_epoch01-val_loss_0.57_bias_removed.pt
# This currently requires you to download the 25M model weights from Globus

In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from genslm import GenSLM, SequenceDataset

In [None]:
# Load model
#model = GenSLM("genslm_25M_patric", model_cache_dir="/content/gdrive/MyDrive")
model = GenSLM("genslm_2.5B_patric", model_cache_dir="/home/xlian/genslm_models/2.5B")
model.eval()

# Select GPU device if it is available, else use CPU
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Input data is a list of gene sequences
sequences = [
    "ATGAAAGTAACCGTTGTTGGAGCAGGTGCAGTTGGTGCAAGTTGCGCAGAATATATTGCA",
    "ATTAAAGATTTCGCATCTGAAGTTGTTTTGTTAGACATTAAAGAAGGTTATGCCGAAGGT",
]

dataset = SequenceDataset(sequences, model.seq_length, model.tokenizer)
dataloader = DataLoader(dataset)

# Compute averaged-embeddings for each input sequence
embeddings = []
with torch.no_grad():
    for batch in dataloader:
        outputs = model(
            batch["input_ids"].to(device),
            batch["attention_mask"].to(device),
            output_hidden_states=True,
        )
        # outputs.hidden_states shape: (layers, batch_size, sequence_length, hidden_size)
        # Use the embeddings of the last layer
        emb = outputs.hidden_states[-1].detach().cpu().numpy()
        # Compute average over sequence length
        emb = np.mean(emb, axis=1)
        embeddings.append(emb)

# Concatenate embeddings into an array of shape (num_sequences, hidden_size)
embeddings = np.concatenate(embeddings)
embeddings.shape

In [None]:
num_gpus = torch.cuda.device_count()
print(f"Number of GPUs available: {num_gpus}")

In [None]:
# NOTE: This is not the best performance you can get. For a scalable implementation,
# refer to genslm.cmdline.run_inference for an example of how to utilize multiple
# GPUs for parallel inference.