<a href="https://colab.research.google.com/github/navidh86/perturbseq-10701/blob/master/Enformer_TF.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [39]:
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np




In [40]:

# =========================
# Load Enformer from TF Hub
# =========================
model_path = "https://tfhub.dev/deepmind/enformer/1"
print("Loading Enformer model...")
model = hub.load(model_path).model
print("Model loaded!")


Loading Enformer model...
Model loaded!


In [41]:

# =========================
# Generate a dummy DNA sequence
# =========================
SEQUENCE_LENGTH = 393216  # required input length
nucleotides = np.array(["A", "C", "G", "T"])
dummy_seq = "".join(np.random.choice(nucleotides, SEQUENCE_LENGTH))

# One-hot encode the sequence (A,C,G,T)
def one_hot_encode(sequence):
    mapping = {"A": 0, "C": 1, "G": 2, "T": 3}
    arr = np.zeros((len(sequence), 4), dtype=np.float32)
    for i, base in enumerate(sequence):
        if base in mapping:
            arr[i, mapping[base]] = 1.0
    return arr

sequence_one_hot = one_hot_encode(dummy_seq)[np.newaxis]  # (1, 393216, 4)


In [42]:

# =========================
# Run model forward pass
# =========================
outputs = model.predict_on_batch(sequence_one_hot)
# Keys include 'human' and 'mouse' heads
print(f"Output keys: {list(outputs.keys())}")



Output keys: ['human', 'mouse']


In [43]:
# =========================
# Extract embeddings / predictions
# =========================
human_output = outputs["human"][0]  # Tensor shape: (896, 5313)
print("Human output shape:", human_output.shape)

# Convert to NumPy first, then average across positions
human_output_np = human_output.numpy()
embedding_vector = human_output_np.mean(axis=0)  # average across bins

print("Embedding vector shape:", embedding_vector.shape)
print("First 10 values:", embedding_vector[:10])

Human output shape: (896, 5313)
Embedding vector shape: (5313,)
First 10 values: [0.07639317 0.07754659 0.14801235 0.06745777 0.05583245 0.05923509
 0.05110176 0.07673459 0.06258251 0.09807899]
